测试下改会话

This commit is contained in:
chenxiangtong
2026-04-15 16:37:15 +08:00
parent 2e0b9336b0
commit 7423b6635c
2 changed files with 39 additions and 47 deletions

37
main.py
View File

@@ -98,6 +98,7 @@ class BangumiPlugin(Star):
service=self.service, service=self.service,
config_manager=self.config_manager, config_manager=self.config_manager,
session=self.session, session=self.session,
context=self.context,
) )
# 5. 添加定时更新任务 # 5. 添加定时更新任务
@@ -117,11 +118,8 @@ class BangumiPlugin(Star):
# --- 命令处理区 --- # --- 命令处理区 ---
@staticmethod @staticmethod
def _resolve_session_key(event: AstrMessageEvent) -> str | None: def _resolve_session_key(event: AstrMessageEvent) -> str:
session_key: str | None = getattr(event, "session_id", None) return event.unified_msg_origin
if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"):
session_key = event.message_obj.group_id
return session_key
@staticmethod @staticmethod
def _parse_subscribe_selection(raw_text: str) -> int | None: def _parse_subscribe_selection(raw_text: str) -> int | None:
@@ -203,10 +201,7 @@ class BangumiPlugin(Star):
yield event.plain_result("❌ 订阅服务未就绪") yield event.plain_result("❌ 订阅服务未就绪")
return return
group_id = self._resolve_session_key(event) session_key = self._resolve_session_key(event)
if not group_id:
yield event.plain_result("❌ 无法获取群组ID")
return
( (
error_msg, error_msg,
@@ -224,7 +219,7 @@ class BangumiPlugin(Star):
if len(candidates) == 1: if len(candidates) == 1:
result = await self.subscription_service.subscribe_by_subject_id( result = await self.subscription_service.subscribe_by_subject_id(
group_id=group_id, session_id=session_key,
subject_id=candidates[0]["subject_id"], subject_id=candidates[0]["subject_id"],
) )
yield event.plain_result(result) yield event.plain_result(result)
@@ -246,12 +241,11 @@ class BangumiPlugin(Star):
"today", "today",
"弃坑", "弃坑",
} }
session_key = group_id session_key = self._resolve_session_key(event)
class GroupSessionFilter(SessionFilter): class ConversationSessionFilter(SessionFilter):
def filter(self, wait_event: AstrMessageEvent) -> str: def filter(self, wait_event: AstrMessageEvent) -> str:
wait_session_key = BangumiPlugin._resolve_session_key(wait_event) return BangumiPlugin._resolve_session_key(wait_event)
return wait_session_key or wait_event.unified_msg_origin
@session_waiter(timeout=300) @session_waiter(timeout=300)
async def subscribe_confirm_waiter( async def subscribe_confirm_waiter(
@@ -292,7 +286,7 @@ class BangumiPlugin(Star):
selected = candidates[selected_index - 1] selected = candidates[selected_index - 1]
result = await self.subscription_service.subscribe_by_subject_id( result = await self.subscription_service.subscribe_by_subject_id(
group_id=session_key, session_id=session_key,
subject_id=selected["subject_id"], subject_id=selected["subject_id"],
) )
await wait_event.send(MessageChain([Comp.Plain(result)])) await wait_event.send(MessageChain([Comp.Plain(result)]))
@@ -302,7 +296,7 @@ class BangumiPlugin(Star):
try: try:
await subscribe_confirm_waiter( await subscribe_confirm_waiter(
event, event,
session_filter=GroupSessionFilter(), session_filter=ConversationSessionFilter(),
) )
except TimeoutError: except TimeoutError:
yield event.plain_result("⏰ 候选确认已过期,请重新使用 `/追番 关键词`。") yield event.plain_result("⏰ 候选确认已过期,请重新使用 `/追番 关键词`。")
@@ -316,15 +310,8 @@ class BangumiPlugin(Star):
yield event.plain_result("❌ 订阅服务未就绪") yield event.plain_result("❌ 订阅服务未就绪")
return return
group_id: str | None = getattr(event, "session_id", None) session_key = self._resolve_session_key(event)
if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"): result = await self.subscription_service.unsubscribe(session_key, query)
group_id = event.message_obj.group_id
if not group_id:
yield event.plain_result("❌ 无法获取群组ID")
return
result = await self.subscription_service.unsubscribe(group_id, query)
yield event.plain_result(result) yield event.plain_result(result)
async def terminate(self) -> None: async def terminate(self) -> None:

View File

@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, cast
import aiohttp import aiohttp
from astrbot.api import logger from astrbot.api import logger
from astrbot.api.star import StarTools from astrbot.api.star import Context, StarTools
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from ..config import ConfigManager from ..config import ConfigManager
@@ -24,10 +24,12 @@ class SubscriptionService:
service: "BangumiService", service: "BangumiService",
config_manager: ConfigManager, config_manager: ConfigManager,
session: aiohttp.ClientSession | None = None, session: aiohttp.ClientSession | None = None,
context: Context | None = None,
) -> None: ) -> None:
self.storage = repository self.storage = repository
self.service = service self.service = service
self.config_manager = config_manager self.config_manager = config_manager
self.context = context
self.renderer = EpisodeRenderer(session=session) self.renderer = EpisodeRenderer(session=session)
async def get_subscribe_candidates( async def get_subscribe_candidates(
@@ -126,7 +128,7 @@ class SubscriptionService:
return "🔍 未找到相关番剧", None return "🔍 未找到相关番剧", None
return await self._build_subscribable_subject(candidates[0]["subject_id"]) return await self._build_subscribable_subject(candidates[0]["subject_id"])
async def subscribe_by_subject_id(self, group_id: str, subject_id: str) -> str: async def subscribe_by_subject_id(self, session_id: str, subject_id: str) -> str:
""" """
基于明确 subject_id 完成订阅。 基于明确 subject_id 完成订阅。
""" """
@@ -138,7 +140,7 @@ class SubscriptionService:
return "❌ 未知错误:未能获取番剧信息" return "❌ 未知错误:未能获取番剧信息"
success = self.storage.subscribe_subject( success = self.storage.subscribe_subject(
group_id=group_id, group_id=session_id,
subject_id=subject_info["subject_id"], subject_id=subject_info["subject_id"],
name=subject_info["name"], name=subject_info["name"],
air_date=subject_info["air_date"], air_date=subject_info["air_date"],
@@ -146,18 +148,18 @@ class SubscriptionService:
) )
if success: if success:
return ( return (
f"✅ 成功订阅《{subject_info['name']}》!\n如有更新将推送到本" f"✅ 成功订阅《{subject_info['name']}》!\n如有更新将推送到本会话"
) )
return "❌ 订阅失败,数据库错误。" return "❌ 订阅失败,数据库错误。"
except (BangumiApiError, DatabaseError, SubscriptionError) as e: except (BangumiApiError, DatabaseError, SubscriptionError) as e:
logger.error(f"SubscriptionService.subscribe_by_subject_id 失败: {e}") logger.error(f"SubscriptionService.subscribe_by_subject_id 失败: {e}")
return f"❌ 处理失败: {e}" return f"❌ 处理失败: {e}"
async def subscribe(self, group_id: str, query: str) -> str: async def subscribe(self, session_id: str, query: str) -> str:
""" """
处理订阅逻辑:匹配条目 -> 存入数据库 -> 建立订阅关系。 处理订阅逻辑:匹配条目 -> 存入数据库 -> 建立订阅关系。
""" """
logger.info(f"处理追番请求: {query}, group_id={group_id}") logger.info(f"处理追番请求: {query}, session_id={session_id}")
try: try:
# 1. 匹配条目 (调用内部迁移后的逻辑) # 1. 匹配条目 (调用内部迁移后的逻辑)
error_msg, subject_info = await self._match_subscribable_subject(query) error_msg, subject_info = await self._match_subscribable_subject(query)
@@ -171,27 +173,27 @@ class SubscriptionService:
# 2 & 3. 原子性地写入条目信息并建立订阅关系 # 2 & 3. 原子性地写入条目信息并建立订阅关系
success = self.storage.subscribe_subject( success = self.storage.subscribe_subject(
group_id=group_id, group_id=session_id,
subject_id=subject_id, subject_id=subject_id,
name=name, name=name,
air_date=subject_info["air_date"], air_date=subject_info["air_date"],
total_episodes=subject_info["total_episodes"], total_episodes=subject_info["total_episodes"],
) )
if success: if success:
return f"✅ 成功订阅《{name}》!\n如有更新将推送到本" return f"✅ 成功订阅《{name}》!\n如有更新将推送到本会话"
else: else:
return "❌ 订阅失败,数据库错误。" return "❌ 订阅失败,数据库错误。"
except (BangumiApiError, DatabaseError, SubscriptionError) as e: except (BangumiApiError, DatabaseError, SubscriptionError) as e:
logger.error(f"SubscriptionService.subscribe 失败: {e}") logger.error(f"SubscriptionService.subscribe 失败: {e}")
return f"❌ 处理失败: {e}" return f"❌ 处理失败: {e}"
async def unsubscribe(self, group_id: str, query: str) -> str: async def unsubscribe(self, session_id: str, query: str) -> str:
""" """
取消订阅逻辑。 取消订阅逻辑。
""" """
logger.info(f"处理取消追番请求: {query}, group_id={group_id}") logger.info(f"处理取消追番请求: {query}, session_id={session_id}")
try: try:
error_msg, subject_info = self._match_local_subscription(group_id, query) error_msg, subject_info = self._match_local_subscription(session_id, query)
if error_msg: if error_msg:
return error_msg return error_msg
if not subject_info: if not subject_info:
@@ -200,7 +202,7 @@ class SubscriptionService:
subject_id = subject_info["subject_id"] subject_id = subject_info["subject_id"]
name = subject_info["name"] name = subject_info["name"]
success = self.storage.remove_subscription(group_id, subject_id) success = self.storage.remove_subscription(session_id, subject_id)
if success: if success:
return f"✅ 已成功取消订阅《{name}》。" return f"✅ 已成功取消订阅《{name}》。"
else: else:
@@ -210,10 +212,10 @@ class SubscriptionService:
return f"❌ 处理失败: {e}" return f"❌ 处理失败: {e}"
def _match_local_subscription( def _match_local_subscription(
self, group_id: str, query: str self, session_id: str, query: str
) -> tuple[str | None, UnsubscribeMatch | None]: ) -> tuple[str | None, UnsubscribeMatch | None]:
""" """
在当前群组的本地订阅中做模糊匹配。 在当前会话的本地订阅中做模糊匹配。
""" """
normalized_query = str(query).strip() normalized_query = str(query).strip()
if not normalized_query: if not normalized_query:
@@ -221,10 +223,10 @@ class SubscriptionService:
# 取 6 条用于判断是否超过默认展示上限5 条) # 取 6 条用于判断是否超过默认展示上限5 条)
candidates = self.storage.find_group_subscription_candidates( candidates = self.storage.find_group_subscription_candidates(
group_id=group_id, keyword=normalized_query, limit=6 group_id=session_id, keyword=normalized_query, limit=6
) )
if not candidates: if not candidates:
return f"❌ 未找到与「{normalized_query}」匹配的本订阅番剧。", None return f"❌ 未找到与「{normalized_query}」匹配的本会话订阅番剧。", None
if len(candidates) == 1: if len(candidates) == 1:
subject = candidates[0] subject = candidates[0]
@@ -318,13 +320,16 @@ class SubscriptionService:
f"🔔 番剧《{subject_name}》更新啦!\n{episode.ep} 集:{episode.name_cn or episode.name}" f"🔔 番剧《{subject_name}》更新啦!\n{episode.ep} 集:{episode.name_cn or episode.name}"
) )
for group_id in subscribed_groups: for session_id in subscribed_groups:
try: try:
await StarTools.send_message_by_id( if self.context:
type="GroupMessage", id=group_id, message_chain=chain await self.context.send_message(session_id, chain)
) else:
logger.info(f"向群组 {group_id} 发送《{subject_name}》更新通知成功。") await StarTools.send_message_by_id(
type="GroupMessage", id=session_id, message_chain=chain
)
logger.info(f"向会话 {session_id} 发送《{subject_name}》更新通知成功。")
except Exception as e: except Exception as e:
logger.error( logger.error(
f"群组 {group_id} 发送《{subject_name}》更新通知失败: {e}" f"会话 {session_id} 发送《{subject_name}》更新通知失败: {e}"
) )