This commit is contained in:
chenxiangtong
2026-03-26 17:38:47 +08:00
commit a05ce6e07e
54 changed files with 5779 additions and 0 deletions

339
main.py Normal file
View File

@@ -0,0 +1,339 @@
import copy
import os
import re
from collections.abc import AsyncGenerator
import aiohttp
import astrbot.api.message_components as Comp
from astrbot.api import logger
from astrbot.api.all import AstrBotConfig
from astrbot.api.event import AstrMessageEvent, MessageChain, filter
# 导入配置与管理
from astrbot.api.star import Context, Star, StarTools, register
from astrbot.core.utils.session_waiter import (
SessionController,
SessionFilter,
session_waiter,
)
from .src.config import ConfigManager
from .src.db import BangumiRepository
# 导入逻辑服务
from .src.services import BangumiService, SearchService, SubscriptionService
from .src.utils import SchedulerManager
@register(
"astrbot_plugin_bangumi_enhance",
"united_pooh",
"AstrBot Bangumi 增强版:为 AstrBot 打造的一站式 Bangumi 追番助手。支持番剧/漫画图文搜索、每日放送时刻表查看及集数更新自动提醒。",
"v1.1.1",
"https://github.com/united-pooh/astrbot_plugin_bangumi",
)
class BangumiPlugin(Star):
def __init__(self, context: Context, config: AstrBotConfig) -> None:
"""
初始化 BangumiPlugin 插件。
"""
super().__init__(context)
self.config = config
self.config_manager = ConfigManager(config)
self.scheduler_manager = SchedulerManager()
self.session: aiohttp.ClientSession | None = None
self.storage: BangumiRepository | None = None
self.service: BangumiService | None = None
self.subscription_service: SubscriptionService | None = None
self.search_service: SearchService | None = None
async def initialize(self) -> None:
"""
插件加载时自动运行的初始化方法。
"""
# 0. 提前获取插件数据目录(必须先于所有依赖 StarTools 的操作)
plugin_data_dir = StarTools.get_data_dir()
# 1. 初始化数据库
try:
db_path = os.path.join(plugin_data_dir, "data.db")
self.storage = BangumiRepository(db_path=db_path)
except (OSError, RuntimeError, ValueError, TypeError) as e:
logger.error(f"数据库初始化失败: {e}")
# 2. 初始化网络会话 (Shared Session)
self.session = aiohttp.ClientSession()
# 3. 初始化核心 API 服务
try:
proxy_url = None
proxy_host = self.config_manager.get_proxy_http()
proxy_port = self.config_manager.get_port()
if proxy_host and proxy_port:
proxy_url = f"{proxy_host}:{proxy_port}"
self.service = BangumiService(
access_token=self.config_manager.get_access_token(),
user_agent=self.config_manager.get_user_agent(),
proxy=proxy_url,
session=self.session,
)
except (RuntimeError, ValueError, TypeError) as e:
logger.error(f"服务初始化失败: {e}")
# 4. 初始化业务逻辑服务 (Dependency Injection)
if self.service:
# 搜索服务
self.search_service = SearchService(
service=self.service,
config_manager=self.config_manager,
session=self.session,
)
# 订阅服务
if self.storage:
self.subscription_service = SubscriptionService(
repository=self.storage,
service=self.service,
config_manager=self.config_manager,
session=self.session,
)
# 5. 添加定时更新任务
if self.subscription_service:
try:
self.scheduler_manager.add_job(
func=self.subscription_service.check_updates,
trigger="cron",
minute=0,
)
logger.info("Bangumi 插件定时更新任务已启动")
except (RuntimeError, ValueError, TypeError) as e:
logger.error(f"添加定时任务失败: {e}")
logger.info("Bangumi 插件初始化流程结束")
# --- 命令处理区 ---
@staticmethod
def _resolve_session_key(event: AstrMessageEvent) -> str | None:
session_key: str | None = getattr(event, "session_id", None)
if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"):
session_key = event.message_obj.group_id
return session_key
@staticmethod
def _parse_subscribe_selection(raw_text: str) -> int | None:
match = re.match(r"^/?追番\s+(\d+)\s*$", raw_text.strip())
if not match:
return None
try:
return int(match.group(1))
except ValueError:
return None
@filter.command("bgm")
async def search(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""全类别搜索 Bangumi 条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=None
):
yield result
@filter.command("bgm番剧")
async def search_anime(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""仅搜索 TV 动画条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=[2], subject_tags=["TV"]
):
yield result
@filter.command("bgm剧场版")
async def search_movie(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""仅搜索剧场版动画条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=[2], subject_tags=["剧场版"]
):
yield result
@filter.command("bgm漫画")
async def search_manga(
self, event: AstrMessageEvent, query: str, top_k: int = 1
) -> AsyncGenerator[object, None]:
"""仅搜索漫画条目。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_subject_search(
event, query, top_k, subject_type=[1], subject_tags=["漫画"]
):
yield result
@filter.command("today")
async def calendar(self, event: AstrMessageEvent) -> AsyncGenerator[object, None]:
"""获取今日番剧放送表。"""
if not self.search_service:
yield event.plain_result("❌ 搜索服务未就绪")
return
async for result in self.search_service.handle_calendar(event):
yield result
@filter.command("追番")
async def subscribe(
self, event: AstrMessageEvent, query: str
) -> AsyncGenerator[object, None]:
"""订阅番剧,更新时自动通知。"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
group_id = self._resolve_session_key(event)
if not group_id:
yield event.plain_result("❌ 无法获取群组ID")
return
(
error_msg,
candidates,
) = await self.subscription_service.get_subscribe_candidates(
keyword=query,
limit=self.config_manager.get_max_fuzzy_results(),
)
if error_msg:
yield event.plain_result(error_msg)
return
if not candidates:
yield event.plain_result("🔍 未找到相关番剧")
return
if len(candidates) == 1:
result = await self.subscription_service.subscribe_by_subject_id(
group_id=group_id,
subject_id=candidates[0]["subject_id"],
)
yield event.plain_result(result)
return
candidate_lines = ["⚠️ 匹配到多个候选,请使用 `/追番 序号` 确认:"]
for index, candidate in enumerate(candidates, start=1):
candidate_lines.append(
f"{index}. {candidate['name']} (ID: {candidate['subject_id']})"
)
candidate_lines.append("5分钟内有效若发送其他命令将自动取消本次确认。")
yield event.plain_result("\n".join(candidate_lines))
cancel_commands = {
"bgm",
"bgm番剧",
"bgm剧场版",
"bgm漫画",
"today",
"弃坑",
}
session_key = group_id
class GroupSessionFilter(SessionFilter):
def filter(self, wait_event: AstrMessageEvent) -> str:
wait_session_key = BangumiPlugin._resolve_session_key(wait_event)
return wait_session_key or wait_event.unified_msg_origin
@session_waiter(timeout=300)
async def subscribe_confirm_waiter(
controller: SessionController,
wait_event: AstrMessageEvent,
) -> None:
incoming_text = wait_event.get_message_str().strip()
first_token = incoming_text.split(maxsplit=1)[0] if incoming_text else ""
normalized_token = (
first_token[1:] if first_token.startswith("/") else first_token
)
if normalized_token in cancel_commands:
new_event = copy.copy(wait_event)
self.context.get_event_queue().put_nowait(new_event)
wait_event.stop_event()
controller.stop()
return
selected_index = self._parse_subscribe_selection(incoming_text)
if selected_index is None:
if normalized_token == "追番":
new_event = copy.copy(wait_event)
self.context.get_event_queue().put_nowait(new_event)
wait_event.stop_event()
controller.stop()
return
controller.keep(timeout=0)
return
if selected_index < 1 or selected_index > len(candidates):
await wait_event.send(
MessageChain(
[Comp.Plain(f"❌ 序号超出范围,请输入 1-{len(candidates)}")]
)
)
controller.keep(timeout=0)
return
selected = candidates[selected_index - 1]
result = await self.subscription_service.subscribe_by_subject_id(
group_id=session_key,
subject_id=selected["subject_id"],
)
await wait_event.send(MessageChain([Comp.Plain(result)]))
wait_event.stop_event()
controller.stop()
try:
await subscribe_confirm_waiter(
event,
session_filter=GroupSessionFilter(),
)
except TimeoutError:
yield event.plain_result("⏰ 候选确认已过期,请重新使用 `/追番 关键词`。")
@filter.command("弃坑")
async def unsubscribe(
self, event: AstrMessageEvent, query: str
) -> AsyncGenerator[object, None]:
"""取消订阅番剧。"""
if not self.subscription_service:
yield event.plain_result("❌ 订阅服务未就绪")
return
group_id: str | None = getattr(event, "session_id", None)
if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"):
group_id = event.message_obj.group_id
if not group_id:
yield event.plain_result("❌ 无法获取群组ID")
return
result = await self.subscription_service.unsubscribe(group_id, query)
yield event.plain_result(result)
async def terminate(self) -> None:
logger.info("正在清理 Bangumi 插件资源...")
if self.scheduler_manager.scheduler.running:
self.scheduler_manager.scheduler.shutdown(wait=False)
if self.session and not self.session.closed:
await self.session.close()
logger.info("已关闭共享网络会话")
await super().terminate()