From 1c58359e44e8477ced0080bea945e54f12564703 Mon Sep 17 00:00:00 2001 From: Kirill Kirilenko Date: Thu, 2 Apr 2026 02:26:00 +0300 Subject: [PATCH] =?UTF-8?q?=D0=A0=D0=B5=D1=84=D0=B0=D0=BA=D1=82=D0=BE?= =?UTF-8?q?=D1=80=D0=B8=D0=BD=D0=B3.=20=D0=92=D1=81=D1=8F=20=D0=BB=D0=BE?= =?UTF-8?q?=D0=B3=D0=B8=D0=BA=D0=B0=20=D0=98=D0=98=20=D0=BF=D0=B5=D1=80?= =?UTF-8?q?=D0=B5=D0=BD=D0=B5=D1=81=D0=B5=D0=BD=D0=B0=20=D0=B2=20=D0=BC?= =?UTF-8?q?=D0=BE=D0=B4=D1=83=D0=BB=D1=8C=20ai.=20=D0=9B=D0=BE=D0=B3=D0=B8?= =?UTF-8?q?=D0=BA=D0=B0=20=D0=B8=D0=BD=D1=81=D1=82=D1=80=D1=83=D0=BC=D0=B5?= =?UTF-8?q?=D0=BD=D1=82=D0=BE=D0=B2=20=D0=B2=D1=8B=D0=B4=D0=B5=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=D0=B0=20=D0=B2=20=D0=BE=D1=82=D0=B4=D0=B5=D0=BB=D1=8C?= =?UTF-8?q?=D0=BD=D1=8B=D0=B5=20=D0=BF=D0=BE=D0=B4=D0=BC=D0=BE=D0=B4=D1=83?= =?UTF-8?q?=D0=BB=D0=B8.=20=D0=98=D1=81=D0=BF=D1=80=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=B5=D0=BD=D1=8B=20=D0=B2=D1=81=D0=B5=20=D0=BF=D1=80=D0=BE?= =?UTF-8?q?=D0=B1=D0=BB=D0=B5=D0=BC=D1=8B,=20=D0=BE=D0=B1=D0=BD=D0=B0?= =?UTF-8?q?=D1=80=D1=83=D0=B6=D0=B5=D0=BD=D0=BD=D1=8B=D0=B5=20PyCharm.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/__init__.py | 18 ++ ai_agent.py => ai/agent.py | 249 ++++-------------- {prompts => ai/prompts}/group_chat.md | 0 {prompts => ai/prompts}/private_chat.md | 0 ai/tool.py | 66 +++++ ai/tools/__init__.py | 7 + ai/tools/image_generation/__init__.py | 18 ++ ai/tools/image_generation/generate_image.py | 79 ++++++ .../image_generation/generate_image_anime.py | 85 ++++++ .../tools/image_generation/prompt.md | 7 - ai/tools/web_search/__init__.py | 14 + ai/tools/web_search/prompt.md | 4 + ai/tools/web_search/tavily_search.py | 56 ++++ ai/utils.py | 39 +++ database.py | 5 +- messages.py | 3 +- prompts/tools.json | 67 ----- tg/__main__.py | 2 +- tg/handlers/admin.py | 4 +- tg/handlers/default.py | 12 +- tg/handlers/private.py | 12 +- tg/handlers/user.py | 16 +- tg/tg_database.py | 10 +- tg/utils.py | 17 +- vk/__main__.py | 2 +- vk/handlers/admin.py | 4 +- vk/handlers/default.py | 12 +- vk/handlers/private.py | 10 +- vk/handlers/user.py | 24 +- vk/utils.py | 17 +- vk/vk_database.py | 10 +- 31 files changed, 534 insertions(+), 335 deletions(-) create mode 100644 ai/__init__.py rename ai_agent.py => ai/agent.py (55%) rename {prompts => ai/prompts}/group_chat.md (100%) rename {prompts => ai/prompts}/private_chat.md (100%) create mode 100644 ai/tool.py create mode 100644 ai/tools/__init__.py create mode 100644 ai/tools/image_generation/__init__.py create mode 100644 ai/tools/image_generation/generate_image.py create mode 100644 ai/tools/image_generation/generate_image_anime.py rename prompts/tools.md => ai/tools/image_generation/prompt.md (85%) create mode 100644 ai/tools/web_search/__init__.py create mode 100644 ai/tools/web_search/prompt.md create mode 100644 ai/tools/web_search/tavily_search.py create mode 100644 ai/utils.py delete mode 100644 prompts/tools.json diff --git a/ai/__init__.py b/ai/__init__.py new file mode 100644 index 0000000..809fef3 --- /dev/null +++ b/ai/__init__.py @@ -0,0 +1,18 @@ +import ai.agent +from database import BasicDatabase + +Message = ai.agent.Message +Agent = ai.agent.AiAgent + +# Глобальный экземпляр агента +agent: ai.agent.AiAgent + + +def create_ai_agent(openrouter_token: str, openrouter_model: str, + fal_token: str, replicate_token: str, tavily_token: str, + db: BasicDatabase, platform: str): + global agent + agent = ai.agent.AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, tavily_token, db, platform) + + +__all__ = ["agent", "Agent", "Message", "create_ai_agent"] diff --git a/ai_agent.py b/ai/agent.py similarity index 55% rename from ai_agent.py rename to ai/agent.py index 7c89e6c..54a057c 100644 --- a/ai_agent.py +++ b/ai/agent.py @@ -1,26 +1,22 @@ -import base64 import datetime import json -from collections.abc import Callable from dataclasses import dataclass -from io import BytesIO -from PIL import Image -from typing import List, Tuple, Optional, Union, Dict, Awaitable +from typing import List, Optional, Tuple, Union from openrouter import OpenRouter, RetryConfig -from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \ - ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict +from openrouter.components import AssistantMessage, AssistantMessageTypedDict, \ + ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict, ToolDefinitionJSONTypedDict from openrouter.errors import ResponseValidationError, OpenRouterError from openrouter.utils import BackoffStrategy -from fal_client import AsyncClient as FalClient -from replicate import Client as ReplicateClient -from tavily import TavilyClient - -from utils import download_file +import ai.tool from database import BasicDatabase +from ai.utils import * +from ai.tools import * + + OPENROUTER_X_TITLE = "TG/VK Chat Bot" OPENROUTER_HTTP_REFERER = "https://ultracoder.org" @@ -28,9 +24,6 @@ GROUP_CHAT_MAX_MESSAGES = 40 PRIVATE_CHAT_MAX_MESSAGES = 40 MAX_OUTPUT_TOKENS = 500 -FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image" -REPLICATE_MODEL = "ultracoderru/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b" - @dataclass() class Message: @@ -60,14 +53,20 @@ class AiAgent: self.client_openrouter = OpenRouter(api_key=openrouter_token, x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER, retry_config=retry_config) - self.client_fal = FalClient(key=fal_token) - self.replicate_client = ReplicateClient(api_token=replicate_token) - self.tavily_client = TavilyClient(api_key=tavily_token) - @dataclass() - class _ToolsArtifacts: - generated_image: Optional[bytes] = None - generated_image_hires: Optional[bytes] = None + # Создание наборов инструментов + self.toolsets: list[ai.tool.ToolSet] = [] + self.toolsets.append( + ImageGenerationToolSet(fal_token=fal_token, replicate_token=replicate_token) + ) + self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token)) + + # Сбор всех инструментов + self.tools: list[ai.tool.Tool] = [] + self.tools_descriptions: list[ToolDefinitionJSONTypedDict] = [] + for toolset in self.toolsets: + self.tools.extend(toolset.functions) + self.tools_descriptions.extend(toolset.get_all_tools_description()) async def get_group_chat_reply(self, bot_id: int, chat_id: int, message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: @@ -87,10 +86,9 @@ class AiAgent: response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) ai_response = response.content - tools_artifacts = AiAgent._ToolsArtifacts() + tools_artifacts = {} if response.tool_calls is not None: - tools_artifacts = await self._process_tool_calls(bot_id, chat_id, - tool_calls=response.tool_calls, context=context) + tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context) ai_response = response2.content @@ -101,11 +99,12 @@ class AiAgent: role="user", text=fwd_message.text, image=fwd_message.image, message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) self.db.context_add_message(bot_id, chat_id, - role="assistant", text=ai_response, image=tools_artifacts.generated_image, + role="assistant", text=ai_response, + image=tools_artifacts.get("generated_image"), message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) - return Message(text=ai_response, image=tools_artifacts.generated_image, - image_hires=tools_artifacts.generated_image_hires), True + return Message(text=ai_response, image=tools_artifacts.get("generated_image"), + image_hires=tools_artifacts.get("generated_image_hires")), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: @@ -125,21 +124,20 @@ class AiAgent: context.append(_serialize_assistant_message(response)) ai_response = response.content - tools_artifacts = AiAgent._ToolsArtifacts() + tools_artifacts = {} if response.tool_calls is not None: - tools_artifacts = await self._process_tool_calls(bot_id, chat_id, - tool_calls=response.tool_calls, context=context) + tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context) ai_response = response2.content self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES) self.db.context_add_message(bot_id, chat_id, role="assistant", - text=ai_response, image=tools_artifacts.generated_image, + text=ai_response, image=tools_artifacts.get("generated_image"), message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) - return Message(text=ai_response, image=tools_artifacts.generated_image, - image_hires=tools_artifacts.generated_image_hires), True + return Message(text=ai_response, image=tools_artifacts.get("generated_image"), + image_hires=tools_artifacts.get("generated_image_hires")), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: @@ -170,7 +168,12 @@ class AiAgent: def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict: prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK') - prompt += '\n' + self.system_prompt_tools + + prompt += '\n# Доступные инструменты\n' + for toolset in self.toolsets: + prompt += '\n' + toolset.system_prompt + + prompt += '\n' + '# Дополнительные инструкции\n' bot = self.db.get_bot(bot_id) if bot['ai_prompt'] is not None: @@ -187,139 +190,37 @@ class AiAgent: response = await self._async_chat_completion_request( model=self.openrouter_model, messages=context, - tools=self.tools_description if allow_tools else None, + tools=self.tools_descriptions if allow_tools else None, tool_choice="auto" if allow_tools else None, max_tokens=MAX_OUTPUT_TOKENS, user=f'{self.platform}_{bot_id}_{chat_id}' ) return self._filter_response(response.choices[0].message) - async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall], - context: List[MessageTypedDict]) -> _ToolsArtifacts: - artifacts = AiAgent._ToolsArtifacts() + async def _process_tool_calls(self, tool_calls: List[ChatMessageToolCall], + context: List[MessageTypedDict]) -> dict: + artifacts = {} if tool_calls is None: return artifacts - functions_map: Dict[str, - Callable[[int, int, Dict, AiAgent._ToolsArtifacts], - Awaitable[List[ChatMessageContentItemTypedDict]]]] = { - "generate_image": self._process_tool_generate_image, - "generate_image_anime": self._process_tool_generate_image_anime, - "tavily_search": self._process_tool_tavily_search - } + tools_map = {tool.name: tool for tool in self.tools} for tool_call in tool_calls: tool_name = tool_call.function.name tool_args = json.loads(tool_call.function.arguments) - if tool_name in functions_map: - tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts) + + if tool_name in tools_map: + tool = tools_map[tool_name] + # Вызов инструмента с передачей artifacts + tool_result = await tool.execute(tool_args, artifacts) context.append({ "role": "tool", "tool_call_id": tool_call.id, "content": tool_result }) + return artifacts - async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \ - -> List[ChatMessageContentItemTypedDict]: - prompt = args.get("prompt", "") - aspect_ratio = args.get("aspect_ratio", None) - - aspect_ratio_size_map = { - "1:1": "square", - "4:3": "landscape_4_3", - "3:4": "portrait_4_3", - "16:9": "landscape_16_9", - "9:16": "portrait_16_9", - "9:20": "portrait_16_9" - } - image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3") - print(f"Генерация изображения {image_size}: {prompt}") - - arguments = { - "prompt": prompt, - "image_size": image_size, - "enable_safety_checker": False - } - - try: - result = await self.client_fal.run(FAL_MODEL, arguments=arguments) - if "images" not in result: - raise RuntimeError("Неожиданный ответ от сервера.") - image_url = result["images"][0]["url"] - artifacts.generated_image_hires = await download_file(image_url) - artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280) - return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", - image=None) - except Exception as e: - print(f"Ошибка генерации изображения: {e}") - return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}") - - async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int, - args: dict, artifacts: _ToolsArtifacts) \ - -> List[ChatMessageContentItemTypedDict]: - prompt = args.get("prompt", "") - negative_prompt = args.get("negative_prompt", "") - aspect_ratio = args.get("aspect_ratio", None) - - aspect_ratio_resolution_map = { - "1:1": (1280, 1280), - "4:3": (1280, 1024), - "3:4": (1024, 1280), - "16:9": (1280, 720), - "9:16": (720, 1280), - "9:20": (720, 1600) - } - width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024)) - print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}") - - arguments = { - "prompt": prompt, - "negative_prompt": negative_prompt, - "add_recommended_tags": False, - "width": width, - "height": height, - "guidance_scale": 4.5, - "num_inference_steps": 20, - "hires_enable": True, - "hires_num_inference_steps": 30, - "disable_safety_checker": True - } - - try: - outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments) - artifacts.generated_image_hires = await outputs[0].aread() - artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280) - return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", - image=None) - except Exception as e: - print(f"Ошибка генерации изображения: {e}") - return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}") - - async def _process_tool_tavily_search(self, _bot_id: int, _chat_id: int, args: dict, - _artifacts: _ToolsArtifacts) -> List[ChatMessageContentItemTypedDict]: - query = args.get("query", "") - print(f"Веб-поиск: {query}") - - try: - results = self.tavily_client.search(query=query, max_results=5) - - if not results or "results" not in results: - return _serialize_message_content(text="Не удалось получить результаты поиска.") - - answer_parts = [] - for i, result in enumerate(results["results"], 1): - title = result.get("title", "Без названия") - url = result.get("url", "") - content = result.get("content", "") - answer_parts.append(f"{i}. {title}\n {url}\n {content}\n") - - answer = "\n".join(answer_parts) - return _serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}") - except Exception as e: - print(f"Ошибка веб-поиска: {e}") - return _serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}") - async def _async_chat_completion_request(self, **kwargs): try: return await self.client_openrouter.chat.send_async(**kwargs) @@ -354,24 +255,10 @@ class AiAgent: return response def _load_prompts(self): - with open("prompts/group_chat.md", "r") as f: + with open("ai/prompts/group_chat.md", "r") as f: self.system_prompt_group_chat = f.read() - with open("prompts/private_chat.md", "r") as f: + with open("ai/prompts/private_chat.md", "r") as f: self.system_prompt_private_chat = f.read() - with open("prompts/tools.md", "r") as f: - self.system_prompt_tools = f.read() - with open("prompts/tools.json", "r") as f: - self.tools_description = json.loads(f.read()) - - -agent: AiAgent - - -def create_ai_agent(openrouter_token: str, openrouter_model: str, - fal_token: str, replicate_token: str, tavily_token: str, - db: BasicDatabase, platform: str): - global agent - agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, tavily_token, db, platform) def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str: @@ -380,22 +267,12 @@ def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> return f"{prefix}: {text}" if text is not None else f"{prefix}:" -def _encode_image(image: bytes) -> str: - encoded_image = base64.b64encode(image).decode('utf-8') - return f"data:image/jpeg;base64,{encoded_image}" - - def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict: - return {"role": role, "content": _serialize_message_content(text, image)} + return {"role": role, "content": serialize_message_content(text, image)} -def _serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> list[dict]: - content = [] - if text is not None: - content.append({"type": "text", "text": text}) - if image is not None: - content.append({"type": "image_url", "detail": "high", "image_url": {"url": _encode_image(image)}}) - return content +def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict: + return _remove_none_recursive(message.model_dump(by_alias=True)) def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]: @@ -413,21 +290,3 @@ def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, an ] else: return data - - -def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict: - return _remove_none_recursive(message.model_dump(by_alias=True)) - - -def _compress_image(image: bytes, max_side: Optional[int] = None) -> bytes: - img = Image.open(BytesIO(image)).convert("RGB") - - if img.width > max_side or img.height > max_side: - scale = min(max_side / img.width, max_side / img.height) - new_width = int(img.width * scale) - new_height = int(img.height * scale) - img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - - output = BytesIO() - img.save(output, format='JPEG', quality=87, optimize=True) - return output.getvalue() diff --git a/prompts/group_chat.md b/ai/prompts/group_chat.md similarity index 100% rename from prompts/group_chat.md rename to ai/prompts/group_chat.md diff --git a/prompts/private_chat.md b/ai/prompts/private_chat.md similarity index 100% rename from prompts/private_chat.md rename to ai/prompts/private_chat.md diff --git a/ai/tool.py b/ai/tool.py new file mode 100644 index 0000000..5d320e3 --- /dev/null +++ b/ai/tool.py @@ -0,0 +1,66 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from openrouter.components import ChatMessageContentItemTypedDict, ToolDefinitionJSONTypedDict + + +class Tool(ABC): + """Интерфейс функции""" + + @property + @abstractmethod + def name(self) -> str: + """Имя функции (snake_case)""" + pass + + @property + @abstractmethod + def description(self) -> str: + """Текстовое описание функции""" + pass + + @property + @abstractmethod + def parameters(self) -> Dict[str, Any]: + """Описание параметров функции""" + pass + + def to_dict(self) -> ToolDefinitionJSONTypedDict: + """JSON-представление инструмента для OpenRouter""" + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.parameters + } + } + + @abstractmethod + async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]: + """Вызов функции. + :param args: Параметры из JSON + :param artifacts: Словарь для хранения артефактов + :return: Содержимое JSON-ответа на вызов функции + """ + pass + + +@dataclass +class ToolSet: + """Набор логически объединенных функций""" + + functions: List[Tool] + """Список функций, входящих в набор""" + + system_prompt: str + """Дополнение к системному запросу, описывающее, как пользоваться функциями""" + + def get_function_by_name(self, name: str) -> Optional[Tool]: + """Поиск инструмента по имени""" + return next((t for t in self.functions if t.name == name), None) + + def get_all_tools_description(self) -> List[ToolDefinitionJSONTypedDict]: + """Получить JSON-описание всех инструментов""" + return [tool.to_dict() for tool in self.functions] diff --git a/ai/tools/__init__.py b/ai/tools/__init__.py new file mode 100644 index 0000000..4793b19 --- /dev/null +++ b/ai/tools/__init__.py @@ -0,0 +1,7 @@ +from ai.tools.image_generation import ImageGenerationToolSet +from ai.tools.web_search import TavilySearchToolSet + +__all__ = [ + "ImageGenerationToolSet", + "TavilySearchToolSet" +] diff --git a/ai/tools/image_generation/__init__.py b/ai/tools/image_generation/__init__.py new file mode 100644 index 0000000..42683be --- /dev/null +++ b/ai/tools/image_generation/__init__.py @@ -0,0 +1,18 @@ +from ai.tool import ToolSet + +from .generate_image import GenerateImageTool +from .generate_image_anime import GenerateImageAnimeTool + + +class ImageGenerationToolSet(ToolSet): + def __init__(self, fal_token: str, replicate_token: str): + functions = [ + GenerateImageTool(fal_token), + GenerateImageAnimeTool(replicate_token) + ] + with open("ai/tools/image_generation/prompt.md", "r") as f: + system_prompt = f.read() + super().__init__(functions=functions, system_prompt=system_prompt) + + +__all__ = ["GenerateImageTool", "GenerateImageAnimeTool", "ImageGenerationToolSet"] diff --git a/ai/tools/image_generation/generate_image.py b/ai/tools/image_generation/generate_image.py new file mode 100644 index 0000000..07a94ae --- /dev/null +++ b/ai/tools/image_generation/generate_image.py @@ -0,0 +1,79 @@ +from fal_client import AsyncClient as FalClient +from openrouter.components import ChatMessageContentItemTypedDict +from typing import Any, Dict, List + +from ai.tool import Tool +from ai.utils import * + +FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image" + + +class GenerateImageTool(Tool): + def __init__(self, fal_token: str): + self._client = FalClient(key=fal_token) + + @property + def name(self) -> str: + return "generate_image" + + @property + def description(self) -> str: + return "Генерация изображения по описанию" + + @property + def parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "Подробное описание сцены на английском языке БЕЗ технических параметров " + "(соотношение сторон, разрешение)" + }, + "aspect_ratio": { + "type": "string", + "enum": ["1:1", "4:3", "3:4", "16:9", "9:16"], + "description": "Соотношение сторон" + } + }, + "required": ["prompt"] + } + + async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]: + prompt = args.get("prompt", "") + aspect_ratio = args.get("aspect_ratio", None) + + aspect_ratio_size_map = { + "1:1": "square", + "4:3": "landscape_4_3", + "3:4": "portrait_4_3", + "16:9": "landscape_16_9", + "9:16": "portrait_16_9", + "9:20": "portrait_16_9" + } + image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3") + print(f"Генерация изображения {image_size}: {prompt}") + + arguments = { + "prompt": prompt, + "image_size": image_size, + "enable_safety_checker": False + } + + try: + result = await self._client.run(FAL_MODEL, arguments=arguments) + if "images" not in result: + raise RuntimeError("Неожиданный ответ от сервера.") + image_url = result["images"][0]["url"] + + from utils import download_file + artifacts["generated_image_hires"] = await download_file(image_url) + artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280) + + return serialize_message_content( + text="Изображение сгенерировано и будет показано пользователю.", + image=None + ) + except Exception as e: + print(f"Ошибка генерации изображения: {e}") + return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}") diff --git a/ai/tools/image_generation/generate_image_anime.py b/ai/tools/image_generation/generate_image_anime.py new file mode 100644 index 0000000..2c29e75 --- /dev/null +++ b/ai/tools/image_generation/generate_image_anime.py @@ -0,0 +1,85 @@ +from openrouter.components import ChatMessageContentItemTypedDict +from replicate import Client as ReplicateClient +from typing import Any, Dict, List + +from ai.tool import Tool +from ai.utils import * + +REPLICATE_MODEL = "ultracoderru/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b" + + +class GenerateImageAnimeTool(Tool): + def __init__(self, replicate_token: str): + self._client = ReplicateClient(api_token=replicate_token) + + @property + def name(self) -> str: + return "generate_image_anime" + + @property + def description(self) -> str: + return "Генерация изображения в стиле аниме по описанию" + + @property + def parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "Положительный запрос" + }, + "negative_prompt": { + "type": "string", + "description": "Отрицательный запрос" + }, + "aspect_ratio": { + "type": "string", + "enum": ["1:1", "4:3", "3:4", "16:9", "9:16", "9:20"], + "description": "Соотношение сторон" + } + }, + "required": ["prompt", "negative_prompt"] + } + + async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]: + prompt = args.get("prompt", "") + negative_prompt = args.get("negative_prompt", "") + aspect_ratio = args.get("aspect_ratio", None) + + aspect_ratio_resolution_map = { + "1:1": (1280, 1280), + "4:3": (1280, 1024), + "3:4": (1024, 1280), + "16:9": (1280, 720), + "9:16": (720, 1280), + "9:20": (720, 1600) + } + width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024)) + print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}") + + arguments = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "add_recommended_tags": False, + "width": width, + "height": height, + "guidance_scale": 4.5, + "num_inference_steps": 20, + "hires_enable": True, + "hires_num_inference_steps": 30, + "disable_safety_checker": True + } + + try: + outputs = await self._client.async_run(REPLICATE_MODEL, input=arguments) + artifacts["generated_image_hires"] = await outputs[0].aread() + artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280) + + return serialize_message_content( + text="Изображение сгенерировано и будет показано пользователю.", + image=None + ) + except Exception as e: + print(f"Ошибка генерации изображения: {e}") + return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}") diff --git a/prompts/tools.md b/ai/tools/image_generation/prompt.md similarity index 85% rename from prompts/tools.md rename to ai/tools/image_generation/prompt.md index 6fadcb2..3580618 100644 --- a/prompts/tools.md +++ b/ai/tools/image_generation/prompt.md @@ -1,5 +1,3 @@ -# Доступные инструменты - ## Генерация изображений Если пользователь просит "нарисовать" или "показать" что-то, сгенерируй изображение путем вызова одной из функций. При вызове функции не нужно добавлять сообщение - оно будет отброшено. @@ -28,8 +26,3 @@ 2. Положительный запрос должен начинаться с `masterpiece, best quality, amazing quality, 4k, very aesthetic, high resolution, ultra-detailed, absurdres, newest, scenery`, а заканчиваться `depth of field, volumetric lighting`. 3. Отрицательный запрос должен заканчиваться `modern, recent, old, oldest, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, long body, lowres, bad anatomy, bad hands, missing fingers, extra digits, fewer digits, cropped, very displeasing, (worst quality, bad quality:1.2), bad anatomy, sketch, jpeg artifacts, signature, watermark, username, signature, simple background, conjoined, bad ai-generated`. 4. Ты можешь добавлять тегам веса, например: `1girl, (long hair:1.2), pink hair`. - -## Веб-поиск -Для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных используй функцию `tavily_search`. -- Вызывай функцию поиска, когда нужна актуальная информация из интернета. -- После получения результатов дай пользователю краткую сводку найденной информации. diff --git a/ai/tools/web_search/__init__.py b/ai/tools/web_search/__init__.py new file mode 100644 index 0000000..26d4e7c --- /dev/null +++ b/ai/tools/web_search/__init__.py @@ -0,0 +1,14 @@ +from ai.tool import ToolSet + +from .tavily_search import TavilySearchTool + + +class TavilySearchToolSet(ToolSet): + def __init__(self, tavily_token: str): + functions = [TavilySearchTool(tavily_token)] + with open("ai/tools/web_search/prompt.md", "r") as f: + system_prompt = f.read() + super().__init__(functions=functions, system_prompt=system_prompt) + + +__all__ = ["TavilySearchTool", "TavilySearchToolSet"] diff --git a/ai/tools/web_search/prompt.md b/ai/tools/web_search/prompt.md new file mode 100644 index 0000000..b651fa2 --- /dev/null +++ b/ai/tools/web_search/prompt.md @@ -0,0 +1,4 @@ +## Веб-поиск +Для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных используй функцию `tavily_search`. +- Вызывай функцию поиска, когда нужна актуальная информация из интернета. +- После получения результатов дай пользователю краткую сводку найденной информации. diff --git a/ai/tools/web_search/tavily_search.py b/ai/tools/web_search/tavily_search.py new file mode 100644 index 0000000..0c890e5 --- /dev/null +++ b/ai/tools/web_search/tavily_search.py @@ -0,0 +1,56 @@ +from tavily import TavilyClient +from typing import Any, Dict, List + +from openrouter.components import ChatMessageContentItemTypedDict + +from ai.tool import Tool +from ai.utils import * + + +class TavilySearchTool(Tool): + def __init__(self, tavily_token: str): + self._client = TavilyClient(api_key=tavily_token) + + @property + def name(self) -> str: + return "tavily_search" + + @property + def description(self) -> str: + return "Поиск информации в интернете" + + @property + def parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Запрос для поиска (на русском или английском языке)" + } + }, + "required": ["query"] + } + + async def execute(self, args: Dict[str, Any], _artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]: + query = args.get("query", "") + print(f"Веб-поиск: {query}") + + try: + results = self._client.search(query=query, max_results=5) + + if not results or "results" not in results: + return serialize_message_content(text="Не удалось получить результаты поиска.") + + answer_parts = [] + for i, result in enumerate(results["results"], 1): + title = result.get("title", "Без названия") + url = result.get("url", "") + content = result.get("content", "") + answer_parts.append(f"{i}. {title}\n {url}\n {content}\n") + + answer = "\n".join(answer_parts) + return serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}") + except Exception as e: + print(f"Ошибка веб-поиска: {e}") + return serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}") diff --git a/ai/utils.py b/ai/utils.py new file mode 100644 index 0000000..6d20932 --- /dev/null +++ b/ai/utils.py @@ -0,0 +1,39 @@ +from base64 import b64encode +from io import BytesIO +from PIL import Image +from typing import Dict, List, Optional + + +def serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> List[Dict]: + content = [] + if text is not None: + content.append({"type": "text", "text": text}) + if image is not None: + content.append({"type": "image_url", "detail": "high", "image_url": {"url": encode_image(image)}}) + return content + + +def encode_image(image: bytes) -> str: + encoded_image = b64encode(image).decode('utf-8') + return f"data:image/jpeg;base64,{encoded_image}" + + +def compress_image(image: bytes, max_side: Optional[int] = None) -> bytes: + img = Image.open(BytesIO(image)).convert("RGB") + + if img.width > max_side or img.height > max_side: + scale = min(max_side / img.width, max_side / img.height) + new_width = int(img.width * scale) + new_height = int(img.height * scale) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + output = BytesIO() + img.save(output, format='JPEG', quality=87, optimize=True) + return output.getvalue() + + +__all__ = [ + "serialize_message_content", + "compress_image", + "encode_image" +] diff --git a/database.py b/database.py index c9211e9..1eaddff 100644 --- a/database.py +++ b/database.py @@ -161,8 +161,9 @@ class BasicDatabase: self.cursor.execute(query, values) def context_set_last_message_id(self, bot_id: int, chat_id: int, message_id: int): - self.cursor.execute("UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL", - message_id, bot_id, chat_id) + self.cursor.execute( + "UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL", + message_id, bot_id, chat_id) def _context_trim(self, bot_id: int, chat_id: int, max_messages: int): current_count = self.context_get_count(bot_id, chat_id) diff --git a/messages.py b/messages.py index b4bf3dc..eaf27bf 100644 --- a/messages.py +++ b/messages.py @@ -1,7 +1,8 @@ MESSAGE_CHAT_NOT_ACTIVE = 'Извините, но я пока не работаю в этом чате.' MESSAGE_PERMISSION_DENIED = 'Извините, но о таком меня может попросить только администратор чата.' MESSAGE_NEED_REPLY = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение.' -MESSAGE_NEED_REPLY_OR_FORWARD = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение или с пересылкой текстовых сообщений.' +MESSAGE_NEED_REPLY_OR_FORWARD = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение ' \ + 'или с пересылкой текстовых сообщений.' MESSAGE_UNSUPPORTED_CONTENT_TYPE = 'Извините, но я понимаю только текст и изображения.' MESSAGE_DEFAULT_RULES = 'Правила не установлены. Просто ведите себя хорошо.' MESSAGE_DEFAULT_CHECK_RULES = 'Правила чата не установлены. Проверка невозможна.' diff --git a/prompts/tools.json b/prompts/tools.json deleted file mode 100644 index d5523fa..0000000 --- a/prompts/tools.json +++ /dev/null @@ -1,67 +0,0 @@ -[ - { - "type": "function", - "function": { - "name": "generate_image", - "description": "Генерация изображения по описанию", - "parameters": { - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "Подробное описание сцены на английском языке БЕЗ технических параметров (соотношение сторон, разрешение)" - }, - "aspect_ratio": { - "type": "string", - "enum": ["1:1", "4:3", "3:4", "16:9", "9:16"], - "description": "Соотношение сторон" - } - }, - "required": ["prompt"] - } - } - }, - { - "type": "function", - "function": { - "name": "generate_image_anime", - "description": "Генерация изображения в стиле аниме по описанию", - "parameters": { - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "Положительный запрос" - }, - "negative_prompt": { - "type": "string", - "description": "Отрицательный запрос" - }, - "aspect_ratio": { - "type": "string", - "enum": ["1:1", "4:3", "3:4", "16:9", "9:16", "9:20"], - "description": "Соотношение сторон" - } - }, - "required": ["prompt", "negative_prompt"] - } - } - }, - { - "type": "function", - "function": { - "name": "tavily_search", - "description": "Веб-поиск по теме запроса. Используй для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Запрос для поиска (на русском или английском языке)" - } - }, - "required": ["query"] - } - } - } -] diff --git a/tg/__main__.py b/tg/__main__.py index ef682bf..0b58ad2 100644 --- a/tg/__main__.py +++ b/tg/__main__.py @@ -4,7 +4,7 @@ import json from aiogram import Bot, Dispatcher -from ai_agent import create_ai_agent +from ai import create_ai_agent import tg.tg_database as database diff --git a/tg/handlers/admin.py b/tg/handlers/admin.py index db799a6..efb69f6 100644 --- a/tg/handlers/admin.py +++ b/tg/handlers/admin.py @@ -2,7 +2,7 @@ from aiogram import Bot, Router, F from aiogram.types import Message from aiogram.utils.formatting import Bold -import ai_agent +import ai import utils from messages import * @@ -106,7 +106,7 @@ async def clear_context_handler(message: Message, bot: Bot): await message.answer(MESSAGE_PERMISSION_DENIED) return - ai_agent.agent.clear_chat_context(bot.id, chat_id) + ai.agent.clear_chat_context(bot.id, chat_id) await message.answer("Контекст очищен.") diff --git a/tg/handlers/default.py b/tg/handlers/default.py index 6fe1165..47d1262 100644 --- a/tg/handlers/default.py +++ b/tg/handlers/default.py @@ -4,7 +4,7 @@ from aiogram import Router, F, Bot from aiogram.types import Message from aiogram.enums.content_type import ContentType -import ai_agent +import ai import utils from messages import * @@ -51,7 +51,7 @@ async def any_message_handler(message: Message, bot: Bot): bot_user = await bot.me() - ai_fwd_messages: list[ai_agent.Message] = [] + ai_fwd_messages: list[ai.Message] = [] try: message_text = get_message_text(message) @@ -64,7 +64,7 @@ async def any_message_handler(message: Message, bot: Bot): ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)] elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id: # Ответ на сообщение бота - last_id = ai_agent.agent.get_last_assistant_message_id(bot.id, chat_id) + last_id = ai.agent.get_last_assistant_message_id(bot.id, chat_id) if message.reply_to_message.message_id != last_id: # Оригинального сообщения нет в контексте, или оно не последнее -> переслать его ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)] @@ -77,10 +77,10 @@ async def any_message_handler(message: Message, bot: Bot): ai_message = await create_ai_message(message, bot) ai_message.text = message_text - answer: ai_agent.Message + answer: ai.agent.Message success: bool answer, success = await utils.run_with_progress( - partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages), + partial(ai.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages), partial(message.bot.send_chat_action, chat_id, 'typing'), interval=4) @@ -91,4 +91,4 @@ async def any_message_handler(message: Message, bot: Bot): else: answer_id = (await message.reply(answer.text)).message_id if success: - ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id) + ai.agent.set_last_response_id(bot.id, chat_id, answer_id) diff --git a/tg/handlers/private.py b/tg/handlers/private.py index c69240c..3efb096 100644 --- a/tg/handlers/private.py +++ b/tg/handlers/private.py @@ -1,11 +1,11 @@ from functools import partial from aiogram import Router, F, Bot -from aiogram.enums import ChatType, ContentType +from aiogram.enums import ChatType from aiogram.filters import Command, CommandObject, CommandStart from aiogram.types import Message -import ai_agent +import ai import utils from messages import * @@ -38,7 +38,7 @@ async def reset_context_handler(message: Message, bot: Bot): chat_id = message.chat.id database.DB.create_chat_if_not_exists(bot.id, chat_id) - ai_agent.agent.clear_chat_context(bot.id, chat_id) + ai.agent.clear_chat_context(bot.id, chat_id) await message.answer("Контекст очищен.") @@ -52,10 +52,10 @@ async def any_message_handler(message: Message, bot: Bot): await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE) return - answer: ai_agent.Message + answer: ai.Message success: bool answer, success = await utils.run_with_progress( - partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, ai_message), + partial(ai.agent.get_private_chat_reply, bot.id, chat_id, ai_message), partial(message.bot.send_chat_action, chat_id, 'typing'), interval=4) @@ -66,4 +66,4 @@ async def any_message_handler(message: Message, bot: Bot): else: answer_id = (await message.answer(answer.text)).message_id if success: - ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id) + ai.agent.set_last_response_id(bot.id, chat_id, answer_id) diff --git a/tg/handlers/user.py b/tg/handlers/user.py index adb9ba2..d92413b 100644 --- a/tg/handlers/user.py +++ b/tg/handlers/user.py @@ -6,7 +6,7 @@ from aiogram.enums import ContentType from aiogram.types import Message from aiogram.utils.formatting import Bold, Italic -import ai_agent +import ai import utils from messages import * @@ -214,18 +214,18 @@ async def check_rules_violation_handler(message: Message, bot: Bot): prompt += chat_rules + '\n\n' prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):' - ai_message = ai_agent.Message(user_name=await get_user_name_for_ai(message.from_user), - text=prompt, message_id=message.message_id) - ai_fwd_messages = [ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user), - text=message.reply_to_message.text)] + ai_message = ai.Message(user_name=await get_user_name_for_ai(message.from_user), + text=prompt, message_id=message.message_id) + ai_fwd_messages = [ai.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user), + text=message.reply_to_message.text)] - answer: ai_agent.Message + answer: ai.Message success: bool answer, success = await utils.run_with_progress( - partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages), + partial(ai.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages), partial(bot.send_chat_action, chat_id, 'typing'), interval=4) answer_id = (await message.answer(answer.text)).message_id if success: - ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id) + ai.agent.set_last_response_id(bot.id, chat_id, answer_id) diff --git a/tg/tg_database.py b/tg/tg_database.py index c060608..7198a32 100644 --- a/tg/tg_database.py +++ b/tg/tg_database.py @@ -39,8 +39,9 @@ class TgDatabase(database.BasicDatabase): warnings TINYINT NOT NULL DEFAULT 0, about VARCHAR(1000), PRIMARY KEY (bot_id, chat_id, user_id), - CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) - """) + CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) + ON UPDATE CASCADE ON DELETE CASCADE + )""") self.cursor.execute(""" CREATE TABLE IF NOT EXISTS contexts ( @@ -52,8 +53,9 @@ class TgDatabase(database.BasicDatabase): text VARCHAR(4000), image MEDIUMBLOB, PRIMARY KEY (id), - CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) - """) + CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) + ON UPDATE CASCADE ON DELETE CASCADE + )""") self.conn.commit() diff --git a/tg/utils.py b/tg/utils.py index 4f0c854..029d03d 100644 --- a/tg/utils.py +++ b/tg/utils.py @@ -6,7 +6,7 @@ from aiogram import Bot from aiogram.enums import ContentType from aiogram.types import User, PhotoSize, Message, BufferedInputFile -import ai_agent +import ai import utils @@ -36,8 +36,8 @@ def get_message_text(message: Message) -> Optional[str]: return None -async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message: - ai_message = ai_agent.Message() +async def create_ai_message(message: Message, bot: Bot) -> ai.Message: + ai_message = ai.Message() ai_message.message_id = message.message_id ai_message.user_name = await get_user_name_for_ai(message.from_user) if message.content_type == ContentType.TEXT: @@ -64,3 +64,14 @@ def wrap_document(document: bytes, name_prefix: str, extension: str) -> Buffered def trim_caption(caption: str) -> str: return caption[:1024] + + +__all__ = [ + "create_ai_message", + "get_message_text", + "get_user_name_for_ai", + "trim_caption", + "wrap_photo", + "wrap_document", + "wrap_document" +] diff --git a/vk/__main__.py b/vk/__main__.py index 889ecf1..b304e72 100644 --- a/vk/__main__.py +++ b/vk/__main__.py @@ -3,7 +3,7 @@ import json from vkbottle.bot import Bot, run_multibot -from ai_agent import create_ai_agent +from ai import create_ai_agent import vk.vk_database as database diff --git a/vk/handlers/admin.py b/vk/handlers/admin.py index 070e3b4..1d5430a 100644 --- a/vk/handlers/admin.py +++ b/vk/handlers/admin.py @@ -3,7 +3,7 @@ from vkbottle.bot import Message from vkbottle.framework.labeler import BotLabeler from vkbottle_types.codegen.objects import MessagesGetConversationMembers -import ai_agent +import ai import utils from messages import * @@ -170,7 +170,7 @@ async def clear_context_handler(message: Message): await message.answer(MESSAGE_PERMISSION_DENIED) return - ai_agent.agent.clear_chat_context(bot_id, chat_id) + ai.agent.clear_chat_context(bot_id, chat_id) await message.answer("Контекст очищен.") diff --git a/vk/handlers/default.py b/vk/handlers/default.py index b9c33c7..9bef79b 100644 --- a/vk/handlers/default.py +++ b/vk/handlers/default.py @@ -6,7 +6,7 @@ from vkbottle.bot import Message from vkbottle.framework.labeler import BotLabeler from vkbottle_types.codegen.objects import GroupsGroup -import ai_agent +import ai import utils from messages import * @@ -59,7 +59,7 @@ async def any_message_handler(message: Message): message_text = message_text.replace(bot_username_mention, bot_user.name) bot_mentioned = True - ai_fwd_messages: list[ai_agent.Message] = [] + ai_fwd_messages: list[ai.agent.Message] = [] try: if bot_mentioned: @@ -73,7 +73,7 @@ async def any_message_handler(message: Message): ai_fwd_messages.append(await create_ai_message(fwd_message)) elif message.reply_message and message.reply_message.from_id == -bot_user.id: # Ответ на сообщение бота - last_id = ai_agent.agent.get_last_assistant_message_id(bot_id, chat_id) + last_id = ai.agent.get_last_assistant_message_id(bot_id, chat_id) if message.reply_message.message_id != last_id: # Оригинального сообщения нет в контексте, или оно не последнее -> переслать его ai_fwd_messages = [await create_ai_message(message.reply_message)] @@ -86,10 +86,10 @@ async def any_message_handler(message: Message): ai_message = await create_ai_message(message) ai_message.text = message_text - answer: ai_agent.Message + answer: ai.agent.Message success: bool answer, success = await utils.run_with_progress( - partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages), + partial(ai.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), interval=4) @@ -100,4 +100,4 @@ async def any_message_handler(message: Message): answer_id = (await message.reply(answer.text)).conversation_message_id if success: - ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id) + ai.agent.set_last_response_id(bot_id, chat_id, answer_id) diff --git a/vk/handlers/private.py b/vk/handlers/private.py index f6ea638..cbecaf6 100644 --- a/vk/handlers/private.py +++ b/vk/handlers/private.py @@ -4,7 +4,7 @@ from vkbottle.bot import Message from vkbottle.dispatch.rules.base import RegexRule from vkbottle.framework.labeler import BotLabeler -import ai_agent +import ai import utils from messages import * @@ -39,7 +39,7 @@ async def reset_context_handler(message: Message): chat_id = message.peer_id database.DB.create_chat_if_not_exists(bot_id, chat_id) - ai_agent.agent.clear_chat_context(bot_id, chat_id) + ai.agent.clear_chat_context(bot_id, chat_id) await message.answer("Контекст очищен.") @@ -54,10 +54,10 @@ async def any_message_handler(message: Message): await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE) return - answer: ai_agent.Message + answer: ai.Message success: bool answer, success = await utils.run_with_progress( - partial(ai_agent.agent.get_private_chat_reply, bot_id, chat_id, ai_message), + partial(ai.agent.get_private_chat_reply, bot_id, chat_id, ai_message), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), interval=4) @@ -68,4 +68,4 @@ async def any_message_handler(message: Message): answer_id = (await message.answer(answer.text)).message_id if success: - ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id) + ai.agent.set_last_response_id(bot_id, chat_id, answer_id) diff --git a/vk/handlers/user.py b/vk/handlers/user.py index 3aacc19..9606728 100644 --- a/vk/handlers/user.py +++ b/vk/handlers/user.py @@ -1,11 +1,11 @@ from functools import partial from typing import List, Any -from vkbottle import bold, italic +from vkbottle import bold, italic, API from vkbottle.bot import Message from vkbottle.framework.labeler import BotLabeler -import ai_agent +import ai import utils from messages import * @@ -246,31 +246,31 @@ async def check_rules_violation_handler(message: Message): prompt += chat_rules + '\n\n' prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):' - ai_message = ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.from_id), - text=prompt, message_id=message.message_id) - ai_fwd_messages: list[ai_agent.Message] = [] + ai_message = ai.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.from_id), + text=prompt, message_id=message.message_id) + ai_fwd_messages: list[ai.Message] = [] if message.reply_message is not None and len(message.reply_message.text) > 0: ai_fwd_messages.append( - ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id), - text=message.reply_message.text)) + ai.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id), + text=message.reply_message.text)) else: for fwd_message in message.fwd_messages: if len(fwd_message.text) > 0: ai_fwd_messages.append( - ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id), - text=fwd_message.text)) + ai.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id), + text=fwd_message.text)) if len(ai_fwd_messages) == 0: await message.answer(MESSAGE_NEED_REPLY_OR_FORWARD) return - answer: ai_agent.Message + answer: ai.Message success: bool answer, success = await utils.run_with_progress( - partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages), + partial(ai.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), interval=4) answer_id = (await message.answer(answer.text)).message_id if success: - ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id) + ai.agent.set_last_response_id(bot_id, chat_id, answer_id) diff --git a/vk/utils.py b/vk/utils.py index 223e15e..4183805 100644 --- a/vk/utils.py +++ b/vk/utils.py @@ -7,7 +7,7 @@ from vkbottle.bot import Message from vkbottle_types.codegen.objects import PhotosPhotoSizes from vkbottle_types.objects import MessagesMessageAttachmentType -import ai_agent +import ai import utils @@ -18,6 +18,7 @@ class MyAPI(API): def get_bot_id(api: API) -> int: + # noinspection PyTypeChecker my_api: MyAPI = api return my_api.bot_id @@ -48,8 +49,8 @@ async def download_photo(photos: List[PhotosPhotoSizes]) -> bytes: raise RuntimeError(f"Failed to download photo. Status code: {response.status}") -async def create_ai_message(message: Message) -> ai_agent.Message: - ai_message = ai_agent.Message() +async def create_ai_message(message: Message) -> ai.Message: + ai_message = ai.Message() ai_message.message_id = message.conversation_message_id ai_message.user_name = await get_user_name_for_ai(message.ctx_api, message.from_id) if len(message.text) > 0: @@ -67,3 +68,13 @@ async def create_ai_message(message: Message) -> ai_agent.Message: async def upload_photo(image: bytes, chat_id: int, api: API) -> str: return await PhotoMessageUploader(api).upload(file_source=image, peer_id=chat_id) + + +__all__ = [ + "MyAPI", + "get_bot_id", + "get_user_name_for_ai", + "download_photo", + "create_ai_message", + "upload_photo" +] diff --git a/vk/vk_database.py b/vk/vk_database.py index c06d59e..47006b8 100644 --- a/vk/vk_database.py +++ b/vk/vk_database.py @@ -42,8 +42,9 @@ class VkDatabase(database.BasicDatabase): happy_birthday TINYINT NOT NULL DEFAULT 1, about VARCHAR(1000), PRIMARY KEY (bot_id, chat_id, user_id), - CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) - """) + CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) + ON UPDATE CASCADE ON DELETE CASCADE + )""") self.cursor.execute(""" CREATE TABLE IF NOT EXISTS contexts ( @@ -55,8 +56,9 @@ class VkDatabase(database.BasicDatabase): text VARCHAR(4000), image MEDIUMBLOB, PRIMARY KEY (id), - CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) - """) + CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) + ON UPDATE CASCADE ON DELETE CASCADE + )""") self.conn.commit()