From 7423b6635cca8cc287d18ccefff41e7a57f9064d Mon Sep 17 00:00:00 2001 From: chenxiangtong Date: Wed, 15 Apr 2026 16:37:15 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=B8=8B=E6=94=B9=E4=BC=9A?= =?UTF-8?q?=E8=AF=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 37 +++++++++------------------ src/services/subscription.py | 49 ++++++++++++++++++++---------------- 2 files changed, 39 insertions(+), 47 deletions(-) diff --git a/main.py b/main.py index 3caba08..99f39bb 100644 --- a/main.py +++ b/main.py @@ -98,6 +98,7 @@ class BangumiPlugin(Star): service=self.service, config_manager=self.config_manager, session=self.session, + context=self.context, ) # 5. 添加定时更新任务 @@ -117,11 +118,8 @@ class BangumiPlugin(Star): # --- 命令处理区 --- @staticmethod - def _resolve_session_key(event: AstrMessageEvent) -> str | None: - session_key: str | None = getattr(event, "session_id", None) - if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"): - session_key = event.message_obj.group_id - return session_key + def _resolve_session_key(event: AstrMessageEvent) -> str: + return event.unified_msg_origin @staticmethod def _parse_subscribe_selection(raw_text: str) -> int | None: @@ -203,10 +201,7 @@ class BangumiPlugin(Star): yield event.plain_result("❌ 订阅服务未就绪") return - group_id = self._resolve_session_key(event) - if not group_id: - yield event.plain_result("❌ 无法获取群组ID") - return + session_key = self._resolve_session_key(event) ( error_msg, @@ -224,7 +219,7 @@ class BangumiPlugin(Star): if len(candidates) == 1: result = await self.subscription_service.subscribe_by_subject_id( - group_id=group_id, + session_id=session_key, subject_id=candidates[0]["subject_id"], ) yield event.plain_result(result) @@ -246,12 +241,11 @@ class BangumiPlugin(Star): "today", "弃坑", } - session_key = group_id + session_key = self._resolve_session_key(event) - class GroupSessionFilter(SessionFilter): + class ConversationSessionFilter(SessionFilter): def filter(self, wait_event: AstrMessageEvent) -> str: - wait_session_key = BangumiPlugin._resolve_session_key(wait_event) - return wait_session_key or wait_event.unified_msg_origin + return BangumiPlugin._resolve_session_key(wait_event) @session_waiter(timeout=300) async def subscribe_confirm_waiter( @@ -292,7 +286,7 @@ class BangumiPlugin(Star): selected = candidates[selected_index - 1] result = await self.subscription_service.subscribe_by_subject_id( - group_id=session_key, + session_id=session_key, subject_id=selected["subject_id"], ) await wait_event.send(MessageChain([Comp.Plain(result)])) @@ -302,7 +296,7 @@ class BangumiPlugin(Star): try: await subscribe_confirm_waiter( event, - session_filter=GroupSessionFilter(), + session_filter=ConversationSessionFilter(), ) except TimeoutError: yield event.plain_result("⏰ 候选确认已过期,请重新使用 `/追番 关键词`。") @@ -316,15 +310,8 @@ class BangumiPlugin(Star): yield event.plain_result("❌ 订阅服务未就绪") return - group_id: str | None = getattr(event, "session_id", None) - if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"): - group_id = event.message_obj.group_id - - if not group_id: - yield event.plain_result("❌ 无法获取群组ID") - return - - result = await self.subscription_service.unsubscribe(group_id, query) + session_key = self._resolve_session_key(event) + result = await self.subscription_service.unsubscribe(session_key, query) yield event.plain_result(result) async def terminate(self) -> None: diff --git a/src/services/subscription.py b/src/services/subscription.py index d86e314..9949314 100644 --- a/src/services/subscription.py +++ b/src/services/subscription.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, cast import aiohttp from astrbot.api import logger -from astrbot.api.star import StarTools +from astrbot.api.star import Context, StarTools from astrbot.core.message.message_event_result import MessageChain from ..config import ConfigManager @@ -24,10 +24,12 @@ class SubscriptionService: service: "BangumiService", config_manager: ConfigManager, session: aiohttp.ClientSession | None = None, + context: Context | None = None, ) -> None: self.storage = repository self.service = service self.config_manager = config_manager + self.context = context self.renderer = EpisodeRenderer(session=session) async def get_subscribe_candidates( @@ -126,7 +128,7 @@ class SubscriptionService: return "🔍 未找到相关番剧", None return await self._build_subscribable_subject(candidates[0]["subject_id"]) - async def subscribe_by_subject_id(self, group_id: str, subject_id: str) -> str: + async def subscribe_by_subject_id(self, session_id: str, subject_id: str) -> str: """ 基于明确 subject_id 完成订阅。 """ @@ -138,7 +140,7 @@ class SubscriptionService: return "❌ 未知错误:未能获取番剧信息" success = self.storage.subscribe_subject( - group_id=group_id, + group_id=session_id, subject_id=subject_info["subject_id"], name=subject_info["name"], air_date=subject_info["air_date"], @@ -146,18 +148,18 @@ class SubscriptionService: ) if success: return ( - f"✅ 成功订阅《{subject_info['name']}》!\n如有更新将推送到本群。" + f"✅ 成功订阅《{subject_info['name']}》!\n如有更新将推送到本会话。" ) return "❌ 订阅失败,数据库错误。" except (BangumiApiError, DatabaseError, SubscriptionError) as e: logger.error(f"SubscriptionService.subscribe_by_subject_id 失败: {e}") return f"❌ 处理失败: {e}" - async def subscribe(self, group_id: str, query: str) -> str: + async def subscribe(self, session_id: str, query: str) -> str: """ 处理订阅逻辑:匹配条目 -> 存入数据库 -> 建立订阅关系。 """ - logger.info(f"处理追番请求: {query}, group_id={group_id}") + logger.info(f"处理追番请求: {query}, session_id={session_id}") try: # 1. 匹配条目 (调用内部迁移后的逻辑) error_msg, subject_info = await self._match_subscribable_subject(query) @@ -171,27 +173,27 @@ class SubscriptionService: # 2 & 3. 原子性地写入条目信息并建立订阅关系 success = self.storage.subscribe_subject( - group_id=group_id, + group_id=session_id, subject_id=subject_id, name=name, air_date=subject_info["air_date"], total_episodes=subject_info["total_episodes"], ) if success: - return f"✅ 成功订阅《{name}》!\n如有更新将推送到本群。" + return f"✅ 成功订阅《{name}》!\n如有更新将推送到本会话。" else: return "❌ 订阅失败,数据库错误。" except (BangumiApiError, DatabaseError, SubscriptionError) as e: logger.error(f"SubscriptionService.subscribe 失败: {e}") return f"❌ 处理失败: {e}" - async def unsubscribe(self, group_id: str, query: str) -> str: + async def unsubscribe(self, session_id: str, query: str) -> str: """ 取消订阅逻辑。 """ - logger.info(f"处理取消追番请求: {query}, group_id={group_id}") + logger.info(f"处理取消追番请求: {query}, session_id={session_id}") try: - error_msg, subject_info = self._match_local_subscription(group_id, query) + error_msg, subject_info = self._match_local_subscription(session_id, query) if error_msg: return error_msg if not subject_info: @@ -200,7 +202,7 @@ class SubscriptionService: subject_id = subject_info["subject_id"] name = subject_info["name"] - success = self.storage.remove_subscription(group_id, subject_id) + success = self.storage.remove_subscription(session_id, subject_id) if success: return f"✅ 已成功取消订阅《{name}》。" else: @@ -210,10 +212,10 @@ class SubscriptionService: return f"❌ 处理失败: {e}" def _match_local_subscription( - self, group_id: str, query: str + self, session_id: str, query: str ) -> tuple[str | None, UnsubscribeMatch | None]: """ - 在当前群组的本地订阅中做模糊匹配。 + 在当前会话的本地订阅中做模糊匹配。 """ normalized_query = str(query).strip() if not normalized_query: @@ -221,10 +223,10 @@ class SubscriptionService: # 取 6 条用于判断是否超过默认展示上限(5 条) candidates = self.storage.find_group_subscription_candidates( - group_id=group_id, keyword=normalized_query, limit=6 + group_id=session_id, keyword=normalized_query, limit=6 ) if not candidates: - return f"❌ 未找到与「{normalized_query}」匹配的本群订阅番剧。", None + return f"❌ 未找到与「{normalized_query}」匹配的本会话订阅番剧。", None if len(candidates) == 1: subject = candidates[0] @@ -318,13 +320,16 @@ class SubscriptionService: f"🔔 番剧《{subject_name}》更新啦!\n第 {episode.ep} 集:{episode.name_cn or episode.name}" ) - for group_id in subscribed_groups: + for session_id in subscribed_groups: try: - await StarTools.send_message_by_id( - type="GroupMessage", id=group_id, message_chain=chain - ) - logger.info(f"向群组 {group_id} 发送《{subject_name}》更新通知成功。") + if self.context: + await self.context.send_message(session_id, chain) + else: + await StarTools.send_message_by_id( + type="GroupMessage", id=session_id, message_chain=chain + ) + logger.info(f"向会话 {session_id} 发送《{subject_name}》更新通知成功。") except Exception as e: logger.error( - f"向群组 {group_id} 发送《{subject_name}》更新通知失败: {e}" + f"向会话 {session_id} 发送《{subject_name}》更新通知失败: {e}" )