测试下改会话
This commit is contained in:
37
main.py
37
main.py
@@ -98,6 +98,7 @@ class BangumiPlugin(Star):
|
|||||||
service=self.service,
|
service=self.service,
|
||||||
config_manager=self.config_manager,
|
config_manager=self.config_manager,
|
||||||
session=self.session,
|
session=self.session,
|
||||||
|
context=self.context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. 添加定时更新任务
|
# 5. 添加定时更新任务
|
||||||
@@ -117,11 +118,8 @@ class BangumiPlugin(Star):
|
|||||||
# --- 命令处理区 ---
|
# --- 命令处理区 ---
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_session_key(event: AstrMessageEvent) -> str | None:
|
def _resolve_session_key(event: AstrMessageEvent) -> str:
|
||||||
session_key: str | None = getattr(event, "session_id", None)
|
return event.unified_msg_origin
|
||||||
if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"):
|
|
||||||
session_key = event.message_obj.group_id
|
|
||||||
return session_key
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_subscribe_selection(raw_text: str) -> int | None:
|
def _parse_subscribe_selection(raw_text: str) -> int | None:
|
||||||
@@ -203,10 +201,7 @@ class BangumiPlugin(Star):
|
|||||||
yield event.plain_result("❌ 订阅服务未就绪")
|
yield event.plain_result("❌ 订阅服务未就绪")
|
||||||
return
|
return
|
||||||
|
|
||||||
group_id = self._resolve_session_key(event)
|
session_key = self._resolve_session_key(event)
|
||||||
if not group_id:
|
|
||||||
yield event.plain_result("❌ 无法获取群组ID")
|
|
||||||
return
|
|
||||||
|
|
||||||
(
|
(
|
||||||
error_msg,
|
error_msg,
|
||||||
@@ -224,7 +219,7 @@ class BangumiPlugin(Star):
|
|||||||
|
|
||||||
if len(candidates) == 1:
|
if len(candidates) == 1:
|
||||||
result = await self.subscription_service.subscribe_by_subject_id(
|
result = await self.subscription_service.subscribe_by_subject_id(
|
||||||
group_id=group_id,
|
session_id=session_key,
|
||||||
subject_id=candidates[0]["subject_id"],
|
subject_id=candidates[0]["subject_id"],
|
||||||
)
|
)
|
||||||
yield event.plain_result(result)
|
yield event.plain_result(result)
|
||||||
@@ -246,12 +241,11 @@ class BangumiPlugin(Star):
|
|||||||
"today",
|
"today",
|
||||||
"弃坑",
|
"弃坑",
|
||||||
}
|
}
|
||||||
session_key = group_id
|
session_key = self._resolve_session_key(event)
|
||||||
|
|
||||||
class GroupSessionFilter(SessionFilter):
|
class ConversationSessionFilter(SessionFilter):
|
||||||
def filter(self, wait_event: AstrMessageEvent) -> str:
|
def filter(self, wait_event: AstrMessageEvent) -> str:
|
||||||
wait_session_key = BangumiPlugin._resolve_session_key(wait_event)
|
return BangumiPlugin._resolve_session_key(wait_event)
|
||||||
return wait_session_key or wait_event.unified_msg_origin
|
|
||||||
|
|
||||||
@session_waiter(timeout=300)
|
@session_waiter(timeout=300)
|
||||||
async def subscribe_confirm_waiter(
|
async def subscribe_confirm_waiter(
|
||||||
@@ -292,7 +286,7 @@ class BangumiPlugin(Star):
|
|||||||
|
|
||||||
selected = candidates[selected_index - 1]
|
selected = candidates[selected_index - 1]
|
||||||
result = await self.subscription_service.subscribe_by_subject_id(
|
result = await self.subscription_service.subscribe_by_subject_id(
|
||||||
group_id=session_key,
|
session_id=session_key,
|
||||||
subject_id=selected["subject_id"],
|
subject_id=selected["subject_id"],
|
||||||
)
|
)
|
||||||
await wait_event.send(MessageChain([Comp.Plain(result)]))
|
await wait_event.send(MessageChain([Comp.Plain(result)]))
|
||||||
@@ -302,7 +296,7 @@ class BangumiPlugin(Star):
|
|||||||
try:
|
try:
|
||||||
await subscribe_confirm_waiter(
|
await subscribe_confirm_waiter(
|
||||||
event,
|
event,
|
||||||
session_filter=GroupSessionFilter(),
|
session_filter=ConversationSessionFilter(),
|
||||||
)
|
)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
yield event.plain_result("⏰ 候选确认已过期,请重新使用 `/追番 关键词`。")
|
yield event.plain_result("⏰ 候选确认已过期,请重新使用 `/追番 关键词`。")
|
||||||
@@ -316,15 +310,8 @@ class BangumiPlugin(Star):
|
|||||||
yield event.plain_result("❌ 订阅服务未就绪")
|
yield event.plain_result("❌ 订阅服务未就绪")
|
||||||
return
|
return
|
||||||
|
|
||||||
group_id: str | None = getattr(event, "session_id", None)
|
session_key = self._resolve_session_key(event)
|
||||||
if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"):
|
result = await self.subscription_service.unsubscribe(session_key, query)
|
||||||
group_id = event.message_obj.group_id
|
|
||||||
|
|
||||||
if not group_id:
|
|
||||||
yield event.plain_result("❌ 无法获取群组ID")
|
|
||||||
return
|
|
||||||
|
|
||||||
result = await self.subscription_service.unsubscribe(group_id, query)
|
|
||||||
yield event.plain_result(result)
|
yield event.plain_result(result)
|
||||||
|
|
||||||
async def terminate(self) -> None:
|
async def terminate(self) -> None:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, cast
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.star import StarTools
|
from astrbot.api.star import Context, StarTools
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
|
||||||
from ..config import ConfigManager
|
from ..config import ConfigManager
|
||||||
@@ -24,10 +24,12 @@ class SubscriptionService:
|
|||||||
service: "BangumiService",
|
service: "BangumiService",
|
||||||
config_manager: ConfigManager,
|
config_manager: ConfigManager,
|
||||||
session: aiohttp.ClientSession | None = None,
|
session: aiohttp.ClientSession | None = None,
|
||||||
|
context: Context | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.storage = repository
|
self.storage = repository
|
||||||
self.service = service
|
self.service = service
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
|
self.context = context
|
||||||
self.renderer = EpisodeRenderer(session=session)
|
self.renderer = EpisodeRenderer(session=session)
|
||||||
|
|
||||||
async def get_subscribe_candidates(
|
async def get_subscribe_candidates(
|
||||||
@@ -126,7 +128,7 @@ class SubscriptionService:
|
|||||||
return "🔍 未找到相关番剧", None
|
return "🔍 未找到相关番剧", None
|
||||||
return await self._build_subscribable_subject(candidates[0]["subject_id"])
|
return await self._build_subscribable_subject(candidates[0]["subject_id"])
|
||||||
|
|
||||||
async def subscribe_by_subject_id(self, group_id: str, subject_id: str) -> str:
|
async def subscribe_by_subject_id(self, session_id: str, subject_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
基于明确 subject_id 完成订阅。
|
基于明确 subject_id 完成订阅。
|
||||||
"""
|
"""
|
||||||
@@ -138,7 +140,7 @@ class SubscriptionService:
|
|||||||
return "❌ 未知错误:未能获取番剧信息"
|
return "❌ 未知错误:未能获取番剧信息"
|
||||||
|
|
||||||
success = self.storage.subscribe_subject(
|
success = self.storage.subscribe_subject(
|
||||||
group_id=group_id,
|
group_id=session_id,
|
||||||
subject_id=subject_info["subject_id"],
|
subject_id=subject_info["subject_id"],
|
||||||
name=subject_info["name"],
|
name=subject_info["name"],
|
||||||
air_date=subject_info["air_date"],
|
air_date=subject_info["air_date"],
|
||||||
@@ -146,18 +148,18 @@ class SubscriptionService:
|
|||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
return (
|
return (
|
||||||
f"✅ 成功订阅《{subject_info['name']}》!\n如有更新将推送到本群。"
|
f"✅ 成功订阅《{subject_info['name']}》!\n如有更新将推送到本会话。"
|
||||||
)
|
)
|
||||||
return "❌ 订阅失败,数据库错误。"
|
return "❌ 订阅失败,数据库错误。"
|
||||||
except (BangumiApiError, DatabaseError, SubscriptionError) as e:
|
except (BangumiApiError, DatabaseError, SubscriptionError) as e:
|
||||||
logger.error(f"SubscriptionService.subscribe_by_subject_id 失败: {e}")
|
logger.error(f"SubscriptionService.subscribe_by_subject_id 失败: {e}")
|
||||||
return f"❌ 处理失败: {e}"
|
return f"❌ 处理失败: {e}"
|
||||||
|
|
||||||
async def subscribe(self, group_id: str, query: str) -> str:
|
async def subscribe(self, session_id: str, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
处理订阅逻辑:匹配条目 -> 存入数据库 -> 建立订阅关系。
|
处理订阅逻辑:匹配条目 -> 存入数据库 -> 建立订阅关系。
|
||||||
"""
|
"""
|
||||||
logger.info(f"处理追番请求: {query}, group_id={group_id}")
|
logger.info(f"处理追番请求: {query}, session_id={session_id}")
|
||||||
try:
|
try:
|
||||||
# 1. 匹配条目 (调用内部迁移后的逻辑)
|
# 1. 匹配条目 (调用内部迁移后的逻辑)
|
||||||
error_msg, subject_info = await self._match_subscribable_subject(query)
|
error_msg, subject_info = await self._match_subscribable_subject(query)
|
||||||
@@ -171,27 +173,27 @@ class SubscriptionService:
|
|||||||
|
|
||||||
# 2 & 3. 原子性地写入条目信息并建立订阅关系
|
# 2 & 3. 原子性地写入条目信息并建立订阅关系
|
||||||
success = self.storage.subscribe_subject(
|
success = self.storage.subscribe_subject(
|
||||||
group_id=group_id,
|
group_id=session_id,
|
||||||
subject_id=subject_id,
|
subject_id=subject_id,
|
||||||
name=name,
|
name=name,
|
||||||
air_date=subject_info["air_date"],
|
air_date=subject_info["air_date"],
|
||||||
total_episodes=subject_info["total_episodes"],
|
total_episodes=subject_info["total_episodes"],
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
return f"✅ 成功订阅《{name}》!\n如有更新将推送到本群。"
|
return f"✅ 成功订阅《{name}》!\n如有更新将推送到本会话。"
|
||||||
else:
|
else:
|
||||||
return "❌ 订阅失败,数据库错误。"
|
return "❌ 订阅失败,数据库错误。"
|
||||||
except (BangumiApiError, DatabaseError, SubscriptionError) as e:
|
except (BangumiApiError, DatabaseError, SubscriptionError) as e:
|
||||||
logger.error(f"SubscriptionService.subscribe 失败: {e}")
|
logger.error(f"SubscriptionService.subscribe 失败: {e}")
|
||||||
return f"❌ 处理失败: {e}"
|
return f"❌ 处理失败: {e}"
|
||||||
|
|
||||||
async def unsubscribe(self, group_id: str, query: str) -> str:
|
async def unsubscribe(self, session_id: str, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
取消订阅逻辑。
|
取消订阅逻辑。
|
||||||
"""
|
"""
|
||||||
logger.info(f"处理取消追番请求: {query}, group_id={group_id}")
|
logger.info(f"处理取消追番请求: {query}, session_id={session_id}")
|
||||||
try:
|
try:
|
||||||
error_msg, subject_info = self._match_local_subscription(group_id, query)
|
error_msg, subject_info = self._match_local_subscription(session_id, query)
|
||||||
if error_msg:
|
if error_msg:
|
||||||
return error_msg
|
return error_msg
|
||||||
if not subject_info:
|
if not subject_info:
|
||||||
@@ -200,7 +202,7 @@ class SubscriptionService:
|
|||||||
subject_id = subject_info["subject_id"]
|
subject_id = subject_info["subject_id"]
|
||||||
name = subject_info["name"]
|
name = subject_info["name"]
|
||||||
|
|
||||||
success = self.storage.remove_subscription(group_id, subject_id)
|
success = self.storage.remove_subscription(session_id, subject_id)
|
||||||
if success:
|
if success:
|
||||||
return f"✅ 已成功取消订阅《{name}》。"
|
return f"✅ 已成功取消订阅《{name}》。"
|
||||||
else:
|
else:
|
||||||
@@ -210,10 +212,10 @@ class SubscriptionService:
|
|||||||
return f"❌ 处理失败: {e}"
|
return f"❌ 处理失败: {e}"
|
||||||
|
|
||||||
def _match_local_subscription(
|
def _match_local_subscription(
|
||||||
self, group_id: str, query: str
|
self, session_id: str, query: str
|
||||||
) -> tuple[str | None, UnsubscribeMatch | None]:
|
) -> tuple[str | None, UnsubscribeMatch | None]:
|
||||||
"""
|
"""
|
||||||
在当前群组的本地订阅中做模糊匹配。
|
在当前会话的本地订阅中做模糊匹配。
|
||||||
"""
|
"""
|
||||||
normalized_query = str(query).strip()
|
normalized_query = str(query).strip()
|
||||||
if not normalized_query:
|
if not normalized_query:
|
||||||
@@ -221,10 +223,10 @@ class SubscriptionService:
|
|||||||
|
|
||||||
# 取 6 条用于判断是否超过默认展示上限(5 条)
|
# 取 6 条用于判断是否超过默认展示上限(5 条)
|
||||||
candidates = self.storage.find_group_subscription_candidates(
|
candidates = self.storage.find_group_subscription_candidates(
|
||||||
group_id=group_id, keyword=normalized_query, limit=6
|
group_id=session_id, keyword=normalized_query, limit=6
|
||||||
)
|
)
|
||||||
if not candidates:
|
if not candidates:
|
||||||
return f"❌ 未找到与「{normalized_query}」匹配的本群订阅番剧。", None
|
return f"❌ 未找到与「{normalized_query}」匹配的本会话订阅番剧。", None
|
||||||
|
|
||||||
if len(candidates) == 1:
|
if len(candidates) == 1:
|
||||||
subject = candidates[0]
|
subject = candidates[0]
|
||||||
@@ -318,13 +320,16 @@ class SubscriptionService:
|
|||||||
f"🔔 番剧《{subject_name}》更新啦!\n第 {episode.ep} 集:{episode.name_cn or episode.name}"
|
f"🔔 番剧《{subject_name}》更新啦!\n第 {episode.ep} 集:{episode.name_cn or episode.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for group_id in subscribed_groups:
|
for session_id in subscribed_groups:
|
||||||
try:
|
try:
|
||||||
await StarTools.send_message_by_id(
|
if self.context:
|
||||||
type="GroupMessage", id=group_id, message_chain=chain
|
await self.context.send_message(session_id, chain)
|
||||||
)
|
else:
|
||||||
logger.info(f"向群组 {group_id} 发送《{subject_name}》更新通知成功。")
|
await StarTools.send_message_by_id(
|
||||||
|
type="GroupMessage", id=session_id, message_chain=chain
|
||||||
|
)
|
||||||
|
logger.info(f"向会话 {session_id} 发送《{subject_name}》更新通知成功。")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"向群组 {group_id} 发送《{subject_name}》更新通知失败: {e}"
|
f"向会话 {session_id} 发送《{subject_name}》更新通知失败: {e}"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user