commit a05ce6e07e9b31b7945c88d80092d72ea37497f0 Author: chenxiangtong Date: Thu Mar 26 17:38:47 2026 +0800 fork diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..faf4130 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(del \"d:\\\\Project\\\\astrbot_plugin_bangumi\\\\src\\\\utils\\\\browser.py\")" + ] + } +} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ad0d3e3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +__pycache__/ +.DS_Store +.idea/ +.pytest_cache/ +.ruff_cache +.vscode/ +settings.json +data/ +.idea/ +*.xml +*.iml +.env +.png diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..570dc9b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,37 @@ +# Changelog + +## v1.1.1 + +### 功能拓展 +- **搜索优化**: `/追番` 现在支持候选确认:当搜索到多个结果时,会先返回列表,可通过 `/追番 序号` 选择目标,减少订错番剧的情况。 +- **取消订阅优化**: `/弃坑` 的匹配更贴近群聊使用场景:优先在本群已订阅列表中匹配,取消订阅更准确。 + +## v1.1.0 + +### 新增功能 +- **取消订阅**: 新增 `/弃坑` 命令,支持群组移除已订阅的番剧更新提醒。 +- **更新卡片渲染**: 引入 `EpisodeRenderer`,在番剧更新时自动推送精美的单集图文通知卡片。 +- **命令别名**: 简化常用命令,支持使用 `/bgm` 代替 `/bgm搜索`。 + +### 核心优化 +- **更新检测逻辑**: 重构剧集更新判定算法,结合播出日期与评论互动数据(`comment > 0`),显著降低更新误报率。 +- **全链路 Base64 渲染**: 渲染引擎(放送表、剧集卡片)全面转向 Base64 内存流,移除临时文件 IO,提升并发性能。 +- **Playwright 鲁棒性**: 优化浏览器安装与初始化逻辑,支持非交互式环境安装,并提供实时状态日志。 +- **类型系统增强**: 引入完整的 `SubjectType`、`ImageSize` 等枚举类型,提升代码可维护性。 +- **代码重构**: 优化 `SubjectsService` 的数据解析流,通过 Pydantic 严格过滤异常 API 返回。 + +## v1.0.0 + +### 新增功能 +- **分类搜索**: 新增 `/bgm番剧`、`/bgm剧场版`、`/bgm漫画` 命令,支持更精准的类型过滤。 +- **每日放送**: 新增 `/today` 命令,渲染精美的每日番剧放送表图片。 +- **追番系统**: 新增 `/追番` 功能,支持订阅番剧并在有新集数更新时自动向群组推送通知。 +- **通用搜索优化**: `/bgm搜索` 命令现在支持更完善的参数处理和 top_k 结果返回。 + +### 代码优化 +- **渲染引擎重构**: 引入 `SubjectRenderer` 和 `CalendarRenderer`,基于 Playwright 实现更美观的图文卡片。 +- **数据库集成**: 引入 SQLAlchemy 驱动的 SQLite 存储,用于管理番剧信息和订阅关系。 +- **自动更新逻辑**: 新增定时任务,每小时自动检查订阅番剧的更新状态。 +- **重构逻辑**: 将搜索与渲染逻辑分离,提取出 `_render_subjects` 和 `_handle_subject` 核心方法,提高代码复用性。 +- **修复 Bug**: 修复了搜索命令中生成器未正确迭代导致无响应的问题。 +- **类型提示**: 为核心方法添加了完善的类型注解和文档说明。 diff --git a/LICENSE-2.0 b/LICENSE-2.0 new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE-2.0 @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..1a85ec5 --- /dev/null +++ b/README.md @@ -0,0 +1,98 @@ +
+ +# Bangumi 搜索插件使用指南 +[![repo](https://img.shields.io/badge/repo-v1.1.1-blue.svg)](https://github.com/united-pooh/astrbot_plugin_bangumi) +[![License](https://img.shields.io/badge/license-Apache%202.0-green.svg)](LICENSE-2.0) +[![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.0.0-orange.svg)](https://github.com/Soulter/AstrBot) +[![Python](https://img.shields.io/badge/python-3.12%2B-blue.svg)](https://www.python.org/) + +**和群友一起追番** +
+ +> **astrbot-plugin-bangumi** 是一个基于 AstrBot 框架的 Bangumi (番组计划) 信息查询与追番插件。它通过对接 Bangumi API,为机器人用户提供精美的图文条目详情、实时放送时刻表,并具备自动化的订阅更新监控系统。无论是想快速查询评分,还是在群内实时接收番剧更新通知,它都能为您提供优雅的交互体验。 + + + +> [!NOTE] +> 本项目在 [astrbot_plugin_bangumi](https://github.com/Amatsutsumi/astrbot_plugin_bangumi) 的基础上进行二次开发 + +## 📌 核心命令 + +### 1. 基础搜索(图文卡片) + +| 命令 | 功能 | 参数 | 示例 | +|:-----|:-----|:-----|:-----| +| `/bgm` | 全类别搜索 | `<关键词\|ID> [top_k]` | `/bgm 进击的巨人 3` | +| `/bgm番剧` | 仅搜索 TV 动画 | `<关键词\|ID> [top_k]` | `/bgm番剧 命运石之门` | +| `/bgm剧场版` | 仅搜索剧场版动画 | `<关键词\|ID> [top_k]` | `/bgm剧场版 凉宫春日的消失` | +| `/bgm漫画` | 仅搜索漫画条目 | `<关键词\|ID> [top_k]` | `/bgm漫画 迷宫饭` | + +> `top_k`(可选):返回结果数量,默认为 `1`。 + +### 2. 放送与订阅 + +| 命令 | 功能 | 参数 | 示例 | +|:-----|:-----|:-----|:-----| +| `/today` | 获取今日番剧放送表 | 无 | `/today` | +| `/追番` | 订阅番剧,更新时自动通知 | `<关键词\|ID>` | `/追番 进击的巨人` | +| `/弃坑` | 取消订阅番剧 | `<关键词\|ID>` | `/弃坑 进击的巨人` | + +**功能亮点**: +- **精美卡片**:自动生成包含封面、评分、排名、简介及剧集进度的图文卡片。 +- **每日放送**:渲染精美的每日放送时刻表。 +- **自动追番**:订阅后自动监控集数更新并实时推送通知。 + +## 🛠️ 配置参数 + +在 AstrBot 的管理面板或配置文件中设置: + +| 参数名 | 类型 | 默认值 | 说明 | +|:-------|:----:|:------:|:-----| +| `access_token` | string | 无 | Bangumi API 访问令牌(部分接口需授权)[¹](#access-token-获取) | +| `user_agent` | string | 无 | 请求头 User-Agent 标识,为空时使用插件默认值 | +| `max_fuzzy_results` | int | `5` | 模糊搜索最大返回数量(范围:1–200) | +| `proxy_http` | string | 无 | HTTP 代理地址(仅 IP,例如 `192.168.0.1`) | +| `port` | string | 无 | HTTP 代理端口(例如 `7890`) | +| `max_retries` | int | `3` | 网络错误最大重试次数(范围:1–10) | +| `render_server_url` | string | `https://api.unitedpooh.top/rpc` | 远程渲染图片的 RPC 服务器地址 | + +### Access Token 获取 + +虽然不强制,但建议配置 Access Token 以避免 API 限流。 + +1. 注册/登录 [Bangumi](https://bgm.tv/) +2. 访问 [个人令牌页面](https://next.bgm.tv/demo/access-token/create) 创建新令牌 +3. 将生成的 Token 填入插件配置的 `access_token` 字段 + +## 📦 环境依赖 + +插件首次运行时会自动检查并安装以下依赖: +- **Playwright 浏览器内核**:用于渲染卡片图片。 + +如果遇到环境问题,可尝试手动安装: +```bash +pip install -r requirements.txt +playwright install chromium +``` + +## ✅ 强类型与本地检查 + +本项目已切换为 Python 3.12 风格类型写法,并在 CI 中启用阻断式质量门禁(`ruff + mypy + pytest`)。 + +### 本地执行命令 + +```bash +ruff check . +ruff format --check . +mypy src main.py +PYTHONPATH=. pytest tests/test_search_service.py tests/test_subscription_service.py +``` + +### 强类型编码规则 + +1. 禁止 `Optional[T]`,统一使用 `T | None`。 +2. 禁止 `typing.List/Dict/Tuple/Set`,统一使用 `list/dict/tuple/set`。 +3. 禁止新增 `Any`;优先使用 `TypedDict`、Pydantic 模型或明确类型别名。 +4. 公共方法必须显式标注参数和返回类型。 +5. 业务接口层禁止使用 `dict[str, Any]` 作为输入/输出类型。 +6. 需要可空时必须在类型中明确体现,禁止隐式可空。 diff --git a/_conf_schema.json b/_conf_schema.json new file mode 100644 index 0000000..2775053 --- /dev/null +++ b/_conf_schema.json @@ -0,0 +1,48 @@ +{ + "access_token": { + "description": "Bangumi API访问令牌(部分接口需授权)", + "type": "string", + "hint": "在https://next.bgm.tv/demo/access-token生成,格式为Bearer令牌", + "default": "" + }, + "user_agent": { + "description": "请求头User-Agent标识", + "type": "string", + "hint": "如果为空,则使用插件默认值", + "default": "" + }, + "max_fuzzy_results": { + "description": "模糊搜索返回的最大结果数量", + "type": "int", + "hint": "取值范围1-200,数值越大返回结果越多", + "default": 5, + "min": 1, + "max": 200 + }, + "proxy_http": { + "description": "代理地址", + "type": "string", + "hint": "IP, 例: 192.168.0.x", + "default": "" + }, + "port": { + "description": "端口", + "type": "string", + "hint": "代理端口, 例: 7890", + "default": "" + }, + "max_retries": { + "description": "最大重试次数", + "type": "int", + "hint": "网络错误时最大的重试次数", + "default": 3, + "min": 1, + "max": 10 + }, + "render_server_url": { + "description": "RPC 渲染服务器地址", + "type": "string", + "hint": "用于远程渲染图片的 RPC 服务器地址", + "default": "https://api.unitedpooh.top/rpc" + } +} \ No newline at end of file diff --git a/logo.png b/logo.png new file mode 100644 index 0000000..b070460 Binary files /dev/null and b/logo.png differ diff --git a/main.py b/main.py new file mode 100644 index 0000000..3caba08 --- /dev/null +++ b/main.py @@ -0,0 +1,339 @@ +import copy +import os +import re +from collections.abc import AsyncGenerator + +import aiohttp +import astrbot.api.message_components as Comp +from astrbot.api import logger +from astrbot.api.all import AstrBotConfig +from astrbot.api.event import AstrMessageEvent, MessageChain, filter + +# 导入配置与管理 +from astrbot.api.star import Context, Star, StarTools, register +from astrbot.core.utils.session_waiter import ( + SessionController, + SessionFilter, + session_waiter, +) + +from .src.config import ConfigManager +from .src.db import BangumiRepository + +# 导入逻辑服务 +from .src.services import BangumiService, SearchService, SubscriptionService +from .src.utils import SchedulerManager + + +@register( + "astrbot_plugin_bangumi_enhance", + "united_pooh", + "AstrBot Bangumi 增强版:为 AstrBot 打造的一站式 Bangumi 追番助手。支持番剧/漫画图文搜索、每日放送时刻表查看及集数更新自动提醒。", + "v1.1.1", + "https://github.com/united-pooh/astrbot_plugin_bangumi", +) +class BangumiPlugin(Star): + def __init__(self, context: Context, config: AstrBotConfig) -> None: + """ + 初始化 BangumiPlugin 插件。 + """ + super().__init__(context) + self.config = config + self.config_manager = ConfigManager(config) + self.scheduler_manager = SchedulerManager() + + self.session: aiohttp.ClientSession | None = None + self.storage: BangumiRepository | None = None + self.service: BangumiService | None = None + self.subscription_service: SubscriptionService | None = None + self.search_service: SearchService | None = None + + async def initialize(self) -> None: + """ + 插件加载时自动运行的初始化方法。 + """ + # 0. 提前获取插件数据目录(必须先于所有依赖 StarTools 的操作) + plugin_data_dir = StarTools.get_data_dir() + + # 1. 初始化数据库 + try: + db_path = os.path.join(plugin_data_dir, "data.db") + self.storage = BangumiRepository(db_path=db_path) + except (OSError, RuntimeError, ValueError, TypeError) as e: + logger.error(f"数据库初始化失败: {e}") + + # 2. 初始化网络会话 (Shared Session) + self.session = aiohttp.ClientSession() + + # 3. 初始化核心 API 服务 + try: + proxy_url = None + proxy_host = self.config_manager.get_proxy_http() + proxy_port = self.config_manager.get_port() + if proxy_host and proxy_port: + proxy_url = f"{proxy_host}:{proxy_port}" + + self.service = BangumiService( + access_token=self.config_manager.get_access_token(), + user_agent=self.config_manager.get_user_agent(), + proxy=proxy_url, + session=self.session, + ) + except (RuntimeError, ValueError, TypeError) as e: + logger.error(f"服务初始化失败: {e}") + + # 4. 初始化业务逻辑服务 (Dependency Injection) + if self.service: + # 搜索服务 + self.search_service = SearchService( + service=self.service, + config_manager=self.config_manager, + session=self.session, + ) + + # 订阅服务 + if self.storage: + self.subscription_service = SubscriptionService( + repository=self.storage, + service=self.service, + config_manager=self.config_manager, + session=self.session, + ) + + # 5. 添加定时更新任务 + if self.subscription_service: + try: + self.scheduler_manager.add_job( + func=self.subscription_service.check_updates, + trigger="cron", + minute=0, + ) + logger.info("Bangumi 插件定时更新任务已启动") + except (RuntimeError, ValueError, TypeError) as e: + logger.error(f"添加定时任务失败: {e}") + + logger.info("Bangumi 插件初始化流程结束") + + # --- 命令处理区 --- + + @staticmethod + def _resolve_session_key(event: AstrMessageEvent) -> str | None: + session_key: str | None = getattr(event, "session_id", None) + if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"): + session_key = event.message_obj.group_id + return session_key + + @staticmethod + def _parse_subscribe_selection(raw_text: str) -> int | None: + match = re.match(r"^/?追番\s+(\d+)\s*$", raw_text.strip()) + if not match: + return None + try: + return int(match.group(1)) + except ValueError: + return None + + @filter.command("bgm") + async def search( + self, event: AstrMessageEvent, query: str, top_k: int = 1 + ) -> AsyncGenerator[object, None]: + """全类别搜索 Bangumi 条目。""" + if not self.search_service: + yield event.plain_result("❌ 搜索服务未就绪") + return + async for result in self.search_service.handle_subject_search( + event, query, top_k, subject_type=None + ): + yield result + + @filter.command("bgm番剧") + async def search_anime( + self, event: AstrMessageEvent, query: str, top_k: int = 1 + ) -> AsyncGenerator[object, None]: + """仅搜索 TV 动画条目。""" + if not self.search_service: + yield event.plain_result("❌ 搜索服务未就绪") + return + async for result in self.search_service.handle_subject_search( + event, query, top_k, subject_type=[2], subject_tags=["TV"] + ): + yield result + + @filter.command("bgm剧场版") + async def search_movie( + self, event: AstrMessageEvent, query: str, top_k: int = 1 + ) -> AsyncGenerator[object, None]: + """仅搜索剧场版动画条目。""" + if not self.search_service: + yield event.plain_result("❌ 搜索服务未就绪") + return + async for result in self.search_service.handle_subject_search( + event, query, top_k, subject_type=[2], subject_tags=["剧场版"] + ): + yield result + + @filter.command("bgm漫画") + async def search_manga( + self, event: AstrMessageEvent, query: str, top_k: int = 1 + ) -> AsyncGenerator[object, None]: + """仅搜索漫画条目。""" + if not self.search_service: + yield event.plain_result("❌ 搜索服务未就绪") + return + async for result in self.search_service.handle_subject_search( + event, query, top_k, subject_type=[1], subject_tags=["漫画"] + ): + yield result + + @filter.command("today") + async def calendar(self, event: AstrMessageEvent) -> AsyncGenerator[object, None]: + """获取今日番剧放送表。""" + if not self.search_service: + yield event.plain_result("❌ 搜索服务未就绪") + return + async for result in self.search_service.handle_calendar(event): + yield result + + @filter.command("追番") + async def subscribe( + self, event: AstrMessageEvent, query: str + ) -> AsyncGenerator[object, None]: + """订阅番剧,更新时自动通知。""" + if not self.subscription_service: + yield event.plain_result("❌ 订阅服务未就绪") + return + + group_id = self._resolve_session_key(event) + if not group_id: + yield event.plain_result("❌ 无法获取群组ID") + return + + ( + error_msg, + candidates, + ) = await self.subscription_service.get_subscribe_candidates( + keyword=query, + limit=self.config_manager.get_max_fuzzy_results(), + ) + if error_msg: + yield event.plain_result(error_msg) + return + if not candidates: + yield event.plain_result("🔍 未找到相关番剧") + return + + if len(candidates) == 1: + result = await self.subscription_service.subscribe_by_subject_id( + group_id=group_id, + subject_id=candidates[0]["subject_id"], + ) + yield event.plain_result(result) + return + + candidate_lines = ["⚠️ 匹配到多个候选,请使用 `/追番 序号` 确认:"] + for index, candidate in enumerate(candidates, start=1): + candidate_lines.append( + f"{index}. {candidate['name']} (ID: {candidate['subject_id']})" + ) + candidate_lines.append("5分钟内有效;若发送其他命令将自动取消本次确认。") + yield event.plain_result("\n".join(candidate_lines)) + + cancel_commands = { + "bgm", + "bgm番剧", + "bgm剧场版", + "bgm漫画", + "today", + "弃坑", + } + session_key = group_id + + class GroupSessionFilter(SessionFilter): + def filter(self, wait_event: AstrMessageEvent) -> str: + wait_session_key = BangumiPlugin._resolve_session_key(wait_event) + return wait_session_key or wait_event.unified_msg_origin + + @session_waiter(timeout=300) + async def subscribe_confirm_waiter( + controller: SessionController, + wait_event: AstrMessageEvent, + ) -> None: + incoming_text = wait_event.get_message_str().strip() + + first_token = incoming_text.split(maxsplit=1)[0] if incoming_text else "" + normalized_token = ( + first_token[1:] if first_token.startswith("/") else first_token + ) + if normalized_token in cancel_commands: + new_event = copy.copy(wait_event) + self.context.get_event_queue().put_nowait(new_event) + wait_event.stop_event() + controller.stop() + return + + selected_index = self._parse_subscribe_selection(incoming_text) + if selected_index is None: + if normalized_token == "追番": + new_event = copy.copy(wait_event) + self.context.get_event_queue().put_nowait(new_event) + wait_event.stop_event() + controller.stop() + return + controller.keep(timeout=0) + return + if selected_index < 1 or selected_index > len(candidates): + await wait_event.send( + MessageChain( + [Comp.Plain(f"❌ 序号超出范围,请输入 1-{len(candidates)}。")] + ) + ) + controller.keep(timeout=0) + return + + selected = candidates[selected_index - 1] + result = await self.subscription_service.subscribe_by_subject_id( + group_id=session_key, + subject_id=selected["subject_id"], + ) + await wait_event.send(MessageChain([Comp.Plain(result)])) + wait_event.stop_event() + controller.stop() + + try: + await subscribe_confirm_waiter( + event, + session_filter=GroupSessionFilter(), + ) + except TimeoutError: + yield event.plain_result("⏰ 候选确认已过期,请重新使用 `/追番 关键词`。") + + @filter.command("弃坑") + async def unsubscribe( + self, event: AstrMessageEvent, query: str + ) -> AsyncGenerator[object, None]: + """取消订阅番剧。""" + if not self.subscription_service: + yield event.plain_result("❌ 订阅服务未就绪") + return + + group_id: str | None = getattr(event, "session_id", None) + if hasattr(event, "message_obj") and hasattr(event.message_obj, "group_id"): + group_id = event.message_obj.group_id + + if not group_id: + yield event.plain_result("❌ 无法获取群组ID") + return + + result = await self.subscription_service.unsubscribe(group_id, query) + yield event.plain_result(result) + + async def terminate(self) -> None: + logger.info("正在清理 Bangumi 插件资源...") + if self.scheduler_manager.scheduler.running: + self.scheduler_manager.scheduler.shutdown(wait=False) + + if self.session and not self.session.closed: + await self.session.close() + logger.info("已关闭共享网络会话") + + await super().terminate() diff --git a/metadata.yaml b/metadata.yaml new file mode 100644 index 0000000..a78e02a --- /dev/null +++ b/metadata.yaml @@ -0,0 +1,20 @@ +name: astrbot_plugin_bangumi_enhance +desc: AstrBot Bangumi 增强版:为 AstrBot 打造的一站式 Bangumi 追番助手。支持番剧/漫画图文搜索、每日放送时刻表查看及集数更新自动提醒。 +version: v1.1.1 +author: united_pooh +license: +repo: https://github.com/united-pooh/astrbot_plugin_bangumi +tags: + - bangumi + - 追番 + - 自动提醒 +keywords: + - bangumi + - anime + - manga + - 番剧 + - 漫画 + - 追番 + - 自动提醒 + - 时刻表 + - 集数更新 \ No newline at end of file diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..fc05326 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "astrbot_plugin_bangumi_enhance", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..62cdf34 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,25 @@ +[tool.ruff] +target-version = "py311" +line-length = 88 + +[tool.ruff.lint] +select = ["E", "F", "I", "B", "UP", "SIM", "RUF"] +ignore = ["E501", "RUF001", "RUF002", "RUF003"] + +[tool.pytest.ini_options] +pythonpath = ["."] +addopts = "-q" + +[tool.mypy] +python_version = "3.11" +strict = true +warn_unused_ignores = true +warn_return_any = true +no_implicit_optional = true +show_error_codes = true +pretty = true +files = ["src", "main.py"] + +[[tool.mypy.overrides]] +module = ["astrbot.*", "playwright.*", "apscheduler.*", "pytz", "jinja2", "yaml"] +ignore_missing_imports = true diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4761c9d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +jinja2 +pillow>=9.2.0 +aiohttp +apscheduler +pytz + +SQLAlchemy + +astrbot diff --git a/src/bangumi_types/__init__.py b/src/bangumi_types/__init__.py new file mode 100644 index 0000000..4ec5909 --- /dev/null +++ b/src/bangumi_types/__init__.py @@ -0,0 +1,3 @@ +from .json_types import JsonArray, JsonObject, JsonPrimitive, JsonValue + +__all__ = ["JsonArray", "JsonObject", "JsonPrimitive", "JsonValue"] diff --git a/src/bangumi_types/json_types.py b/src/bangumi_types/json_types.py new file mode 100644 index 0000000..e731627 --- /dev/null +++ b/src/bangumi_types/json_types.py @@ -0,0 +1,6 @@ +from typing import TypeAlias + +JsonPrimitive: TypeAlias = str | int | float | bool | None +JsonValue: TypeAlias = JsonPrimitive | list["JsonValue"] | dict[str, "JsonValue"] +JsonObject: TypeAlias = dict[str, JsonValue] +JsonArray: TypeAlias = list[JsonValue] diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000..949f33a --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,3 @@ +from .config_manager import ConfigManager + +__all__ = ["ConfigManager"] diff --git a/src/config/config_manager.py b/src/config/config_manager.py new file mode 100644 index 0000000..613c1f0 --- /dev/null +++ b/src/config/config_manager.py @@ -0,0 +1,51 @@ +from pathlib import Path + +import yaml +from astrbot.api import AstrBotConfig, logger + + +class ConfigManager: + def __init__(self, config: AstrBotConfig) -> None: + self.config = config + + def get_access_token(self) -> str: + """ + 获取bangumi的access_token + """ + return self.config.get("access_token", "") + + def get_user_agent(self) -> str: + user_agent = self.config.get("user_agent", "") + if user_agent == "": + with open( + f"{Path(__file__).resolve().parent.parent.parent}/metadata.yaml", + encoding="utf-8", + ) as f: + metadata = yaml.safe_load(f) + user_agent = f"AstrBot-Bangumi-Plugin/{metadata['version']} (https://github.com/united-pooh/astrbot_plugin_bangumi)" + return user_agent + + def get_max_fuzzy_results(self) -> int: + return self.config.get("max_fuzzy_results", 5) + + def get_proxy_http(self) -> str: + return self.config.get("proxy_http", "127.0.0.1") + + def get_port(self) -> str: + return self.config.get("port", "7890") + + def get_max_retries(self) -> int: + return self.config.get("max_retries", 3) + + def get_render_server_url(self) -> str: + return self.config.get("render_server_url", "https://api.unitedpooh.top/rpc") + + def save_config(self) -> None: + """ + 保存bgm插件配置到配置文件中, 并重新加载配置 + """ + try: + self.config.save_config() + logger.info("配置已保存") + except (AttributeError, OSError, RuntimeError, ValueError, TypeError) as e: + logger.error(f"保存bgm插件配置失败: {e}") diff --git a/src/db/__init__.py b/src/db/__init__.py new file mode 100644 index 0000000..2a1289e --- /dev/null +++ b/src/db/__init__.py @@ -0,0 +1,11 @@ +""" +数据库层公共接口 + +导出 ORM 模型和数据访问层,供业务层使用。 + +""" + +from .models import BangumiSubject, Base, Subscription +from .repository import BangumiRepository + +__all__ = ["BangumiRepository", "BangumiSubject", "Base", "Subscription"] diff --git a/src/db/models.py b/src/db/models.py new file mode 100644 index 0000000..73cc8e4 --- /dev/null +++ b/src/db/models.py @@ -0,0 +1,60 @@ +""" +数据库 ORM 模型定义 + +此模块包含所有 SQLAlchemy ORM 模型,用于定义数据库表结构和关系。 + +""" + +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class BangumiSubject(Base): + """ + 番剧条目模型 + """ + + __tablename__ = "bangumi_subjects" + + subject_id = Column(String, primary_key=True) + name = Column(String) + air_date = Column(String) # 开播日期/时间 + total_episodes = Column(Integer, default=0) + current_episode = Column(Integer, default=0) # 当前已更新/已通知集数 + updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) + + # 建立与 Subscription 的一对多关系 + subscriptions = relationship( + "Subscription", back_populates="subject", cascade="all, delete-orphan" + ) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return f"{self.name} ({self.subject_id}) [{self.current_episode}/{self.total_episodes}]" + + +class Subscription(Base): + """ + 订阅关系模型 + """ + + __tablename__ = "subscriptions" + + group_id = Column(String, primary_key=True) + subject_id = Column( + String, ForeignKey("bangumi_subjects.subject_id"), primary_key=True + ) + created_at = Column(DateTime, default=func.now()) + + # 建立与 BangumiSubject 的多对一关系 + subject = relationship("BangumiSubject", back_populates="subscriptions") + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return f"- 群 {self.group_id} 订阅了 {self.subject.name} ({self.subject.subject_id})" diff --git a/src/db/repository.py b/src/db/repository.py new file mode 100644 index 0000000..9ad5535 --- /dev/null +++ b/src/db/repository.py @@ -0,0 +1,392 @@ +""" +数据访问层(Repository 模式) + +此模块封装所有数据库操作,为业务层提供数据访问接口。 + +""" + +import os +from difflib import SequenceMatcher + +from astrbot.api import logger +from sqlalchemy import create_engine, or_ +from sqlalchemy.orm import joinedload, scoped_session, sessionmaker + +from ..services import DatabaseError +from .models import BangumiSubject, Base, Subscription + + +class BangumiRepository: + """ + 番剧数据访问层 + """ + + def __init__(self, db_path: str) -> None: + """ + 初始化数据访问层 + + Args: + db_path: 数据库文件路径 + """ + os.makedirs(os.path.dirname(db_path), exist_ok=True) + self.db_path = db_path + self._init_db() + + def _init_db(self) -> None: + """ + 初始化数据库连接和表结构 + """ + try: + # 使用 sqlite + engine = create_engine(f"sqlite:///{self.db_path}") + # 创建表 + Base.metadata.create_all(engine) + # 创建 session factory + self.Session = scoped_session(sessionmaker(bind=engine)) + except Exception as e: + raise DatabaseError(f"初始化数据库失败: {e}") from e + + def update_subject(self, subject_id: str, **kwargs) -> bool: + """ + 更新或保存番剧信息 + + Args: + subject_id: 番剧 ID + **kwargs: 支持传入 name, air_date, total_episodes, current_episode 等 + + Returns: + 操作是否成功 + + """ + session = self.Session() + try: + subject = ( + session.query(BangumiSubject) + .filter_by(subject_id=str(subject_id)) + .first() + ) + if not subject: + name = kwargs.pop("name", "未知番剧") + subject = BangumiSubject( + subject_id=str(subject_id), name=name, **kwargs + ) + session.add(subject) + else: + for key, value in kwargs.items(): + if hasattr(subject, key) and value is not None: + setattr(subject, key, value) + session.commit() + return True + except Exception as e: + logger.error(f"更新番剧信息失败: {e}") + session.rollback() + raise DatabaseError(f"更新番剧信息失败: {e}") from e + finally: + session.close() + + def add_subscription(self, group_id: str, subject_id: str) -> bool: + """ + 添加订阅关系 + + Args: + group_id: 群组 ID + subject_id: 番剧 ID + + Returns: + 操作是否成功 + + """ + session = self.Session() + try: + # 确保 Subject 存在 + subject = ( + session.query(BangumiSubject) + .filter_by(subject_id=str(subject_id)) + .first() + ) + if not subject: + subject = BangumiSubject(subject_id=str(subject_id), name="未知番剧") + session.add(subject) + + existing = ( + session.query(Subscription) + .filter_by(group_id=str(group_id), subject_id=str(subject_id)) + .first() + ) + + if not existing: + new_sub = Subscription( + group_id=str(group_id), subject_id=str(subject_id) + ) + session.add(new_sub) + + session.commit() # 单次 commit,保证原子性 + return True + except Exception as e: + logger.error(f"添加订阅失败: {e}") + session.rollback() + raise DatabaseError(f"添加订阅失败: {e}") from e + finally: + session.close() + + def remove_subscription(self, group_id: str, subject_id: str) -> bool: + """ + 移除订阅关系 + + Args: + group_id: 群组 ID + subject_id: 番剧 ID + + Returns: + 操作是否成功 + + """ + session = self.Session() + try: + sub = ( + session.query(Subscription) + .filter_by(group_id=str(group_id), subject_id=str(subject_id)) + .first() + ) + if sub: + session.delete(sub) + session.commit() + return True + return False # 订阅不存在 + except Exception as e: + logger.error(f"移除订阅失败: {e}") + session.rollback() + raise DatabaseError(f"移除订阅失败: {e}") from e + finally: + session.close() + + def get_subscriptions(self, group_id: str) -> list[str]: + """ + 获取指定群组的所有订阅 + + Args: + group_id: 群组 ID + + Returns: + 订阅的番剧 ID 列表 + + """ + session = self.Session() + try: + subs = session.query(Subscription).filter_by(group_id=str(group_id)).all() + return [sub.subject_id for sub in subs] + except Exception as e: + logger.error(f"获取订阅失败: {e}") + raise DatabaseError(f"获取订阅失败: {e}") from e + finally: + session.close() + + def get_monitored_subjects(self) -> list[BangumiSubject]: + """ + 获取所有已订阅的番剧列表,用于轮询更新 + + Returns: + 番剧对象列表 + + """ + session = self.Session() + try: + # Eager load subscriptions 避免 DetachedInstanceError + subjects = ( + session.query(BangumiSubject) + .options(joinedload(BangumiSubject.subscriptions)) + .all() + ) + return subjects + except Exception as e: + logger.error(f"获取监控番剧失败: {e}") + raise DatabaseError(f"获取监控番剧失败: {e}") from e + finally: + session.close() + + def update_subject_episode(self, subject_id: str, new_episode: int) -> bool: + """ + 更新番剧最新集数(快捷方法) + + Args: + subject_id: 番剧 ID + new_episode: 新的集数 + + Returns: + 操作是否成功 + + """ + return self.update_subject(subject_id, current_episode=new_episode) + + def subscribe_subject( + self, + group_id: str, + subject_id: str, + name: str, + air_date: str = "", + total_episodes: int = 0, + ) -> bool: + """ + 原子性地 upsert 番剧信息并建立订阅关系。 + + 将 update_subject + add_subscription 合并到单一事务中, + 避免两次独立调用之间发生异常导致脏数据。 + + Args: + group_id: 群组 ID + subject_id: 番剧 ID + name: 番剧名称 + air_date: 开播日期 + total_episodes: 总集数 + + Returns: + 操作是否成功 + """ + session = self.Session() + try: + # 1. upsert BangumiSubject + subject = ( + session.query(BangumiSubject) + .filter_by(subject_id=str(subject_id)) + .first() + ) + if not subject: + subject = BangumiSubject( + subject_id=str(subject_id), + name=name, + air_date=air_date, + total_episodes=total_episodes, + ) + session.add(subject) + else: + subject.name = name + if air_date: + subject.air_date = air_date + if total_episodes: + subject.total_episodes = total_episodes + + # 2. 添加订阅关系(若不存在) + existing = ( + session.query(Subscription) + .filter_by(group_id=str(group_id), subject_id=str(subject_id)) + .first() + ) + if not existing: + session.add( + Subscription(group_id=str(group_id), subject_id=str(subject_id)) + ) + + # 3. 单次 commit,保证 subject 与 subscription 同时成功或同时回滚 + session.commit() + return True + except Exception as e: + logger.error(f"原子订阅失败: {e}") + session.rollback() + raise DatabaseError(f"原子订阅失败: {e}") from e + finally: + session.close() + + def get_subject_subscribers(self, subject_id: str) -> list[str]: + """ + 获取订阅了某番剧的所有群组 ID + + Args: + subject_id: 番剧 ID + + Returns: + 群组 ID 列表 + + """ + session = self.Session() + try: + subs = ( + session.query(Subscription).filter_by(subject_id=str(subject_id)).all() + ) + return [sub.group_id for sub in subs] + except Exception as e: + logger.error(f"获取订阅群组失败: {e}") + raise DatabaseError(f"获取订阅群组失败: {e}") from e + finally: + session.close() + + def get_all_subscribed_groups(self) -> list[str]: + """ + 获取所有拥有订阅的群组 ID + + Returns: + 群组 ID 列表 + + """ + session = self.Session() + try: + groups = session.query(Subscription.group_id).distinct().all() + return [g[0] for g in groups] + except Exception as e: + logger.error(f"获取所有订阅群组失败: {e}") + raise DatabaseError(f"获取所有订阅群组失败: {e}") from e + finally: + session.close() + + def find_group_subscription_candidates( + self, group_id: str, keyword: str, limit: int = 5 + ) -> list[BangumiSubject]: + """ + 在指定群组的订阅中查找与关键词匹配的番剧候选。 + + 匹配优先级: + 1. subject_id 精确匹配 + 2. subject_id 前缀匹配 + 3. name 包含匹配(忽略大小写) + 4. name 相似度(SequenceMatcher) + """ + session = self.Session() + try: + normalized_keyword = str(keyword).strip() + if not normalized_keyword: + return [] + + keyword_lower = normalized_keyword.lower() + search_pattern = f"%{normalized_keyword}%" + + candidates = ( + session.query(BangumiSubject) + .join( + Subscription, Subscription.subject_id == BangumiSubject.subject_id + ) + .filter(Subscription.group_id == str(group_id)) + .filter( + or_( + BangumiSubject.subject_id == normalized_keyword, + BangumiSubject.subject_id.like(f"{normalized_keyword}%"), + BangumiSubject.name.ilike(search_pattern), + ) + ) + .all() + ) + + def score(subject: BangumiSubject) -> tuple[int, int, int, float, str]: + subject_id = str(subject.subject_id or "") + name = str(subject.name or "") + name_lower = name.lower() + exact_id = int(subject_id == normalized_keyword) + prefix_id = int(subject_id.startswith(normalized_keyword)) + name_contains = int(keyword_lower in name_lower) + similarity = SequenceMatcher(None, keyword_lower, name_lower).ratio() + return (exact_id, prefix_id, name_contains, similarity, subject_id) + + sorted_candidates = sorted( + candidates, + key=lambda subject: ( + -score(subject)[0], + -score(subject)[1], + -score(subject)[2], + -score(subject)[3], + score(subject)[4], + ), + ) + return sorted_candidates[:limit] + except Exception as e: + logger.error(f"查询群组订阅候选失败: {e}") + raise DatabaseError(f"查询群组订阅候选失败: {e}") from e + finally: + session.close() diff --git a/src/render/__init__.py b/src/render/__init__.py new file mode 100644 index 0000000..0f40bf4 --- /dev/null +++ b/src/render/__init__.py @@ -0,0 +1,5 @@ +from .calendar_renderer import CalendarRenderer +from .episode_renderer import EpisodeRenderer +from .subject_renderer import SubjectRenderer + +__all__ = ["CalendarRenderer", "EpisodeRenderer", "SubjectRenderer"] diff --git a/src/render/base_renderer.py b/src/render/base_renderer.py new file mode 100644 index 0000000..cd2048f --- /dev/null +++ b/src/render/base_renderer.py @@ -0,0 +1,149 @@ +from collections.abc import Awaitable, Callable +from pathlib import Path + +import aiohttp +import jinja2 +from astrbot.api import logger + +from ..services import RenderData + + +class BaseRenderer: + def __init__(self, session: aiohttp.ClientSession | None = None) -> None: + self.template_dir = Path(__file__).resolve().parent.parent / "templates" + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(str(self.template_dir)), autoescape=True + ) + self._session = session + + def _generate_html( + self, template_path: str, render_data: RenderData, sub_dir: str = "" + ) -> str: + """Render a Jinja2 template and inject a tag for static assets.""" + template = self.template_env.get_template(template_path) + html = template.render(**render_data) + base_path = self.template_dir / sub_dir if sub_dir else self.template_dir + base_url = base_path.as_uri() + "/" + if "" in html: + return html.replace("", f'', 1) + return f'{html}' + + async def _handle_rpc_response( + self, response: aiohttp.ClientResponse + ) -> str | None: + if response.status != 200: + logger.error(f"[-] RPC 渲染服务器返回错误状态码: {response.status}") + return None + try: + result = await response.json() + except aiohttp.ContentTypeError: + logger.error("[-] RPC 响应内容不是有效的 JSON") + return None + except (ValueError, TypeError, RuntimeError) as e: + logger.error(f"[-] 解析 RPC JSON 响应失败: {e}") + return None + + if not isinstance(result, dict): + logger.error(f"[-] RPC 响应格式错误: {type(result)}") + return None + if "error" in result: + logger.error(f"[-] RPC 渲染返回业务错误: {result['error']}") + return None + + res_obj = result.get("result") + if isinstance(res_obj, dict) and "image" in res_obj: + image = res_obj["image"] + return image if isinstance(image, str) else None + + logger.error(f"[-] RPC 响应中未找到 result.image: {result}") + return None + + async def _render_via_rpc( + self, + rpc_url: str, + html_content: str, + selector: str, + timeout: int = 30000, + wait_time: float = 0, + ) -> str | None: + """Send HTML to the remote RPC renderer and return a base64 image.""" + if not rpc_url: + return None + + import asyncio + + payload = { + "jsonrpc": "2.0", + "method": "screenshot", + "params": { + "html": html_content, + "selector": selector, + "wait_time": wait_time, + "timeout": timeout, + "scale": 3, + }, + "id": int(asyncio.get_event_loop().time() * 1000), + } + client_timeout = aiohttp.ClientTimeout(total=timeout / 1000.0) + + try: + if self._session and not self._session.closed: + async with self._session.post( + rpc_url, json=payload, timeout=client_timeout + ) as response: + return await self._handle_rpc_response(response) + else: + async with ( + aiohttp.ClientSession() as session, + session.post(rpc_url, json=payload, timeout=client_timeout) as response, + ): + return await self._handle_rpc_response(response) + + except aiohttp.ClientConnectorError as e: + logger.error(f"[-] RPC 渲染服务器连接失败: {e}") + except TimeoutError: + logger.error(f"[-] RPC 渲染请求超时 ({timeout}ms)") + except aiohttp.ClientResponseError as e: + logger.error(f"[-] RPC 渲染服务器响应异常: {e.status} {e.message}") + except (RuntimeError, ValueError, TypeError) as e: + logger.error(f"[-] RPC 渲染请求发生未知异常: {e}") + return None + + async def render( + self, + template_path: str, + render_data: RenderData, + selector: str, + local_render_func: Callable[[], Awaitable[str | None]], + rpc_url: str | None = None, + sub_dir: str = "", + timeout: int = 30000, + wait_time: float = 0, + ) -> str | None: + """ + Unified render entry point. + + Priority: + 1. Remote RPC server (if *rpc_url* is configured) + 2. Local Pillow rendering via *local_render_func* + """ + if rpc_url: + logger.debug(f"[+] 尝试通过 RPC 渲染: {template_path}") + html_content = self._generate_html(template_path, render_data, sub_dir) + result = await self._render_via_rpc( + rpc_url=rpc_url, + html_content=html_content, + selector=selector, + timeout=timeout, + wait_time=wait_time, + ) + if result: + return result + logger.warning(f"[-] RPC 渲染失败 ({template_path}),回退到本地 Pillow 渲染...") + + logger.debug(f"[+] 本地 Pillow 渲染: {template_path}") + try: + return await local_render_func() + except (RuntimeError, ValueError, TypeError) as e: + logger.error(f"[-] 本地 Pillow 渲染失败 ({template_path}): {e}") + return None diff --git a/src/render/calendar_renderer.py b/src/render/calendar_renderer.py new file mode 100644 index 0000000..72305a6 --- /dev/null +++ b/src/render/calendar_renderer.py @@ -0,0 +1,47 @@ +import datetime +from typing import cast + +from astrbot.api import logger + +from ..services import CalendarDay, CalendarWeekday, RenderData +from .base_renderer import BaseRenderer +from .pillow.calendar_card import draw_calendar_card + + +def reorder_days(calendar_data: list[CalendarDay]) -> list[CalendarDay]: + today_id = datetime.datetime.now().isoweekday() + today_index = 0 + for i, day in enumerate(calendar_data): + weekday: CalendarWeekday = day.get("weekday", {}) + if weekday.get("id") == today_id: + today_index = i + day["is_today"] = True + break + return calendar_data[today_index:] + calendar_data[:today_index] + + +class CalendarRenderer(BaseRenderer): + async def render_calendar( + self, + calendar_data: list[CalendarDay], + rpc_url: str | None = None, + headless: bool = True, + max_retries: int = 3, + ) -> str | None: + try: + reordered_days = reorder_days(calendar_data) + except (ValueError, TypeError, RuntimeError) as e: + logger.error(f"[-] 处理日历数据失败: {e}") + return None + + render_data = cast(RenderData, {"days": reordered_days}) + return await self.render( + template_path="calendar/calendar.html", + render_data=render_data, + selector=".container", + local_render_func=lambda: draw_calendar_card(render_data, self._session), + rpc_url=rpc_url, + sub_dir="calendar", + timeout=30000, + wait_time=0, + ) diff --git a/src/render/episode_renderer.py b/src/render/episode_renderer.py new file mode 100644 index 0000000..16e35fe --- /dev/null +++ b/src/render/episode_renderer.py @@ -0,0 +1,22 @@ +from ..services import Episode, RenderData +from .base_renderer import BaseRenderer +from .pillow.episode_card import draw_episode_card + + +class EpisodeRenderer(BaseRenderer): + async def render_episode( + self, + episode_data: Episode, + rpc_url: str | None = None, + headless: bool = True, + max_retries: int = 3, + ) -> str | None: + render_data: RenderData = episode_data.model_dump() + return await self.render( + template_path="update/episode.html", + render_data=render_data, + selector="#card-container", + local_render_func=lambda: draw_episode_card(render_data, self._session), + rpc_url=rpc_url, + timeout=30000, + ) diff --git a/src/render/pillow/__init__.py b/src/render/pillow/__init__.py new file mode 100644 index 0000000..34e80cb --- /dev/null +++ b/src/render/pillow/__init__.py @@ -0,0 +1,5 @@ +from .calendar_card import draw_calendar_card +from .episode_card import draw_episode_card +from .subject_card import draw_subject_card + +__all__ = ["draw_subject_card", "draw_calendar_card", "draw_episode_card"] diff --git a/src/render/pillow/calendar_card.py b/src/render/pillow/calendar_card.py new file mode 100644 index 0000000..17e7250 --- /dev/null +++ b/src/render/pillow/calendar_card.py @@ -0,0 +1,268 @@ +""" +Pillow-based calendar card renderer. + +Renders a 1400×auto image with a 7-column daily broadcast grid. +Today's column is highlighted in orange. +""" +from __future__ import annotations + +import asyncio + +import aiohttp +from PIL import Image, ImageDraw + +from .font_manager import get_font +from .image_utils import ( + fetch_image, + fit_cover, + image_to_base64, + placeholder_cover, + rounded_clip, + strip_supplementary, + text_line_height, + text_width, + wrap_text, +) + +# ── Palette ──────────────────────────────────────────────────────────────────── +BG_PAGE = (240, 242, 245) +WHITE = (255, 255, 255) +PRIMARY = (251, 140, 0) +TEXT_MAIN = (26, 26, 26) +TEXT_SUB = (133, 144, 166) +TEXT_LIGHT = (153, 153, 153) +BORDER = (234, 234, 234) +COL_BG = (255, 255, 255, 200) # slightly translucent column bg (drawn as solid) + +# ── Layout constants ─────────────────────────────────────────────────────────── +CANVAS_W = 1400 +CANVAS_PAD = 30 +GRID_GAP = 16 +COLS = 7 +COL_W = (CANVAS_W - 2 * CANVAS_PAD - (COLS - 1) * GRID_GAP) // COLS # ≈ 177 + +DAY_HEADER_H = 66 # day column header +ITEM_PAD = 10 # horizontal padding inside an anime item +COVER_ASPECT = 1.5 # height = width * COVER_ASPECT (2∶3) +COVER_W = COL_W - 2 * ITEM_PAD +COVER_H = int(COVER_W * COVER_ASPECT) +INFO_H = 78 # fixed height for title + score/rank under cover +ITEM_H = ITEM_PAD + COVER_H + ITEM_PAD + INFO_H + ITEM_PAD +SEPARATOR = 1 + +HEADER_AREA_H = 64 # "每日放送表" title bar +HEADER_GAP = 20 # gap between header and grid + +MAX_CONCURRENT_FETCH = 10 + + +# ── Helpers ──────────────────────────────────────────────────────────────────── + +async def _fetch_all( + items_by_col: list[list[dict]], + session: aiohttp.ClientSession | None, +) -> list[list[Image.Image | None]]: + """Fetch all cover images concurrently (capped at MAX_CONCURRENT_FETCH).""" + sem = asyncio.Semaphore(MAX_CONCURRENT_FETCH) + + async def _fetch_one(url: str) -> Image.Image | None: + async with sem: + return await fetch_image(url, session, timeout=10) + + tasks: list[asyncio.Task[Image.Image | None]] = [] + for items in items_by_col: + for item in items: + images = item.get("images") or {} + url = str( + images.get("common") or images.get("large") or images.get("medium") or "" + ) + tasks.append(asyncio.create_task(_fetch_one(url))) + + results_flat = await asyncio.gather(*tasks, return_exceptions=True) + + # Re-group by column + images_by_col: list[list[Image.Image | None]] = [] + idx = 0 + for items in items_by_col: + col_imgs: list[Image.Image | None] = [] + for _ in items: + r = results_flat[idx] + col_imgs.append(r if isinstance(r, Image.Image) else None) + idx += 1 + images_by_col.append(col_imgs) + return images_by_col + + +# ── Main renderer ────────────────────────────────────────────────────────────── + +async def draw_calendar_card( + data: dict, + session: aiohttp.ClientSession | None = None, +) -> str | None: + """Render the weekly broadcast calendar and return a base64-encoded PNG string.""" + try: + return await _render(data, session) + except Exception as exc: + try: + from astrbot.api import logger + logger.error(f"[-] Pillow calendar card render failed: {exc}") + except Exception: + pass + return None + + +async def _render(data: dict, session: aiohttp.ClientSession | None) -> str | None: # noqa: C901 + days: list[dict] = data.get("days") or [] + if not days: + return None + + # Normalise to exactly 7 columns (pad with empty if needed) + while len(days) < COLS: + days.append({"weekday": {"cn": "", "en": ""}, "items": [], "is_today": False}) + + items_by_col: list[list[dict]] = [ + [item for item in (day.get("items") or []) if isinstance(item, dict)] + for day in days + ] + + # ── Fetch all covers concurrently ────────────────────────────────────────── + images_by_col = await _fetch_all(items_by_col, session) + + # ── Compute column heights ───────────────────────────────────────────────── + col_heights = [ + DAY_HEADER_H + len(items) * (ITEM_H + SEPARATOR) + GRID_GAP + for items in items_by_col + ] + max_col_h = max(col_heights) if col_heights else DAY_HEADER_H + 200 + + # ── Canvas ───────────────────────────────────────────────────────────────── + canvas_h = CANVAS_PAD + HEADER_AREA_H + HEADER_GAP + max_col_h + CANVAS_PAD + canvas = Image.new("RGBA", (CANVAS_W, canvas_h), (*BG_PAGE, 255)) + d = ImageDraw.Draw(canvas) + + # ── Fonts ────────────────────────────────────────────────────────────────── + F = { + "h_title": get_font(28, bold=True), + "h_sub": get_font(14), + "day_cn": get_font(16, bold=True), + "day_en": get_font(11), + "item_title": get_font(13, bold=True), + "score": get_font(13, bold=True), + "rank": get_font(11), + } + + # ── Page header: "每日放送表" ───────────────────────────────────────────── + hx = CANVAS_PAD + 10 + hy = CANVAS_PAD + # Orange accent bar + d.rounded_rectangle([hx, hy + 4, hx + 7, hy + 36], radius=3, fill=PRIMARY) + d.text((hx + 18, hy), "每日放送表", font=F["h_title"], fill=TEXT_MAIN) + sub_x = hx + 18 + text_width("每日放送表", F["h_title"]) + 14 + d.text((sub_x, hy + 8), "Bangumi Calendar", font=F["h_sub"], fill=TEXT_SUB) + + # ── Day columns ──────────────────────────────────────────────────────────── + grid_top = CANVAS_PAD + HEADER_AREA_H + HEADER_GAP + + for col_idx, day in enumerate(days): + cx = CANVAS_PAD + col_idx * (COL_W + GRID_GAP) + cy = grid_top + is_today: bool = bool(day.get("is_today")) + items = items_by_col[col_idx] + col_imgs = images_by_col[col_idx] + col_h = max_col_h # all columns same height for visual alignment + + weekday = day.get("weekday") or {} + day_cn = strip_supplementary(str(weekday.get("cn") or "")) + day_en = str(weekday.get("en") or "").upper() + + # Column background + col_fill: tuple[int, int, int] = WHITE + col_outline = (180, 180, 180) if not is_today else PRIMARY + border_w = 1 if not is_today else 2 + d.rounded_rectangle( + [cx, cy, cx + COL_W, cy + col_h], + radius=14, + fill=col_fill, + outline=col_outline, + width=border_w, + ) + + # Day header + header_fill = PRIMARY if is_today else WHITE + header_text = WHITE if is_today else TEXT_MAIN + d.rounded_rectangle( + [cx, cy, cx + COL_W, cy + DAY_HEADER_H], + radius=14, + fill=header_fill, + ) + # Flatten bottom corners of header by overdrawing a rect + d.rectangle( + [cx, cy + DAY_HEADER_H - 14, cx + COL_W, cy + DAY_HEADER_H], + fill=header_fill, + ) + cn_w = text_width(day_cn, F["day_cn"]) + d.text((cx + (COL_W - cn_w) // 2, cy + 14), day_cn, + font=F["day_cn"], fill=header_text) + en_w = text_width(day_en, F["day_en"]) + d.text((cx + (COL_W - en_w) // 2, cy + 14 + text_line_height(F["day_cn"]) + 4), + day_en, font=F["day_en"], fill=(*header_text[:3], 180)) + + # Anime items + item_y = cy + DAY_HEADER_H + SEPARATOR + + if not items: + d.text( + (cx + 16, item_y + 40), + "今日无更新", + font=F["day_en"], + fill=TEXT_LIGHT, + ) + else: + for item_idx, item in enumerate(items): + img = col_imgs[item_idx] if item_idx < len(col_imgs) else None + + # Separator line between items (skip for first) + if item_idx > 0: + d.line([cx + 1, item_y, cx + COL_W - 1, item_y], fill=BORDER) + + # Cover image + cover_top = item_y + ITEM_PAD + if img: + cover = fit_cover(img, COVER_W, COVER_H) + else: + cover = placeholder_cover(COVER_W, COVER_H, color=(230, 230, 230)) + clipped = rounded_clip(cover, 8) + canvas.alpha_composite(clipped, (cx + ITEM_PAD, cover_top)) + + # Info section + info_top = cover_top + COVER_H + ITEM_PAD + name_cn = strip_supplementary(str(item.get("name_cn") or item.get("name") or "")) + name_lines = wrap_text(name_cn, F["item_title"], COL_W - 2 * ITEM_PAD)[:2] + LH = text_line_height(F["item_title"]) + 3 + for line_idx, line in enumerate(name_lines): + d.text((cx + ITEM_PAD, info_top + line_idx * LH), line, + font=F["item_title"], fill=TEXT_MAIN) + + # Score + rank + meta_y = info_top + 2 * LH + 6 + rating = item.get("rating") if isinstance(item.get("rating"), dict) else {} + score = rating.get("score") if rating else None + rank = item.get("rank") + meta_x = cx + ITEM_PAD + if score: + score_txt = f"★ {score}" + d.text((meta_x, meta_y), score_txt, font=F["score"], fill=PRIMARY) + meta_x += text_width(score_txt, F["score"]) + 10 + if rank: + rank_txt = f"#{rank}" + rank_tw = text_width(rank_txt, F["rank"]) + 12 + d.rounded_rectangle( + [meta_x, meta_y, meta_x + rank_tw, meta_y + text_line_height(F["rank"]) + 4], + radius=4, + fill=(245, 245, 245), + ) + d.text((meta_x + 6, meta_y + 2), rank_txt, font=F["rank"], fill=TEXT_LIGHT) + + item_y += ITEM_H + SEPARATOR + + return image_to_base64(canvas) diff --git a/src/render/pillow/episode_card.py b/src/render/pillow/episode_card.py new file mode 100644 index 0000000..cb1939c --- /dev/null +++ b/src/render/pillow/episode_card.py @@ -0,0 +1,154 @@ +""" +Pillow-based episode update card renderer. + +Renders a 768×(768*4/3) card with a full-bleed cover image, +gradient overlay, and episode metadata overlaid at the bottom. +""" +from __future__ import annotations + +import aiohttp +from PIL import Image, ImageDraw + +from .font_manager import get_font +from .image_utils import ( + fetch_image, + fit_cover, + image_to_base64, + make_gradient_overlay, + placeholder_cover, + strip_supplementary, + text_line_height, + text_width, + wrap_text, +) + +# ── Palette ──────────────────────────────────────────────────────────────────── +BLACK = (0, 0, 0) +WHITE = (255, 255, 255) +PINK = (236, 72, 153) # --accent-pink: #ec4899 +WHITE_80 = (255, 255, 255, 204) +WHITE_85 = (255, 255, 255, 217) + +# ── Layout constants ─────────────────────────────────────────────────────────── +CARD_W = 768 +CARD_H = int(CARD_W * 4 / 3) # 1024 + + +# ── Main renderer ────────────────────────────────────────────────────────────── + +async def draw_episode_card( + data: dict, + session: aiohttp.ClientSession | None = None, +) -> str | None: + """Render an episode update card and return a base64-encoded PNG string.""" + try: + return await _render(data, session) + except Exception as exc: + try: + from astrbot.api import logger + logger.error(f"[-] Pillow episode card render failed: {exc}") + except Exception: + pass + return None + + +async def _render(data: dict, session: aiohttp.ClientSession | None) -> str | None: + image_url = str(data.get("image_url") or "") + name = strip_supplementary(str(data.get("name") or "")) + name_cn = strip_supplementary(str(data.get("name_cn") or "")) + title = name_cn or name or "" + desc = strip_supplementary(str(data.get("desc") or "")) + sort_num = data.get("sort") or data.get("ep") or 1 + airdate = str(data.get("airdate") or "") + comment = int(data.get("comment") or 0) + + # ── Fonts ────────────────────────────────────────────────────────────────── + F = { + "ep_num": get_font(52, bold=True), + "title": get_font(44, bold=True), + "meta": get_font(16, bold=True), + "desc": get_font(16), + } + + # ── Fetch cover image ────────────────────────────────────────────────────── + cover_img = await fetch_image(image_url, session, timeout=10) + + # ── Base canvas (black) ──────────────────────────────────────────────────── + canvas = Image.new("RGBA", (CARD_W, CARD_H), (*BLACK, 255)) + + if cover_img: + cover = fit_cover(cover_img, CARD_W, CARD_H) + canvas.alpha_composite(cover, (0, 0)) + else: + # Fallback: dark radial-ish background + placeholder = placeholder_cover(CARD_W, CARD_H, color=(42, 42, 53)) + canvas.alpha_composite(placeholder) + + # ── Gradient overlay ─────────────────────────────────────────────────────── + overlay = make_gradient_overlay(CARD_W, CARD_H, start_pct=0.4) + canvas.alpha_composite(overlay) + + # ── Content overlay ──────────────────────────────────────────────────────── + d = ImageDraw.Draw(canvas) + + SIDE_PAD = 28 + content_x = SIDE_PAD + content_bottom = CARD_H - 28 + + # Build lines bottom-up + # 1. Description (max 3 lines) + desc_lines: list[str] = [] + if desc: + desc_lines = wrap_text(desc, F["desc"], CARD_W - 2 * SIDE_PAD)[:3] + + LH_desc = int(text_line_height(F["desc"]) * 1.7) + desc_block_h = len(desc_lines) * LH_desc + (24 if desc_lines else 0) + + # 2. Metadata row + meta_parts: list[str] = [] + if airdate and "-" in airdate: + meta_parts.append(airdate.split("-")[0]) + meta_parts.append("24min") + if comment > 0: + meta_parts.append(f"{comment} comments") + meta_txt = " | ".join(meta_parts) + meta_h = text_line_height(F["meta"]) + 20 + + # 3. EP + Title row + ep_txt = f"EP.{int(sort_num):02d}" + ep_h = text_line_height(F["ep_num"]) + LH_title = text_line_height(F["title"]) + 6 + title_lines = wrap_text(title, F["title"], CARD_W - 2 * SIDE_PAD - text_width(ep_txt, F["ep_num"]) - 16)[:2] + title_block_h = len(title_lines) * LH_title + + header_h = max(ep_h, title_block_h) + + # ── Draw from bottom up ──────────────────────────────────────────────────── + cur_y = content_bottom + + # Description + if desc_lines: + cur_y -= desc_block_h - LH_desc + for line in reversed(desc_lines): + d.text((content_x, cur_y), line, font=F["desc"], fill=WHITE_85) + cur_y -= LH_desc + cur_y -= 24 + + # Metadata row + cur_y -= meta_h + d.text((content_x, cur_y), meta_txt, font=F["meta"], fill=(*WHITE[:3], 200)) + cur_y -= 20 + + # EP + Title row + ep_w = text_width(ep_txt, F["ep_num"]) + ep_y = cur_y - header_h + d.text((content_x, ep_y), ep_txt, font=F["ep_num"], fill=(*PINK, 255)) + + title_x = content_x + ep_w + 16 + title_available_w = CARD_W - title_x - SIDE_PAD + # Re-wrap with accurate available width + title_lines = wrap_text(title, F["title"], title_available_w)[:2] + for i, line in enumerate(title_lines): + d.text((title_x, ep_y + i * LH_title), line, font=F["title"], fill=WHITE) + + return image_to_base64(canvas) diff --git a/src/render/pillow/font_manager.py b/src/render/pillow/font_manager.py new file mode 100644 index 0000000..e16067a --- /dev/null +++ b/src/render/pillow/font_manager.py @@ -0,0 +1,97 @@ +import sys +from pathlib import Path + +from PIL import ImageFont + +_cache: dict[tuple[bool, int], ImageFont.FreeTypeFont | ImageFont.ImageFont] = {} +_regular_path: str | None = None +_bold_path: str | None = None +_initialized: bool = False + +_CANDIDATES: dict[str, dict[bool, list[str]]] = { + "win32": { + False: [ + "C:/Windows/Fonts/msyh.ttc", + "C:/Windows/Fonts/simsun.ttc", + "C:/Windows/Fonts/simhei.ttf", + ], + True: [ + "C:/Windows/Fonts/msyhbd.ttc", + "C:/Windows/Fonts/simhei.ttf", + "C:/Windows/Fonts/msyh.ttc", + ], + }, + "darwin": { + False: [ + "/System/Library/Fonts/PingFang.ttc", + "/System/Library/Fonts/Supplemental/Arial Unicode MS.ttf", + ], + True: [ + "/System/Library/Fonts/PingFang.ttc", + ], + }, + "linux": { + False: [ + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/noto-cjk/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/noto/NotoSansSC-Regular.otf", + "/usr/share/fonts/google-noto-cjk/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", + ], + True: [ + "/usr/share/fonts/opentype/noto/NotoSansCJK-Bold.ttc", + "/usr/share/fonts/noto-cjk/NotoSansCJK-Bold.ttc", + "/usr/share/fonts/truetype/noto/NotoSansSC-Bold.otf", + "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", + ], + }, +} + + +def _init() -> None: + global _regular_path, _bold_path, _initialized + if _initialized: + return + + platform = sys.platform if sys.platform in _CANDIDATES else "linux" + + def _find(bold: bool) -> str | None: + for p in _CANDIDATES[platform][bold]: + if Path(p).exists(): + return p + # Cross-platform fallback: try all platforms' regular fonts + for plat_cands in _CANDIDATES.values(): + for p in plat_cands[False]: + if Path(p).exists(): + return p + return None + + _regular_path = _find(False) + _bold_path = _find(True) or _regular_path + _initialized = True + + +def get_font(size: int, bold: bool = False) -> ImageFont.FreeTypeFont | ImageFont.ImageFont: + """Return a CJK-compatible font at the requested size.""" + _init() + key = (bold, size) + if key in _cache: + return _cache[key] + + path = _bold_path if bold else _regular_path + if path: + try: + font = ImageFont.truetype(path, size) + _cache[key] = font + return font + except Exception: + pass + + # Ultimate fallback: PIL built-in bitmap font + try: + fb: ImageFont.FreeTypeFont | ImageFont.ImageFont = ImageFont.load_default(size=size) + except TypeError: + fb = ImageFont.load_default() + _cache[key] = fb + return fb diff --git a/src/render/pillow/image_utils.py b/src/render/pillow/image_utils.py new file mode 100644 index 0000000..3b650e7 --- /dev/null +++ b/src/render/pillow/image_utils.py @@ -0,0 +1,147 @@ +import base64 +import io + +import aiohttp +from PIL import Image, ImageDraw + + +async def fetch_image( + url: str, + session: aiohttp.ClientSession | None = None, + timeout: int = 10, +) -> Image.Image | None: + """Fetch an image from a URL or data-URI and return a PIL Image.""" + if not url: + return None + + # Handle data URIs (e.g. "data:image/png;base64,...") + if url.startswith("data:"): + try: + _, encoded = url.split(",", 1) + raw = base64.b64decode(encoded) + return Image.open(io.BytesIO(raw)).convert("RGBA") + except Exception: + return None + + try: + ct = aiohttp.ClientTimeout(total=timeout) + headers = {"Referer": "https://bgm.tv/"} + if session and not session.closed: + async with session.get(url, timeout=ct, headers=headers) as resp: + if resp.status == 200: + data = await resp.read() + return Image.open(io.BytesIO(data)).convert("RGBA") + else: + async with aiohttp.ClientSession() as s: + async with s.get(url, timeout=ct, headers=headers) as resp: + if resp.status == 200: + data = await resp.read() + return Image.open(io.BytesIO(data)).convert("RGBA") + except Exception: + pass + return None + + +def image_to_base64(img: Image.Image) -> str: + """Convert a PIL Image to a base64-encoded PNG string.""" + buf = io.BytesIO() + img.convert("RGB").save(buf, format="PNG", optimize=True) + return base64.b64encode(buf.getvalue()).decode("utf-8") + + +def fit_cover(img: Image.Image, width: int, height: int) -> Image.Image: + """Resize and center-crop to fill the target box (CSS object-fit: cover).""" + src_ratio = img.width / img.height + tgt_ratio = width / height + if src_ratio > tgt_ratio: + new_h, new_w = height, int(height * src_ratio) + else: + new_w, new_h = width, int(width / src_ratio) + img = img.resize((new_w, new_h), Image.LANCZOS) + left = (new_w - width) // 2 + top = (new_h - height) // 2 + return img.crop((left, top, left + width, top + height)) + + +def rounded_clip(img: Image.Image, radius: int) -> Image.Image: + """Apply rounded corners to an image via an alpha mask.""" + img = img.convert("RGBA") + mask = Image.new("L", img.size, 0) + ImageDraw.Draw(mask).rounded_rectangle( + [0, 0, img.width - 1, img.height - 1], radius=radius, fill=255 + ) + result = Image.new("RGBA", img.size, (0, 0, 0, 0)) + result.paste(img, mask=mask) + return result + + +def placeholder_cover(width: int, height: int, color: tuple[int, int, int] = (224, 224, 224)) -> Image.Image: + """Create a solid-color placeholder cover.""" + return Image.new("RGBA", (width, height), (*color, 255)) + + +def text_width(text: str, font: object) -> int: + """Return the rendered width of *text* in pixels.""" + if hasattr(font, "getlength"): + return int(font.getlength(text)) # type: ignore[union-attr] + try: + return font.getsize(text)[0] # type: ignore[union-attr] + except Exception: + size = getattr(font, "size", 12) + return len(text) * size + + +def text_line_height(font: object) -> int: + """Return the line height (ascent + descent) of a font.""" + try: + ascent, descent = font.getmetrics() # type: ignore[union-attr] + return ascent + descent + except Exception: + return getattr(font, "size", 12) + 4 + + +def wrap_text(text: str, font: object, max_width: int) -> list[str]: + """Word-wrap *text* to fit within *max_width* pixels (character-level wrap for CJK).""" + lines: list[str] = [] + for paragraph in text.split("\n"): + if not paragraph: + lines.append("") + continue + line = "" + for char in paragraph: + candidate = line + char + if text_width(candidate, font) <= max_width: + line = candidate + else: + if line: + lines.append(line) + line = char + if line: + lines.append(line) + return lines or [""] + + +def strip_supplementary(text: str) -> str: + """Remove supplementary-plane characters (emoji above U+FFFF) that most system fonts can't render.""" + return "".join(c for c in text if ord(c) <= 0xFFFF) + + +def make_gradient_overlay(width: int, height: int, start_pct: float = 0.4) -> Image.Image: + """ + Create a vertical black-to-opaque gradient overlay. + The first *start_pct* of the height is fully transparent; the remainder fades to ~95% opacity. + """ + strip = Image.new("L", (1, 256)) + for i in range(256): + pct = i / 255 + if pct < start_pct: + v = 0 + else: + progress = (pct - start_pct) / (1.0 - start_pct) + v = min(255, int(progress * 242)) + strip.putpixel((0, i), v) + alpha_strip = strip.resize((1, height), Image.BICUBIC) + alpha_full = alpha_strip.resize((width, height), Image.NEAREST) + overlay = Image.new("RGBA", (width, height), (0, 0, 0, 255)) + overlay.putalpha(alpha_full) + return overlay diff --git a/src/render/pillow/subject_card.py b/src/render/pillow/subject_card.py new file mode 100644 index 0000000..1217923 --- /dev/null +++ b/src/render/pillow/subject_card.py @@ -0,0 +1,441 @@ +""" +Pillow-based subject card renderer. + +Renders a 800×auto card with: + - Left column (210px): cover image · episode progress grid · rating histogram + - Right column (514px): title · score/rank · tags · summary · footer +""" +from __future__ import annotations + +import asyncio + +import aiohttp +from PIL import Image, ImageDraw, ImageFilter + +from .font_manager import get_font +from .image_utils import ( + fetch_image, + fit_cover, + image_to_base64, + placeholder_cover, + rounded_clip, + strip_supplementary, + text_line_height, + text_width, + wrap_text, +) + +# ── Palette ──────────────────────────────────────────────────────────────────── +WHITE = (255, 255, 255) +PRIMARY = (251, 140, 0) +SECONDARY_BG = (255, 243, 224) +TEXT_MAIN = (26, 26, 26) +TEXT_SUB = (133, 144, 166) +TEXT_LIGHT = (153, 153, 153) +BORDER = (234, 234, 234) +EP_GRAY_BG = (232, 232, 232) +EP_GRAY_TEXT = (102, 102, 102) +EP_ORANGE = (255, 152, 0) +ORANGE_DEEP = (230, 81, 0) +FOOTER_BG = (249, 249, 249) +HIST_LOW = (255, 224, 178) # low scores (1-6) + +# ── Layout constants ─────────────────────────────────────────────────────────── +CARD_W = 800 +CARD_PAD = 24 +LEFT_W = 210 +COL_GAP = 28 +RIGHT_W = CARD_W - CARD_PAD * 2 - LEFT_W - COL_GAP # = 514 +COVER_W = LEFT_W +COVER_H = 315 # 2∶3 default +SHADOW_PAD = 20 +SHADOW_BLUR = 12 + + +# ── Helpers ──────────────────────────────────────────────────────────────────── + +def _draw_rounded_rect( + draw: ImageDraw.ImageDraw, + box: tuple[int, int, int, int], + radius: int, + fill: tuple | None = None, + outline: tuple | None = None, + width: int = 1, +) -> None: + draw.rounded_rectangle(box, radius=radius, fill=fill, outline=outline, width=width) + + +def _make_shadow(card_w: int, card_h: int) -> Image.Image: + """Return an RGBA image containing only the blurred shadow.""" + sw = card_w + SHADOW_PAD * 2 + sh = card_h + SHADOW_PAD * 2 + shadow = Image.new("RGBA", (sw, sh), (0, 0, 0, 0)) + ImageDraw.Draw(shadow).rounded_rectangle( + [SHADOW_PAD, SHADOW_PAD + 8, SHADOW_PAD + card_w, SHADOW_PAD + card_h + 8], + radius=20, + fill=(0, 0, 0, 28), + ) + return shadow.filter(ImageFilter.GaussianBlur(SHADOW_BLUR)) + + +def _int_rating_count(rating_count: object, key: int) -> int: + if not isinstance(rating_count, dict): + return 0 + v = rating_count.get(str(key)) or rating_count.get(key) or 0 + return int(v) if isinstance(v, (int, float)) else 0 + + +# ── Main renderer ────────────────────────────────────────────────────────────── + +async def draw_subject_card( + data: dict, + session: aiohttp.ClientSession | None = None, +) -> str | None: + """Render a subject info card and return a base64-encoded PNG string.""" + try: + return await _render(data, session) + except Exception as exc: + try: + from astrbot.api import logger + logger.error(f"[-] Pillow subject card render failed: {exc}") + except Exception: + pass + return None + + +async def _render(data: dict, session: aiohttp.ClientSession | None) -> str | None: # noqa: C901 + # ── Extract fields ───────────────────────────────────────────────────────── + name_cn: str = strip_supplementary(str(data.get("name_cn") or "")) + name: str = strip_supplementary(str(data.get("name") or "")) + title = name_cn or name + subtitle = name if (name_cn and name and name != name_cn) else "" + + image_url = str(data.get("image_url") or "") + rating = data.get("rating") if isinstance(data.get("rating"), dict) else {} + score = rating.get("score") if rating else None + rating_total = rating.get("total") if rating else None + rating_rank = rating.get("rank") if rating else None + rating_count = rating.get("count") if rating else {} + rank = data.get("rank") or rating_rank + + tags_raw = data.get("tags") or [] + tags = [strip_supplementary(str(t.get("name", ""))) for t in tags_raw[:8] if isinstance(t, dict)] + tags = [t for t in tags if t] + + summary = strip_supplementary(str(data.get("summary") or "暂无简介")) + date_str = strip_supplementary(str(data.get("date") or "")) + platform = strip_supplementary(str(data.get("platform") or "")) + # strip leading emoji like "🎬 " from platform + platform = platform.lstrip() + + subject_id = data.get("id") + episode_list: list[dict] = [ep for ep in (data.get("episode_list") or []) if isinstance(ep, dict)] + air_weekday = strip_supplementary(str(data.get("air_weekday") or "")) + collection = data.get("collection") if isinstance(data.get("collection"), dict) else {} + collection_doing = collection.get("doing") if collection else None + + # ── Fonts ────────────────────────────────────────────────────────────────── + F = { + "title": get_font(26, bold=True), + "subtitle": get_font(13), + "score": get_font(34, bold=True), + "star": get_font(22), + "rank_tag": get_font(13, bold=True), + "count": get_font(13), + "tag": get_font(12), + "sum_label": get_font(13, bold=True), + "summary": get_font(14), + "footer": get_font(12), + "ep_label": get_font(10, bold=True), + "ep_cell": get_font(10, bold=True), + "chart_title": get_font(10, bold=True), + "weekday": get_font(20, bold=True), + "weekday_sub": get_font(9), + } + + # ── Fetch cover image ────────────────────────────────────────────────────── + cover_img = await fetch_image(image_url, session, timeout=10) + + # ── Compute left column height ───────────────────────────────────────────── + if cover_img: + ratio = cover_img.width / cover_img.height + cover_h = max(int(COVER_W / ratio), COVER_H) + cover_h = min(cover_h, 420) + else: + cover_h = COVER_H + + ep_block_h = 0 + CELLS_PER_ROW = 6 + CELL = 28 + CELL_GAP = 4 + if episode_list: + n_rows = (len(episode_list) + CELLS_PER_ROW - 1) // CELLS_PER_ROW + ep_block_h = 10 + text_line_height(F["ep_label"]) + 8 + n_rows * (CELL + CELL_GAP) - CELL_GAP + 10 + + hist_block_h = 0 + has_hist = isinstance(rating_count, dict) and any( + _int_rating_count(rating_count, i) for i in range(1, 11) + ) + if has_hist: + hist_block_h = 12 + text_line_height(F["chart_title"]) + 4 + 45 + 2 + text_line_height(F["chart_title"]) + 12 + + left_h = cover_h + if ep_block_h: + left_h += 12 + ep_block_h + if hist_block_h: + left_h += 12 + hist_block_h + + # ── Compute right column height ──────────────────────────────────────────── + LH_title = text_line_height(F["title"]) + 4 + title_lines = wrap_text(title, F["title"], RIGHT_W - 70)[:3] + title_block_h = len(title_lines) * LH_title + + subtitle_block_h = (text_line_height(F["subtitle"]) + 6) if subtitle else 0 + + score_row_h = 46 + + # Tags height (flex-wrap) + tag_row_h = text_line_height(F["tag"]) + 8 + tag_block_h = 0 + if tags: + row_x = 0 + rows = 1 + for tag in tags: + tw = text_width(tag, F["tag"]) + 24 + if row_x + tw + 8 > RIGHT_W: + rows += 1 + row_x = tw + 8 + else: + row_x += tw + 8 + tag_block_h = rows * tag_row_h + (rows - 1) * 8 + + # Summary + LH_sum = int(text_line_height(F["summary"]) * 1.75) + summary_lines = wrap_text(summary, F["summary"], RIGHT_W)[:7] + summary_block_h = ( + 20 # border-top gap + + text_line_height(F["sum_label"]) + 8 # label + + len(summary_lines) * LH_sum + ) + + footer_block_h = 16 + 32 + + right_h = ( + 16 + + title_block_h + + (subtitle_block_h if subtitle_block_h else 0) + + 16 + + score_row_h + + 16 + + (tag_block_h + 16 if tag_block_h else 0) + + summary_block_h + + footer_block_h + ) + + # ── Create card surface ──────────────────────────────────────────────────── + inner_h = max(left_h, right_h, 360) + card_h = inner_h + CARD_PAD * 2 + + card = Image.new("RGBA", (CARD_W, card_h), (255, 255, 255, 255)) + d = ImageDraw.Draw(card) + d.rounded_rectangle([0, 0, CARD_W - 1, card_h - 1], radius=20, fill=WHITE) + + # ── LEFT COLUMN ─────────────────────────────────────────────────────────── + lx = CARD_PAD + ly = CARD_PAD + + # Cover + if cover_img: + raw = fit_cover(cover_img, COVER_W, cover_h) + else: + raw = placeholder_cover(COVER_W, cover_h) + card.alpha_composite(rounded_clip(raw, 12), (lx, ly)) + + cur_ly = ly + cover_h + 12 + + # Episode grid block + if ep_block_h and episode_list: + _draw_rounded_rect(d, (lx, cur_ly, lx + LEFT_W, cur_ly + ep_block_h), radius=10, + fill=WHITE, outline=BORDER, width=1) + # Label row + label_y = cur_ly + 10 + d.text((lx + 10, label_y), "放送进度", font=F["ep_label"], fill=TEXT_LIGHT) + aired_count = sum(1 for ep in episode_list if ep.get("aired")) + prog = f"{aired_count} / {len(episode_list)}" + pw = text_width(prog, F["ep_label"]) + d.text((lx + LEFT_W - 10 - pw, label_y), prog, font=F["ep_label"], fill=PRIMARY) + + grid_y = label_y + text_line_height(F["ep_label"]) + 8 + for i, ep_item in enumerate(episode_list): + col = i % CELLS_PER_ROW + row = i // CELLS_PER_ROW + cx = lx + 10 + col * (CELL + CELL_GAP) + cy = grid_y + row * (CELL + CELL_GAP) + fill_c = EP_ORANGE if ep_item.get("aired") else EP_GRAY_BG + text_c = WHITE if ep_item.get("aired") else EP_GRAY_TEXT + d.rounded_rectangle([cx, cy, cx + CELL - 1, cy + CELL - 1], radius=5, fill=fill_c) + ep_num = str(ep_item.get("ep", "")) + ew = text_width(ep_num, F["ep_cell"]) + eh = text_line_height(F["ep_cell"]) + d.text((cx + (CELL - ew) // 2, cy + (CELL - eh) // 2), ep_num, + font=F["ep_cell"], fill=text_c) + cur_ly += ep_block_h + 12 + + # Rating histogram block + if hist_block_h and has_hist: + _draw_rounded_rect(d, (lx, cur_ly, lx + LEFT_W, cur_ly + hist_block_h), radius=10, + fill=WHITE, outline=BORDER, width=1) + ct_txt = "评分分布" + ct_w = text_width(ct_txt, F["chart_title"]) + d.text((lx + (LEFT_W - ct_w) // 2, cur_ly + 12), ct_txt, + font=F["chart_title"], fill=TEXT_LIGHT) + + bars_top = cur_ly + 12 + text_line_height(F["chart_title"]) + 4 + BAR_AREA_W = LEFT_W - 16 + BAR_GAP = 2 + bar_w = (BAR_AREA_W - 9 * BAR_GAP) // 10 + BARS_H = 45 + values = [_int_rating_count(rating_count, i) for i in range(1, 11)] + max_v = max(values) if values else 1 + if max_v == 0: + max_v = 1 + for i, v in enumerate(values): + bar_px = max(2, int(v / max_v * BARS_H)) + bx = lx + 8 + i * (bar_w + BAR_GAP) + by = bars_top + BARS_H - bar_px + bar_fill = PRIMARY if i >= 6 else HIST_LOW + d.rounded_rectangle([bx, by, bx + bar_w, bars_top + BARS_H], radius=2, fill=bar_fill) + + line_y = bars_top + BARS_H + 2 + d.line([lx + 8, line_y, lx + LEFT_W - 8, line_y], fill=BORDER) + lbl_y = line_y + 2 + d.text((lx + 8, lbl_y), "1", font=F["chart_title"], fill=(200, 200, 200)) + mid_x = lx + 8 + 4 * (bar_w + BAR_GAP) + d.text((mid_x, lbl_y), "5", font=F["chart_title"], fill=(200, 200, 200)) + right_x = lx + 8 + 9 * (bar_w + BAR_GAP) + d.text((right_x, lbl_y), "10", font=F["chart_title"], fill=(200, 200, 200)) + + # ── RIGHT COLUMN ────────────────────────────────────────────────────────── + rx = CARD_PAD + LEFT_W + COL_GAP + ry = CARD_PAD + 16 + + # Title + for line in title_lines: + d.text((rx, ry), line, font=F["title"], fill=TEXT_MAIN) + ry += LH_title + + # Subtitle + if subtitle: + d.text((rx, ry), subtitle[:60], font=F["subtitle"], fill=TEXT_SUB) + ry += subtitle_block_h + + ry += 16 # gap + + # Score row + if score is not None: + cur_rx = rx + # Star + score value + d.text((cur_rx, ry + 8), "★", font=F["star"], fill=PRIMARY) + cur_rx += text_width("★", F["star"]) + 4 + score_str = str(score) + d.text((cur_rx, ry), score_str, font=F["score"], fill=PRIMARY) + cur_rx += text_width(score_str, F["score"]) + 12 + + # Rank badge + if rank: + rank_txt = f"#{rank}" + rank_tw = text_width(rank_txt, F["rank_tag"]) + 20 + _draw_rounded_rect(d, (cur_rx, ry + 8, cur_rx + rank_tw, ry + 34), + radius=8, fill=SECONDARY_BG) + d.text((cur_rx + 10, ry + 10), rank_txt, font=F["rank_tag"], fill=ORANGE_DEEP) + cur_rx += rank_tw + 12 + + # Vote count + if rating_total: + ct = f"{rating_total} 人评分" + ct_w = text_width(ct, F["count"]) + 20 + _draw_rounded_rect(d, (cur_rx, ry + 10, cur_rx + ct_w, ry + 34), + radius=16, fill=(245, 245, 245)) + d.text((cur_rx + 10, ry + 12), ct, font=F["count"], fill=TEXT_SUB) + cur_rx += ct_w + 8 + + # Watching count + if collection_doing: + dt = f"{collection_doing} 人在看" + dt_w = text_width(dt, F["count"]) + 20 + _draw_rounded_rect(d, (cur_rx, ry + 10, cur_rx + dt_w, ry + 34), + radius=16, fill=(245, 245, 245)) + d.text((cur_rx + 10, ry + 12), dt, font=F["count"], fill=TEXT_SUB) + + ry += score_row_h + 16 + else: + d.text((rx, ry), "暂无评分", font=F["subtitle"], fill=TEXT_LIGHT) + ry += score_row_h + 16 + + # Tags + if tags: + tag_x, tag_y = rx, ry + for tag_txt in tags: + tw = text_width(tag_txt, F["tag"]) + 24 + if tag_x + tw > rx + RIGHT_W: + tag_x = rx + tag_y += tag_row_h + 8 + _draw_rounded_rect(d, (tag_x, tag_y, tag_x + tw, tag_y + tag_row_h), + radius=6, fill=WHITE, outline=(220, 220, 220), width=1) + d.text((tag_x + 12, tag_y + 4), tag_txt, font=F["tag"], fill=TEXT_SUB) + tag_x += tw + 8 + ry = tag_y + tag_row_h + 16 + + # Summary section (with dashed separator) + ry += 20 # space for dashed line + dash_y = ry - 12 + for dx in range(rx, rx + RIGHT_W, 12): + d.line([dx, dash_y, min(dx + 6, rx + RIGHT_W), dash_y], fill=(220, 220, 220)) + d.text((rx, ry), "简介", font=F["sum_label"], fill=(*TEXT_MAIN, 200)) + ry += text_line_height(F["sum_label"]) + 8 + + for line in summary_lines: + d.text((rx, ry), line, font=F["summary"], fill=(74, 74, 74)) + ry += LH_sum + + # Footer (pinned to bottom of card) + footer_y = card_h - CARD_PAD - 30 + d.line([rx, footer_y, rx + RIGHT_W, footer_y], fill=BORDER) + fi_x = rx + + if date_str: + d.text((fi_x, footer_y + 8), date_str, font=F["footer"], fill=TEXT_LIGHT) + fi_x += text_width(date_str, F["footer"]) + 20 + + if platform: + d.text((fi_x, footer_y + 8), platform, font=F["footer"], fill=TEXT_LIGHT) + + if subject_id is not None: + id_txt = f"ID: {subject_id}" + id_w = text_width(id_txt, F["footer"]) + d.text((rx + RIGHT_W - id_w, footer_y + 8), id_txt, font=F["footer"], fill=(200, 200, 200)) + + # ── Weekday badge ───────────────────────────────────────────────────────── + if air_weekday: + BADGE_W, BADGE_H = 56, 48 + badge = Image.new("RGBA", (BADGE_W, BADGE_H), (0, 0, 0, 0)) + bd = ImageDraw.Draw(badge) + bd.polygon( + [(0, 0), (BADGE_W, 0), (BADGE_W, BADGE_H), (0, BADGE_H)], + fill=(*EP_ORANGE, 255), + ) + wk_w = text_width(air_weekday, F["weekday"]) + bd.text(((BADGE_W - wk_w) // 2, 6), air_weekday, font=F["weekday"], fill=WHITE) + sub_txt = "曜日" + sub_w = text_width(sub_txt, F["weekday_sub"]) + bd.text(((BADGE_W - sub_w) // 2, 30), sub_txt, font=F["weekday_sub"], + fill=(255, 255, 255, 210)) + # Rounded left-bottom corner only: clip top-right corner of card + card.alpha_composite(badge, (CARD_W - BADGE_W, 0)) + + # ── Composite with shadow ───────────────────────────────────────────────── + shadow = _make_shadow(CARD_W, card_h) + result = shadow + result.alpha_composite(card, (SHADOW_PAD, SHADOW_PAD)) + + return image_to_base64(result) diff --git a/src/render/subject_renderer.py b/src/render/subject_renderer.py new file mode 100644 index 0000000..ca34ba7 --- /dev/null +++ b/src/render/subject_renderer.py @@ -0,0 +1,157 @@ +import asyncio +import datetime +from collections import Counter +from typing import cast + +import aiohttp +from astrbot.api import logger + +from ..services import EpisodeItem, RenderData, SubjectType +from .base_renderer import BaseRenderer +from .pillow.subject_card import draw_subject_card + + +def _process_images(data: RenderData) -> None: + if "image_url" in data: + return + images = data.get("images") + if not isinstance(images, dict): + return + images = cast(dict[str, object], images) + data["image_url"] = ( + images.get("large") or images.get("common") or images.get("medium") or "" + ) + + +def _process_dates(data: RenderData) -> None: + if "date" in data: + return + if "air_date" in data: + data["date"] = data["air_date"] + + +def _process_platform(data: RenderData) -> None: + if "platform" in data: + return + if "type" not in data: + return + try: + type_id = int(data["type"]) + data["platform"] = SubjectType(type_id).to_display() + except (ValueError, TypeError): + data["platform"] = "未知" + + +def _infer_air_weekday(aired_weekdays: list[int]) -> str: + if not aired_weekdays: + return "" + weekday_names = {1: "月", 2: "火", 3: "水", 4: "木", 5: "金", 6: "土", 7: "日"} + recent = aired_weekdays[-4:] + most_common = Counter(recent).most_common(1)[0][0] + return weekday_names.get(most_common, "") + + +def _parse_episode_list( + episodes: list[EpisodeItem], today: datetime.date +) -> tuple[list[dict[str, int | bool | None]], list[int]]: + episode_list: list[dict[str, int | bool | None]] = [] + aired_weekdays: list[int] = [] + for ep in episodes: + if ep.get("type", 0) != 0 or ep.get("ep", 0) == 0: + continue + aired = False + airdate_str = ep.get("airdate") + if airdate_str: + try: + airdate = datetime.datetime.strptime(airdate_str, "%Y-%m-%d").date() + aired = airdate <= today + if aired: + aired_weekdays.append(airdate.isoweekday()) + except ValueError: + pass + if ep.get("comment", 0) > 0: + aired = True + episode_list.append({"ep": ep.get("ep"), "aired": aired}) + return episode_list, aired_weekdays + + +def _process_episodes(data: RenderData) -> None: + episodes = data.get("episodes") + if not isinstance(episodes, list): + return + today = datetime.date.today() + normalized_episodes: list[EpisodeItem] = [] + for episode in episodes: + if isinstance(episode, dict): + normalized_episodes.append(cast(EpisodeItem, episode)) + episode_list, aired_weekdays = _parse_episode_list(normalized_episodes, today) + data["episode_list"] = episode_list + air_weekday = _infer_air_weekday(aired_weekdays) + if air_weekday: + data["air_weekday"] = air_weekday + + +def preprocess_data(data: RenderData) -> RenderData: + processed = data.copy() + _process_images(processed) + _process_dates(processed) + _process_platform(processed) + _process_episodes(processed) + return processed + + +class SubjectRenderer(BaseRenderer): + async def render_subject_card( + self, + data: RenderData, + rpc_url: str | None = None, + headless: bool = True, + wait_time: int = 0, + max_retries: int = 3, + timeout: int = 30000, + ) -> str | None: + render_data = preprocess_data(data) + return await self.render( + template_path="subject/subject.html", + render_data=render_data, + selector="#card", + local_render_func=lambda: draw_subject_card(render_data, self._session), + rpc_url=rpc_url, + sub_dir="subject", + timeout=timeout, + wait_time=wait_time, + ) + + async def render_batch_subject_cards_to_base64( + self, + data_list: list[RenderData], + rpc_url: str | None = None, + headless: bool = True, + wait_time: int = 0, + max_retries: int = 3, + timeout: int = 30000, + max_concurrency: int = 3, + ) -> list[str]: + semaphore = asyncio.Semaphore(max_concurrency) + + async def _limited_render(data: RenderData) -> str | None: + async with semaphore: + return await self.render_subject_card( + data=data, + rpc_url=rpc_url, + headless=headless, + wait_time=wait_time, + max_retries=max_retries, + timeout=timeout, + ) + + tasks = [_limited_render(data) for data in data_list] + results = await asyncio.gather(*tasks, return_exceptions=True) + + valid_results: list[str] = [] + for i, res in enumerate(results): + if isinstance(res, Exception): + logger.warning(f"批量渲染第 {i + 1} 项失败: {res}") + elif res: + valid_results.append(res) + return valid_results diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..0c49ac5 --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,74 @@ +from typing import TYPE_CHECKING, Any + +import aiohttp + +from .calendar import CalendarService +from .contracts import ( + CalendarDay, + CalendarWeekday, + EpisodeItem, + MessageResult, + RenderData, +) +from .exceptions import ( + BangumiApiError, + BangumiRateLimitError, + DatabaseError, + NoSubjectFound, + SubscriptionError, +) +from .schemas import Episode +from .subjects import SubjectsService +from .types import ImageSize, SubjectType + +if TYPE_CHECKING: + from .search import SearchService + from .subscription import SubscriptionService + + +# 聚合类:继承所有子Service的功能 +class BangumiService(SubjectsService, CalendarService): + def __init__( + self, + access_token: str, + user_agent: str, + proxy: str | None = None, + session: aiohttp.ClientSession | None = None, + ) -> None: + # 初始化最基础的父类 (BaseBangumiService) + # 因为所有Service都继承自BaseBangumiService,super会自动处理MRO链 + super().__init__(access_token, user_agent, proxy, session=session) + + +def __getattr__(name: str) -> Any: + if name == "SearchService": + from .search import SearchService + + return SearchService + if name == "SubscriptionService": + from .subscription import SubscriptionService + + return SubscriptionService + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "BangumiApiError", + "BangumiRateLimitError", + "BangumiService", + "CalendarDay", + "CalendarService", + "CalendarWeekday", + "DatabaseError", + "Episode", + "EpisodeItem", + "ImageSize", + "MessageResult", + "NoSubjectFound", + "RenderData", + "SearchService", + "SubjectType", + "SubjectsService", + "SubscriptionError", + "SubscriptionService", +] diff --git a/src/services/base.py b/src/services/base.py new file mode 100644 index 0000000..fbd1c2a --- /dev/null +++ b/src/services/base.py @@ -0,0 +1,174 @@ +import asyncio +import json +import time +from typing import Literal, cast, overload + +import aiohttp +from astrbot.api import logger + +from ..bangumi_types import JsonArray, JsonObject +from .contracts import SearchSubjectsResponse +from .exceptions import BangumiApiError, BangumiRateLimitError, NoSubjectFound + + +class BaseBangumiService: + def __init__( + self, + access_token: str, + user_agent: str, + proxy: str | None = None, + max_retries: int = 3, + session: aiohttp.ClientSession | None = None, + ) -> None: + if not access_token: + raise ValueError("Bangumi access_token 未设置") + self.base_url = "https://api.bgm.tv" + self.headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + "User-Agent": user_agent, + } + self.proxy = proxy + self.last_request_time = 0.0 + self._rate_lock = asyncio.Lock() + self._timeout = aiohttp.ClientTimeout(total=30, connect=10) + # 兜底 session(惰性创建,避免每次新建 TCP 连接) + self._fallback_session: aiohttp.ClientSession | None = None + # 这里只放通用的缓存,或者具体业务的缓存放到具体类中 + self.search_cache: dict[str, SearchSubjectsResponse] = {} + self.max_retries = max_retries + self._session = session + + @overload + async def _request( + self, + url: str, + method: str = "GET", + params: JsonObject | None = None, + json_data: JsonObject | None = None, + is_json: Literal[True] = True, + ) -> JsonObject | JsonArray: ... + + @overload + async def _request( + self, + url: str, + method: str = "GET", + params: JsonObject | None = None, + json_data: JsonObject | None = None, + is_json: Literal[False] = False, + ) -> bytes: ... + + async def _request( + self, + url: str, + method: str = "GET", + params: JsonObject | None = None, + json_data: JsonObject | None = None, + is_json: bool = True, + ) -> JsonObject | JsonArray | bytes: + """ + 通用API请求函数, 带限流和重试处理 + """ + last_exception: Exception | None = None + + for attempt in range(self.max_retries): + async with self._rate_lock: + now = time.time() + gap = 1.1 - (now - self.last_request_time) + if gap > 0: + await asyncio.sleep(gap) + self.last_request_time = time.time() + + logger.info( + f"Bangumi API请求 (尝试 {attempt + 1}/{self.max_retries}): {method} {url}" + ) + + try: + # 优先使用外部注入的 Session + session = ( + self._session + if self._session and not self._session.closed + else await self._get_fallback_session() + ) + request_context = ( + session.post( + url, + json=json_data, + params=params, + proxy=self.proxy, + headers=self.headers, + timeout=self._timeout, + ) + if method.upper() == "POST" + else session.get( + url, + params=params, + proxy=self.proxy, + headers=self.headers, + timeout=self._timeout, + ) + ) + + async with request_context as response: + if response.status >= 500: + last_exception = BangumiApiError( + f"服务器错误 ({response.status}),尝试 {attempt + 1}/{self.max_retries}" + ) + logger.warning(f"服务器返回错误状态码: {response.status}") + await asyncio.sleep(1.5) + continue + return await self._handle_response(response, is_json=is_json) + + except aiohttp.ClientError as e: + logger.warning(f"网络请求失败: {e}") + last_exception = e + if attempt < self.max_retries - 1: + await asyncio.sleep(1.5) + else: + logger.error("达到最大重试次数,请求失败") + + raise BangumiApiError(f"请求失败,请稍后再试: {last_exception}") + + async def _get_fallback_session(self) -> aiohttp.ClientSession: + """惰性创建并复用兜底 ClientSession。""" + if self._fallback_session is None or self._fallback_session.closed: + self._fallback_session = aiohttp.ClientSession(headers=self.headers) + return self._fallback_session + + @overload + async def _handle_response( + self, response: aiohttp.ClientResponse, is_json: Literal[True] = True + ) -> JsonObject | JsonArray: ... + + @overload + async def _handle_response( + self, response: aiohttp.ClientResponse, is_json: Literal[False] + ) -> bytes: ... + + async def _handle_response( + self, response: aiohttp.ClientResponse, is_json: bool = True + ) -> JsonObject | JsonArray | bytes: + """ + 处理api响应 + + """ + if response.status == 200: + if is_json: + raw = await response.json() + if isinstance(raw, (dict, list)): + return cast(JsonObject | JsonArray, raw) + raise BangumiApiError("API 返回了非 JSON 对象/数组类型") + return await response.read() + if response.status == 404: + raise NoSubjectFound("未找到相关条目") + if response.status == 429: + raise BangumiRateLimitError("API请求过于频繁") + + try: + error_data = await response.json() + error_text = json.dumps(error_data, ensure_ascii=False) + except (aiohttp.ContentTypeError, ValueError, TypeError): + error_text = await response.text() + logger.error(f"API错误: {response.status} - {error_text}") + raise BangumiApiError(f"API服务异常 ({response.status})") diff --git a/src/services/calendar.py b/src/services/calendar.py new file mode 100644 index 0000000..4fc22e0 --- /dev/null +++ b/src/services/calendar.py @@ -0,0 +1,65 @@ +import asyncio +import copy +import time +from typing import cast + +from astrbot.api import logger + +from .base import BaseBangumiService +from .contracts import CalendarDay +from .exceptions import BangumiApiError + + +class CalendarService(BaseBangumiService): + CALENDAR_CACHE_TTL_SECONDS = 12 * 60 * 60 + + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self._calendar_cache: list[CalendarDay] | None = None + self._calendar_cache_expire_at: float = 0.0 + self._calendar_cache_lock = asyncio.Lock() + + def _is_calendar_cache_valid(self, now: float) -> bool: + return self._calendar_cache is not None and now < self._calendar_cache_expire_at + + def invalidate_calendar_cache(self) -> None: + self._calendar_cache = None + self._calendar_cache_expire_at = 0.0 + + async def get_calendar(self) -> list[CalendarDay]: + now = time.time() + if self._is_calendar_cache_valid(now): + return copy.deepcopy(self._calendar_cache) + + # 双重检查 + 锁,避免并发下重复请求远端 API + async with self._calendar_cache_lock: + now = time.time() + if self._is_calendar_cache_valid(now): + return copy.deepcopy(self._calendar_cache) + + url = f"{self.base_url}/calendar" + previous_cache = copy.deepcopy(self._calendar_cache) + try: + data = await self._request(url, method="GET") + except (BangumiApiError, RuntimeError, ValueError, TypeError) as e: + logger.error(f"get_calendar 刷新缓存失败: {e}") + if previous_cache is not None: + return previous_cache + return [] + + if not isinstance(data, list): + logger.warning(f"get_calendar 返回了非 list 类型: {type(data)}") + if previous_cache is not None: + return previous_cache + return [] + + normalized: list[CalendarDay] = [] + for item in data: + if isinstance(item, dict): + normalized.append(cast(CalendarDay, item)) + else: + logger.warning(f"get_calendar 列表元素类型异常: {type(item)}") + + self._calendar_cache = copy.deepcopy(normalized) + self._calendar_cache_expire_at = now + self.CALENDAR_CACHE_TTL_SECONDS + return copy.deepcopy(self._calendar_cache) diff --git a/src/services/characters.py b/src/services/characters.py new file mode 100644 index 0000000..c1f4df8 --- /dev/null +++ b/src/services/characters.py @@ -0,0 +1,30 @@ +from typing import cast + +from ..bangumi_types import JsonObject +from .base import BaseBangumiService +from .contracts import PersonDetailsResponse, PersonsSearchResponse + + +class CharactersService(BaseBangumiService): + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + + async def search_persons( + self, keyword: str, limit: int = 10 + ) -> PersonsSearchResponse: + """通过关键词搜索人物""" + url = f"{self.base_url}/v0/search/persons" + json_data: JsonObject = {"keyword": keyword} + params: JsonObject = {"limit": limit} + data = await self._request( + url, method="POST", json_data=json_data, params=params + ) + if isinstance(data, dict) and isinstance(data.get("data"), list): + return cast(PersonsSearchResponse, data) + return {"data": []} + + async def get_person_details(self, person_id: int) -> PersonDetailsResponse: + """获取单个人物的详细信息""" + url = f"{self.base_url}/v0/persons/{person_id}" + data = await self._request(url) + return cast(PersonDetailsResponse, data if isinstance(data, dict) else {}) diff --git a/src/services/contracts.py b/src/services/contracts.py new file mode 100644 index 0000000..d8bf07b --- /dev/null +++ b/src/services/contracts.py @@ -0,0 +1,111 @@ +from typing import TypeAlias, TypedDict + +from ..bangumi_types import JsonValue + + +class SearchSubjectItem(TypedDict, total=False): + id: int | str + name: str + name_cn: str + type: int + + +class SearchSubjectsResponse(TypedDict): + data: list[SearchSubjectItem] + + +class EpisodeItem(TypedDict, total=False): + id: int + subject_id: int + type: int + ep: int + sort: int + name: str + name_cn: str + airdate: str + comment: int + disc: int + duration: str + duration_seconds: int + + +class EpisodeListResponse(TypedDict): + data: list[EpisodeItem] + + +class SubjectDetailsResponse(TypedDict, total=False): + id: int | str + name: str + name_cn: str + date: str + air_date: str + eps: int + episodes: list[EpisodeItem] + platform: str + type: int + images: dict[str, JsonValue] + image_url: str + summary: str + tags: list[dict[str, JsonValue]] + infobox: list[dict[str, JsonValue]] + total_episodes: int + rating: dict[str, JsonValue] + episode_list: list[dict[str, JsonValue]] + air_weekday: str + + +class CalendarWeekday(TypedDict, total=False): + id: int + cn: str + en: str + ja: str + + +class CalendarItem(TypedDict, total=False): + id: int | str + name: str + name_cn: str + images: dict[str, JsonValue] + + +class CalendarDay(TypedDict, total=False): + weekday: CalendarWeekday + items: list[CalendarItem] + is_today: bool + + +class UserDetailsResponse(TypedDict, total=False): + id: int | str + username: str + nickname: str + + +class PersonDetailsResponse(TypedDict, total=False): + id: int | str + name: str + summary: str + + +class PersonsSearchResponse(TypedDict): + data: list[PersonDetailsResponse] + + +class SubscribeMatch(TypedDict): + subject_id: str + name: str + air_date: str + total_episodes: int + + +class SubscribeCandidate(TypedDict): + subject_id: str + name: str + + +class UnsubscribeMatch(TypedDict): + subject_id: str + name: str + + +RenderData: TypeAlias = dict[str, JsonValue] +MessageResult: TypeAlias = object diff --git a/src/services/exceptions.py b/src/services/exceptions.py new file mode 100644 index 0000000..3eb5c68 --- /dev/null +++ b/src/services/exceptions.py @@ -0,0 +1,28 @@ +class NoSubjectFound(Exception): + """找不到对应条目的异常类""" + + pass + + +class BangumiApiError(Exception): + """Bangumi API请求错误的异常类""" + + pass + + +class BangumiRateLimitError(Exception): + """API限流异常类""" + + pass + + +class DatabaseError(Exception): + """数据库操作异常:替换 repository 层宽泛的 except Exception,提供更精准的错误上下文。""" + + pass + + +class SubscriptionError(Exception): + """订阅业务异常:替换 subscription 服务层宽泛的 except Exception,提供更精准的错误反馈。""" + + pass diff --git a/src/services/persons.py b/src/services/persons.py new file mode 100644 index 0000000..7115118 --- /dev/null +++ b/src/services/persons.py @@ -0,0 +1,32 @@ +from typing import cast + +from ..bangumi_types import JsonObject +from .base import BaseBangumiService +from .contracts import PersonDetailsResponse, PersonsSearchResponse + + +class PersonsService(BaseBangumiService): + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + + # --- 新增人物相关方法 --- + + async def search_persons( + self, keyword: str, limit: int = 10 + ) -> PersonsSearchResponse: + """通过关键词搜索人物""" + url = f"{self.base_url}/v0/search/persons" + json_data: JsonObject = {"keyword": keyword} + params: JsonObject = {"limit": limit} + data = await self._request( + url, method="POST", json_data=json_data, params=params + ) + if isinstance(data, dict) and isinstance(data.get("data"), list): + return cast(PersonsSearchResponse, data) + return {"data": []} + + async def get_person_details(self, person_id: int) -> PersonDetailsResponse: + """获取单个人物的详细信息""" + url = f"{self.base_url}/v0/persons/{person_id}" + data = await self._request(url) + return cast(PersonDetailsResponse, data if isinstance(data, dict) else {}) diff --git a/src/services/schemas.py b/src/services/schemas.py new file mode 100644 index 0000000..fd8b6f3 --- /dev/null +++ b/src/services/schemas.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel, ConfigDict, Field + + +class Episode(BaseModel): + """ + Bangumi Episode 数据模型 + 用于校验和解析从 API 返回的剧集信息 + """ + + airdate: str | None = Field(None, description="播出日期,格式: YYYY-MM-DD") + name: str = Field(..., description="剧集日文名称") + name_cn: str = Field(..., description="剧集中文名称") + duration: str | None = Field(None, description="时长,格式: HH:MM:SS") + desc: str = Field(default="", description="剧集简介") + ep: int = Field(..., description="集数", ge=0) + sort: int = Field(..., description="排序号", ge=0) + id: int = Field(..., description="剧集ID") + subject_id: int = Field(..., description="条目ID") + comment: int = Field(default=0, description="评论数", ge=0) + type: int = Field(..., description="剧集类型") + disc: int = Field(default=0, description="碟片号", ge=0) + duration_seconds: int | None = Field(None, description="时长(秒)", ge=0) + + # 允许额外字段,API 可能返回更多数据 + model_config = ConfigDict(extra="allow") diff --git a/src/services/search.py b/src/services/search.py new file mode 100644 index 0000000..bedc1a1 --- /dev/null +++ b/src/services/search.py @@ -0,0 +1,141 @@ +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, cast + +import aiohttp +import astrbot.api.message_components as Comp +from astrbot.api import logger +from astrbot.api.event import AstrMessageEvent + +from ..config import ConfigManager +from ..render import CalendarRenderer, SubjectRenderer +from .contracts import ( + MessageResult, + RenderData, + SearchSubjectItem, + SubjectDetailsResponse, +) +from .exceptions import BangumiApiError + +if TYPE_CHECKING: + from . import BangumiService + + +class SearchService: + def __init__( + self, + service: "BangumiService", + config_manager: ConfigManager, + session: aiohttp.ClientSession | None = None, + ) -> None: + self.service = service + self.config_manager = config_manager + self.subject_renderer = SubjectRenderer(session=session) + self.calendar_renderer = CalendarRenderer(session=session) + + async def handle_subject_search( + self, + event: AstrMessageEvent, + query: str, + top_k: int = 1, + subject_type: list[int] | None = None, + subject_tags: list[str] | None = None, + ) -> AsyncGenerator[MessageResult, None]: + """ + 处理条目搜索的核心流程:搜索 -> 渲染 (Base64) -> 发送。 + """ + if not query: + yield event.plain_result("❌ 请提供搜索关键词") + return + + logger.info(f"搜索请求: {query}, type={subject_type}, top_k={top_k}") + + try: + # 1. 搜索条目 + search_res = await self.service.search_subjects( + keyword=query, subject_type=subject_type, subject_tags=subject_tags + ) + if not search_res or "data" not in search_res or not search_res["data"]: + yield event.plain_result("🔍 未找到相关条目") + return + + # 2. 渲染并获取 Base64 组件 + image_components = await self._prepare_subject_images_base64( + search_res["data"], top_k + ) + + # 3. 发送结果 + if image_components: + yield event.chain_result(image_components) + else: + yield event.plain_result("❌ 未能生成渲染图片") + + except (BangumiApiError, RuntimeError, ValueError) as e: + logger.error(f"SearchService.handle_subject_search 失败: {e}") + yield event.plain_result(f"❌ 处理失败: {e}") + + async def handle_calendar( + self, event: AstrMessageEvent + ) -> AsyncGenerator[MessageResult, None]: + """ + 处理每日放送逻辑。 + """ + try: + calendar_res = await self.service.get_calendar() + if not calendar_res: + yield event.plain_result("❌ 未获取到放送数据") + return + + base64_image = await self.calendar_renderer.render_calendar( + calendar_res, + rpc_url=self.config_manager.get_render_server_url(), + max_retries=self.config_manager.get_max_retries(), + ) + + if base64_image: + yield event.chain_result([Comp.Image.fromBase64(base64_image)]) + else: + yield event.plain_result("❌ 图片生成失败") + except (BangumiApiError, RuntimeError, ValueError) as e: + logger.error(f"SearchService.handle_calendar 失败: {e}") + yield event.plain_result(f"❌ 处理失败: {e}") + + async def _prepare_subject_images_base64( + self, subjects: list[SearchSubjectItem], top_k: int + ) -> list[Comp.Image]: + """ + 内部逻辑:准备渲染数据并生成 Base64 图片组件。 + """ + data_list: list[SubjectDetailsResponse] = [] + + for item in subjects[:top_k]: + subject_id = item.get("id") + if not subject_id: + continue + + # 获取详情 + subject_data = await self.service.get_subject_details(str(subject_id)) + if not subject_data: + continue + + # 补充剧集进度信息 + try: + episodes_data = await self.service.get_subject_episodes(int(subject_id)) + if episodes_data and "data" in episodes_data: + subject_data["episodes"] = episodes_data["data"] + except (BangumiApiError, ValueError, TypeError) as e: + logger.warning(f"获取剧集信息失败 (subject_id={subject_id}): {e}") + + data_list.append(subject_data) + + if not data_list: + return [] + + # 批量渲染为 Base64 + base64_list = await self.subject_renderer.render_batch_subject_cards_to_base64( + data_list=cast(list[RenderData], data_list), + rpc_url=self.config_manager.get_render_server_url(), + max_retries=self.config_manager.get_max_retries(), + ) + + # 包装成消息组件 + return [Comp.Image.fromBase64(b64) for b64 in base64_list] diff --git a/src/services/subjects.py b/src/services/subjects.py new file mode 100644 index 0000000..cd03d0b --- /dev/null +++ b/src/services/subjects.py @@ -0,0 +1,171 @@ +import base64 +import datetime +from typing import cast + +from astrbot.api import logger +from pydantic import ValidationError + +from ..bangumi_types import JsonObject +from .base import BaseBangumiService +from .contracts import ( + EpisodeItem, + EpisodeListResponse, + SearchSubjectItem, + SearchSubjectsResponse, + SubjectDetailsResponse, +) +from .schemas import Episode +from .types import ImageSize + + +class SubjectsService(BaseBangumiService): + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + + async def search_subjects( + self, + keyword: str, + limit: int = 5, + offset: int = 0, + subject_type: list[int] | None = None, + subject_tags: list[str] | None = None, + ) -> SearchSubjectsResponse: + cache_key = f"search:{keyword}:{limit}" + if cache_key in self.search_cache: + return self.search_cache[cache_key] + + url = f"{self.base_url}/v0/search/subjects" + filters: dict[str, object] = {} + json_data: JsonObject = { + "keyword": keyword, + "limit": limit, + "offset": offset, + "filter": filters, + } + if subject_type is not None: + filters["type"] = subject_type + if subject_tags is not None: + filters["tag"] = subject_tags + data = await self._request( + url, + method="POST", + json_data=json_data, + ) + if isinstance(data, dict): + raw_items = data.get("data") + if isinstance(raw_items, list): + normalized: SearchSubjectsResponse = {"data": []} + for item in raw_items: + if isinstance(item, dict): + normalized["data"].append(cast(SearchSubjectItem, item)) + self.search_cache[cache_key] = cast(SearchSubjectsResponse, normalized) + return normalized + + fallback: SearchSubjectsResponse = {"data": []} + self.search_cache[cache_key] = fallback + return fallback + + async def get_subject_details(self, subject_id: str) -> SubjectDetailsResponse: + """ + 获取条目的信息 + """ + url = f"{self.base_url}/v0/subjects/{subject_id}" + data = await self._request(url) + return cast(SubjectDetailsResponse, data if isinstance(data, dict) else {}) + + async def get_subject_image(self, subject_id: str, size: ImageSize) -> bytes: + """ + 获取条目的图片原始二进制数据 + """ + url = f"{self.base_url}/v0/subjects/{subject_id}/image" + params: JsonObject = {"type": size.value} + return await self._request(url, params=params, is_json=False) + + async def get_subject_base64image( + self, subject_id: str, size: ImageSize + ) -> str | None: + """ + 获取条目的图片并转换为 Base64 编码的字符串 + """ + try: + image_bytes = await self.get_subject_image(subject_id, size) + if image_bytes: + return base64.b64encode(image_bytes).decode("utf-8") + except (ValueError, TypeError, RuntimeError) as e: + logger.error(f"获取条目 {subject_id} 的 Base64 图片失败: {e}") + return None + + async def get_subject_episodes(self, subject_id: int) -> EpisodeListResponse: + """ + 获取条目的剧集信息 + + Args: + subject_id: 条目的id + Returns: + data: 剧集信息 + total: 总集数 + """ + url = f"{self.base_url}/v0/episodes" + params: JsonObject = {"subject_id": subject_id} + data = await self._request(url, params=params) + if isinstance(data, dict): + raw_items = data.get("data") + if isinstance(raw_items, list): + normalized: EpisodeListResponse = {"data": []} + for item in raw_items: + if isinstance(item, dict): + normalized["data"].append(cast(EpisodeItem, item)) + return normalized + return {"data": []} + + async def get_latest_episode(self, subject_id: int) -> Episode | None: + """ + 从 episodes 数据中提取最新一集的信息。 + 最新一集的定义:已播出且有互动(评论)的普通剧集。 + """ + episodes_data = await self.get_subject_episodes(subject_id) + raw_list = episodes_data.get("data", []) + if not raw_list: + return None + + # 解析并校验数据 + episodes = self._parse_episodes(raw_list) + + # 获取今天的日期用于比较 + today = datetime.date.today() + + # 逆序查找:从最后一集向前找第一个符合条件的 + for episode in reversed(episodes): + if episode.ep == 0: + continue + + # 检查播出状态 + is_aired = True + if episode.airdate: + try: + episode_date = datetime.datetime.strptime( + episode.airdate, "%Y-%m-%d" + ).date() + is_aired = episode_date <= today + except ValueError: + # 日期格式异常时,不因为日期判定为未播出 + pass + + # 核心业务逻辑:已播出且有评论互动 + if is_aired and episode.comment > 0: + return episode + + return None + + @staticmethod + def _parse_episodes(raw_data: list[EpisodeItem]) -> list[Episode]: + """ + 辅助函数:将原始字典列表解析为 Episode 模型列表,自动过滤校验失败的数据。 + """ + parsed_episodes: list[Episode] = [] + for item in raw_data: + try: + parsed_episodes.append(Episode(**item)) + except ValidationError as e: + logger.warning(f"解析剧集数据失败,已跳过: {e}, 原始数据: {item}") + return parsed_episodes diff --git a/src/services/subscription.py b/src/services/subscription.py new file mode 100644 index 0000000..d86e314 --- /dev/null +++ b/src/services/subscription.py @@ -0,0 +1,330 @@ +from typing import TYPE_CHECKING, cast + +import aiohttp +from astrbot.api import logger +from astrbot.api.star import StarTools +from astrbot.core.message.message_event_result import MessageChain + +from ..config import ConfigManager +from ..db import BangumiRepository +from ..render import EpisodeRenderer +from .contracts import SubscribeCandidate, SubscribeMatch, UnsubscribeMatch +from .exceptions import BangumiApiError, DatabaseError, SubscriptionError +from .schemas import Episode +from .types import ImageSize + +if TYPE_CHECKING: + from . import BangumiService + + +class SubscriptionService: + def __init__( + self, + repository: BangumiRepository, + service: "BangumiService", + config_manager: ConfigManager, + session: aiohttp.ClientSession | None = None, + ) -> None: + self.storage = repository + self.service = service + self.config_manager = config_manager + self.renderer = EpisodeRenderer(session=session) + + async def get_subscribe_candidates( + self, keyword: str, limit: int + ) -> tuple[str | None, list[SubscribeCandidate]]: + """ + 查询订阅候选,命中多条时由上层进行二次确认。 + """ + normalized_keyword = keyword.strip() + if not normalized_keyword: + return "❌ 请提供要订阅的番剧关键词或ID。", [] + + effective_limit = max(1, min(limit, 10)) + search_res = await self.service.search_subjects( + keyword=normalized_keyword, + limit=effective_limit, + subject_type=[2], + subject_tags=None, + ) + raw_items = search_res.get("data", []) + if not raw_items: + return "🔍 未找到相关番剧", [] + + candidates: list[SubscribeCandidate] = [] + seen: set[str] = set() + for item in raw_items: + subject_id_raw = item.get("id") + if subject_id_raw is None: + continue + subject_id = str(subject_id_raw) + if subject_id in seen: + continue + seen.add(subject_id) + raw_name = item.get("name_cn") or item.get("name") or f"ID:{subject_id}" + candidates.append({"subject_id": subject_id, "name": str(raw_name)}) + + if not candidates: + return "🔍 未找到相关番剧", [] + return None, candidates + + async def _build_subscribable_subject( + self, subject_id: str + ) -> tuple[str | None, SubscribeMatch | None]: + """ + 根据 subject_id 构建可订阅条目(详情 + 放送表校验)。 + """ + details = await self.service.get_subject_details(subject_id) + if not details: + return "❌ 获取番剧详情失败", None + + raw_name = details.get("name_cn") or details.get("name") + name = str(raw_name) if raw_name else "未知番剧" + + calendar_res = await self.service.get_calendar() + is_in_calendar = False + if calendar_res: + for day_item in calendar_res: + for item in day_item.get("items", []): + if str(item.get("id")) == subject_id: + is_in_calendar = True + break + if is_in_calendar: + break + + if not is_in_calendar: + return ( + f"⚠️ {name} 不在当前的每日放送列表中 (可能已完结或未开播),暂不支持自动追踪。", + None, + ) + + total_episodes_raw = details.get("eps", 0) + total_episodes = ( + int(total_episodes_raw) if isinstance(total_episodes_raw, (int, str)) else 0 + ) + air_date = str(details.get("date", "")) + result_data: SubscribeMatch = { + "subject_id": subject_id, + "name": name, + "air_date": air_date, + "total_episodes": total_episodes, + } + return None, cast(SubscribeMatch, result_data) + + async def _match_subscribable_subject( + self, keyword: str + ) -> tuple[str | None, SubscribeMatch | None]: + """ + 查找可订阅的番剧逻辑(从 API 层迁移至此)。 + """ + error_msg, candidates = await self.get_subscribe_candidates( + keyword=keyword, limit=1 + ) + if error_msg: + return error_msg, None + if not candidates: + return "🔍 未找到相关番剧", None + return await self._build_subscribable_subject(candidates[0]["subject_id"]) + + async def subscribe_by_subject_id(self, group_id: str, subject_id: str) -> str: + """ + 基于明确 subject_id 完成订阅。 + """ + try: + error_msg, subject_info = await self._build_subscribable_subject(subject_id) + if error_msg: + return error_msg + if not subject_info: + return "❌ 未知错误:未能获取番剧信息" + + success = self.storage.subscribe_subject( + group_id=group_id, + subject_id=subject_info["subject_id"], + name=subject_info["name"], + air_date=subject_info["air_date"], + total_episodes=subject_info["total_episodes"], + ) + if success: + return ( + f"✅ 成功订阅《{subject_info['name']}》!\n如有更新将推送到本群。" + ) + return "❌ 订阅失败,数据库错误。" + except (BangumiApiError, DatabaseError, SubscriptionError) as e: + logger.error(f"SubscriptionService.subscribe_by_subject_id 失败: {e}") + return f"❌ 处理失败: {e}" + + async def subscribe(self, group_id: str, query: str) -> str: + """ + 处理订阅逻辑:匹配条目 -> 存入数据库 -> 建立订阅关系。 + """ + logger.info(f"处理追番请求: {query}, group_id={group_id}") + try: + # 1. 匹配条目 (调用内部迁移后的逻辑) + error_msg, subject_info = await self._match_subscribable_subject(query) + if error_msg: + return error_msg + if not subject_info: + return "❌ 未知错误:未能获取番剧信息" + + subject_id = subject_info["subject_id"] + name = subject_info["name"] + + # 2 & 3. 原子性地写入条目信息并建立订阅关系 + success = self.storage.subscribe_subject( + group_id=group_id, + subject_id=subject_id, + name=name, + air_date=subject_info["air_date"], + total_episodes=subject_info["total_episodes"], + ) + if success: + return f"✅ 成功订阅《{name}》!\n如有更新将推送到本群。" + else: + return "❌ 订阅失败,数据库错误。" + except (BangumiApiError, DatabaseError, SubscriptionError) as e: + logger.error(f"SubscriptionService.subscribe 失败: {e}") + return f"❌ 处理失败: {e}" + + async def unsubscribe(self, group_id: str, query: str) -> str: + """ + 取消订阅逻辑。 + """ + logger.info(f"处理取消追番请求: {query}, group_id={group_id}") + try: + error_msg, subject_info = self._match_local_subscription(group_id, query) + if error_msg: + return error_msg + if not subject_info: + return "❌ 未知错误:未能获取番剧信息" + + subject_id = subject_info["subject_id"] + name = subject_info["name"] + + success = self.storage.remove_subscription(group_id, subject_id) + if success: + return f"✅ 已成功取消订阅《{name}》。" + else: + return f"❌ 取消订阅失败:你可能并没有订阅《{name}》。" + except (BangumiApiError, DatabaseError, SubscriptionError) as e: + logger.error(f"SubscriptionService.unsubscribe 失败: {e}") + return f"❌ 处理失败: {e}" + + def _match_local_subscription( + self, group_id: str, query: str + ) -> tuple[str | None, UnsubscribeMatch | None]: + """ + 在当前群组的本地订阅中做模糊匹配。 + """ + normalized_query = str(query).strip() + if not normalized_query: + return "❌ 请提供要取消订阅的番剧关键词或ID。", None + + # 取 6 条用于判断是否超过默认展示上限(5 条) + candidates = self.storage.find_group_subscription_candidates( + group_id=group_id, keyword=normalized_query, limit=6 + ) + if not candidates: + return f"❌ 未找到与「{normalized_query}」匹配的本群订阅番剧。", None + + if len(candidates) == 1: + subject = candidates[0] + return None, { + "subject_id": str(subject.subject_id), + "name": str(subject.name), + } + + display_limit = 5 + display_candidates = candidates[:display_limit] + lines = [ + "⚠️ 匹配到多个已订阅番剧,请提供更精确名称或直接使用 ID:", + ] + for idx, subject in enumerate(display_candidates, start=1): + lines.append(f"{idx}. {subject.name} (ID: {subject.subject_id})") + if len(candidates) > display_limit: + lines.append("(仅显示前 5 项)") + return "\n".join(lines), None + + async def check_updates(self) -> None: + """ + 定时任务核心逻辑:检查所有监控中的番剧是否有更新。 + """ + subjects = self.storage.get_monitored_subjects() + logger.info(f"开始更新 {len(subjects)} 个番剧的集数信息") + + for subject in subjects: + try: + # 获取最新集数 + latest_episode = await self.service.get_latest_episode( + int(subject.subject_id) + ) + if not latest_episode: + continue + + # 尝试获取封面图用于渲染 + try: + image_base64 = await self.service.get_subject_base64image( + subject.subject_id, size=ImageSize.LARGE + ) + if image_base64: + latest_episode.image_url = ( + f"data:image/png;base64,{image_base64}" + ) + except BangumiApiError as e: + logger.error(f"获取条目 {subject.name} 图片失败: {e}") + + # 比对更新 + if latest_episode.ep > subject.current_episode: + logger.info( + f"番剧《{subject.name}》有更新: {subject.current_episode} -> {latest_episode.ep}" + ) + + # 更新数据库 + # 显式转换为 str 以解决 Pylance 对 SQLAlchemy Column 对象的类型报错 + self.storage.update_subject_episode( + str(subject.subject_id), latest_episode.ep + ) + + # 发送通知 + await self._notify_subscribers( + latest_episode, str(subject.subject_id), str(subject.name) + ) + + except (BangumiApiError, DatabaseError) as e: + logger.error(f"更新番剧《{subject.name}》失败: {e}") + + async def _notify_subscribers( + self, episode: Episode, subject_id: str, subject_name: str + ) -> None: + """ + 渲染并发送更新通知。 + """ + subscribed_groups = self.storage.get_subject_subscribers(subject_id) + if not subscribed_groups: + return + + # 渲染图片 + base64_image = await self.renderer.render_episode( + episode, + rpc_url=self.config_manager.get_render_server_url(), + max_retries=self.config_manager.get_max_retries(), + ) + + chain = MessageChain() + if base64_image: + chain = chain.base64_image(base64_image) + else: + # 如果图片渲染失败,发送纯文本通知作为兜底 + chain = chain.message( + f"🔔 番剧《{subject_name}》更新啦!\n第 {episode.ep} 集:{episode.name_cn or episode.name}" + ) + + for group_id in subscribed_groups: + try: + await StarTools.send_message_by_id( + type="GroupMessage", id=group_id, message_chain=chain + ) + logger.info(f"向群组 {group_id} 发送《{subject_name}》更新通知成功。") + except Exception as e: + logger.error( + f"向群组 {group_id} 发送《{subject_name}》更新通知失败: {e}" + ) diff --git a/src/services/types.py b/src/services/types.py new file mode 100644 index 0000000..2a8fb82 --- /dev/null +++ b/src/services/types.py @@ -0,0 +1,57 @@ +from enum import Enum, IntEnum, StrEnum + + +class SubjectType(IntEnum): + """Bangumi 条目类型""" + + BOOK = 1 + ANIME = 2 + MUSIC = 3 + GAME = 4 + REAL = 6 + + def to_display(self) -> str: + """获取带 Emoji 的显示名称""" + _map = { + SubjectType.BOOK: "📚 书籍", + SubjectType.ANIME: "🎬 动画", + SubjectType.MUSIC: "🎵 音乐", + SubjectType.GAME: "🎮 游戏", + SubjectType.REAL: "🌐 三次元", + } + return _map.get(self, "未知") + + +class PersonType(IntEnum): + """Bangumi 人物类型""" + + INDIVIDUAL = 1 + COMPANY = 2 + GROUP = 3 + + def to_display(self) -> str: + """获取带 Emoji 的显示名称""" + _map = { + PersonType.INDIVIDUAL: "👤 个人", + PersonType.COMPANY: "🏢 公司", + PersonType.GROUP: "👥 组合", + } + return _map.get(self, "未知") + + +class ImageSize(Enum): + """图片尺寸规格""" + + SMALL = "small" + GRID = "grid" + LARGE = "large" + MEDIUM = "medium" + COMMON = "common" + + +class CommonTag(StrEnum): + """常用标签常量""" + + TV = "TV" + MOVIE = "剧场版" + MANGA = "漫画" diff --git a/src/services/users.py b/src/services/users.py new file mode 100644 index 0000000..473961b --- /dev/null +++ b/src/services/users.py @@ -0,0 +1,17 @@ +from typing import cast +from urllib.parse import quote + +from .base import BaseBangumiService +from .contracts import UserDetailsResponse + + +class UsersService(BaseBangumiService): + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + + async def get_user_details(self, username: str) -> UserDetailsResponse: + """获取用户详细信息""" + encoded_username = quote(username) + url = f"{self.base_url}/v0/users/{encoded_username}" + data = await self._request(url) + return cast(UserDetailsResponse, data if isinstance(data, dict) else {}) diff --git a/src/templates/calendar/calendar.html b/src/templates/calendar/calendar.html new file mode 100644 index 0000000..409bcff --- /dev/null +++ b/src/templates/calendar/calendar.html @@ -0,0 +1,266 @@ + + + + + + Bangumi Calendar + + + +
+
+

每日放送表 Bangumi Calendar +

+
+ +
+ {% for day in days %} +
+
+ {{ day.weekday.cn }} + {{ day.weekday.en }} +
+
+ {% for item in day['items'] %} +
+
+ {{ item.name_cn or item.name }} +
+
+
{{ item.name_cn or item.name }}
+
+ {% if item.rating and item.rating.score %} +
+ + {{ item.rating.score }} +
+ {% endif %} + {% if item.rank %} +
#{{ item.rank }}
+ {% endif %} +
+
+
+ {% else %} +
今日无更新内容
+ {% endfor %} +
+
+ {% endfor %} +
+
+ + \ No newline at end of file diff --git a/src/templates/subject/subject.html b/src/templates/subject/subject.html new file mode 100644 index 0000000..b1777e3 --- /dev/null +++ b/src/templates/subject/subject.html @@ -0,0 +1,596 @@ + + + + + + + Subject Card + + + + +
+ + {% if air_weekday %} +
+ {{ air_weekday }} + 曜日 +
+ {% endif %} + + +
+
+ {% if image_url %} + Cover + {% endif %} +
+ + + {% if episode_list %} +
+
+ 放送进度 + {% set aired_count = episode_list | selectattr('aired') | list | length %} + {{ aired_count }} / {{ episode_list | length }} +
+
+ {% for ep_item in episode_list %} +
+ {{ ep_item.ep }} +
+ {% endfor %} +
+
+ {% endif %} + + + {% if rating and rating.count %} +
+
评分分布
+
+ {% set max_count = rating.count.values() | max if rating.count else 1 %} + + {% for i in range(1, 11) %} + {% set count = rating.count[i|string] or 0 %} + {% set height_percent = (count / max_count * 100) if max_count > 0 else 0 %} + {% set final_height = height_percent if height_percent > 1 else 1 %} +
+
+
+ {% endfor %} +
+
+ 1 + + + + 5 + + + + + 10 +
+
+ {% endif %} +
+ + +
+ +
+

{{ name_cn or name }}

+ {% if name_cn and name != name_cn %} +

{{ name }}

+ {% endif %} +
+ + + {% if rating %} +
+
+ + {{ rating.score }} +
+ {% set display_rank = rating.rank or rank %} + {% if display_rank %} +
+ #{{ display_rank }} +
+ {% endif %} +
+ {{ rating.total }} 人评分 +
+ {% if collection and collection.doing %} +
+ {{ collection.doing }} 人在看 +
+ {% endif %} +
+ {% else %} +
+ 暂无评分 +
+ {% endif %} + + + {% if tags %} +
+ {% for tag in tags[:8] %} + {{ tag.name }} + {% endfor %} +
+ {% endif %} + + +
+
简介
+

+ {{ summary if summary else '暂无简介' }} +

+
+ + + +
+
+ + + \ No newline at end of file diff --git a/src/templates/update/episode.html b/src/templates/update/episode.html new file mode 100644 index 0000000..da4c567 --- /dev/null +++ b/src/templates/update/episode.html @@ -0,0 +1,237 @@ + + + + + + Episode Card Update + + + + + + + + +
+ +
+ {% if image_url %} + Cover + {% else %} +
+ 🎬 +
+ {% endif %} + + +
+
+ + +
+
+
+ EP.{{ '%02d' % sort if sort else '01' }} +

{{ name_cn or name or "第 " + (sort|string) + " 话" }}

+
+ + +
+ + {% if desc %} +

{{ desc }}

+ {% endif %} +
+
+ + + \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..a3aeee4 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,5 @@ +from .async_utils import retry +from .env_manager import EnvManager +from .scheduler import SchedulerManager + +__all__ = ["EnvManager", "SchedulerManager", "retry"] diff --git a/src/utils/async_utils.py b/src/utils/async_utils.py new file mode 100644 index 0000000..af5899b --- /dev/null +++ b/src/utils/async_utils.py @@ -0,0 +1,42 @@ +import asyncio +from collections.abc import Awaitable, Callable +from typing import TypeVar + +from astrbot.api import logger + +T = TypeVar("T") + + +async def retry( + func: Callable[..., Awaitable[T]], + retries: int = 3, + delay: float = 1.0, + label: str = "任务", + *args: object, + **kwargs: object, +) -> T: + """ + 通用异步重试方法 + :param func: 需要重试的异步函数 + :param retries: 最大重试次数 + :param delay: 重试间隔(秒) + :param label: 用于日志显示的标签 + :param args: 传递给 func 的位置参数 + :param kwargs: 传递给 func 的关键字参数 + :return: 异步函数的返回结果 + + """ + last_exception: Exception | None = None + for i in range(retries): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + logger.warning(f"{label} 执行失败 (尝试 {i + 1}/{retries}): {e}") + if i < retries - 1: + await asyncio.sleep(delay) + + logger.error(f"{label} 在 {retries} 次尝试后最终失败") + if last_exception is None: + raise RuntimeError(f"{label} 在 {retries} 次尝试后最终失败") + raise last_exception diff --git a/src/utils/env_manager.py b/src/utils/env_manager.py new file mode 100644 index 0000000..8ca55a6 --- /dev/null +++ b/src/utils/env_manager.py @@ -0,0 +1,18 @@ +from astrbot.api import logger + + +class EnvManager: + """ + Stub kept for API compatibility. Playwright has been removed; + local rendering now uses Pillow with no additional setup required. + """ + + def __init__(self, data_dir: str) -> None: + self.data_dir = data_dir + + def is_installed(self) -> bool: + """Always returns True — no external renderer needs to be installed.""" + return True + + async def install_dependencies(self) -> None: + logger.info("[+] Pillow renderer: no additional dependencies to install.") diff --git a/src/utils/scheduler.py b/src/utils/scheduler.py new file mode 100644 index 0000000..4a58992 --- /dev/null +++ b/src/utils/scheduler.py @@ -0,0 +1,83 @@ +""" +APScheduler 管理器 + +此模块提供了一个为 asyncio 和特定时区配置的 APScheduler 单例管理器。 +""" + +import asyncio +from collections.abc import Callable + +import pytz +from apscheduler.jobstores.base import JobLookupError +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from astrbot.api import logger + + +class SchedulerManager: + """ + APScheduler 的管理器类。 + 它使用 Asia/Shanghai 时区初始化调度器,并提供添加、删除和管理任务的方法。 + """ + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls, *args: object, **kwargs: object) -> "SchedulerManager": + # 伪单例实现,确保只存在一个调度器实例。 + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + """ + 初始化 SchedulerManager。 + 每次调用 SchedulerManager() 时都会调用此方法,但调度器本身只创建一次。 + """ + if not hasattr(self, "scheduler"): + self.scheduler = AsyncIOScheduler(timezone=pytz.timezone("Asia/Shanghai")) + self.scheduler.start() + logger.info("调度器已初始化并在 Asia/Shanghai 时区启动.") + + def add_job( + self, func: Callable[..., object], trigger: str, **kwargs: object + ) -> str | None: + """ + 向调度器添加一个任务。 + + Args: + func (Callable): 要执行的异步函数。 + trigger (str): 触发器类型(例如:'interval'、'cron'、'date')。 + **kwargs: 触发器的参数(例如:seconds=30, hour=8, minute=0)。 + + Returns: + str | None: 添加的任务ID,如果失败则返回 None。 + """ + try: + job = self.scheduler.add_job(func, trigger, **kwargs) + return job.id + except (RuntimeError, ValueError, TypeError) as e: + logger.error(f"Error adding job: {e}") + return None + + def cancel_job(self, job_id: str) -> None: + """ + 根据任务ID取消任务。 + + Args: + job_id (str): 要取消的任务的ID。 + """ + try: + self.scheduler.remove_job(job_id) + logger.info(f"定时任务{job_id}已取消.") + except JobLookupError: + logger.warning(f"未找到定时任务{job_id}") + except (RuntimeError, ValueError, TypeError) as e: + logger.error(f"取消任务失败{job_id}: {e}") + + def shutdown(self) -> None: + """ + 关闭调度器。 + """ + if self.scheduler.running: + self.scheduler.shutdown() + logger.info("调度器已关闭.") diff --git a/tests/test_calendar_service.py b/tests/test_calendar_service.py new file mode 100644 index 0000000..9d4355f --- /dev/null +++ b/tests/test_calendar_service.py @@ -0,0 +1,107 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +import src.services.calendar as calendar_module +from src.services import BangumiApiError, CalendarService + + +@pytest.fixture +def service() -> CalendarService: + return CalendarService(access_token="token", user_agent="ua") + + +@pytest.mark.asyncio +async def test_calendar_cache_hit(service: CalendarService) -> None: + payload = [{"weekday": {"id": 1}, "items": [{"id": 1, "name": "A"}]}] + service._request = AsyncMock(return_value=payload) + + first = await service.get_calendar() + second = await service.get_calendar() + + assert first == second + assert service._request.await_count == 1 + + +@pytest.mark.asyncio +async def test_calendar_cache_expired_refresh( + service: CalendarService, monkeypatch: pytest.MonkeyPatch +) -> None: + now = 1_000_000.0 + + def fake_time() -> float: + return now + + monkeypatch.setattr(calendar_module.time, "time", fake_time) + + service._request = AsyncMock( + side_effect=[ + [{"weekday": {"id": 1}, "items": []}], + [{"weekday": {"id": 2}, "items": []}], + ] + ) + + first = await service.get_calendar() + now += service.CALENDAR_CACHE_TTL_SECONDS + 1 + second = await service.get_calendar() + + assert first != second + assert service._request.await_count == 2 + + +@pytest.mark.asyncio +async def test_calendar_cache_returns_deepcopy(service: CalendarService) -> None: + payload = [{"weekday": {"id": 1}, "items": []}] + service._request = AsyncMock(return_value=payload) + + first = await service.get_calendar() + first[0]["weekday"]["id"] = 7 + second = await service.get_calendar() + + assert second[0]["weekday"]["id"] == 1 + assert service._request.await_count == 1 + + +@pytest.mark.asyncio +async def test_calendar_cache_refresh_failed_fallback_stale( + service: CalendarService, monkeypatch: pytest.MonkeyPatch +) -> None: + now = 2_000_000.0 + + def fake_time() -> float: + return now + + monkeypatch.setattr(calendar_module.time, "time", fake_time) + + service._request = AsyncMock( + side_effect=[ + [{"weekday": {"id": 1}, "items": [{"id": 1}]}], + BangumiApiError("boom"), + ] + ) + + first = await service.get_calendar() + now += service.CALENDAR_CACHE_TTL_SECONDS + 1 + second = await service.get_calendar() + + assert second == first + assert service._request.await_count == 2 + + +@pytest.mark.asyncio +async def test_calendar_cache_concurrent_single_refresh( + service: CalendarService, +) -> None: + payload = [{"weekday": {"id": 3}, "items": []}] + + async def slow_fetch(*args: object, **kwargs: object) -> list[dict[str, object]]: + await asyncio.sleep(0.05) + return payload + + service._request = AsyncMock(side_effect=slow_fetch) + + first, second = await asyncio.gather(service.get_calendar(), service.get_calendar()) + + assert first == second + assert service._request.await_count == 1 diff --git a/tests/test_search_service.py b/tests/test_search_service.py new file mode 100644 index 0000000..c940b96 --- /dev/null +++ b/tests/test_search_service.py @@ -0,0 +1,70 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from astrbot.api.event import AstrMessageEvent + +from src.services import SearchService + + +@pytest.fixture +def mock_service() -> MagicMock: + service = MagicMock() + service.search_subjects = AsyncMock() + service.get_subject_details = AsyncMock() + service.get_subject_episodes = AsyncMock() + service.get_calendar = AsyncMock() + return service + + +@pytest.fixture +def mock_config_manager() -> MagicMock: + config_manager = MagicMock() + config_manager.get_render_server_url.return_value = "https://api.unitedpooh.top/rpc" + config_manager.get_max_retries.return_value = 1 + return config_manager + + +@pytest.mark.asyncio +async def test_handle_calendar_success( + mock_service: MagicMock, mock_config_manager: MagicMock +) -> None: + # 准备 Mock 数据 + mock_service.get_calendar.return_value = [{"weekday": {"id": 1}, "items": []}] + + search_service = SearchService( + service=mock_service, config_manager=mock_config_manager + ) + + # Mock 渲染器,避免进入模板渲染逻辑 + search_service.calendar_renderer.render_calendar = AsyncMock( + return_value="fake_base64" + ) + + event = MagicMock(spec=AstrMessageEvent) + event.chain_result = MagicMock(side_effect=lambda x: x) + + results: list[object] = [] + async for res in search_service.handle_calendar(event): + results.append(res) + + assert len(results) > 0 + mock_service.get_calendar.assert_called_once() + event.chain_result.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_subject_search_no_query( + mock_service: MagicMock, mock_config_manager: MagicMock +) -> None: + search_service = SearchService( + service=mock_service, config_manager=mock_config_manager + ) + event = MagicMock(spec=AstrMessageEvent) + event.plain_result = MagicMock(side_effect=lambda x: x) + + results: list[object] = [] + async for res in search_service.handle_subject_search(event, query=""): + results.append(res) + + assert len(results) > 0 + assert "❌ 请提供搜索关键词" in str(results[0]) diff --git a/tests/test_subject_renderer.py b/tests/test_subject_renderer.py new file mode 100644 index 0000000..10e9f51 --- /dev/null +++ b/tests/test_subject_renderer.py @@ -0,0 +1,72 @@ +import pytest +from loguru import logger + +from src.render import SubjectRenderer + + +@pytest.mark.asyncio +async def test_render_subject_card_success() -> None: + # 准备测试数据 + subject_data = { + "date": "2026-01-11", + "platform": "TV", + "images": { + "small": "https://lain.bgm.tv/r/200/pic/cover/l/71/50/525565_OxOv7.jpg", + "grid": "https://lain.bgm.tv/r/100/pic/cover/l/71/50/525565_OxOv7.jpg", + "large": "https://lain.bgm.tv/pic/cover/l/71/50/525565_OxOv7.jpg", + "medium": "https://lain.bgm.tv/r/800/pic/cover/l/71/50/525565_OxOv7.jpg", + "common": "https://lain.bgm.tv/r/400/pic/cover/l/71/50/525565_OxOv7.jpg", + }, + "summary": "总是活力充沛,却又很在意周遭目光的女孩:铃木实优\r\n以及个性文静,却能清楚表达自己意见的男生:谷悠介\r\n\r\n本次故事将讲述这两人的生活点滴。铃木喜欢着谷,却一直无法鼓起勇气告白。直到某天,两人放学回家时走在同一条路上并牵起了手。借由该契机,两人相互倾诉对彼此的好感并开始了交往。同学们虽然感到讶异,但也都很支持两人的恋情。\r\n这部恋爱喜剧描写的,正是这对个性截然相反的两人,在彼此尊重之下慢慢加深互相的理解,并与朋友们一同度过的校园生活点滴。如此温暖的故事就此开幕!\r\n\r\n\r\n\r\n[简介原文]\r\nいつも元気いっぱいだけど周りの目を気にしてしまう女子・鈴木と、\r\n物静かだけど自分の意見をしっかり言える男子・谷。\r\n正反対な二人が误解や勘違いをしながらもお互いを尊重し、\r\nゆっくりと理解を深めていく姿と、友人たちとの学校生活を描くラブコメディ。", + "name": "正反対な君と僕", + "name_cn": "相反的你和我", + "tags": [ + {"name": "恋爱", "count": 1356}, + {"name": "校园", "count": 1071}, + {"name": "2026年1月", "count": 1033}, + {"name": "漫画改", "count": 823}, + ], + "infobox": [ + {"key": "中文名", "value": "相反的你和我"}, + {"key": "别名", "value": [{"v": "正相反的你与我"}]}, + {"key": "话数", "value": "12"}, + {"key": "放送开始", "value": "2026年1月11日"}, + ], + "total_episodes": 12, + "id": 525565, + "type": 2, + "rating": { + "rank": 677, + "total": 2517, + "count": { + "1": 6, + "2": 3, + "3": 7, + "4": 13, + "5": 40, + "6": 167, + "7": 753, + "8": 1234, + "9": 194, + "10": 100, + }, + "score": 7.6, + }, + } + + renderer = SubjectRenderer() + + # 运行渲染器 + base64_image = await renderer.render_subject_card( + rpc_url="https://api.unitedpooh.top/rpc", + data=subject_data, + headless=True, + timeout=60000, + ) + + # 验证结果 + assert base64_image is not None, "[-] 渲染失败,未返回 Base64 字符串" + assert isinstance(base64_image, str), "返回值应为 Base64 字符串" + assert len(base64_image) > 100, "Base64 字符串过短" + + logger.info(f"[+] 渲染成功!图片长度: {len(base64_image)} 字符") diff --git a/tests/test_subscription_service.py b/tests/test_subscription_service.py new file mode 100644 index 0000000..5f23126 --- /dev/null +++ b/tests/test_subscription_service.py @@ -0,0 +1,206 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from src.services import SubscriptionService + + +@pytest.fixture +def mock_repo() -> MagicMock: + repo = MagicMock() + repo.subscribe_subject = MagicMock(return_value=True) + repo.remove_subscription = MagicMock(return_value=True) + repo.find_group_subscription_candidates = MagicMock(return_value=[]) + return repo + + +@pytest.fixture +def mock_service() -> MagicMock: + service = MagicMock() + service.search_subjects = AsyncMock() + service.get_subject_details = AsyncMock() + service.get_calendar = AsyncMock() + service.get_latest_episode = AsyncMock() + service.get_subject_base64image = AsyncMock() + return service + + +@pytest.mark.asyncio +async def test_subscribe_success(mock_repo, mock_service) -> None: + mock_service.search_subjects.return_value = {"data": [{"id": 123}]} + mock_service.get_subject_details.return_value = { + "id": 123, + "name": "Test Anime", + "name_cn": "测试番剧", + "date": "2024-01-01", + "eps": 12, + } + mock_service.get_calendar.return_value = [{"items": [{"id": 123}]}] + + sub_service = SubscriptionService( + repository=mock_repo, service=mock_service, config_manager=MagicMock() + ) + result = await sub_service.subscribe("group_1", "Test Anime") + + assert "成功订阅《测试番剧》" in result + mock_repo.subscribe_subject.assert_called_once_with( + group_id="group_1", + subject_id="123", + name="测试番剧", + air_date="2024-01-01", + total_episodes=12, + ) + + +@pytest.mark.asyncio +async def test_unsubscribe_local_single_match_success(mock_repo, mock_service) -> None: + mock_repo.find_group_subscription_candidates.return_value = [ + SimpleNamespace(subject_id="123", name="测试番剧") + ] + + sub_service = SubscriptionService( + repository=mock_repo, service=mock_service, config_manager=MagicMock() + ) + result = await sub_service.unsubscribe("group_1", "测") + + assert "已成功取消订阅《测试番剧》" in result + mock_repo.find_group_subscription_candidates.assert_called_once_with( + group_id="group_1", keyword="测", limit=6 + ) + mock_repo.remove_subscription.assert_called_once_with("group_1", "123") + mock_service.search_subjects.assert_not_called() + mock_service.get_subject_details.assert_not_called() + mock_service.get_calendar.assert_not_called() + + +@pytest.mark.asyncio +async def test_unsubscribe_local_no_match(mock_repo, mock_service) -> None: + mock_repo.find_group_subscription_candidates.return_value = [] + + sub_service = SubscriptionService( + repository=mock_repo, service=mock_service, config_manager=MagicMock() + ) + result = await sub_service.unsubscribe("group_1", "不存在") + + assert "未找到与「不存在」匹配的本群订阅番剧" in result + mock_repo.remove_subscription.assert_not_called() + mock_service.search_subjects.assert_not_called() + mock_service.get_subject_details.assert_not_called() + mock_service.get_calendar.assert_not_called() + + +@pytest.mark.asyncio +async def test_unsubscribe_local_multi_match_returns_candidates( + mock_repo, mock_service +) -> None: + mock_repo.find_group_subscription_candidates.return_value = [ + SimpleNamespace(subject_id="1", name="进击的巨人"), + SimpleNamespace(subject_id="2", name="进击!巨人中学"), + SimpleNamespace(subject_id="3", name="巨人族的新娘"), + ] + + sub_service = SubscriptionService( + repository=mock_repo, service=mock_service, config_manager=MagicMock() + ) + result = await sub_service.unsubscribe("group_1", "巨人") + + assert "匹配到多个已订阅番剧" in result + assert "1. 进击的巨人 (ID: 1)" in result + assert "2. 进击!巨人中学 (ID: 2)" in result + assert "3. 巨人族的新娘 (ID: 3)" in result + mock_repo.remove_subscription.assert_not_called() + mock_service.search_subjects.assert_not_called() + mock_service.get_subject_details.assert_not_called() + mock_service.get_calendar.assert_not_called() + + +@pytest.mark.asyncio +async def test_unsubscribe_local_remove_failed(mock_repo, mock_service) -> None: + mock_repo.find_group_subscription_candidates.return_value = [ + SimpleNamespace(subject_id="123", name="测试番剧") + ] + mock_repo.remove_subscription.return_value = False + + sub_service = SubscriptionService( + repository=mock_repo, service=mock_service, config_manager=MagicMock() + ) + result = await sub_service.unsubscribe("group_1", "测") + + assert "取消订阅失败:你可能并没有订阅《测试番剧》" in result + mock_repo.remove_subscription.assert_called_once_with("group_1", "123") + + +@pytest.mark.asyncio +async def test_get_subscribe_candidates_multi_match(mock_repo, mock_service) -> None: + mock_service.search_subjects.return_value = { + "data": [ + {"id": 1, "name_cn": "进击的巨人"}, + {"id": 2, "name": "进击!巨人中学"}, + {"id": 1, "name_cn": "进击的巨人"}, + {"name_cn": "无ID条目"}, + ] + } + + sub_service = SubscriptionService( + repository=mock_repo, service=mock_service, config_manager=MagicMock() + ) + error_msg, candidates = await sub_service.get_subscribe_candidates("巨人", 5) + + assert error_msg is None + assert candidates == [ + {"subject_id": "1", "name": "进击的巨人"}, + {"subject_id": "2", "name": "进击!巨人中学"}, + ] + mock_service.search_subjects.assert_awaited_once_with( + keyword="巨人", + limit=5, + subject_type=[2], + subject_tags=None, + ) + + +@pytest.mark.asyncio +async def test_subscribe_by_subject_id_success(mock_repo, mock_service) -> None: + mock_service.get_subject_details.return_value = { + "id": 456, + "name": "Test Name", + "name_cn": "测试番剧2", + "date": "2025-01-01", + "eps": 24, + } + mock_service.get_calendar.return_value = [{"items": [{"id": 456}]}] + + sub_service = SubscriptionService( + repository=mock_repo, service=mock_service, config_manager=MagicMock() + ) + result = await sub_service.subscribe_by_subject_id("group_1", "456") + + assert "✅ 成功订阅《测试番剧2》" in result + mock_repo.subscribe_subject.assert_called_once_with( + group_id="group_1", + subject_id="456", + name="测试番剧2", + air_date="2025-01-01", + total_episodes=24, + ) + + +@pytest.mark.asyncio +async def test_subscribe_by_subject_id_not_in_calendar(mock_repo, mock_service) -> None: + mock_service.get_subject_details.return_value = { + "id": 789, + "name": "Not In Calendar", + "name_cn": "未放送番剧", + "date": "2025-06-01", + "eps": 12, + } + mock_service.get_calendar.return_value = [{"items": [{"id": 456}]}] + + sub_service = SubscriptionService( + repository=mock_repo, service=mock_service, config_manager=MagicMock() + ) + result = await sub_service.subscribe_by_subject_id("group_1", "789") + + assert "不在当前的每日放送列表中" in result + mock_repo.subscribe_subject.assert_not_called()