Устранено дублирование кода в AiAgent.

Добавлена возможность пересылки сообщений в личных чатах.
Обновлены зависимости.
Добавлен requirements.txt.
Исправлены предупреждения PyCharm 2026.1.
This commit is contained in:
Kirill Kirilenko 2026-04-07 01:01:34 +03:00
parent d2c10fa0a5
commit 4b265b5405
21 changed files with 278 additions and 197 deletions

View file

@ -5,14 +5,16 @@ Message = ai.agent.Message
Agent = ai.agent.AiAgent
# Глобальный экземпляр агента
agent: ai.agent.AiAgent
agent_instance: ai.agent.AiAgent
def create_ai_agent(openrouter_token: str, openrouter_model: str,
fal_token: str, replicate_token: str, tavily_token: str,
db: BasicDatabase, platform: str):
global agent
agent = ai.agent.AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, tavily_token, db, platform)
global agent_instance
agent_instance = ai.agent.AiAgent(openrouter_token, openrouter_model,
fal_token, replicate_token, tavily_token,
db, platform)
__all__ = ["agent", "Agent", "Message", "create_ai_agent"]
__all__ = ["agent_instance", "Agent", "Message", "create_ai_agent"]

View file

@ -2,11 +2,12 @@ import datetime
import json
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from openrouter import OpenRouter, RetryConfig
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, \
ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict, ToolDefinitionJSONTypedDict
from openrouter.components import ChatAssistantMessage, ChatAssistantMessageTypedDict, \
ChatToolCall, ChatResult, ChatSystemMessageTypedDict, ChatMessagesTypedDict, ChatAssistantMessageContent
from openrouter.errors import ResponseValidationError, OpenRouterError
from openrouter.utils import BackoffStrategy
@ -16,7 +17,6 @@ from database import BasicDatabase
from ai.utils import *
from ai.tools import *
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
OPENROUTER_HTTP_REFERER = "https://ultracoder.org"
@ -27,11 +27,11 @@ MAX_OUTPUT_TOKENS = 500
@dataclass()
class Message:
user_name: str = None
text: str = None
image: bytes = None
image_hires: bytes = None
message_id: int = None
user_name: Optional[str] = None
text: Optional[str] = None
image: Optional[bytes] = None
image_hires: Optional[bytes] = None
message_id: Optional[int] = None
class AiAgent:
@ -51,7 +51,8 @@ class AiAgent:
self._load_prompts()
self.client_openrouter = OpenRouter(api_key=openrouter_token,
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER,
x_open_router_title=OPENROUTER_X_TITLE,
http_referer=OPENROUTER_HTTP_REFERER,
retry_config=retry_config)
# Создание наборов инструментов
@ -63,88 +64,33 @@ class AiAgent:
# Сбор всех инструментов
self.tools: list[ai.tool.Tool] = []
self.tools_descriptions: list[ToolDefinitionJSONTypedDict] = []
self.tools_descriptions: list = []
for toolset in self.toolsets:
self.tools.extend(toolset.functions)
self.tools_descriptions.extend(toolset.get_all_tools_description())
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
message.text = _add_message_prefix(message.text, message.user_name)
return await self._handle_chat_reply(
bot_id=bot_id,
chat_id=chat_id,
message=message,
forwarded_messages=forwarded_messages,
is_group_chat=True,
max_messages=GROUP_CHAT_MAX_MESSAGES
)
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
context.append(_serialize_message(role="user", text=message.text, image=message.image))
for fwd_message in forwarded_messages:
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
if fwd_message.text is not None:
message_text += '\n' + fwd_message.text
fwd_message.text = message_text
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
try:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
ai_response = response.content
tools_artifacts = {}
if response.tool_calls is not None:
tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
for fwd_message in forwarded_messages:
self.db.context_add_message(bot_id, chat_id,
role="user", text=fwd_message.text, image=fwd_message.image,
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id,
role="assistant", text=ai_response,
image=tools_artifacts.get("generated_image"),
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
image_hires=tools_artifacts.get("generated_image_hires")), True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
message.text = _add_message_prefix(message.text)
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
context.append(_serialize_message(role="user", text=message.text, image=message.image))
try:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
context.append(_serialize_assistant_message(response))
ai_response = response.content
tools_artifacts = {}
if response.tool_calls is not None:
tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, role="assistant",
text=ai_response, image=tools_artifacts.get("generated_image"),
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
image_hires=tools_artifacts.get("generated_image_hires")), True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
async def get_private_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message] = None) \
-> Tuple[Message, bool]:
return await self._handle_chat_reply(
bot_id=bot_id,
chat_id=chat_id,
message=message,
forwarded_messages=forwarded_messages or [],
is_group_chat=False,
max_messages=PRIVATE_CHAT_MAX_MESSAGES
)
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
@ -157,15 +103,71 @@ class AiAgent:
####################################################################################
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
context: List[MessageTypedDict] = [
async def _handle_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message],
is_group_chat: bool, max_messages: int) -> Tuple[Message, bool]:
# 1. Подготовка текста сообщения (префикс)
if is_group_chat:
message.text = _add_message_prefix(message.text, message.user_name)
else:
message.text = _add_message_prefix(message.text)
# 2. Сбор контекста из БД
context = self._get_chat_context(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id)
context.append(_serialize_message(role="user", text=message.text, image=message.image))
# 3. Обработка пересланных сообщений
for fwd_message in forwarded_messages:
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
if fwd_message.text is not None:
message_text += '\n' + fwd_message.text
fwd_message.text = message_text
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
# 4. Генерация ответа с поддержкой инструментов
try:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
context.append(_serialize_assistant_message(response))
ai_response: Optional[ChatAssistantMessageContent] = response.content
tools_artifacts = {}
while response.tool_calls is not None:
tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context)
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
context.append(_serialize_assistant_message(response))
ai_response = response.content
# 5. Сохранение истории в БД
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=max_messages)
for fwd_message in forwarded_messages:
self.db.context_add_message(bot_id, chat_id,
role="user", text=fwd_message.text, image=fwd_message.image,
message_id=fwd_message.message_id, max_messages=max_messages)
self.db.context_add_message(bot_id, chat_id,
role="assistant", text=ai_response,
image=tools_artifacts.get("generated_image"),
message_id=None, max_messages=max_messages)
return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
image_hires=tools_artifacts.get("generated_image_hires")), True
except Exception as e:
if "Rate limit exceeded" in str(e):
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[ChatMessagesTypedDict]:
context: List[ChatMessagesTypedDict] = [
self._construct_system_prompt(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id)
]
for message in self.db.context_get_messages(bot_id, chat_id):
context.append(_serialize_message(message["role"], message["text"], message["image"]))
return context
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict:
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> ChatSystemMessageTypedDict:
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
@ -186,7 +188,7 @@ class AiAgent:
return {"role": "system", "content": prompt}
async def _generate_reply(self, bot_id: int, chat_id: int,
context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage:
context: List[ChatMessagesTypedDict], allow_tools: bool = False) -> ChatAssistantMessage:
response = await self._async_chat_completion_request(
model=self.openrouter_model,
messages=context,
@ -197,8 +199,8 @@ class AiAgent:
)
return self._filter_response(response.choices[0].message)
async def _process_tool_calls(self, tool_calls: List[ChatMessageToolCall],
context: List[MessageTypedDict]) -> dict:
async def _process_tool_calls(self, tool_calls: List[ChatToolCall],
context: List[ChatMessagesTypedDict]) -> dict:
artifacts = {}
if tool_calls is None:
return artifacts
@ -221,8 +223,9 @@ class AiAgent:
return artifacts
async def _async_chat_completion_request(self, **kwargs):
async def _async_chat_completion_request(self, **kwargs) -> ChatResult:
try:
# noinspection PyTypeChecker
return await self.client_openrouter.chat.send_async(**kwargs)
except ResponseValidationError as e:
# Костыль для OpenRouter SDK:
@ -248,7 +251,7 @@ class AiAgent:
raise e
@staticmethod
def _filter_response(response: AssistantMessage) -> AssistantMessage:
def _filter_response(response: ChatAssistantMessage) -> ChatAssistantMessage:
text = str(response.content)
text = text.replace("<image>", "")
response.content = text
@ -271,11 +274,12 @@ def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -
return {"role": role, "content": serialize_message_content(text, image)}
def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict:
def _serialize_assistant_message(message: ChatAssistantMessage) -> ChatAssistantMessageTypedDict:
# noinspection PyTypeChecker
return _remove_none_recursive(message.model_dump(by_alias=True))
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:
def _remove_none_recursive(data: Union[Dict, List, Any]) -> Union[Dict, List, Any]:
if isinstance(data, dict):
return {
k: _remove_none_recursive(v)

View file

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from openrouter.components import ChatMessageContentItemTypedDict, ToolDefinitionJSONTypedDict
from openrouter.components import ChatToolMessageContentTypedDict, ChatFunctionToolFunctionTypedDict
class Tool(ABC):
@ -26,7 +26,7 @@ class Tool(ABC):
"""Описание параметров функции"""
pass
def to_dict(self) -> ToolDefinitionJSONTypedDict:
def to_dict(self) -> ChatFunctionToolFunctionTypedDict:
"""JSON-представление инструмента для OpenRouter"""
return {
"type": "function",
@ -38,7 +38,7 @@ class Tool(ABC):
}
@abstractmethod
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
"""Вызов функции.
:param args: Параметры из JSON
:param artifacts: Словарь для хранения артефактов
@ -61,6 +61,6 @@ class ToolSet:
"""Поиск инструмента по имени"""
return next((t for t in self.functions if t.name == name), None)
def get_all_tools_description(self) -> List[ToolDefinitionJSONTypedDict]:
def get_all_tools_description(self) -> List[ChatFunctionToolFunctionTypedDict]:
"""Получить JSON-описание всех инструментов"""
return [tool.to_dict() for tool in self.functions]

View file

@ -1,7 +1,8 @@
from fal_client import AsyncClient as FalClient
from openrouter.components import ChatMessageContentItemTypedDict
from typing import Any, Dict, List
from openrouter.components import ChatToolMessageContentTypedDict
from fal_client import AsyncClient as FalClient
from ai.tool import Tool
from ai.utils import *
@ -39,9 +40,9 @@ class GenerateImageTool(Tool):
"required": ["prompt"]
}
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
prompt = args.get("prompt", "")
aspect_ratio = args.get("aspect_ratio", None)
aspect_ratio = args.get("aspect_ratio", "4:3")
aspect_ratio_size_map = {
"1:1": "square",

View file

@ -1,7 +1,8 @@
from openrouter.components import ChatMessageContentItemTypedDict
from replicate import Client as ReplicateClient
from typing import Any, Dict, List
from openrouter.components import ChatToolMessageContentTypedDict
from replicate import Client as ReplicateClient
from ai.tool import Tool
from ai.utils import *
@ -42,10 +43,10 @@ class GenerateImageAnimeTool(Tool):
"required": ["prompt", "negative_prompt"]
}
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
prompt = args.get("prompt", "")
negative_prompt = args.get("negative_prompt", "")
aspect_ratio = args.get("aspect_ratio", None)
aspect_ratio = args.get("aspect_ratio", "4:3")
aspect_ratio_resolution_map = {
"1:1": (1280, 1280),
@ -72,7 +73,7 @@ class GenerateImageAnimeTool(Tool):
}
try:
outputs = await self._client.async_run(REPLICATE_MODEL, input=arguments)
outputs: Any = await self._client.async_run(REPLICATE_MODEL, input=arguments)
artifacts["generated_image_hires"] = await outputs[0].aread()
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)

View file

@ -1,7 +1,7 @@
from tavily import TavilyClient
from typing import Any, Dict, List
from openrouter.components import ChatMessageContentItemTypedDict
from openrouter.components import ChatToolMessageContentTypedDict
from tavily import TavilyClient
from ai.tool import Tool
from ai.utils import *
@ -32,7 +32,7 @@ class TavilySearchTool(Tool):
"required": ["query"]
}
async def execute(self, args: Dict[str, Any], _artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
async def execute(self, args: Dict[str, Any], _artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
query = args.get("query", "")
print(f"Веб-поиск: {query}")

View file

@ -21,7 +21,7 @@ def encode_image(image: bytes) -> str:
def compress_image(image: bytes, max_side: Optional[int] = None) -> bytes:
img = Image.open(BytesIO(image)).convert("RGB")
if img.width > max_side or img.height > max_side:
if max_side is not None and (img.width > max_side or img.height > max_side):
scale = min(max_side / img.width, max_side / img.height)
new_width = int(img.width * scale)
new_height = int(img.height * scale)

View file

@ -1,5 +1,5 @@
from datetime import datetime
from typing import List, Optional, Union
from typing import Dict, List, Optional
from pyodbc import connect, SQL_CHAR, SQL_WCHAR, Row
@ -14,7 +14,7 @@ class BasicDatabase:
def get_bots(self):
self.cursor.execute("SELECT * FROM bots")
return self._to_dict(self.cursor.fetchall())
return self._to_list(self.cursor.fetchall())
def get_bot(self, bot_id: int):
self.cursor.execute("SELECT * FROM bots WHERE id = ?", bot_id)
@ -22,7 +22,7 @@ class BasicDatabase:
def get_chats(self, bot_id: int):
self.cursor.execute("SELECT * FROM chats WHERE bot_id = ?", bot_id)
return self._to_dict(self.cursor.fetchall())
return self._to_list(self.cursor.fetchall())
def get_chat(self, bot_id: int, chat_id: int):
self.cursor.execute("SELECT * FROM chats WHERE bot_id = ? AND chat_id = ?", bot_id, chat_id)
@ -45,7 +45,7 @@ class BasicDatabase:
def get_users(self, bot_id: int, chat_id: int):
self.cursor.execute("SELECT * FROM users WHERE bot_id = ? AND chat_id = ?", bot_id, chat_id)
return self._to_dict(self.cursor.fetchall())
return self._to_list(self.cursor.fetchall())
def add_user(self, bot_id: int, chat_id: int, user_id: int):
self.cursor.execute("INSERT INTO users (bot_id, chat_id, user_id) VALUES (?, ?, ?)",
@ -79,7 +79,7 @@ class BasicDatabase:
WHERE bot_id = ? AND chat_id = ? AND messages_today > 0
ORDER BY messages_today DESC
""", bot_id, chat_id)
return self._to_dict(self.cursor.fetchall())
return self._to_list(self.cursor.fetchall())
def get_top_messages_month(self, bot_id: int, chat_id: int):
self.cursor.execute("""
@ -87,7 +87,7 @@ class BasicDatabase:
WHERE bot_id = ? AND chat_id = ? AND messages_month > 0
ORDER BY messages_month DESC
""", bot_id, chat_id)
return self._to_dict(self.cursor.fetchall())
return self._to_list(self.cursor.fetchall())
def get_top_silent(self, bot_id: int, chat_id: int, threshold_days: int):
current_time = int(datetime.now().timestamp())
@ -98,7 +98,7 @@ class BasicDatabase:
WHERE bot_id = ? AND chat_id = ? AND last_message <= ?
ORDER BY last_message ASC
""", current_time, bot_id, chat_id, threshold)
result = self._to_dict(self.cursor.fetchall())
result = self._to_list(self.cursor.fetchall())
for row in result:
if row['value'] > 3650:
row['value'] = 'никогда'
@ -110,7 +110,7 @@ class BasicDatabase:
WHERE bot_id = ? AND chat_id = ? AND warnings > 0
ORDER BY warnings DESC
""", bot_id, chat_id)
return self._to_dict(self.cursor.fetchall())
return self._to_list(self.cursor.fetchall())
def reset_messages_today(self, bot_id: int):
self.cursor.execute("UPDATE users SET messages_today = 0 WHERE bot_id = ?", bot_id)
@ -118,13 +118,14 @@ class BasicDatabase:
def reset_messages_month(self, bot_id: int):
self.cursor.execute("UPDATE users SET messages_month = 0 WHERE bot_id = ?", bot_id)
def context_get_messages(self, bot_id: int, chat_id: int) -> list[dict]:
def context_get_messages(self, bot_id: int, chat_id: int) -> List[Dict]:
self.cursor.execute("""
SELECT role, text, image FROM contexts
WHERE bot_id = ? AND chat_id = ?
ORDER BY id
""", bot_id, chat_id)
return self._to_dict(self.cursor.fetchall())
result = self._to_list(self.cursor.fetchall())
return result
def context_get_count(self, bot_id: int, chat_id: int) -> int:
self.cursor.execute("SELECT COUNT(*) FROM contexts WHERE bot_id = ? AND chat_id = ?", bot_id, chat_id)
@ -195,22 +196,22 @@ class BasicDatabase:
user = self.get_user(bot_id, chat_id, user_id)
return user
def _to_dict(self, args: Union[Row, List[Row], None]):
def _to_dict(self, args: Optional[Row]) -> Optional[Dict]:
columns = [column[0] for column in self.cursor.description]
if args is None:
return None
elif isinstance(args, Row):
if args is not None:
result = {}
for i, column in enumerate(columns):
result[column] = args[i]
return result
elif isinstance(args, list) and all(isinstance(item, Row) for item in args):
results: list[dict] = []
for row in args:
row_dict = {}
for i, column in enumerate(columns):
row_dict[column] = row[i]
results.append(row_dict)
return results
else:
raise TypeError("unexpected type")
return None
def _to_list(self, args: List[Row]) -> List[Dict]:
columns = [column[0] for column in self.cursor.description]
results: list[dict] = []
for row in args:
row_dict = {}
for i, column in enumerate(columns):
row_dict[column] = row[i]
results.append(row_dict)
return results

11
requirements.txt Normal file
View file

@ -0,0 +1,11 @@
aiogram~=3.27.0
aiohttp~=3.13.3
vkbottle~=4.8.1
vkbottle-types~=5.199.99.18
pyodbc~=5.3.0
openrouter==0.8.1
replicate~=1.0.7
fal_client~=0.13.2
tavily~=1.1.0
pillow~=12.2.0
pymorphy3~=2.0.6

View file

@ -15,7 +15,7 @@ async def user_join_handler(message: Message, bot: Bot):
if chat['active'] == 0:
return
for member in message.new_chat_members:
for member in message.new_chat_members or []:
if member.is_bot:
continue
@ -32,7 +32,7 @@ async def user_join_handler(message: Message, bot: Bot):
return
member = message.left_chat_member
if member.is_bot:
if not member or member.is_bot:
return
database.DB.delete_user(bot.id, chat_id, member.id)
@ -40,6 +40,9 @@ async def user_join_handler(message: Message, bot: Bot):
@router.message(F.content_type == ContentType.MIGRATE_TO_CHAT_ID)
async def migration_handler(message: Message, bot: Bot):
if message.migrate_to_chat_id is None:
return
old_id, new_id = message.chat.id, message.migrate_to_chat_id
database.DB.chat_delete(bot.id, new_id)
database.DB.chat_update(bot.id, old_id, chat_id=new_id)

View file

@ -23,7 +23,7 @@ async def start_handler(message: Message, bot: Bot):
chat_id = message.chat.id
database.DB.create_chat_if_not_exists(bot.id, chat_id)
if not await tg_user_is_admin(chat_id, message.from_user.id, bot):
if message.from_user is None or not await tg_user_is_admin(chat_id, message.from_user.id, bot):
await message.answer(MESSAGE_PERMISSION_DENIED)
return
@ -46,7 +46,7 @@ async def rules_handler(message: Message, bot: Bot):
else:
await message.answer(MESSAGE_DEFAULT_RULES)
else:
if not await tg_user_is_admin(chat_id, message.from_user.id, bot):
if message.from_user is None or not await tg_user_is_admin(chat_id, message.from_user.id, bot):
await message.answer(MESSAGE_PERMISSION_DENIED)
return
@ -62,7 +62,7 @@ async def set_greeting_handler(message: Message, bot: Bot):
await message.answer(MESSAGE_CHAT_NOT_ACTIVE)
return
if not await tg_user_is_admin(chat_id, message.from_user.id, bot):
if message.from_user is None or not await tg_user_is_admin(chat_id, message.from_user.id, bot):
await message.answer(MESSAGE_PERMISSION_DENIED)
return
@ -82,7 +82,7 @@ async def set_ai_prompt_handler(message: Message, bot: Bot):
await message.answer(MESSAGE_CHAT_NOT_ACTIVE)
return
if not await tg_user_is_admin(chat_id, message.from_user.id, bot):
if message.from_user is None or not await tg_user_is_admin(chat_id, message.from_user.id, bot):
await message.answer(MESSAGE_PERMISSION_DENIED)
return
@ -102,11 +102,11 @@ async def clear_context_handler(message: Message, bot: Bot):
await message.answer(MESSAGE_CHAT_NOT_ACTIVE)
return
if not await tg_user_is_admin(chat_id, message.from_user.id, bot):
if message.from_user is None or not await tg_user_is_admin(chat_id, message.from_user.id, bot):
await message.answer(MESSAGE_PERMISSION_DENIED)
return
ai.agent.clear_chat_context(bot.id, chat_id)
ai.agent_instance.clear_chat_context(bot.id, chat_id)
await message.answer("Контекст очищен.")
@ -118,11 +118,11 @@ async def warning_handler(message: Message, bot: Bot):
await message.answer(MESSAGE_CHAT_NOT_ACTIVE)
return
if not await tg_user_is_admin(chat_id, message.from_user.id, bot):
if message.from_user is None or not await tg_user_is_admin(chat_id, message.from_user.id, bot):
await message.answer(MESSAGE_PERMISSION_DENIED)
return
if message.reply_to_message is None:
if message.reply_to_message is None or message.reply_to_message.from_user is None:
await message.answer(MESSAGE_NEED_REPLY)
return
@ -133,9 +133,15 @@ async def warning_handler(message: Message, bot: Bot):
user = database.DB.get_user(bot.id, chat_id, user_id)
user_info = message.reply_to_message.from_user
# TODO: родительный падеж имени и фамилии, если возможно
await message.answer('У {} {} {}.'.format(
utils.full_name(user_info.first_name, user_info.last_name),
user['warnings'],
utils.make_word_agree_with_number(user['warnings'], 'предупреждение'))
)
if user_info is not None:
# TODO: родительный падеж имени и фамилии, если возможно
await message.answer('У {} {} {}.'.format(
utils.full_name(user_info.first_name, user_info.last_name),
user['warnings'],
utils.make_word_agree_with_number(user['warnings'], 'предупреждение'))
)
else:
await message.answer('У пользователя {} {}.'.format(
user['warnings'],
utils.make_word_agree_with_number(user['warnings'], 'предупреждение'))
)

View file

@ -41,7 +41,7 @@ async def any_message_handler(message: Message, bot: Bot):
return
# Игнорировать ботов
if message.from_user.is_bot:
if message.from_user is None or message.from_user.is_bot:
return
user_id = message.from_user.id
@ -55,6 +55,7 @@ async def any_message_handler(message: Message, bot: Bot):
try:
message_text = get_message_text(message)
assert bot_user.username is not None
bot_username_mention = '@' + bot_user.username
if message_text is not None and message_text.find(bot_username_mention) != -1:
# Сообщение содержит @bot_username
@ -62,9 +63,10 @@ async def any_message_handler(message: Message, bot: Bot):
if message.reply_to_message:
# Сообщение также является ответом -> переслать оригинальное сообщение
ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)]
elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id:
elif (message.reply_to_message and
message.reply_to_message.from_user is not None and message.reply_to_message.from_user.id == bot_user.id):
# Ответ на сообщение бота
last_id = ai.agent.get_last_assistant_message_id(bot.id, chat_id)
last_id = ai.agent_instance.get_last_assistant_message_id(bot.id, chat_id)
if message.reply_to_message.message_id != last_id:
# Оригинального сообщения нет в контексте, или оно не последнее -> переслать его
ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)]
@ -80,15 +82,18 @@ async def any_message_handler(message: Message, bot: Bot):
answer: ai.agent.Message
success: bool
answer, success = await utils.run_with_progress(
partial(ai.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
partial(message.bot.send_chat_action, chat_id, 'typing'),
partial(ai.agent_instance.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
partial(bot.send_chat_action, chat_id, 'typing'),
interval=4)
if answer.image is not None:
if answer.image is not None and answer.image_hires is not None:
answer_id = (await message.reply_photo(photo=wrap_photo(answer.image),
caption=trim_caption(answer.text))).message_id
await message.reply_document(document=wrap_document(answer.image_hires, 'image', 'png'))
else:
elif answer.text is not None:
answer_id = (await message.reply(answer.text)).message_id
else:
return
if success:
ai.agent.set_last_response_id(bot.id, chat_id, answer_id)
ai.agent_instance.set_last_response_id(bot.id, chat_id, answer_id)

View file

@ -38,7 +38,7 @@ async def reset_context_handler(message: Message, bot: Bot):
chat_id = message.chat.id
database.DB.create_chat_if_not_exists(bot.id, chat_id)
ai.agent.clear_chat_context(bot.id, chat_id)
ai.agent_instance.clear_chat_context(bot.id, chat_id)
await message.answer("Контекст очищен.")
@ -52,18 +52,33 @@ async def any_message_handler(message: Message, bot: Bot):
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
return
ai_fwd_messages: list[ai.Message] = []
if message.reply_to_message:
last_id = ai.agent_instance.get_last_assistant_message_id(bot.id, chat_id)
if message.reply_to_message.message_id != last_id:
# Оригинального сообщения нет в контексте, или оно не последнее -> переслать его
fwd_message = await create_ai_message(message.reply_to_message, bot)
if (message.reply_to_message.from_user is not None and message.from_user is not None and
message.reply_to_message.from_user.id == message.from_user.id):
# Замаскировать реальное имя пользователя
fwd_message.user_name = "Пользователь"
ai_fwd_messages = [fwd_message]
answer: ai.Message
success: bool
answer, success = await utils.run_with_progress(
partial(ai.agent.get_private_chat_reply, bot.id, chat_id, ai_message),
partial(message.bot.send_chat_action, chat_id, 'typing'),
partial(ai.agent_instance.get_private_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
partial(bot.send_chat_action, chat_id, 'typing'),
interval=4)
if answer.image is not None:
if answer.image is not None and answer.image_hires is not None:
answer_id = (await message.answer_photo(photo=wrap_photo(answer.image),
caption=trim_caption(answer.text))).message_id
await message.answer_document(document=wrap_document(answer.image_hires, 'image', 'png'))
else:
elif answer.text is not None:
answer_id = (await message.answer(answer.text)).message_id
else:
return
if success:
ai.agent.set_last_response_id(bot.id, chat_id, answer_id)
ai.agent_instance.set_last_response_id(bot.id, chat_id, answer_id)

View file

@ -81,6 +81,7 @@ async def about_handler(message: Message, bot: Bot):
await message.answer(MESSAGE_CHAT_NOT_ACTIVE)
return
assert message.from_user is not None
target_user = message.from_user
user = database.DB.create_user_if_not_exists(bot.id, chat_id, target_user.id)
if message.reply_to_message is None:
@ -105,7 +106,7 @@ async def whois_handler(message: Message, bot: Bot):
await message.answer(MESSAGE_CHAT_NOT_ACTIVE)
return
if message.reply_to_message is None:
if message.reply_to_message is None or message.reply_to_message.from_user is None:
await message.answer(MESSAGE_NEED_REPLY)
return
@ -222,10 +223,14 @@ async def check_rules_violation_handler(message: Message, bot: Bot):
answer: ai.Message
success: bool
answer, success = await utils.run_with_progress(
partial(ai.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
partial(ai.agent_instance.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
partial(bot.send_chat_action, chat_id, 'typing'),
interval=4)
answer_id = (await message.answer(answer.text)).message_id
if answer.text is not None:
answer_id = (await message.answer(answer.text)).message_id
else:
return
if success:
ai.agent.set_last_response_id(bot.id, chat_id, answer_id)
ai.agent_instance.set_last_response_id(bot.id, chat_id, answer_id)

View file

@ -10,8 +10,10 @@ import ai
import utils
async def get_user_name_for_ai(user: User):
if user.first_name and user.last_name:
async def get_user_name_for_ai(user: Optional[User]) -> str:
if user is None:
return "Неизвестный пользователь"
elif user.first_name and user.last_name:
return "{} {}".format(user.first_name, user.last_name)
elif user.first_name:
return user.first_name
@ -39,10 +41,10 @@ def get_message_text(message: Message) -> Optional[str]:
async def create_ai_message(message: Message, bot: Bot) -> ai.Message:
ai_message = ai.Message()
ai_message.message_id = message.message_id
ai_message.user_name = await get_user_name_for_ai(message.from_user)
if message.content_type == ContentType.TEXT:
ai_message.user_name = await get_user_name_for_ai(message.from_user) if message.from_user else "Неизвестный"
if message.text is not None:
ai_message.text = message.text
elif message.content_type == ContentType.PHOTO:
elif message.photo is not None:
if message.media_group_id is None:
ai_message.text = message.caption
ai_message.image = await download_photo(message.photo[-1], bot)
@ -62,8 +64,11 @@ def wrap_document(document: bytes, name_prefix: str, extension: str) -> Buffered
return BufferedInputFile(document, name)
def trim_caption(caption: str) -> str:
return caption[:1024]
def trim_caption(caption: Optional[str]) -> Optional[str]:
if caption is not None:
return caption[:1024]
else:
return None
__all__ = [

View file

@ -170,7 +170,7 @@ async def clear_context_handler(message: Message):
await message.answer(MESSAGE_PERMISSION_DENIED)
return
ai.agent.clear_chat_context(bot_id, chat_id)
ai.agent_instance.clear_chat_context(bot_id, chat_id)
await message.answer("Контекст очищен.")

View file

@ -44,6 +44,7 @@ async def any_message_handler(message: Message):
if bot_user is None:
bot_user = (await message.ctx_api.groups.get_by_id()).groups[0]
assert bot_user is not None and bot_user.screen_name is not None
bot_username_mention = '@' + bot_user.screen_name
pattern = r"\[club" + str(bot_user.id) + r"\|(.+)]"
bot_mentioned = False
@ -54,7 +55,7 @@ async def any_message_handler(message: Message):
message_text = re.sub(pattern, r'\1', message_text)
bot_mentioned = True
if len(message_text) > 0 and message.text.find(bot_username_mention) != -1:
if len(message_text) > 0 and message.text.find(bot_username_mention) != -1 and bot_user.name is not None:
# Сообщение содержит @bot_username
message_text = message_text.replace(bot_username_mention, bot_user.name)
bot_mentioned = True
@ -73,7 +74,7 @@ async def any_message_handler(message: Message):
ai_fwd_messages.append(await create_ai_message(fwd_message))
elif message.reply_message and message.reply_message.from_id == -bot_user.id:
# Ответ на сообщение бота
last_id = ai.agent.get_last_assistant_message_id(bot_id, chat_id)
last_id = ai.agent_instance.get_last_assistant_message_id(bot_id, chat_id)
if message.reply_message.message_id != last_id:
# Оригинального сообщения нет в контексте, или оно не последнее -> переслать его
ai_fwd_messages = [await create_ai_message(message.reply_message)]
@ -89,15 +90,15 @@ async def any_message_handler(message: Message):
answer: ai.agent.Message
success: bool
answer, success = await utils.run_with_progress(
partial(ai.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
partial(ai.agent_instance.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
interval=4)
if answer.image is not None:
if answer.image is not None and answer.image_hires is not None:
photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api)
answer_id = (await message.reply(answer.text, attachment=photo)).conversation_message_id
else:
answer_id = (await message.reply(answer.text)).conversation_message_id
if success:
ai.agent.set_last_response_id(bot_id, chat_id, answer_id)
ai.agent_instance.set_last_response_id(bot_id, chat_id, answer_id)

View file

@ -39,7 +39,7 @@ async def reset_context_handler(message: Message):
chat_id = message.peer_id
database.DB.create_chat_if_not_exists(bot_id, chat_id)
ai.agent.clear_chat_context(bot_id, chat_id)
ai.agent_instance.clear_chat_context(bot_id, chat_id)
await message.answer("Контекст очищен.")
@ -54,18 +54,29 @@ async def any_message_handler(message: Message):
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
return
ai_fwd_messages: list[ai.Message] = []
if message.reply_message:
last_id = ai.agent_instance.get_last_assistant_message_id(bot_id, chat_id)
if message.reply_message.message_id != last_id:
# Оригинального сообщения нет в контексте, или оно не последнее -> переслать его
fwd_message = await create_ai_message(message.reply_message)
if message.reply_message.from_id == message.from_id:
# Замаскировать реальное имя пользователя
fwd_message.user_name = "Пользователь"
ai_fwd_messages = [fwd_message]
answer: ai.Message
success: bool
answer, success = await utils.run_with_progress(
partial(ai.agent.get_private_chat_reply, bot_id, chat_id, ai_message),
partial(ai.agent_instance.get_private_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
interval=4)
if answer.image is not None:
if answer.image is not None and answer.image_hires is not None:
photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api)
answer_id = (await message.answer(answer.text, attachment=photo)).conversation_message_id
else:
answer_id = (await message.answer(answer.text)).message_id
if success:
ai.agent.set_last_response_id(bot_id, chat_id, answer_id)
ai.agent_instance.set_last_response_id(bot_id, chat_id, answer_id)

View file

@ -267,10 +267,10 @@ async def check_rules_violation_handler(message: Message):
answer: ai.Message
success: bool
answer, success = await utils.run_with_progress(
partial(ai.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
partial(ai.agent_instance.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
interval=4)
answer_id = (await message.answer(answer.text)).message_id
if success:
ai.agent.set_last_response_id(bot_id, chat_id, answer_id)
ai.agent_instance.set_last_response_id(bot_id, chat_id, answer_id)

View file

@ -67,6 +67,8 @@ async def reset_counters(reset_month: bool, api: API):
async def check_birthdays(api: API):
bot_id = get_bot_id(api)
chats = database.DB.get_chats(bot_id)
today = datetime.datetime.today()
for chat in chats:
if chat['active'] == 0:
continue
@ -74,7 +76,8 @@ async def check_birthdays(api: API):
chat_id = chat['chat_id']
# noinspection PyTypeChecker
members = await api.messages.get_conversation_members(peer_id=chat_id, fields=['bdate'])
today = datetime.datetime.today()
if members.profiles is None:
break
for item in members.items:
user_id = item.member_id
@ -148,5 +151,9 @@ async def daily_maintenance_task(api: API):
async def startup_task(api: API):
me = (await api.groups.get_by_id()).groups[0]
groups = (await api.groups.get_by_id()).groups
if groups is None:
print("Не удалось получить информацию о боте")
return
me = groups[0]
print(f"Бот '{me.name}' (id={me.id}) запущен.")

View file

@ -33,13 +33,16 @@ async def get_user_name_for_ai(api: API, user_id: int):
async def download_photo(photos: List[PhotosPhotoSizes]) -> bytes:
max_photo_size = 16*1024*1024
async with aiohttp.ClientSession() as session:
async with (aiohttp.ClientSession() as session):
for size_type in ['w', 'z', 'y', 'x', 'm', 's']:
for photo in photos:
if photo.type != size_type:
if photo.url is None or photo.type != size_type:
continue
async with session.head(photo.url) as response:
if response.status != 200 or response.content_length > max_photo_size:
if response.status != 200:
break
if response.content_length is not None and response.content_length > max_photo_size:
print("Размер изображения превышает установленное ограничение.")
break
async with session.get(photo.url) as response:
if response.status == 200: