Files
astrbot_plugin_bangumi/main.py
chenxiangtong a05ce6e07e fork
2026-03-26 17:38:47 +08:00

340 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import copy
import os
import re
from collections.abc import AsyncGenerator
import aiohttp
import astrbot.api.message_components as Comp
from astrbot.api import logger
from astrbot.api.all import AstrBotConfig
from astrbot.api.event import AstrMessageEvent, MessageChain, filter
# 导入配置与管理
from astrbot.api.star import Context, Star, StarTools, register
from astrbot.core.utils.session_waiter import (
SessionController,
SessionFilter,
session_waiter,
)
from .src.config import ConfigManager
from .src.db import BangumiRepository
# 导入逻辑服务
from .src.services import BangumiService, SearchService, SubscriptionService
from .src.utils import SchedulerManager
@register(
"astrbot_plugin_bangumi_enhance",
"united_pooh",
"AstrBot Bangumi 增强版:为 AstrBot 打造的一站式 Bangumi 追番助手。支持番剧/漫画图文搜索、每日放送时刻表查看及集数更新自动提醒。",
"v1.1.1",
"https://github.com/united-pooh/astrbot_plugin_bangumi",
)
class BangumiPlugin(Star):
def __init__(self, context: Context, config: AstrBotConfig) -> None:
"""
初始化 BangumiPlugin 插件。
"""
super().__init__(context)
self.config = config
self.config_manager = ConfigManager(config)
self.scheduler_manager = SchedulerManager()
self.session: aiohttp.ClientSession | None = None
self.storage: BangumiRepository | None = None
self.service: BangumiService | None = None
self.subscription_service: SubscriptionService | None = None
self.search_service: SearchService | None = None
async def initialize(self) -> None:
"""
插件加载时自动运行的初始化方法。
"""
# 0. 提前获取插件数据目录(必须先于所有依赖 StarTools 的操作)
plugin_data_dir = StarTools.get_data_dir()
# 1. 初始化数据库
try:
db_path = os.path.join(plugin_data_dir, "data.db")
self.storage = BangumiRepository(db_path=db_path)
except (OSError, RuntimeError, ValueError, TypeError) as e:
logger.error(f"数据库初始化失败: {e}")
# 2. 初始化网络会话 (Shared Session)
self.session = aiohttp.ClientSession()
# 3. 初始化核心 API 服务
try:
proxy_url = None
proxy_host = self.config_manager.get_proxy_http()
proxy_port = self.config_manager.get_port()
if proxy_host and proxy_port:
proxy_url = f"{proxy_host}:{proxy_port}"
self.service = BangumiService(
access_token=self.config_manager.get_access_token(),
user_agent=self.config_manager.get_user_agent(),
proxy=proxy_url,
session=self.session,
)
except (RuntimeError, ValueError, TypeError) as e:
logger.error(f"服务初始化失败: {e}")
# 4. 初始化业务逻辑服务 (Dependency Injection)
if self.service:
# 搜索服务
self.search_service = SearchService(
service=self.service,
config_manager=self.config_manager,
session=self.session,
)
# 订阅服务
if self.storage:
self.subscription_service = SubscriptionService(
repository=self.storage,
service=self.service,
config_manager=self.config_manager,
session=self.session,
)
# 5. 添加定时更新任务
if self.subscription_service:
try:
self.scheduler_manager.add_job(
func=self.subscription_service.check_updates,
trigger="cron",
minute=0,
)
logger.info("Bangumi 插件定时更新任务已启动")
except (RuntimeError, ValueError, TypeError) as e:
logger.error(f"添加定时任务失败: {e}")
logger.info("Bangumi 插件初始化流程结束")
# --- 命令处理区 ---
@staticmethod
def _resolve_session_key(event: AstrMessageEvent) -> str | None:
session_key: str | None = getattr(event, "session_id", None)
if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"):
session_key = event.message_obj.group_id
return session_key
@staticmethod
def _parse_subscribe_selection(raw_text: str) -> int | None:
match = re.match(r"^/?追番\s+(\d+)\s*$", raw_text.strip())
if not match:
return None
try:
return int(match.group(1))
except ValueError:
return None
@filter.command("bgm")
async def search(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""全类别搜索 Bangumi 条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=None
):
yield result
@filter.command("bgm番剧")
async def search_anime(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""仅搜索 TV 动画条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=[2], subject_tags=["TV"]
):
yield result
@filter.command("bgm剧场版")
async def search_movie(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""仅搜索剧场版动画条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=[2], subject_tags=["剧场版"]
):
yield result
@filter.command("bgm漫画")
async def search_manga(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""仅搜索漫画条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=[1], subject_tags=["漫画"]
):
yield result
@filter.command("today")
async def calendar(self, event: AstrMessageEvent) -> AsyncGenerator[object, None]:
"""获取今日番剧放送表。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_calendar(event):
yield result
@filter.command("追番")
async def subscribe(
self, event: AstrMessageEvent, query: str
) -> AsyncGenerator[object, None]:
"""订阅番剧,更新时自动通知。"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
group_id = self._resolve_session_key(event)
if not group_id:
yield event.plain_result("❌ 无法获取群组ID")
return
(
error_msg,
candidates,
) = await self.subscription_service.get_subscribe_candidates(
keyword=query,
limit=self.config_manager.get_max_fuzzy_results(),
)
if error_msg:
yield event.plain_result(error_msg)
return
if not candidates:
yield event.plain_result("🔍 未找到相关番剧")
return
if len(candidates) == 1:
result = await self.subscription_service.subscribe_by_subject_id(
group_id=group_id,
subject_id=candidates[0]["subject_id"],
)
yield event.plain_result(result)
return
candidate_lines = ["⚠️ 匹配到多个候选,请使用 `/追番 序号` 确认:"]
for index, candidate in enumerate(candidates, start=1):
candidate_lines.append(
f"{index}. {candidate['name']} (ID: {candidate['subject_id']})"
)
candidate_lines.append("5分钟内有效若发送其他命令将自动取消本次确认。")
yield event.plain_result("\n".join(candidate_lines))
cancel_commands = {
"bgm",
"bgm番剧",
"bgm剧场版",
"bgm漫画",
"today",
"弃坑",
}
session_key = group_id
class GroupSessionFilter(SessionFilter):
def filter(self, wait_event: AstrMessageEvent) -> str:
wait_session_key = BangumiPlugin._resolve_session_key(wait_event)
return wait_session_key or wait_event.unified_msg_origin
@session_waiter(timeout=300)
async def subscribe_confirm_waiter(
controller: SessionController,
wait_event: AstrMessageEvent,
) -> None:
incoming_text = wait_event.get_message_str().strip()
first_token = incoming_text.split(maxsplit=1)[0] if incoming_text else ""
normalized_token = (
first_token[1:] if first_token.startswith("/") else first_token
)
if normalized_token in cancel_commands:
new_event = copy.copy(wait_event)
self.context.get_event_queue().put_nowait(new_event)
wait_event.stop_event()
controller.stop()
return
selected_index = self._parse_subscribe_selection(incoming_text)
if selected_index is None:
if normalized_token == "追番":
new_event = copy.copy(wait_event)
self.context.get_event_queue().put_nowait(new_event)
wait_event.stop_event()
controller.stop()
return
controller.keep(timeout=0)
return
if selected_index < 1 or selected_index > len(candidates):
await wait_event.send(
MessageChain(
[Comp.Plain(f"❌ 序号超出范围,请输入 1-{len(candidates)}")]
)
)
controller.keep(timeout=0)
return
selected = candidates[selected_index - 1]
result = await self.subscription_service.subscribe_by_subject_id(
group_id=session_key,
subject_id=selected["subject_id"],
)
await wait_event.send(MessageChain([Comp.Plain(result)]))
wait_event.stop_event()
controller.stop()
try:
await subscribe_confirm_waiter(
event,
session_filter=GroupSessionFilter(),
)
except TimeoutError:
yield event.plain_result("⏰ 候选确认已过期,请重新使用 `/追番 关键词`。")
@filter.command("弃坑")
async def unsubscribe(
self, event: AstrMessageEvent, query: str
) -> AsyncGenerator[object, None]:
"""取消订阅番剧。"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
group_id: str | None = getattr(event, "session_id", None)
if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"):
group_id = event.message_obj.group_id
if not group_id:
yield event.plain_result("❌ 无法获取群组ID")
return
result = await self.subscription_service.unsubscribe(group_id, query)
yield event.plain_result(result)
async def terminate(self) -> None:
logger.info("正在清理 Bangumi 插件资源...")
if self.scheduler_manager.scheduler.running:
self.scheduler_manager.scheduler.shutdown(wait=False)
if self.session and not self.session.closed:
await self.session.close()
logger.info("已关闭共享网络会话")
await super().terminate()