vk_chat_bot/ai/agent.py
2026-06-15 04:10:04 +03:00

300 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
import json
from dataclasses import dataclass
from typing import List, Optional, Tuple, cast, Dict, Any
from openai import AsyncOpenAI, omit
from openai.types.chat import (
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionFunctionToolParam,
ChatCompletionMessageToolCallUnion,
ChatCompletionMessageFunctionToolCall
)
import ai.tool
from database import BasicDatabase
from ai.utils import *
from ai.tools import *
OPENAI_DEFAULT_URL = "https://openrouter.ai/api/v1"
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
OPENROUTER_HTTP_REFERER = "https://kiriru.cc"
OPENROUTER_CATEGORIES = "general-chat,roleplay"
GROUP_CHAT_MAX_MESSAGES = 40
PRIVATE_CHAT_MAX_MESSAGES = 40
MAX_OUTPUT_TOKENS = 500
@dataclass()
class Message:
user_name: Optional[str] = None
text: Optional[str] = None
image: Optional[bytes] = None
image_hires: Optional[bytes] = None
message_id: Optional[int] = None
class ChatContextManager:
def __init__(self, db: BasicDatabase, bot_id: int, chat_id: int,
system_prompt: ChatCompletionSystemMessageParam, max_messages: int):
self.db = db
self.bot_id = bot_id
self.chat_id = chat_id
self.max_messages = max_messages
self.context: List[ChatCompletionMessageParam] = [system_prompt]
for message in self.db.context_get_messages(bot_id, chat_id):
# noinspection PyTypeChecker
self.context.append(message)
self.pending_messages: List[ChatCompletionMessageParam] = []
self.pending_messages_ids: List[Optional[int]] = []
def add_user_message(self, message: Message):
self._add_pending_message(_serialize_user_message(message.text, message.image), message.message_id)
def add_assistant_message(self, message: ChatCompletionMessage):
self._add_pending_message(_serialize_assistant_message(message), None)
def add_tool_message(self, message: ChatCompletionToolMessageParam):
self._add_pending_message(message, None)
def get_current_context(self):
return self.context
def commit(self):
for i, message in enumerate(self.pending_messages):
self.db.context_add_message(self.bot_id, self.chat_id, message["role"], message=dict(message),
message_id=self.pending_messages_ids[i], max_messages=self.max_messages)
def _add_pending_message(self, message: ChatCompletionMessageParam, message_id: Optional[int] = None):
self.pending_messages.append(message)
self.pending_messages_ids.append(message_id)
self.context.append(message)
class AiAgent:
def __init__(self,
openai_url: Optional[str], openai_token: str, openai_model: str,
replicate_token: str, tavily_token: str,
db: BasicDatabase,
platform: str):
self.db = db
self.openai_model = openai_model
self.platform = platform
self._load_prompts()
self.client = AsyncOpenAI(
base_url=openai_url if openai_url is not None else OPENAI_DEFAULT_URL,
api_key=openai_token
)
# Создание наборов инструментов
self.toolsets: list[ai.tool.ToolSet] = []
self.toolsets.append(ImageGenerationToolSet(replicate_token=replicate_token))
self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token))
# Сбор всех инструментов
self.tools: list[ai.tool.Tool] = []
self.tools_descriptions: list[ChatCompletionFunctionToolParam] = []
for toolset in self.toolsets:
self.tools.extend(toolset.functions)
self.tools_descriptions.extend(toolset.get_all_tools_description())
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
return await self._handle_chat_reply(
bot_id=bot_id,
chat_id=chat_id,
message=message,
forwarded_messages=forwarded_messages,
is_group_chat=True,
max_messages=GROUP_CHAT_MAX_MESSAGES
)
async def get_private_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message] = None) \
-> Tuple[Message, bool]:
return await self._handle_chat_reply(
bot_id=bot_id,
chat_id=chat_id,
message=message,
forwarded_messages=forwarded_messages or [],
is_group_chat=False,
max_messages=PRIVATE_CHAT_MAX_MESSAGES
)
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
def set_last_response_id(self, bot_id: int, chat_id: int, message_id: int):
self.db.context_set_last_message_id(bot_id, chat_id, message_id)
def clear_chat_context(self, bot_id: int, chat_id: int):
self.db.context_clear(bot_id, chat_id)
####################################################################################
async def _handle_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message],
is_group_chat: bool, max_messages: int) -> Tuple[Message, bool]:
context_manager = ChatContextManager(
db=self.db, bot_id=bot_id, chat_id=chat_id, max_messages=max_messages,
system_prompt=self._construct_system_prompt(is_group_chat, bot_id, chat_id)
)
# Добавление нового сообщения пользователя
if is_group_chat:
message.text = _add_message_prefix(message.text, message.user_name)
else:
message.text = _add_message_prefix(message.text)
context_manager.add_user_message(message)
# Добавление пересланных сообщений
for fwd_message in forwarded_messages:
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
if fwd_message.text is not None:
message_text += '\n' + fwd_message.text
fwd_message.text = message_text
context_manager.add_user_message(fwd_message)
# Генерация ответа с поддержкой инструментов
try:
response = await self._generate_reply(bot_id, chat_id, context=context_manager.get_current_context(),
allow_tools=True)
context_manager.add_assistant_message(response)
ai_response: Optional[str] = response.content
tools_artifacts = {}
while response.tool_calls is not None and len(response.tool_calls) > 0:
tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls,
context_manager=context_manager)
response = await self._generate_reply(bot_id, chat_id, context=context_manager.get_current_context(),
allow_tools=True)
context_manager.add_assistant_message(response)
ai_response = response.content
# Сохранение обновленного контекста в БД
context_manager.commit()
return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
image_hires=tools_artifacts.get("generated_image_hires")), True
except Exception as e:
if "Rate limit exceeded" in str(e):
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) \
-> ChatCompletionSystemMessageParam:
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
prompt += '\n# Доступные инструменты\n'
for toolset in self.toolsets:
prompt += '\n' + toolset.system_prompt
prompt += '\n' + '# Дополнительные инструкции\n'
bot = self.db.get_bot(bot_id)
if bot['ai_prompt'] is not None:
prompt += '\n' + bot['ai_prompt'] + '\n'
chat = self.db.create_chat_if_not_exists(bot_id, chat_id)
if chat['ai_prompt'] is not None:
prompt += '\n' + chat['ai_prompt']
return {"role": "system", "content": prompt}
async def _generate_reply(self, bot_id: int, chat_id: int,
context: List[ChatCompletionMessageParam], allow_tools: bool = False) \
-> ChatCompletionMessage:
response = await self.client.chat.completions.create(
model=self.openai_model,
messages=context,
tools=self.tools_descriptions if allow_tools else omit,
tool_choice="auto" if allow_tools else omit,
max_tokens=MAX_OUTPUT_TOKENS,
user=f'{self.platform}_{bot_id}_{chat_id}',
extra_headers={
"HTTP-Referer": OPENROUTER_HTTP_REFERER,
"X-OpenRouter-Title": OPENROUTER_X_TITLE,
"X-OpenRouter-Categories": OPENROUTER_CATEGORIES
}
)
return response.choices[0].message
async def _process_tool_calls(self, tool_calls: List[ChatCompletionMessageToolCallUnion],
context_manager: ChatContextManager) -> Dict[str, Any]:
artifacts = {}
if tool_calls is None:
return artifacts
tools_map = {tool.name: tool for tool in self.tools}
for tool_call in tool_calls:
if tool_call.type != "function":
continue
func_call = cast(ChatCompletionMessageFunctionToolCall, tool_call)
tool_name = func_call.function.name
tool_args = json.loads(func_call.function.arguments)
if tool_name in tools_map:
tool = tools_map[tool_name]
tool_result = await tool.execute(tool_args, artifacts)
context_manager.add_tool_message(
ChatCompletionToolMessageParam(role="tool", tool_call_id=tool_call.id, content=tool_result)
)
return artifacts
# TODO: удалить
@staticmethod
def _filter_response(response: ChatCompletionMessage) -> ChatCompletionMessage:
text = str(response.content)
text = text.replace("<image>", "")
response.content = text
return response
def _load_prompts(self):
with open("ai/prompts/group_chat.md", "r") as f:
self.system_prompt_group_chat = f.read()
with open("ai/prompts/private_chat.md", "r") as f:
self.system_prompt_private_chat = f.read()
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
current_time = datetime.datetime.now().strftime("%d.%m.%Y %H:%M")
prefix = f"[{current_time}, {username}]" if username is not None else f"[{current_time}]"
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
def _serialize_user_message(text: Optional[str], image: Optional[bytes]) -> ChatCompletionUserMessageParam:
if image is None:
if text is not None:
content = text
else:
raise ValueError("Either text or image must be provided")
else:
content = []
if text is not None:
content.append({"type": "text", "text": text})
content.append({"type": "image_url", "image_url": {"url": encode_image(image), "detail": "high"}})
return {"role": "user", "content": content}
def _serialize_assistant_message(message: ChatCompletionMessage) -> ChatCompletionAssistantMessageParam:
return message.model_dump(exclude_none=True)