Files
astrbot_plugin_bangumi/main.py
2026-04-15 19:12:56 +08:00

511 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import copy
import os
import re
from collections.abc import AsyncGenerator
import aiohttp
import astrbot.api.message_components as Comp
from astrbot.api import logger
from astrbot.api.all import AstrBotConfig
from astrbot.api.event import AstrMessageEvent, MessageChain, filter
# 导入配置与管理
from astrbot.api.star import Context, Star, StarTools, register
from astrbot.core.utils.session_waiter import (
SessionController,
SessionFilter,
session_waiter,
)
from .src.config import ConfigManager
from .src.db import BangumiRepository
# 导入逻辑服务
from .src.services import BangumiService, SearchService, SubscriptionService
from .src.utils import SchedulerManager
@register(
"astrbot_plugin_bangumi_enhance",
"united_pooh",
"AstrBot Bangumi 增强版:为 AstrBot 打造的一站式 Bangumi 追番助手。支持番剧/漫画图文搜索、每日放送时刻表查看及集数更新自动提醒。",
"v1.1.1",
"https://github.com/united-pooh/astrbot_plugin_bangumi",
)
class BangumiPlugin(Star):
def __init__(self, context: Context, config: AstrBotConfig) -> None:
"""
初始化 BangumiPlugin 插件。
"""
super().__init__(context)
self.config = config
self.config_manager = ConfigManager(config)
self.scheduler_manager = SchedulerManager()
self.session: aiohttp.ClientSession | None = None
self.storage: BangumiRepository | None = None
self.service: BangumiService | None = None
self.subscription_service: SubscriptionService | None = None
self.search_service: SearchService | None = None
async def initialize(self) -> None:
"""
插件加载时自动运行的初始化方法。
"""
# 0. 提前获取插件数据目录(必须先于所有依赖 StarTools 的操作)
plugin_data_dir = StarTools.get_data_dir()
# 1. 初始化数据库
try:
db_path = os.path.join(plugin_data_dir, "data.db")
self.storage = BangumiRepository(db_path=db_path)
except (OSError, RuntimeError, ValueError, TypeError) as e:
logger.error(f"数据库初始化失败: {e}")
# 2. 初始化网络会话 (Shared Session)
self.session = aiohttp.ClientSession()
# 3. 初始化核心 API 服务
try:
proxy_url = None
proxy_host = self.config_manager.get_proxy_http()
proxy_port = self.config_manager.get_port()
if proxy_host and proxy_port:
proxy_url = f"{proxy_host}:{proxy_port}"
self.service = BangumiService(
access_token=self.config_manager.get_access_token(),
user_agent=self.config_manager.get_user_agent(),
proxy=proxy_url,
session=self.session,
)
except (RuntimeError, ValueError, TypeError) as e:
logger.error(f"服务初始化失败: {e}")
# 4. 初始化业务逻辑服务 (Dependency Injection)
if self.service:
# 搜索服务
self.search_service = SearchService(
service=self.service,
config_manager=self.config_manager,
session=self.session,
)
# 订阅服务
if self.storage:
self.subscription_service = SubscriptionService(
repository=self.storage,
service=self.service,
config_manager=self.config_manager,
session=self.session,
context=self.context,
)
# 5. 添加定时更新任务
if self.subscription_service:
try:
self.scheduler_manager.add_job(
func=self.subscription_service.check_updates,
trigger="cron",
minute=0,
)
logger.info("Bangumi 插件定时更新任务已启动")
except (RuntimeError, ValueError, TypeError) as e:
logger.error(f"添加定时任务失败: {e}")
logger.info("Bangumi 插件初始化流程结束")
# --- 命令处理区 ---
@staticmethod
def _resolve_session_key(event: AstrMessageEvent) -> str:
return event.unified_msg_origin
@staticmethod
def _parse_subscribe_selection(raw_text: str) -> int | None:
match = re.match(r"^/?追番\s+(\d+)\s*$", raw_text.strip())
if not match:
return None
try:
return int(match.group(1))
except ValueError:
return None
@filter.command("bgm")
async def search(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""全类别搜索 Bangumi 条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=None
):
yield result
@filter.command("bgm番剧")
async def search_anime(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""仅搜索 TV 动画条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=[2], subject_tags=["TV"]
):
yield result
@filter.command("bgm剧场版")
async def search_movie(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""仅搜索剧场版动画条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=[2], subject_tags=["剧场版"]
):
yield result
@filter.command("bgm漫画")
async def search_manga(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""仅搜索漫画条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=[1], subject_tags=["漫画"]
):
yield result
@filter.command("today")
async def calendar(self, event: AstrMessageEvent) -> AsyncGenerator[object, None]:
"""获取今日番剧放送表。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_calendar(event):
yield result
@filter.command("追番")
async def subscribe(
self, event: AstrMessageEvent, query: str
) -> AsyncGenerator[object, None]:
"""订阅番剧,更新时自动通知。"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
session_key = self._resolve_session_key(event)
(
error_msg,
candidates,
) = await self.subscription_service.get_subscribe_candidates(
keyword=query,
limit=self.config_manager.get_max_fuzzy_results(),
)
if error_msg:
yield event.plain_result(error_msg)
return
if not candidates:
yield event.plain_result("🔍 未找到相关番剧")
return
if len(candidates) == 1:
result = await self.subscription_service.subscribe_by_subject_id(
session_id=session_key,
subject_id=candidates[0]["subject_id"],
)
yield event.plain_result(result)
return
candidate_lines = ["⚠️ 匹配到多个候选,请使用 `/追番 序号` 确认:"]
for index, candidate in enumerate(candidates, start=1):
candidate_lines.append(
f"{index}. {candidate['name']} (ID: {candidate['subject_id']})"
)
candidate_lines.append("5分钟内有效若发送其他命令将自动取消本次确认。")
yield event.plain_result("\n".join(candidate_lines))
cancel_commands = {
"bgm",
"bgm番剧",
"bgm剧场版",
"bgm漫画",
"today",
"弃坑",
}
session_key = self._resolve_session_key(event)
class ConversationSessionFilter(SessionFilter):
def filter(self, wait_event: AstrMessageEvent) -> str:
return BangumiPlugin._resolve_session_key(wait_event)
@session_waiter(timeout=300)
async def subscribe_confirm_waiter(
controller: SessionController,
wait_event: AstrMessageEvent,
) -> None:
incoming_text = wait_event.get_message_str().strip()
first_token = incoming_text.split(maxsplit=1)[0] if incoming_text else ""
normalized_token = (
first_token[1:] if first_token.startswith("/") else first_token
)
if normalized_token in cancel_commands:
new_event = copy.copy(wait_event)
self.context.get_event_queue().put_nowait(new_event)
wait_event.stop_event()
controller.stop()
return
selected_index = self._parse_subscribe_selection(incoming_text)
if selected_index is None:
if normalized_token == "追番":
new_event = copy.copy(wait_event)
self.context.get_event_queue().put_nowait(new_event)
wait_event.stop_event()
controller.stop()
return
controller.keep(timeout=0)
return
if selected_index < 1 or selected_index > len(candidates):
await wait_event.send(
MessageChain(
[Comp.Plain(f"❌ 序号超出范围,请输入 1-{len(candidates)}")]
)
)
controller.keep(timeout=0)
return
selected = candidates[selected_index - 1]
result = await self.subscription_service.subscribe_by_subject_id(
session_id=session_key,
subject_id=selected["subject_id"],
)
await wait_event.send(MessageChain([Comp.Plain(result)]))
wait_event.stop_event()
controller.stop()
try:
await subscribe_confirm_waiter(
event,
session_filter=ConversationSessionFilter(),
)
except TimeoutError:
yield event.plain_result("⏰ 候选确认已过期,请重新使用 `/追番 关键词`。")
@filter.command("弃坑")
async def unsubscribe(
self, event: AstrMessageEvent, query: str
) -> AsyncGenerator[object, None]:
"""取消订阅番剧。"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
session_key = self._resolve_session_key(event)
result = await self.subscription_service.unsubscribe(session_key, query)
yield event.plain_result(result)
@filter.command("追番列表")
async def list_subscriptions(
self, event: AstrMessageEvent
) -> AsyncGenerator[object, None]:
"""列举当前会话的所有追番订阅。"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
session_key = self._resolve_session_key(event)
result = self.subscription_service.list_subscriptions(session_key)
yield event.plain_result(result)
# --- LLM Tool 区 ---
@filter.llm_tool(name="bangumi_search")
async def llm_search(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""在 Bangumi 数据库中全类别搜索番剧、漫画、游戏等条目,展示条目简介、评分等信息。当用户询问某个作品的详情、评价、评分时调用。
Args:
query(string): 搜索关键词,如作品名称
top_k(number): 返回结果数量,默认 1最多 5
"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, int(top_k), subject_type=None
):
yield result
@filter.llm_tool(name="bangumi_search_anime")
async def llm_search_anime(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""在 Bangumi 数据库中搜索 TV 番剧/动漫条目,展示动漫详情与评分。当用户明确询问某部 TV 动漫/番剧时调用。
Args:
query(string): 搜索关键词,如番剧名称
top_k(number): 返回结果数量,默认 1最多 5
"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, int(top_k), subject_type=[2], subject_tags=["TV"]
):
yield result
@filter.llm_tool(name="bangumi_search_movie")
async def llm_search_movie(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""在 Bangumi 数据库中搜索剧场版动画条目。当用户询问某部剧场版/电影动画的信息时调用。
Args:
query(string): 搜索关键词,如剧场版名称
top_k(number): 返回结果数量,默认 1最多 5
"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, int(top_k), subject_type=[2], subject_tags=["剧场版"]
):
yield result
@filter.llm_tool(name="bangumi_search_manga")
async def llm_search_manga(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""在 Bangumi 数据库中搜索漫画条目,展示漫画详情与评分。当用户询问某部漫画的信息时调用。
Args:
query(string): 搜索关键词,如漫画名称
top_k(number): 返回结果数量,默认 1最多 5
"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, int(top_k), subject_type=[1], subject_tags=["漫画"]
):
yield result
@filter.llm_tool(name="bangumi_today_calendar")
async def llm_today_calendar(
self, event: AstrMessageEvent
) -> AsyncGenerator[object, None]:
"""获取今日番剧放送时刻表,展示今天有哪些动漫新番更新。当用户询问"今天有什么番""今天更新了什么""今日放送"时调用。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_calendar(event):
yield result
@filter.llm_tool(name="bangumi_list_subscriptions")
async def llm_list_subscriptions(
self, event: AstrMessageEvent
) -> AsyncGenerator[object, None]:
"""列举当前会话已订阅的所有追番包含番剧名称、Bangumi ID 和当前更新集数。当用户询问"我订阅了哪些番""追番列表""我在追什么"时调用。"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
session_key = self._resolve_session_key(event)
result = self.subscription_service.list_subscriptions(session_key)
yield event.plain_result(result)
@filter.llm_tool(name="bangumi_subscribe")
async def llm_subscribe(
self,
event: AstrMessageEvent,
query: str = "",
subject_id: str = "",
) -> AsyncGenerator[object, None]:
"""订阅番剧更新通知。当用户表达想要追番、订阅、收到某番剧更新提醒时调用。
若提供 subject_id 则直接订阅;否则按 query 关键词搜索,唯一匹配时自动订阅,多个候选时返回列表供用户确认(用户选定后再次调用并传入 subject_id
Args:
query(string): 要订阅的番剧名称关键词,与 subject_id 二选一
subject_id(string): 番剧的 Bangumi ID优先级高于 query
"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
session_key = self._resolve_session_key(event)
if subject_id:
result = await self.subscription_service.subscribe_by_subject_id(
session_id=session_key,
subject_id=subject_id,
)
yield event.plain_result(result)
return
if not query:
yield event.plain_result("❌ 请提供番剧名称关键词或 Bangumi ID")
return
error_msg, candidates = await self.subscription_service.get_subscribe_candidates(
keyword=query,
limit=self.config_manager.get_max_fuzzy_results(),
)
if error_msg:
yield event.plain_result(error_msg)
return
if not candidates:
yield event.plain_result("🔍 未找到相关番剧")
return
if len(candidates) == 1:
result = await self.subscription_service.subscribe_by_subject_id(
session_id=session_key,
subject_id=candidates[0]["subject_id"],
)
yield event.plain_result(result)
return
lines = ["⚠️ 匹配到多个候选,请告知序号或提供 Bangumi ID 以确认订阅:"]
for index, candidate in enumerate(candidates, start=1):
lines.append(f"{index}. {candidate['name']} (ID: {candidate['subject_id']})")
yield event.plain_result("\n".join(lines))
@filter.llm_tool(name="bangumi_unsubscribe")
async def llm_unsubscribe(
self, event: AstrMessageEvent, query: str
) -> AsyncGenerator[object, None]:
"""取消订阅番剧更新通知(弃坑)。当用户表达想要退订、弃坑、取消追某部番剧时调用。
Args:
query(string): 要取消订阅的番剧名称关键词或 Bangumi ID
"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
session_key = self._resolve_session_key(event)
result = await self.subscription_service.unsubscribe(session_key, query)
yield event.plain_result(result)
async def terminate(self) -> None:
logger.info("正在清理 Bangumi 插件资源...")
if self.scheduler_manager.scheduler.running:
self.scheduler_manager.scheduler.shutdown(wait=False)
if self.session and not self.session.closed:
await self.session.close()
logger.info("已关闭共享网络会话")
await super().terminate()