300 lines
12 KiB
Python
300 lines
12 KiB
Python
import datetime
|
||
import json
|
||
|
||
from dataclasses import dataclass
|
||
from typing import List, Optional, Tuple, cast, Dict, Any
|
||
|
||
from openai import AsyncOpenAI, omit
|
||
from openai.types.chat import (
|
||
ChatCompletionMessage,
|
||
ChatCompletionMessageParam,
|
||
ChatCompletionUserMessageParam,
|
||
ChatCompletionAssistantMessageParam,
|
||
ChatCompletionSystemMessageParam,
|
||
ChatCompletionToolMessageParam,
|
||
ChatCompletionFunctionToolParam,
|
||
ChatCompletionMessageToolCallUnion,
|
||
ChatCompletionMessageFunctionToolCall
|
||
)
|
||
|
||
import ai.tool
|
||
from database import BasicDatabase
|
||
|
||
from ai.utils import *
|
||
from ai.tools import *
|
||
|
||
OPENAI_DEFAULT_URL = "https://openrouter.ai/api/v1"
|
||
|
||
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
|
||
OPENROUTER_HTTP_REFERER = "https://kiriru.cc"
|
||
OPENROUTER_CATEGORIES = "general-chat,roleplay"
|
||
|
||
GROUP_CHAT_MAX_MESSAGES = 40
|
||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||
MAX_OUTPUT_TOKENS = 500
|
||
|
||
|
||
@dataclass()
|
||
class Message:
|
||
user_name: Optional[str] = None
|
||
text: Optional[str] = None
|
||
image: Optional[bytes] = None
|
||
image_hires: Optional[bytes] = None
|
||
message_id: Optional[int] = None
|
||
|
||
|
||
class ChatContextManager:
|
||
def __init__(self, db: BasicDatabase, bot_id: int, chat_id: int,
|
||
system_prompt: ChatCompletionSystemMessageParam, max_messages: int):
|
||
self.db = db
|
||
self.bot_id = bot_id
|
||
self.chat_id = chat_id
|
||
self.max_messages = max_messages
|
||
|
||
self.context: List[ChatCompletionMessageParam] = [system_prompt]
|
||
for message in self.db.context_get_messages(bot_id, chat_id):
|
||
# noinspection PyTypeChecker
|
||
self.context.append(message)
|
||
|
||
self.pending_messages: List[ChatCompletionMessageParam] = []
|
||
self.pending_messages_ids: List[Optional[int]] = []
|
||
|
||
def add_user_message(self, message: Message):
|
||
self._add_pending_message(_serialize_user_message(message.text, message.image), message.message_id)
|
||
|
||
def add_assistant_message(self, message: ChatCompletionMessage):
|
||
self._add_pending_message(_serialize_assistant_message(message), None)
|
||
|
||
def add_tool_message(self, message: ChatCompletionToolMessageParam):
|
||
self._add_pending_message(message, None)
|
||
|
||
def get_current_context(self):
|
||
return self.context
|
||
|
||
def commit(self):
|
||
for i, message in enumerate(self.pending_messages):
|
||
self.db.context_add_message(self.bot_id, self.chat_id, message["role"], message=dict(message),
|
||
message_id=self.pending_messages_ids[i], max_messages=self.max_messages)
|
||
|
||
def _add_pending_message(self, message: ChatCompletionMessageParam, message_id: Optional[int] = None):
|
||
self.pending_messages.append(message)
|
||
self.pending_messages_ids.append(message_id)
|
||
self.context.append(message)
|
||
|
||
|
||
class AiAgent:
|
||
def __init__(self,
|
||
openai_url: Optional[str], openai_token: str, openai_model: str,
|
||
replicate_token: str, tavily_token: str,
|
||
db: BasicDatabase,
|
||
platform: str):
|
||
self.db = db
|
||
self.openai_model = openai_model
|
||
self.platform = platform
|
||
|
||
self._load_prompts()
|
||
|
||
self.client = AsyncOpenAI(
|
||
base_url=openai_url if openai_url is not None else OPENAI_DEFAULT_URL,
|
||
api_key=openai_token
|
||
)
|
||
|
||
# Создание наборов инструментов
|
||
self.toolsets: list[ai.tool.ToolSet] = []
|
||
self.toolsets.append(ImageGenerationToolSet(replicate_token=replicate_token))
|
||
self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token))
|
||
|
||
# Сбор всех инструментов
|
||
self.tools: list[ai.tool.Tool] = []
|
||
self.tools_descriptions: list[ChatCompletionFunctionToolParam] = []
|
||
for toolset in self.toolsets:
|
||
self.tools.extend(toolset.functions)
|
||
self.tools_descriptions.extend(toolset.get_all_tools_description())
|
||
|
||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
||
return await self._handle_chat_reply(
|
||
bot_id=bot_id,
|
||
chat_id=chat_id,
|
||
message=message,
|
||
forwarded_messages=forwarded_messages,
|
||
is_group_chat=True,
|
||
max_messages=GROUP_CHAT_MAX_MESSAGES
|
||
)
|
||
|
||
async def get_private_chat_reply(self, bot_id: int, chat_id: int,
|
||
message: Message, forwarded_messages: List[Message] = None) \
|
||
-> Tuple[Message, bool]:
|
||
return await self._handle_chat_reply(
|
||
bot_id=bot_id,
|
||
chat_id=chat_id,
|
||
message=message,
|
||
forwarded_messages=forwarded_messages or [],
|
||
is_group_chat=False,
|
||
max_messages=PRIVATE_CHAT_MAX_MESSAGES
|
||
)
|
||
|
||
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
|
||
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
|
||
|
||
def set_last_response_id(self, bot_id: int, chat_id: int, message_id: int):
|
||
self.db.context_set_last_message_id(bot_id, chat_id, message_id)
|
||
|
||
def clear_chat_context(self, bot_id: int, chat_id: int):
|
||
self.db.context_clear(bot_id, chat_id)
|
||
|
||
####################################################################################
|
||
|
||
async def _handle_chat_reply(self, bot_id: int, chat_id: int,
|
||
message: Message, forwarded_messages: List[Message],
|
||
is_group_chat: bool, max_messages: int) -> Tuple[Message, bool]:
|
||
context_manager = ChatContextManager(
|
||
db=self.db, bot_id=bot_id, chat_id=chat_id, max_messages=max_messages,
|
||
system_prompt=self._construct_system_prompt(is_group_chat, bot_id, chat_id)
|
||
)
|
||
|
||
# Добавление нового сообщения пользователя
|
||
if is_group_chat:
|
||
message.text = _add_message_prefix(message.text, message.user_name)
|
||
else:
|
||
message.text = _add_message_prefix(message.text)
|
||
context_manager.add_user_message(message)
|
||
|
||
# Добавление пересланных сообщений
|
||
for fwd_message in forwarded_messages:
|
||
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
|
||
if fwd_message.text is not None:
|
||
message_text += '\n' + fwd_message.text
|
||
fwd_message.text = message_text
|
||
context_manager.add_user_message(fwd_message)
|
||
|
||
# Генерация ответа с поддержкой инструментов
|
||
try:
|
||
response = await self._generate_reply(bot_id, chat_id, context=context_manager.get_current_context(),
|
||
allow_tools=True)
|
||
context_manager.add_assistant_message(response)
|
||
ai_response: Optional[str] = response.content
|
||
|
||
tools_artifacts = {}
|
||
while response.tool_calls is not None and len(response.tool_calls) > 0:
|
||
tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls,
|
||
context_manager=context_manager)
|
||
response = await self._generate_reply(bot_id, chat_id, context=context_manager.get_current_context(),
|
||
allow_tools=True)
|
||
context_manager.add_assistant_message(response)
|
||
ai_response = response.content
|
||
|
||
# Сохранение обновленного контекста в БД
|
||
context_manager.commit()
|
||
|
||
return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
|
||
image_hires=tools_artifacts.get("generated_image_hires")), True
|
||
|
||
except Exception as e:
|
||
if "Rate limit exceeded" in str(e):
|
||
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
|
||
else:
|
||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
|
||
|
||
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) \
|
||
-> ChatCompletionSystemMessageParam:
|
||
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
|
||
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
|
||
|
||
prompt += '\n# Доступные инструменты\n'
|
||
for toolset in self.toolsets:
|
||
prompt += '\n' + toolset.system_prompt
|
||
|
||
prompt += '\n' + '# Дополнительные инструкции\n'
|
||
|
||
bot = self.db.get_bot(bot_id)
|
||
if bot['ai_prompt'] is not None:
|
||
prompt += '\n' + bot['ai_prompt'] + '\n'
|
||
|
||
chat = self.db.create_chat_if_not_exists(bot_id, chat_id)
|
||
if chat['ai_prompt'] is not None:
|
||
prompt += '\n' + chat['ai_prompt']
|
||
|
||
return {"role": "system", "content": prompt}
|
||
|
||
async def _generate_reply(self, bot_id: int, chat_id: int,
|
||
context: List[ChatCompletionMessageParam], allow_tools: bool = False) \
|
||
-> ChatCompletionMessage:
|
||
|
||
response = await self.client.chat.completions.create(
|
||
model=self.openai_model,
|
||
messages=context,
|
||
tools=self.tools_descriptions if allow_tools else omit,
|
||
tool_choice="auto" if allow_tools else omit,
|
||
max_tokens=MAX_OUTPUT_TOKENS,
|
||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
||
extra_headers={
|
||
"HTTP-Referer": OPENROUTER_HTTP_REFERER,
|
||
"X-OpenRouter-Title": OPENROUTER_X_TITLE,
|
||
"X-OpenRouter-Categories": OPENROUTER_CATEGORIES
|
||
}
|
||
)
|
||
return response.choices[0].message
|
||
|
||
async def _process_tool_calls(self, tool_calls: List[ChatCompletionMessageToolCallUnion],
|
||
context_manager: ChatContextManager) -> Dict[str, Any]:
|
||
artifacts = {}
|
||
if tool_calls is None:
|
||
return artifacts
|
||
|
||
tools_map = {tool.name: tool for tool in self.tools}
|
||
|
||
for tool_call in tool_calls:
|
||
if tool_call.type != "function":
|
||
continue
|
||
func_call = cast(ChatCompletionMessageFunctionToolCall, tool_call)
|
||
tool_name = func_call.function.name
|
||
tool_args = json.loads(func_call.function.arguments)
|
||
|
||
if tool_name in tools_map:
|
||
tool = tools_map[tool_name]
|
||
tool_result = await tool.execute(tool_args, artifacts)
|
||
context_manager.add_tool_message(
|
||
ChatCompletionToolMessageParam(role="tool", tool_call_id=tool_call.id, content=tool_result)
|
||
)
|
||
|
||
return artifacts
|
||
|
||
# TODO: удалить
|
||
@staticmethod
|
||
def _filter_response(response: ChatCompletionMessage) -> ChatCompletionMessage:
|
||
text = str(response.content)
|
||
text = text.replace("<image>", "")
|
||
response.content = text
|
||
return response
|
||
|
||
def _load_prompts(self):
|
||
with open("ai/prompts/group_chat.md", "r") as f:
|
||
self.system_prompt_group_chat = f.read()
|
||
with open("ai/prompts/private_chat.md", "r") as f:
|
||
self.system_prompt_private_chat = f.read()
|
||
|
||
|
||
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
|
||
current_time = datetime.datetime.now().strftime("%d.%m.%Y %H:%M")
|
||
prefix = f"[{current_time}, {username}]" if username is not None else f"[{current_time}]"
|
||
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
|
||
|
||
|
||
def _serialize_user_message(text: Optional[str], image: Optional[bytes]) -> ChatCompletionUserMessageParam:
|
||
if image is None:
|
||
if text is not None:
|
||
content = text
|
||
else:
|
||
raise ValueError("Either text or image must be provided")
|
||
else:
|
||
content = []
|
||
if text is not None:
|
||
content.append({"type": "text", "text": text})
|
||
content.append({"type": "image_url", "image_url": {"url": encode_image(image), "detail": "high"}})
|
||
return {"role": "user", "content": content}
|
||
|
||
|
||
def _serialize_assistant_message(message: ChatCompletionMessage) -> ChatCompletionAssistantMessageParam:
|
||
return message.model_dump(exclude_none=True)
|