Переход с OpenRouter API на OpenAI API.
This commit is contained in:
parent
5b9f5cd1d6
commit
ef1e2e8a3e
14 changed files with 198 additions and 213 deletions
|
|
@ -4,7 +4,7 @@
|
|||
Это движок чат-бота для мессенджеров VK и Telegram, написанный на Python 3.
|
||||
Движок поддерживает некоторые полезные функции для групповых чатов (правила, приветствие новичков, статистика сообщений, вычисление молчунов и др.).
|
||||
В движок интегрирован модуль ИИ: пользователи могут общаться с ботом в личных сообщениях, а также в групповых чатах, если упомянут его.
|
||||
Модуль ИИ использует API OpenRouter для генерации ответов, а также API fal.ai и Replicate для генерации изображений.
|
||||
Модуль ИИ использует OpenAI-совместимый API для генерации ответов, а также API Replicate для генерации изображений.
|
||||
|
||||
## Архитектура проекта
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ vk_chat_bot/
|
|||
## Основные технологии
|
||||
|
||||
- **Асинхронность:** asyncio
|
||||
- **ИИ:** OpenRouter (Grok 4.1 Fast), fal.ai (Seedream 4.5) для обычных изображений, Replicate (Nova Anime XL) для генерации изображений в стиле аниме.
|
||||
- **ИИ:** OpenAI-compatible API для вызова LLM, Replicate для генерации изображений.
|
||||
- **Telegram:** aiogram 3.x
|
||||
- **VK:** vkbottle
|
||||
- **СУБД:** MariaDB (через pyodbc)
|
||||
|
|
@ -78,7 +78,7 @@ python -m vk -c vk.json
|
|||
Основной класс, обрабатывающий:
|
||||
- `get_group_chat_reply()` - генерация ответа в групповом чате
|
||||
- `get_private_chat_reply()` - генерация ответа в личном чате
|
||||
- `_generate_reply()` - вызов LLM через OpenRouter
|
||||
- `_generate_reply()` - вызов LLM через OpenAI
|
||||
- `_process_tool_calls()` - обработка вызова функций
|
||||
|
||||
### Обработчики сообщений и событий мессенджера (handlers)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import ai.agent
|
||||
from database import BasicDatabase
|
||||
|
||||
|
|
@ -8,11 +10,11 @@ Agent = ai.agent.AiAgent
|
|||
agent_instance: ai.agent.AiAgent
|
||||
|
||||
|
||||
def create_ai_agent(openrouter_token: str, openrouter_model: str,
|
||||
def create_ai_agent(openai_url: Optional[str], openai_token: str, openai_model: str,
|
||||
replicate_token: str, tavily_token: str,
|
||||
db: BasicDatabase, platform: str):
|
||||
global agent_instance
|
||||
agent_instance = ai.agent.AiAgent(openrouter_token, openrouter_model,
|
||||
agent_instance = ai.agent.AiAgent(openai_url, openai_token, openai_model,
|
||||
replicate_token, tavily_token,
|
||||
db, platform)
|
||||
|
||||
|
|
|
|||
260
ai/agent.py
260
ai/agent.py
|
|
@ -2,14 +2,20 @@ import datetime
|
|||
import json
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, cast, Dict, Any
|
||||
|
||||
from openrouter import OpenRouter, RetryConfig
|
||||
from openrouter.components import ChatAssistantMessage, ChatAssistantMessageTypedDict, \
|
||||
ChatToolCall, ChatResult, ChatSystemMessageTypedDict, ChatMessagesTypedDict, ChatAssistantMessageContent
|
||||
|
||||
from openrouter.errors import ResponseValidationError, OpenRouterError
|
||||
from openrouter.utils import BackoffStrategy
|
||||
from openai import AsyncOpenAI, omit
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionFunctionToolParam,
|
||||
ChatCompletionMessageToolCallUnion,
|
||||
ChatCompletionMessageFunctionToolCall
|
||||
)
|
||||
|
||||
import ai.tool
|
||||
from database import BasicDatabase
|
||||
|
|
@ -17,8 +23,11 @@ from database import BasicDatabase
|
|||
from ai.utils import *
|
||||
from ai.tools import *
|
||||
|
||||
OPENAI_DEFAULT_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
|
||||
OPENROUTER_HTTP_REFERER = "https://ultracoder.org"
|
||||
OPENROUTER_HTTP_REFERER = "https://kiriru.cc"
|
||||
OPENROUTER_CATEGORIES = "general-chat,roleplay"
|
||||
|
||||
GROUP_CHAT_MAX_MESSAGES = 40
|
||||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||||
|
|
@ -34,37 +43,70 @@ class Message:
|
|||
message_id: Optional[int] = None
|
||||
|
||||
|
||||
class ChatContextManager:
|
||||
def __init__(self, db: BasicDatabase, bot_id: int, chat_id: int,
|
||||
system_prompt: ChatCompletionSystemMessageParam, max_messages: int):
|
||||
self.db = db
|
||||
self.bot_id = bot_id
|
||||
self.chat_id = chat_id
|
||||
self.max_messages = max_messages
|
||||
|
||||
self.context: List[ChatCompletionMessageParam] = [system_prompt]
|
||||
for message in self.db.context_get_messages(bot_id, chat_id):
|
||||
# noinspection PyTypeChecker
|
||||
self.context.append(message)
|
||||
|
||||
self.pending_messages: List[ChatCompletionMessageParam] = []
|
||||
self.pending_messages_ids: List[Optional[int]] = []
|
||||
|
||||
def add_user_message(self, message: Message):
|
||||
self._add_pending_message(_serialize_user_message(message.text, message.image), message.message_id)
|
||||
|
||||
def add_assistant_message(self, message: ChatCompletionMessage):
|
||||
self._add_pending_message(_serialize_assistant_message(message), None)
|
||||
|
||||
def add_tool_message(self, message: ChatCompletionToolMessageParam):
|
||||
self._add_pending_message(message, None)
|
||||
|
||||
def get_current_context(self):
|
||||
return self.context
|
||||
|
||||
def commit(self):
|
||||
for i, message in enumerate(self.pending_messages):
|
||||
self.db.context_add_message(self.bot_id, self.chat_id, message["role"], message=dict(message),
|
||||
message_id=self.pending_messages_ids[i], max_messages=self.max_messages)
|
||||
|
||||
def _add_pending_message(self, message: ChatCompletionMessageParam, message_id: Optional[int] = None):
|
||||
self.pending_messages.append(message)
|
||||
self.pending_messages_ids.append(message_id)
|
||||
self.context.append(message)
|
||||
|
||||
|
||||
class AiAgent:
|
||||
def __init__(self,
|
||||
openrouter_token: str, openrouter_model: str,
|
||||
openai_url: Optional[str], openai_token: str, openai_model: str,
|
||||
replicate_token: str, tavily_token: str,
|
||||
db: BasicDatabase,
|
||||
platform: str):
|
||||
retry_config = RetryConfig(strategy="backoff",
|
||||
backoff=BackoffStrategy(
|
||||
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
||||
retry_connection_errors=True)
|
||||
self.db = db
|
||||
self.openrouter_model = openrouter_model
|
||||
self.openai_model = openai_model
|
||||
self.platform = platform
|
||||
|
||||
self._load_prompts()
|
||||
|
||||
self.client_openrouter = OpenRouter(api_key=openrouter_token,
|
||||
x_open_router_title=OPENROUTER_X_TITLE,
|
||||
http_referer=OPENROUTER_HTTP_REFERER,
|
||||
retry_config=retry_config)
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=openai_url if openai_url is not None else OPENAI_DEFAULT_URL,
|
||||
api_key=openai_token
|
||||
)
|
||||
|
||||
# Создание наборов инструментов
|
||||
self.toolsets: list[ai.tool.ToolSet] = []
|
||||
self.toolsets.append(
|
||||
ImageGenerationToolSet(replicate_token=replicate_token)
|
||||
)
|
||||
self.toolsets.append(ImageGenerationToolSet(replicate_token=replicate_token))
|
||||
self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token))
|
||||
|
||||
# Сбор всех инструментов
|
||||
self.tools: list[ai.tool.Tool] = []
|
||||
self.tools_descriptions: list = []
|
||||
self.tools_descriptions: list[ChatCompletionFunctionToolParam] = []
|
||||
for toolset in self.toolsets:
|
||||
self.tools.extend(toolset.functions)
|
||||
self.tools_descriptions.extend(toolset.get_all_tools_description())
|
||||
|
|
@ -106,48 +148,44 @@ class AiAgent:
|
|||
async def _handle_chat_reply(self, bot_id: int, chat_id: int,
|
||||
message: Message, forwarded_messages: List[Message],
|
||||
is_group_chat: bool, max_messages: int) -> Tuple[Message, bool]:
|
||||
# 1. Подготовка текста сообщения (префикс)
|
||||
context_manager = ChatContextManager(
|
||||
db=self.db, bot_id=bot_id, chat_id=chat_id, max_messages=max_messages,
|
||||
system_prompt=self._construct_system_prompt(is_group_chat, bot_id, chat_id)
|
||||
)
|
||||
|
||||
# Добавление нового сообщения пользователя
|
||||
if is_group_chat:
|
||||
message.text = _add_message_prefix(message.text, message.user_name)
|
||||
else:
|
||||
message.text = _add_message_prefix(message.text)
|
||||
context_manager.add_user_message(message)
|
||||
|
||||
# 2. Сбор контекста из БД
|
||||
context = self._get_chat_context(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id)
|
||||
context.append(_serialize_message(role="user", text=message.text, image=message.image))
|
||||
|
||||
# 3. Обработка пересланных сообщений
|
||||
# Добавление пересланных сообщений
|
||||
for fwd_message in forwarded_messages:
|
||||
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
|
||||
if fwd_message.text is not None:
|
||||
message_text += '\n' + fwd_message.text
|
||||
fwd_message.text = message_text
|
||||
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
|
||||
context_manager.add_user_message(fwd_message)
|
||||
|
||||
# 4. Генерация ответа с поддержкой инструментов
|
||||
# Генерация ответа с поддержкой инструментов
|
||||
try:
|
||||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||||
context.append(_serialize_assistant_message(response))
|
||||
ai_response: Optional[ChatAssistantMessageContent] = response.content
|
||||
response = await self._generate_reply(bot_id, chat_id, context=context_manager.get_current_context(),
|
||||
allow_tools=True)
|
||||
context_manager.add_assistant_message(response)
|
||||
ai_response: Optional[str] = response.content
|
||||
|
||||
tools_artifacts = {}
|
||||
while response.tool_calls is not None:
|
||||
tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context)
|
||||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||||
context.append(_serialize_assistant_message(response))
|
||||
while response.tool_calls is not None and len(response.tool_calls) > 0:
|
||||
tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls,
|
||||
context_manager=context_manager)
|
||||
response = await self._generate_reply(bot_id, chat_id, context=context_manager.get_current_context(),
|
||||
allow_tools=True)
|
||||
context_manager.add_assistant_message(response)
|
||||
ai_response = response.content
|
||||
|
||||
# 5. Сохранение истории в БД
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||
message_id=message.message_id, max_messages=max_messages)
|
||||
for fwd_message in forwarded_messages:
|
||||
self.db.context_add_message(bot_id, chat_id,
|
||||
role="user", text=fwd_message.text, image=fwd_message.image,
|
||||
message_id=fwd_message.message_id, max_messages=max_messages)
|
||||
self.db.context_add_message(bot_id, chat_id,
|
||||
role="assistant", text=ai_response,
|
||||
image=tools_artifacts.get("generated_image"),
|
||||
message_id=None, max_messages=max_messages)
|
||||
# Сохранение обновленного контекста в БД
|
||||
context_manager.commit()
|
||||
|
||||
return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
|
||||
image_hires=tools_artifacts.get("generated_image_hires")), True
|
||||
|
|
@ -159,15 +197,8 @@ class AiAgent:
|
|||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||||
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
|
||||
|
||||
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[ChatMessagesTypedDict]:
|
||||
context: List[ChatMessagesTypedDict] = [
|
||||
self._construct_system_prompt(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id)
|
||||
]
|
||||
for message in self.db.context_get_messages(bot_id, chat_id):
|
||||
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
||||
return context
|
||||
|
||||
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> ChatSystemMessageTypedDict:
|
||||
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) \
|
||||
-> ChatCompletionSystemMessageParam:
|
||||
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
|
||||
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
|
||||
|
||||
|
|
@ -188,19 +219,26 @@ class AiAgent:
|
|||
return {"role": "system", "content": prompt}
|
||||
|
||||
async def _generate_reply(self, bot_id: int, chat_id: int,
|
||||
context: List[ChatMessagesTypedDict], allow_tools: bool = False) -> ChatAssistantMessage:
|
||||
response = await self._async_chat_completion_request(
|
||||
model=self.openrouter_model,
|
||||
messages=context,
|
||||
tools=self.tools_descriptions if allow_tools else None,
|
||||
tool_choice="auto" if allow_tools else None,
|
||||
max_tokens=MAX_OUTPUT_TOKENS,
|
||||
user=f'{self.platform}_{bot_id}_{chat_id}'
|
||||
)
|
||||
return self._filter_response(response.choices[0].message)
|
||||
context: List[ChatCompletionMessageParam], allow_tools: bool = False) \
|
||||
-> ChatCompletionMessage:
|
||||
|
||||
async def _process_tool_calls(self, tool_calls: List[ChatToolCall],
|
||||
context: List[ChatMessagesTypedDict]) -> dict:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.openai_model,
|
||||
messages=context,
|
||||
tools=self.tools_descriptions if allow_tools else omit,
|
||||
tool_choice="auto" if allow_tools else omit,
|
||||
max_tokens=MAX_OUTPUT_TOKENS,
|
||||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
||||
extra_headers={
|
||||
"HTTP-Referer": OPENROUTER_HTTP_REFERER,
|
||||
"X-OpenRouter-Title": OPENROUTER_X_TITLE,
|
||||
"X-OpenRouter-Categories": OPENROUTER_CATEGORIES
|
||||
}
|
||||
)
|
||||
return response.choices[0].message
|
||||
|
||||
async def _process_tool_calls(self, tool_calls: List[ChatCompletionMessageToolCallUnion],
|
||||
context_manager: ChatContextManager) -> Dict[str, Any]:
|
||||
artifacts = {}
|
||||
if tool_calls is None:
|
||||
return artifacts
|
||||
|
|
@ -208,50 +246,24 @@ class AiAgent:
|
|||
tools_map = {tool.name: tool for tool in self.tools}
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
if tool_call.type != "function":
|
||||
continue
|
||||
func_call = cast(ChatCompletionMessageFunctionToolCall, tool_call)
|
||||
tool_name = func_call.function.name
|
||||
tool_args = json.loads(func_call.function.arguments)
|
||||
|
||||
if tool_name in tools_map:
|
||||
tool = tools_map[tool_name]
|
||||
# Вызов инструмента с передачей artifacts
|
||||
tool_result = await tool.execute(tool_args, artifacts)
|
||||
context.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_result
|
||||
})
|
||||
context_manager.add_tool_message(
|
||||
ChatCompletionToolMessageParam(role="tool", tool_call_id=tool_call.id, content=tool_result)
|
||||
)
|
||||
|
||||
return artifacts
|
||||
|
||||
async def _async_chat_completion_request(self, **kwargs) -> ChatResult:
|
||||
try:
|
||||
# noinspection PyTypeChecker
|
||||
return await self.client_openrouter.chat.send_async(**kwargs)
|
||||
except ResponseValidationError as e:
|
||||
# Костыль для OpenRouter SDK:
|
||||
# https://github.com/OpenRouterTeam/python-sdk/issues/44
|
||||
body = json.loads(e.body)
|
||||
if "error" in body:
|
||||
try:
|
||||
raw_response = json.loads(body["error"]["metadata"]["raw"])
|
||||
message = str(raw_response["error"]["message"])
|
||||
e = RuntimeError(message)
|
||||
except Exception:
|
||||
pass
|
||||
raise e
|
||||
except OpenRouterError as e:
|
||||
if e.message == "Provider returned error":
|
||||
body = json.loads(e.body)
|
||||
try:
|
||||
raw_response = json.loads(body["error"]["metadata"]["raw"])
|
||||
message = str(raw_response["error"]["message"])
|
||||
e = RuntimeError(message)
|
||||
except Exception:
|
||||
pass
|
||||
raise e
|
||||
|
||||
# TODO: удалить
|
||||
@staticmethod
|
||||
def _filter_response(response: ChatAssistantMessage) -> ChatAssistantMessage:
|
||||
def _filter_response(response: ChatCompletionMessage) -> ChatCompletionMessage:
|
||||
text = str(response.content)
|
||||
text = text.replace("<image>", "")
|
||||
response.content = text
|
||||
|
|
@ -270,27 +282,19 @@ def _add_message_prefix(text: Optional[str], username: Optional[str] = None) ->
|
|||
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
|
||||
|
||||
|
||||
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
||||
return {"role": role, "content": serialize_message_content(text, image)}
|
||||
|
||||
|
||||
def _serialize_assistant_message(message: ChatAssistantMessage) -> ChatAssistantMessageTypedDict:
|
||||
# noinspection PyTypeChecker
|
||||
return _remove_none_recursive(message.model_dump(by_alias=True))
|
||||
|
||||
|
||||
def _remove_none_recursive(data: Union[Dict, List, Any]) -> Union[Dict, List, Any]:
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
k: _remove_none_recursive(v)
|
||||
for k, v in data.items()
|
||||
if v is not None
|
||||
}
|
||||
elif isinstance(data, list):
|
||||
return [
|
||||
_remove_none_recursive(item)
|
||||
for item in data
|
||||
if item is not None
|
||||
]
|
||||
def _serialize_user_message(text: Optional[str], image: Optional[bytes]) -> ChatCompletionUserMessageParam:
|
||||
if image is None:
|
||||
if text is not None:
|
||||
content = text
|
||||
else:
|
||||
raise ValueError("Either text or image must be provided")
|
||||
else:
|
||||
return data
|
||||
content = []
|
||||
if text is not None:
|
||||
content.append({"type": "text", "text": text})
|
||||
content.append({"type": "image_url", "image_url": {"url": encode_image(image), "detail": "high"}})
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
|
||||
def _serialize_assistant_message(message: ChatCompletionMessage) -> ChatCompletionAssistantMessageParam:
|
||||
return message.model_dump(exclude_none=True)
|
||||
|
|
|
|||
10
ai/tool.py
10
ai/tool.py
|
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openrouter.components import ChatToolMessageContentTypedDict, ChatFunctionToolFunctionTypedDict
|
||||
from openai.types.chat import ChatCompletionFunctionToolParam
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
|
|
@ -26,7 +26,7 @@ class Tool(ABC):
|
|||
"""Описание параметров функции"""
|
||||
pass
|
||||
|
||||
def to_dict(self) -> ChatFunctionToolFunctionTypedDict:
|
||||
def to_dict(self) -> ChatCompletionFunctionToolParam:
|
||||
"""JSON-представление инструмента для OpenRouter"""
|
||||
return {
|
||||
"type": "function",
|
||||
|
|
@ -38,7 +38,7 @@ class Tool(ABC):
|
|||
}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
|
||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> str:
|
||||
"""Вызов функции.
|
||||
:param args: Параметры из JSON
|
||||
:param artifacts: Словарь для хранения артефактов
|
||||
|
|
@ -61,6 +61,6 @@ class ToolSet:
|
|||
"""Поиск инструмента по имени"""
|
||||
return next((t for t in self.functions if t.name == name), None)
|
||||
|
||||
def get_all_tools_description(self) -> List[ChatFunctionToolFunctionTypedDict]:
|
||||
def get_all_tools_description(self) -> List[ChatCompletionFunctionToolParam]:
|
||||
"""Получить JSON-описание всех инструментов"""
|
||||
return [tool.to_dict() for tool in self.functions]
|
||||
return [function.to_dict() for function in self.functions]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
from openrouter.components import ChatToolMessageContentTypedDict
|
||||
from replicate import Client as ReplicateClient
|
||||
|
||||
from ai.tool import Tool
|
||||
|
|
@ -12,15 +11,15 @@ REPLICATE_MODEL = "bytedance/seedream-4.5"
|
|||
class GenerateImageTool(Tool):
|
||||
def __init__(self, replicate_token: str):
|
||||
self._client = ReplicateClient(api_token=replicate_token)
|
||||
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "generate_image"
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Генерация изображения по описанию"
|
||||
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
|
|
@ -39,12 +38,12 @@ class GenerateImageTool(Tool):
|
|||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
|
||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
|
||||
|
||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> str:
|
||||
prompt = args.get("prompt", "")
|
||||
aspect_ratio = args.get("aspect_ratio", "4:3")
|
||||
print(f"Генерация изображения {aspect_ratio}: {prompt}")
|
||||
|
||||
|
||||
arguments = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
|
|
@ -55,11 +54,7 @@ class GenerateImageTool(Tool):
|
|||
outputs: Any = await self._client.async_run(REPLICATE_MODEL, input=arguments)
|
||||
artifacts["generated_image_hires"] = await outputs[0].aread()
|
||||
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)
|
||||
|
||||
return serialize_message_content(
|
||||
text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=None
|
||||
)
|
||||
return "Изображение сгенерировано и будет показано пользователю."
|
||||
except Exception as e:
|
||||
print(f"Ошибка генерации изображения: {e}")
|
||||
return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||
return f"Не удалось сгенерировать изображение: {e}"
|
||||
|
|
|
|||
|
|
@ -1,26 +1,25 @@
|
|||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
from openrouter.components import ChatToolMessageContentTypedDict
|
||||
from replicate import Client as ReplicateClient
|
||||
|
||||
from ai.tool import Tool
|
||||
from ai.utils import *
|
||||
|
||||
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b"
|
||||
REPLICATE_MODEL = "kirirururu/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b"
|
||||
|
||||
|
||||
class GenerateImageAnimeTool(Tool):
|
||||
def __init__(self, replicate_token: str):
|
||||
self._client = ReplicateClient(api_token=replicate_token)
|
||||
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "generate_image_anime"
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Генерация изображения в стиле аниме по описанию"
|
||||
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
|
|
@ -42,12 +41,12 @@ class GenerateImageAnimeTool(Tool):
|
|||
},
|
||||
"required": ["prompt", "negative_prompt"]
|
||||
}
|
||||
|
||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
|
||||
|
||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> str:
|
||||
prompt = args.get("prompt", "")
|
||||
negative_prompt = args.get("negative_prompt", "")
|
||||
aspect_ratio = args.get("aspect_ratio", "4:3")
|
||||
|
||||
|
||||
aspect_ratio_resolution_map = {
|
||||
"1:1": (1280, 1280),
|
||||
"4:3": (1280, 1024),
|
||||
|
|
@ -58,7 +57,7 @@ class GenerateImageAnimeTool(Tool):
|
|||
}
|
||||
width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
|
||||
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
|
||||
|
||||
|
||||
arguments = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
|
|
@ -71,16 +70,12 @@ class GenerateImageAnimeTool(Tool):
|
|||
"hires_num_inference_steps": 30,
|
||||
"disable_safety_checker": True
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
outputs: Any = await self._client.async_run(REPLICATE_MODEL, input=arguments)
|
||||
artifacts["generated_image_hires"] = await outputs[0].aread()
|
||||
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)
|
||||
|
||||
return serialize_message_content(
|
||||
text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=None
|
||||
)
|
||||
return "Изображение сгенерировано и будет показано пользователю."
|
||||
except Exception as e:
|
||||
print(f"Ошибка генерации изображения: {e}")
|
||||
return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||
return f"Не удалось сгенерировать изображение: {e}"
|
||||
|
|
|
|||
|
|
@ -1,24 +1,22 @@
|
|||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
from openrouter.components import ChatToolMessageContentTypedDict
|
||||
from tavily import TavilyClient
|
||||
|
||||
from ai.tool import Tool
|
||||
from ai.utils import *
|
||||
|
||||
|
||||
class TavilySearchTool(Tool):
|
||||
def __init__(self, tavily_token: str):
|
||||
self._client = TavilyClient(api_key=tavily_token)
|
||||
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "tavily_search"
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Поиск информации в интернете"
|
||||
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
|
|
@ -31,26 +29,26 @@ class TavilySearchTool(Tool):
|
|||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
async def execute(self, args: Dict[str, Any], _artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
|
||||
|
||||
async def execute(self, args: Dict[str, Any], _artifacts: Dict[str, Any]) -> str:
|
||||
query = args.get("query", "")
|
||||
print(f"Веб-поиск: {query}")
|
||||
|
||||
|
||||
try:
|
||||
results = self._client.search(query=query, max_results=5)
|
||||
|
||||
|
||||
if not results or "results" not in results:
|
||||
return serialize_message_content(text="Не удалось получить результаты поиска.")
|
||||
|
||||
return "Не удалось получить результаты поиска."
|
||||
|
||||
answer_parts = []
|
||||
for i, result in enumerate(results["results"], 1):
|
||||
title = result.get("title", "Без названия")
|
||||
url = result.get("url", "")
|
||||
content = result.get("content", "")
|
||||
answer_parts.append(f"{i}. {title}\n {url}\n {content}\n")
|
||||
|
||||
|
||||
answer = "\n".join(answer_parts)
|
||||
return serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}")
|
||||
return f"По запросу \"{query}\" найдено:\n\n{answer}"
|
||||
except Exception as e:
|
||||
print(f"Ошибка веб-поиска: {e}")
|
||||
return serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}")
|
||||
return f"Не удалось выполнить веб-поиск: {e}"
|
||||
|
|
|
|||
12
ai/utils.py
12
ai/utils.py
|
|
@ -1,16 +1,7 @@
|
|||
from base64 import b64encode
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
def serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> List[Dict]:
|
||||
content = []
|
||||
if text is not None:
|
||||
content.append({"type": "text", "text": text})
|
||||
if image is not None:
|
||||
content.append({"type": "image_url", "detail": "high", "image_url": {"url": encode_image(image)}})
|
||||
return content
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def encode_image(image: bytes) -> str:
|
||||
|
|
@ -33,7 +24,6 @@ def compress_image(image: bytes, max_side: Optional[int] = None) -> bytes:
|
|||
|
||||
|
||||
__all__ = [
|
||||
"serialize_message_content",
|
||||
"compress_image",
|
||||
"encode_image"
|
||||
]
|
||||
|
|
|
|||
25
database.py
25
database.py
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Type
|
||||
|
||||
|
|
@ -161,11 +162,12 @@ class BasicDatabase:
|
|||
with self.pool.acquire() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT role, text, image FROM contexts
|
||||
SELECT message FROM contexts
|
||||
WHERE bot_id = ? AND chat_id = ?
|
||||
ORDER BY id
|
||||
""", (bot_id, chat_id))
|
||||
return cursor.fetchall()
|
||||
""", (bot_id, chat_id))
|
||||
result = cursor.fetchall()
|
||||
return [json.loads(_to_val(str, item)) for item in result]
|
||||
|
||||
def context_get_count(self, bot_id: int, chat_id: int) -> int:
|
||||
with self.pool.acquire() as conn:
|
||||
|
|
@ -185,17 +187,14 @@ class BasicDatabase:
|
|||
return _to_val(int, cursor.fetchone())
|
||||
|
||||
def context_add_message(self, bot_id: int, chat_id: int, role: str,
|
||||
text: Optional[str], image: Optional[bytes],
|
||||
message_id: Optional[int], max_messages: int):
|
||||
assert (text or image)
|
||||
|
||||
message: Dict, message_id: Optional[int], max_messages: int):
|
||||
self._context_trim(bot_id, chat_id, max_messages)
|
||||
|
||||
# Подготовка данных для вставки
|
||||
data = {
|
||||
"bot_id": bot_id, "chat_id": chat_id,
|
||||
"message_id": message_id, "role": role,
|
||||
"text": text, "image": image
|
||||
"message": json.dumps(message, ensure_ascii=False)
|
||||
}
|
||||
|
||||
# Формирование SQL-запроса и параметров вставки
|
||||
|
|
@ -211,9 +210,13 @@ class BasicDatabase:
|
|||
def context_set_last_message_id(self, bot_id: int, chat_id: int, message_id: int):
|
||||
with self.pool.acquire() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL",
|
||||
(message_id, bot_id, chat_id))
|
||||
cursor.execute("""
|
||||
UPDATE contexts
|
||||
SET message_id = %s
|
||||
WHERE bot_id = %s AND chat_id = %s AND message_id IS NULL
|
||||
ORDER BY id DESC
|
||||
LIMIT 1
|
||||
""", (message_id, bot_id, chat_id))
|
||||
|
||||
def _context_trim(self, bot_id: int, chat_id: int, max_messages: int):
|
||||
current_count = self.context_get_count(bot_id, chat_id)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ aiohttp~=3.13.5
|
|||
vkbottle~=4.8.2
|
||||
vkbottle-types~=5.199.99.20
|
||||
mariadb[pool]~=2.0.0rc2
|
||||
openrouter==0.9.1
|
||||
openai~=2.41.1
|
||||
replicate~=1.0.7
|
||||
tavily~=1.1.0
|
||||
pillow~=12.2.0
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ async def main() -> None:
|
|||
|
||||
database.create_database(config['db_hostname'], config['db_user'], config['db_password'], config['db_database'])
|
||||
|
||||
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
||||
create_ai_agent(config.get('openai_url', None), config['openai_token'], config['openai_model'],
|
||||
config['replicate_token'], config['tavily_token'],
|
||||
database.DB, 'tg')
|
||||
|
||||
|
|
|
|||
|
|
@ -53,8 +53,7 @@ class TgDatabase(database.BasicDatabase):
|
|||
chat_id BIGINT NOT NULL,
|
||||
message_id BIGINT,
|
||||
role VARCHAR(16) NOT NULL,
|
||||
text VARCHAR(4000),
|
||||
image MEDIUMBLOB,
|
||||
message MEDIUMTEXT NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
|
||||
ON UPDATE CASCADE ON DELETE CASCADE
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ if __name__ == '__main__':
|
|||
|
||||
database.create_database(config['db_hostname'], config['db_user'], config['db_password'], config['db_database'])
|
||||
|
||||
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
||||
create_ai_agent(config.get('openai_url', None), config['openai_token'], config['openai_model'],
|
||||
config['replicate_token'], config['tavily_token'],
|
||||
database.DB, 'vk')
|
||||
|
||||
|
|
|
|||
|
|
@ -56,8 +56,7 @@ class VkDatabase(database.BasicDatabase):
|
|||
chat_id BIGINT NOT NULL,
|
||||
message_id BIGINT,
|
||||
role VARCHAR(16) NOT NULL,
|
||||
text VARCHAR(4000),
|
||||
image MEDIUMBLOB,
|
||||
message MEDIUMTEXT NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
|
||||
ON UPDATE CASCADE ON DELETE CASCADE
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue