From ef1e2e8a3e5e4da0ece5ef120313758f1b5a7fce Mon Sep 17 00:00:00 2001 From: Kirill Kirilenko Date: Sun, 14 Jun 2026 19:52:45 +0300 Subject: [PATCH] =?UTF-8?q?=D0=9F=D0=B5=D1=80=D0=B5=D1=85=D0=BE=D0=B4=20?= =?UTF-8?q?=D1=81=20OpenRouter=20API=20=D0=BD=D0=B0=20OpenAI=20API.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AGENTS.md | 6 +- ai/__init__.py | 6 +- ai/agent.py | 260 +++++++++--------- ai/tool.py | 10 +- ai/tools/image_generation/generate_image.py | 23 +- .../image_generation/generate_image_anime.py | 29 +- ai/tools/web_search/tavily_search.py | 28 +- ai/utils.py | 12 +- database.py | 25 +- requirements.txt | 2 +- tg/__main__.py | 2 +- tg/tg_database.py | 3 +- vk/__main__.py | 2 +- vk/vk_database.py | 3 +- 14 files changed, 198 insertions(+), 213 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index baad300..e4be003 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -4,7 +4,7 @@ Это движок чат-бота для мессенджеров VK и Telegram, написанный на Python 3. Движок поддерживает некоторые полезные функции для групповых чатов (правила, приветствие новичков, статистика сообщений, вычисление молчунов и др.). В движок интегрирован модуль ИИ: пользователи могут общаться с ботом в личных сообщениях, а также в групповых чатах, если упомянут его. -Модуль ИИ использует API OpenRouter для генерации ответов, а также API fal.ai и Replicate для генерации изображений. +Модуль ИИ использует OpenAI-совместимый API для генерации ответов, а также API Replicate для генерации изображений. ## Архитектура проекта @@ -48,7 +48,7 @@ vk_chat_bot/ ## Основные технологии - **Асинхронность:** asyncio -- **ИИ:** OpenRouter (Grok 4.1 Fast), fal.ai (Seedream 4.5) для обычных изображений, Replicate (Nova Anime XL) для генерации изображений в стиле аниме. +- **ИИ:** OpenAI-compatible API для вызова LLM, Replicate для генерации изображений. - **Telegram:** aiogram 3.x - **VK:** vkbottle - **СУБД:** MariaDB (через pyodbc) @@ -78,7 +78,7 @@ python -m vk -c vk.json Основной класс, обрабатывающий: - `get_group_chat_reply()` - генерация ответа в групповом чате - `get_private_chat_reply()` - генерация ответа в личном чате -- `_generate_reply()` - вызов LLM через OpenRouter +- `_generate_reply()` - вызов LLM через OpenAI - `_process_tool_calls()` - обработка вызова функций ### Обработчики сообщений и событий мессенджера (handlers) diff --git a/ai/__init__.py b/ai/__init__.py index 828a267..9415ea1 100644 --- a/ai/__init__.py +++ b/ai/__init__.py @@ -1,3 +1,5 @@ +from typing import Optional + import ai.agent from database import BasicDatabase @@ -8,11 +10,11 @@ Agent = ai.agent.AiAgent agent_instance: ai.agent.AiAgent -def create_ai_agent(openrouter_token: str, openrouter_model: str, +def create_ai_agent(openai_url: Optional[str], openai_token: str, openai_model: str, replicate_token: str, tavily_token: str, db: BasicDatabase, platform: str): global agent_instance - agent_instance = ai.agent.AiAgent(openrouter_token, openrouter_model, + agent_instance = ai.agent.AiAgent(openai_url, openai_token, openai_model, replicate_token, tavily_token, db, platform) diff --git a/ai/agent.py b/ai/agent.py index 45535ea..641c8f5 100644 --- a/ai/agent.py +++ b/ai/agent.py @@ -2,14 +2,20 @@ import datetime import json from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, cast, Dict, Any -from openrouter import OpenRouter, RetryConfig -from openrouter.components import ChatAssistantMessage, ChatAssistantMessageTypedDict, \ - ChatToolCall, ChatResult, ChatSystemMessageTypedDict, ChatMessagesTypedDict, ChatAssistantMessageContent - -from openrouter.errors import ResponseValidationError, OpenRouterError -from openrouter.utils import BackoffStrategy +from openai import AsyncOpenAI, omit +from openai.types.chat import ( + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionFunctionToolParam, + ChatCompletionMessageToolCallUnion, + ChatCompletionMessageFunctionToolCall +) import ai.tool from database import BasicDatabase @@ -17,8 +23,11 @@ from database import BasicDatabase from ai.utils import * from ai.tools import * +OPENAI_DEFAULT_URL = "https://openrouter.ai/api/v1" + OPENROUTER_X_TITLE = "TG/VK Chat Bot" -OPENROUTER_HTTP_REFERER = "https://ultracoder.org" +OPENROUTER_HTTP_REFERER = "https://kiriru.cc" +OPENROUTER_CATEGORIES = "general-chat,roleplay" GROUP_CHAT_MAX_MESSAGES = 40 PRIVATE_CHAT_MAX_MESSAGES = 40 @@ -34,37 +43,70 @@ class Message: message_id: Optional[int] = None +class ChatContextManager: + def __init__(self, db: BasicDatabase, bot_id: int, chat_id: int, + system_prompt: ChatCompletionSystemMessageParam, max_messages: int): + self.db = db + self.bot_id = bot_id + self.chat_id = chat_id + self.max_messages = max_messages + + self.context: List[ChatCompletionMessageParam] = [system_prompt] + for message in self.db.context_get_messages(bot_id, chat_id): + # noinspection PyTypeChecker + self.context.append(message) + + self.pending_messages: List[ChatCompletionMessageParam] = [] + self.pending_messages_ids: List[Optional[int]] = [] + + def add_user_message(self, message: Message): + self._add_pending_message(_serialize_user_message(message.text, message.image), message.message_id) + + def add_assistant_message(self, message: ChatCompletionMessage): + self._add_pending_message(_serialize_assistant_message(message), None) + + def add_tool_message(self, message: ChatCompletionToolMessageParam): + self._add_pending_message(message, None) + + def get_current_context(self): + return self.context + + def commit(self): + for i, message in enumerate(self.pending_messages): + self.db.context_add_message(self.bot_id, self.chat_id, message["role"], message=dict(message), + message_id=self.pending_messages_ids[i], max_messages=self.max_messages) + + def _add_pending_message(self, message: ChatCompletionMessageParam, message_id: Optional[int] = None): + self.pending_messages.append(message) + self.pending_messages_ids.append(message_id) + self.context.append(message) + + class AiAgent: def __init__(self, - openrouter_token: str, openrouter_model: str, + openai_url: Optional[str], openai_token: str, openai_model: str, replicate_token: str, tavily_token: str, db: BasicDatabase, platform: str): - retry_config = RetryConfig(strategy="backoff", - backoff=BackoffStrategy( - initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000), - retry_connection_errors=True) self.db = db - self.openrouter_model = openrouter_model + self.openai_model = openai_model self.platform = platform self._load_prompts() - self.client_openrouter = OpenRouter(api_key=openrouter_token, - x_open_router_title=OPENROUTER_X_TITLE, - http_referer=OPENROUTER_HTTP_REFERER, - retry_config=retry_config) + self.client = AsyncOpenAI( + base_url=openai_url if openai_url is not None else OPENAI_DEFAULT_URL, + api_key=openai_token + ) # Создание наборов инструментов self.toolsets: list[ai.tool.ToolSet] = [] - self.toolsets.append( - ImageGenerationToolSet(replicate_token=replicate_token) - ) + self.toolsets.append(ImageGenerationToolSet(replicate_token=replicate_token)) self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token)) # Сбор всех инструментов self.tools: list[ai.tool.Tool] = [] - self.tools_descriptions: list = [] + self.tools_descriptions: list[ChatCompletionFunctionToolParam] = [] for toolset in self.toolsets: self.tools.extend(toolset.functions) self.tools_descriptions.extend(toolset.get_all_tools_description()) @@ -106,48 +148,44 @@ class AiAgent: async def _handle_chat_reply(self, bot_id: int, chat_id: int, message: Message, forwarded_messages: List[Message], is_group_chat: bool, max_messages: int) -> Tuple[Message, bool]: - # 1. Подготовка текста сообщения (префикс) + context_manager = ChatContextManager( + db=self.db, bot_id=bot_id, chat_id=chat_id, max_messages=max_messages, + system_prompt=self._construct_system_prompt(is_group_chat, bot_id, chat_id) + ) + + # Добавление нового сообщения пользователя if is_group_chat: message.text = _add_message_prefix(message.text, message.user_name) else: message.text = _add_message_prefix(message.text) + context_manager.add_user_message(message) - # 2. Сбор контекста из БД - context = self._get_chat_context(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id) - context.append(_serialize_message(role="user", text=message.text, image=message.image)) - - # 3. Обработка пересланных сообщений + # Добавление пересланных сообщений for fwd_message in forwarded_messages: message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name) if fwd_message.text is not None: message_text += '\n' + fwd_message.text fwd_message.text = message_text - context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image)) + context_manager.add_user_message(fwd_message) - # 4. Генерация ответа с поддержкой инструментов + # Генерация ответа с поддержкой инструментов try: - response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) - context.append(_serialize_assistant_message(response)) - ai_response: Optional[ChatAssistantMessageContent] = response.content + response = await self._generate_reply(bot_id, chat_id, context=context_manager.get_current_context(), + allow_tools=True) + context_manager.add_assistant_message(response) + ai_response: Optional[str] = response.content tools_artifacts = {} - while response.tool_calls is not None: - tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context) - response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) - context.append(_serialize_assistant_message(response)) + while response.tool_calls is not None and len(response.tool_calls) > 0: + tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, + context_manager=context_manager) + response = await self._generate_reply(bot_id, chat_id, context=context_manager.get_current_context(), + allow_tools=True) + context_manager.add_assistant_message(response) ai_response = response.content - # 5. Сохранение истории в БД - self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, - message_id=message.message_id, max_messages=max_messages) - for fwd_message in forwarded_messages: - self.db.context_add_message(bot_id, chat_id, - role="user", text=fwd_message.text, image=fwd_message.image, - message_id=fwd_message.message_id, max_messages=max_messages) - self.db.context_add_message(bot_id, chat_id, - role="assistant", text=ai_response, - image=tools_artifacts.get("generated_image"), - message_id=None, max_messages=max_messages) + # Сохранение обновленного контекста в БД + context_manager.commit() return Message(text=ai_response, image=tools_artifacts.get("generated_image"), image_hires=tools_artifacts.get("generated_image_hires")), True @@ -159,15 +197,8 @@ class AiAgent: print(f"Ошибка выполнения запроса к ИИ: {e}") return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False - def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[ChatMessagesTypedDict]: - context: List[ChatMessagesTypedDict] = [ - self._construct_system_prompt(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id) - ] - for message in self.db.context_get_messages(bot_id, chat_id): - context.append(_serialize_message(message["role"], message["text"], message["image"])) - return context - - def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> ChatSystemMessageTypedDict: + def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) \ + -> ChatCompletionSystemMessageParam: prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK') @@ -188,19 +219,26 @@ class AiAgent: return {"role": "system", "content": prompt} async def _generate_reply(self, bot_id: int, chat_id: int, - context: List[ChatMessagesTypedDict], allow_tools: bool = False) -> ChatAssistantMessage: - response = await self._async_chat_completion_request( - model=self.openrouter_model, - messages=context, - tools=self.tools_descriptions if allow_tools else None, - tool_choice="auto" if allow_tools else None, - max_tokens=MAX_OUTPUT_TOKENS, - user=f'{self.platform}_{bot_id}_{chat_id}' - ) - return self._filter_response(response.choices[0].message) + context: List[ChatCompletionMessageParam], allow_tools: bool = False) \ + -> ChatCompletionMessage: - async def _process_tool_calls(self, tool_calls: List[ChatToolCall], - context: List[ChatMessagesTypedDict]) -> dict: + response = await self.client.chat.completions.create( + model=self.openai_model, + messages=context, + tools=self.tools_descriptions if allow_tools else omit, + tool_choice="auto" if allow_tools else omit, + max_tokens=MAX_OUTPUT_TOKENS, + user=f'{self.platform}_{bot_id}_{chat_id}', + extra_headers={ + "HTTP-Referer": OPENROUTER_HTTP_REFERER, + "X-OpenRouter-Title": OPENROUTER_X_TITLE, + "X-OpenRouter-Categories": OPENROUTER_CATEGORIES + } + ) + return response.choices[0].message + + async def _process_tool_calls(self, tool_calls: List[ChatCompletionMessageToolCallUnion], + context_manager: ChatContextManager) -> Dict[str, Any]: artifacts = {} if tool_calls is None: return artifacts @@ -208,50 +246,24 @@ class AiAgent: tools_map = {tool.name: tool for tool in self.tools} for tool_call in tool_calls: - tool_name = tool_call.function.name - tool_args = json.loads(tool_call.function.arguments) + if tool_call.type != "function": + continue + func_call = cast(ChatCompletionMessageFunctionToolCall, tool_call) + tool_name = func_call.function.name + tool_args = json.loads(func_call.function.arguments) if tool_name in tools_map: tool = tools_map[tool_name] - # Вызов инструмента с передачей artifacts tool_result = await tool.execute(tool_args, artifacts) - context.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": tool_result - }) + context_manager.add_tool_message( + ChatCompletionToolMessageParam(role="tool", tool_call_id=tool_call.id, content=tool_result) + ) return artifacts - async def _async_chat_completion_request(self, **kwargs) -> ChatResult: - try: - # noinspection PyTypeChecker - return await self.client_openrouter.chat.send_async(**kwargs) - except ResponseValidationError as e: - # Костыль для OpenRouter SDK: - # https://github.com/OpenRouterTeam/python-sdk/issues/44 - body = json.loads(e.body) - if "error" in body: - try: - raw_response = json.loads(body["error"]["metadata"]["raw"]) - message = str(raw_response["error"]["message"]) - e = RuntimeError(message) - except Exception: - pass - raise e - except OpenRouterError as e: - if e.message == "Provider returned error": - body = json.loads(e.body) - try: - raw_response = json.loads(body["error"]["metadata"]["raw"]) - message = str(raw_response["error"]["message"]) - e = RuntimeError(message) - except Exception: - pass - raise e - + # TODO: удалить @staticmethod - def _filter_response(response: ChatAssistantMessage) -> ChatAssistantMessage: + def _filter_response(response: ChatCompletionMessage) -> ChatCompletionMessage: text = str(response.content) text = text.replace("", "") response.content = text @@ -270,27 +282,19 @@ def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> return f"{prefix}: {text}" if text is not None else f"{prefix}:" -def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict: - return {"role": role, "content": serialize_message_content(text, image)} - - -def _serialize_assistant_message(message: ChatAssistantMessage) -> ChatAssistantMessageTypedDict: - # noinspection PyTypeChecker - return _remove_none_recursive(message.model_dump(by_alias=True)) - - -def _remove_none_recursive(data: Union[Dict, List, Any]) -> Union[Dict, List, Any]: - if isinstance(data, dict): - return { - k: _remove_none_recursive(v) - for k, v in data.items() - if v is not None - } - elif isinstance(data, list): - return [ - _remove_none_recursive(item) - for item in data - if item is not None - ] +def _serialize_user_message(text: Optional[str], image: Optional[bytes]) -> ChatCompletionUserMessageParam: + if image is None: + if text is not None: + content = text + else: + raise ValueError("Either text or image must be provided") else: - return data + content = [] + if text is not None: + content.append({"type": "text", "text": text}) + content.append({"type": "image_url", "image_url": {"url": encode_image(image), "detail": "high"}}) + return {"role": "user", "content": content} + + +def _serialize_assistant_message(message: ChatCompletionMessage) -> ChatCompletionAssistantMessageParam: + return message.model_dump(exclude_none=True) diff --git a/ai/tool.py b/ai/tool.py index 2539f78..fd7e7ae 100644 --- a/ai/tool.py +++ b/ai/tool.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, Optional -from openrouter.components import ChatToolMessageContentTypedDict, ChatFunctionToolFunctionTypedDict +from openai.types.chat import ChatCompletionFunctionToolParam class Tool(ABC): @@ -26,7 +26,7 @@ class Tool(ABC): """Описание параметров функции""" pass - def to_dict(self) -> ChatFunctionToolFunctionTypedDict: + def to_dict(self) -> ChatCompletionFunctionToolParam: """JSON-представление инструмента для OpenRouter""" return { "type": "function", @@ -38,7 +38,7 @@ class Tool(ABC): } @abstractmethod - async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]: + async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> str: """Вызов функции. :param args: Параметры из JSON :param artifacts: Словарь для хранения артефактов @@ -61,6 +61,6 @@ class ToolSet: """Поиск инструмента по имени""" return next((t for t in self.functions if t.name == name), None) - def get_all_tools_description(self) -> List[ChatFunctionToolFunctionTypedDict]: + def get_all_tools_description(self) -> List[ChatCompletionFunctionToolParam]: """Получить JSON-описание всех инструментов""" - return [tool.to_dict() for tool in self.functions] + return [function.to_dict() for function in self.functions] diff --git a/ai/tools/image_generation/generate_image.py b/ai/tools/image_generation/generate_image.py index 9802119..931b35a 100644 --- a/ai/tools/image_generation/generate_image.py +++ b/ai/tools/image_generation/generate_image.py @@ -1,6 +1,5 @@ -from typing import Any, Dict, List +from typing import Any, Dict -from openrouter.components import ChatToolMessageContentTypedDict from replicate import Client as ReplicateClient from ai.tool import Tool @@ -12,15 +11,15 @@ REPLICATE_MODEL = "bytedance/seedream-4.5" class GenerateImageTool(Tool): def __init__(self, replicate_token: str): self._client = ReplicateClient(api_token=replicate_token) - + @property def name(self) -> str: return "generate_image" - + @property def description(self) -> str: return "Генерация изображения по описанию" - + @property def parameters(self) -> Dict[str, Any]: return { @@ -39,12 +38,12 @@ class GenerateImageTool(Tool): }, "required": ["prompt"] } - - async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]: + + async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> str: prompt = args.get("prompt", "") aspect_ratio = args.get("aspect_ratio", "4:3") print(f"Генерация изображения {aspect_ratio}: {prompt}") - + arguments = { "prompt": prompt, "aspect_ratio": aspect_ratio, @@ -55,11 +54,7 @@ class GenerateImageTool(Tool): outputs: Any = await self._client.async_run(REPLICATE_MODEL, input=arguments) artifacts["generated_image_hires"] = await outputs[0].aread() artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280) - - return serialize_message_content( - text="Изображение сгенерировано и будет показано пользователю.", - image=None - ) + return "Изображение сгенерировано и будет показано пользователю." except Exception as e: print(f"Ошибка генерации изображения: {e}") - return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}") + return f"Не удалось сгенерировать изображение: {e}" diff --git a/ai/tools/image_generation/generate_image_anime.py b/ai/tools/image_generation/generate_image_anime.py index 9bb746e..8b72a85 100644 --- a/ai/tools/image_generation/generate_image_anime.py +++ b/ai/tools/image_generation/generate_image_anime.py @@ -1,26 +1,25 @@ -from typing import Any, Dict, List +from typing import Any, Dict -from openrouter.components import ChatToolMessageContentTypedDict from replicate import Client as ReplicateClient from ai.tool import Tool from ai.utils import * -REPLICATE_MODEL = "ultracoderru/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b" +REPLICATE_MODEL = "kirirururu/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b" class GenerateImageAnimeTool(Tool): def __init__(self, replicate_token: str): self._client = ReplicateClient(api_token=replicate_token) - + @property def name(self) -> str: return "generate_image_anime" - + @property def description(self) -> str: return "Генерация изображения в стиле аниме по описанию" - + @property def parameters(self) -> Dict[str, Any]: return { @@ -42,12 +41,12 @@ class GenerateImageAnimeTool(Tool): }, "required": ["prompt", "negative_prompt"] } - - async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]: + + async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> str: prompt = args.get("prompt", "") negative_prompt = args.get("negative_prompt", "") aspect_ratio = args.get("aspect_ratio", "4:3") - + aspect_ratio_resolution_map = { "1:1": (1280, 1280), "4:3": (1280, 1024), @@ -58,7 +57,7 @@ class GenerateImageAnimeTool(Tool): } width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024)) print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}") - + arguments = { "prompt": prompt, "negative_prompt": negative_prompt, @@ -71,16 +70,12 @@ class GenerateImageAnimeTool(Tool): "hires_num_inference_steps": 30, "disable_safety_checker": True } - + try: outputs: Any = await self._client.async_run(REPLICATE_MODEL, input=arguments) artifacts["generated_image_hires"] = await outputs[0].aread() artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280) - - return serialize_message_content( - text="Изображение сгенерировано и будет показано пользователю.", - image=None - ) + return "Изображение сгенерировано и будет показано пользователю." except Exception as e: print(f"Ошибка генерации изображения: {e}") - return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}") + return f"Не удалось сгенерировать изображение: {e}" diff --git a/ai/tools/web_search/tavily_search.py b/ai/tools/web_search/tavily_search.py index 779e9b1..04ecdc6 100644 --- a/ai/tools/web_search/tavily_search.py +++ b/ai/tools/web_search/tavily_search.py @@ -1,24 +1,22 @@ -from typing import Any, Dict, List +from typing import Any, Dict -from openrouter.components import ChatToolMessageContentTypedDict from tavily import TavilyClient from ai.tool import Tool -from ai.utils import * class TavilySearchTool(Tool): def __init__(self, tavily_token: str): self._client = TavilyClient(api_key=tavily_token) - + @property def name(self) -> str: return "tavily_search" - + @property def description(self) -> str: return "Поиск информации в интернете" - + @property def parameters(self) -> Dict[str, Any]: return { @@ -31,26 +29,26 @@ class TavilySearchTool(Tool): }, "required": ["query"] } - - async def execute(self, args: Dict[str, Any], _artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]: + + async def execute(self, args: Dict[str, Any], _artifacts: Dict[str, Any]) -> str: query = args.get("query", "") print(f"Веб-поиск: {query}") - + try: results = self._client.search(query=query, max_results=5) - + if not results or "results" not in results: - return serialize_message_content(text="Не удалось получить результаты поиска.") - + return "Не удалось получить результаты поиска." + answer_parts = [] for i, result in enumerate(results["results"], 1): title = result.get("title", "Без названия") url = result.get("url", "") content = result.get("content", "") answer_parts.append(f"{i}. {title}\n {url}\n {content}\n") - + answer = "\n".join(answer_parts) - return serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}") + return f"По запросу \"{query}\" найдено:\n\n{answer}" except Exception as e: print(f"Ошибка веб-поиска: {e}") - return serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}") + return f"Не удалось выполнить веб-поиск: {e}" diff --git a/ai/utils.py b/ai/utils.py index 700c013..bf3d214 100644 --- a/ai/utils.py +++ b/ai/utils.py @@ -1,16 +1,7 @@ from base64 import b64encode from io import BytesIO from PIL import Image -from typing import Dict, List, Optional - - -def serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> List[Dict]: - content = [] - if text is not None: - content.append({"type": "text", "text": text}) - if image is not None: - content.append({"type": "image_url", "detail": "high", "image_url": {"url": encode_image(image)}}) - return content +from typing import Optional def encode_image(image: bytes) -> str: @@ -33,7 +24,6 @@ def compress_image(image: bytes, max_side: Optional[int] = None) -> bytes: __all__ = [ - "serialize_message_content", "compress_image", "encode_image" ] diff --git a/database.py b/database.py index cb85f1b..30d7698 100644 --- a/database.py +++ b/database.py @@ -1,3 +1,4 @@ +import json from datetime import datetime from typing import Dict, List, Optional, Any, Type @@ -161,11 +162,12 @@ class BasicDatabase: with self.pool.acquire() as conn: with conn.cursor() as cursor: cursor.execute(""" - SELECT role, text, image FROM contexts + SELECT message FROM contexts WHERE bot_id = ? AND chat_id = ? ORDER BY id - """, (bot_id, chat_id)) - return cursor.fetchall() + """, (bot_id, chat_id)) + result = cursor.fetchall() + return [json.loads(_to_val(str, item)) for item in result] def context_get_count(self, bot_id: int, chat_id: int) -> int: with self.pool.acquire() as conn: @@ -185,17 +187,14 @@ class BasicDatabase: return _to_val(int, cursor.fetchone()) def context_add_message(self, bot_id: int, chat_id: int, role: str, - text: Optional[str], image: Optional[bytes], - message_id: Optional[int], max_messages: int): - assert (text or image) - + message: Dict, message_id: Optional[int], max_messages: int): self._context_trim(bot_id, chat_id, max_messages) # Подготовка данных для вставки data = { "bot_id": bot_id, "chat_id": chat_id, "message_id": message_id, "role": role, - "text": text, "image": image + "message": json.dumps(message, ensure_ascii=False) } # Формирование SQL-запроса и параметров вставки @@ -211,9 +210,13 @@ class BasicDatabase: def context_set_last_message_id(self, bot_id: int, chat_id: int, message_id: int): with self.pool.acquire() as conn: with conn.cursor() as cursor: - cursor.execute( - "UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL", - (message_id, bot_id, chat_id)) + cursor.execute(""" + UPDATE contexts + SET message_id = %s + WHERE bot_id = %s AND chat_id = %s AND message_id IS NULL + ORDER BY id DESC + LIMIT 1 + """, (message_id, bot_id, chat_id)) def _context_trim(self, bot_id: int, chat_id: int, max_messages: int): current_count = self.context_get_count(bot_id, chat_id) diff --git a/requirements.txt b/requirements.txt index ab7ac62..c8e4468 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ aiohttp~=3.13.5 vkbottle~=4.8.2 vkbottle-types~=5.199.99.20 mariadb[pool]~=2.0.0rc2 -openrouter==0.9.1 +openai~=2.41.1 replicate~=1.0.7 tavily~=1.1.0 pillow~=12.2.0 diff --git a/tg/__main__.py b/tg/__main__.py index 9652976..ca66ca0 100644 --- a/tg/__main__.py +++ b/tg/__main__.py @@ -24,7 +24,7 @@ async def main() -> None: database.create_database(config['db_hostname'], config['db_user'], config['db_password'], config['db_database']) - create_ai_agent(config['openrouter_token'], config['openrouter_model'], + create_ai_agent(config.get('openai_url', None), config['openai_token'], config['openai_model'], config['replicate_token'], config['tavily_token'], database.DB, 'tg') diff --git a/tg/tg_database.py b/tg/tg_database.py index 45f4f34..9e8f537 100644 --- a/tg/tg_database.py +++ b/tg/tg_database.py @@ -53,8 +53,7 @@ class TgDatabase(database.BasicDatabase): chat_id BIGINT NOT NULL, message_id BIGINT, role VARCHAR(16) NOT NULL, - text VARCHAR(4000), - image MEDIUMBLOB, + message MEDIUMTEXT NOT NULL, PRIMARY KEY (id), CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE diff --git a/vk/__main__.py b/vk/__main__.py index 83720a3..4a895ff 100644 --- a/vk/__main__.py +++ b/vk/__main__.py @@ -24,7 +24,7 @@ if __name__ == '__main__': database.create_database(config['db_hostname'], config['db_user'], config['db_password'], config['db_database']) - create_ai_agent(config['openrouter_token'], config['openrouter_model'], + create_ai_agent(config.get('openai_url', None), config['openai_token'], config['openai_model'], config['replicate_token'], config['tavily_token'], database.DB, 'vk') diff --git a/vk/vk_database.py b/vk/vk_database.py index 88f819c..680e5c7 100644 --- a/vk/vk_database.py +++ b/vk/vk_database.py @@ -56,8 +56,7 @@ class VkDatabase(database.BasicDatabase): chat_id BIGINT NOT NULL, message_id BIGINT, role VARCHAR(16) NOT NULL, - text VARCHAR(4000), - image MEDIUMBLOB, + message MEDIUMTEXT NOT NULL, PRIMARY KEY (id), CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE