Рефакторинг.

Вся логика ИИ перенесена в модуль ai.
Логика инструментов выделена в отдельные подмодули.
Исправлены все проблемы, обнаруженные PyCharm.
This commit is contained in:
Kirill Kirilenko 2026-04-02 02:26:00 +03:00
parent 924d728533
commit 1c58359e44
31 changed files with 534 additions and 335 deletions

18
ai/__init__.py Normal file
View file

@ -0,0 +1,18 @@
import ai.agent
from database import BasicDatabase
Message = ai.agent.Message
Agent = ai.agent.AiAgent
# Глобальный экземпляр агента
agent: ai.agent.AiAgent
def create_ai_agent(openrouter_token: str, openrouter_model: str,
fal_token: str, replicate_token: str, tavily_token: str,
db: BasicDatabase, platform: str):
global agent
agent = ai.agent.AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, tavily_token, db, platform)
__all__ = ["agent", "Agent", "Message", "create_ai_agent"]

View file

@ -1,26 +1,22 @@
import base64
import datetime import datetime
import json import json
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from typing import List, Optional, Tuple, Union
from PIL import Image
from typing import List, Tuple, Optional, Union, Dict, Awaitable
from openrouter import OpenRouter, RetryConfig from openrouter import OpenRouter, RetryConfig
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \ from openrouter.components import AssistantMessage, AssistantMessageTypedDict, \
ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict, ToolDefinitionJSONTypedDict
from openrouter.errors import ResponseValidationError, OpenRouterError from openrouter.errors import ResponseValidationError, OpenRouterError
from openrouter.utils import BackoffStrategy from openrouter.utils import BackoffStrategy
from fal_client import AsyncClient as FalClient import ai.tool
from replicate import Client as ReplicateClient
from tavily import TavilyClient
from utils import download_file
from database import BasicDatabase from database import BasicDatabase
from ai.utils import *
from ai.tools import *
OPENROUTER_X_TITLE = "TG/VK Chat Bot" OPENROUTER_X_TITLE = "TG/VK Chat Bot"
OPENROUTER_HTTP_REFERER = "https://ultracoder.org" OPENROUTER_HTTP_REFERER = "https://ultracoder.org"
@ -28,9 +24,6 @@ GROUP_CHAT_MAX_MESSAGES = 40
PRIVATE_CHAT_MAX_MESSAGES = 40 PRIVATE_CHAT_MAX_MESSAGES = 40
MAX_OUTPUT_TOKENS = 500 MAX_OUTPUT_TOKENS = 500
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b"
@dataclass() @dataclass()
class Message: class Message:
@ -60,14 +53,20 @@ class AiAgent:
self.client_openrouter = OpenRouter(api_key=openrouter_token, self.client_openrouter = OpenRouter(api_key=openrouter_token,
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER, x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER,
retry_config=retry_config) retry_config=retry_config)
self.client_fal = FalClient(key=fal_token)
self.replicate_client = ReplicateClient(api_token=replicate_token)
self.tavily_client = TavilyClient(api_key=tavily_token)
@dataclass() # Создание наборов инструментов
class _ToolsArtifacts: self.toolsets: list[ai.tool.ToolSet] = []
generated_image: Optional[bytes] = None self.toolsets.append(
generated_image_hires: Optional[bytes] = None ImageGenerationToolSet(fal_token=fal_token, replicate_token=replicate_token)
)
self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token))
# Сбор всех инструментов
self.tools: list[ai.tool.Tool] = []
self.tools_descriptions: list[ToolDefinitionJSONTypedDict] = []
for toolset in self.toolsets:
self.tools.extend(toolset.functions)
self.tools_descriptions.extend(toolset.get_all_tools_description())
async def get_group_chat_reply(self, bot_id: int, chat_id: int, async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
@ -87,10 +86,9 @@ class AiAgent:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
ai_response = response.content ai_response = response.content
tools_artifacts = AiAgent._ToolsArtifacts() tools_artifacts = {}
if response.tool_calls is not None: if response.tool_calls is not None:
tools_artifacts = await self._process_tool_calls(bot_id, chat_id, tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context)
tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content ai_response = response2.content
@ -101,11 +99,12 @@ class AiAgent:
role="user", text=fwd_message.text, image=fwd_message.image, role="user", text=fwd_message.text, image=fwd_message.image,
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, self.db.context_add_message(bot_id, chat_id,
role="assistant", text=ai_response, image=tools_artifacts.generated_image, role="assistant", text=ai_response,
image=tools_artifacts.get("generated_image"),
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_artifacts.generated_image, return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
image_hires=tools_artifacts.generated_image_hires), True image_hires=tools_artifacts.get("generated_image_hires")), True
except Exception as e: except Exception as e:
if str(e).find("Rate limit exceeded") != -1: if str(e).find("Rate limit exceeded") != -1:
@ -125,21 +124,20 @@ class AiAgent:
context.append(_serialize_assistant_message(response)) context.append(_serialize_assistant_message(response))
ai_response = response.content ai_response = response.content
tools_artifacts = AiAgent._ToolsArtifacts() tools_artifacts = {}
if response.tool_calls is not None: if response.tool_calls is not None:
tools_artifacts = await self._process_tool_calls(bot_id, chat_id, tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context)
tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content ai_response = response2.content
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES) message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, role="assistant", self.db.context_add_message(bot_id, chat_id, role="assistant",
text=ai_response, image=tools_artifacts.generated_image, text=ai_response, image=tools_artifacts.get("generated_image"),
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_artifacts.generated_image, return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
image_hires=tools_artifacts.generated_image_hires), True image_hires=tools_artifacts.get("generated_image_hires")), True
except Exception as e: except Exception as e:
if str(e).find("Rate limit exceeded") != -1: if str(e).find("Rate limit exceeded") != -1:
@ -170,7 +168,12 @@ class AiAgent:
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict: def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict:
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK') prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
prompt += '\n' + self.system_prompt_tools
prompt += '\n# Доступные инструменты\n'
for toolset in self.toolsets:
prompt += '\n' + toolset.system_prompt
prompt += '\n' + '# Дополнительные инструкции\n'
bot = self.db.get_bot(bot_id) bot = self.db.get_bot(bot_id)
if bot['ai_prompt'] is not None: if bot['ai_prompt'] is not None:
@ -187,139 +190,37 @@ class AiAgent:
response = await self._async_chat_completion_request( response = await self._async_chat_completion_request(
model=self.openrouter_model, model=self.openrouter_model,
messages=context, messages=context,
tools=self.tools_description if allow_tools else None, tools=self.tools_descriptions if allow_tools else None,
tool_choice="auto" if allow_tools else None, tool_choice="auto" if allow_tools else None,
max_tokens=MAX_OUTPUT_TOKENS, max_tokens=MAX_OUTPUT_TOKENS,
user=f'{self.platform}_{bot_id}_{chat_id}' user=f'{self.platform}_{bot_id}_{chat_id}'
) )
return self._filter_response(response.choices[0].message) return self._filter_response(response.choices[0].message)
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall], async def _process_tool_calls(self, tool_calls: List[ChatMessageToolCall],
context: List[MessageTypedDict]) -> _ToolsArtifacts: context: List[MessageTypedDict]) -> dict:
artifacts = AiAgent._ToolsArtifacts() artifacts = {}
if tool_calls is None: if tool_calls is None:
return artifacts return artifacts
functions_map: Dict[str, tools_map = {tool.name: tool for tool in self.tools}
Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
"generate_image": self._process_tool_generate_image,
"generate_image_anime": self._process_tool_generate_image_anime,
"tavily_search": self._process_tool_tavily_search
}
for tool_call in tool_calls: for tool_call in tool_calls:
tool_name = tool_call.function.name tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments) tool_args = json.loads(tool_call.function.arguments)
if tool_name in functions_map:
tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts) if tool_name in tools_map:
tool = tools_map[tool_name]
# Вызов инструмента с передачей artifacts
tool_result = await tool.execute(tool_args, artifacts)
context.append({ context.append({
"role": "tool", "role": "tool",
"tool_call_id": tool_call.id, "tool_call_id": tool_call.id,
"content": tool_result "content": tool_result
}) })
return artifacts return artifacts
async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
-> List[ChatMessageContentItemTypedDict]:
prompt = args.get("prompt", "")
aspect_ratio = args.get("aspect_ratio", None)
aspect_ratio_size_map = {
"1:1": "square",
"4:3": "landscape_4_3",
"3:4": "portrait_4_3",
"16:9": "landscape_16_9",
"9:16": "portrait_16_9",
"9:20": "portrait_16_9"
}
image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3")
print(f"Генерация изображения {image_size}: {prompt}")
arguments = {
"prompt": prompt,
"image_size": image_size,
"enable_safety_checker": False
}
try:
result = await self.client_fal.run(FAL_MODEL, arguments=arguments)
if "images" not in result:
raise RuntimeError("Неожиданный ответ от сервера.")
image_url = result["images"][0]["url"]
artifacts.generated_image_hires = await download_file(image_url)
artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
image=None)
except Exception as e:
print(f"Ошибка генерации изображения: {e}")
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int,
args: dict, artifacts: _ToolsArtifacts) \
-> List[ChatMessageContentItemTypedDict]:
prompt = args.get("prompt", "")
negative_prompt = args.get("negative_prompt", "")
aspect_ratio = args.get("aspect_ratio", None)
aspect_ratio_resolution_map = {
"1:1": (1280, 1280),
"4:3": (1280, 1024),
"3:4": (1024, 1280),
"16:9": (1280, 720),
"9:16": (720, 1280),
"9:20": (720, 1600)
}
width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
arguments = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"add_recommended_tags": False,
"width": width,
"height": height,
"guidance_scale": 4.5,
"num_inference_steps": 20,
"hires_enable": True,
"hires_num_inference_steps": 30,
"disable_safety_checker": True
}
try:
outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments)
artifacts.generated_image_hires = await outputs[0].aread()
artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
image=None)
except Exception as e:
print(f"Ошибка генерации изображения: {e}")
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
async def _process_tool_tavily_search(self, _bot_id: int, _chat_id: int, args: dict,
_artifacts: _ToolsArtifacts) -> List[ChatMessageContentItemTypedDict]:
query = args.get("query", "")
print(f"Веб-поиск: {query}")
try:
results = self.tavily_client.search(query=query, max_results=5)
if not results or "results" not in results:
return _serialize_message_content(text="Не удалось получить результаты поиска.")
answer_parts = []
for i, result in enumerate(results["results"], 1):
title = result.get("title", "Без названия")
url = result.get("url", "")
content = result.get("content", "")
answer_parts.append(f"{i}. {title}\n {url}\n {content}\n")
answer = "\n".join(answer_parts)
return _serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}")
except Exception as e:
print(f"Ошибка веб-поиска: {e}")
return _serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}")
async def _async_chat_completion_request(self, **kwargs): async def _async_chat_completion_request(self, **kwargs):
try: try:
return await self.client_openrouter.chat.send_async(**kwargs) return await self.client_openrouter.chat.send_async(**kwargs)
@ -354,24 +255,10 @@ class AiAgent:
return response return response
def _load_prompts(self): def _load_prompts(self):
with open("prompts/group_chat.md", "r") as f: with open("ai/prompts/group_chat.md", "r") as f:
self.system_prompt_group_chat = f.read() self.system_prompt_group_chat = f.read()
with open("prompts/private_chat.md", "r") as f: with open("ai/prompts/private_chat.md", "r") as f:
self.system_prompt_private_chat = f.read() self.system_prompt_private_chat = f.read()
with open("prompts/tools.md", "r") as f:
self.system_prompt_tools = f.read()
with open("prompts/tools.json", "r") as f:
self.tools_description = json.loads(f.read())
agent: AiAgent
def create_ai_agent(openrouter_token: str, openrouter_model: str,
fal_token: str, replicate_token: str, tavily_token: str,
db: BasicDatabase, platform: str):
global agent
agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, tavily_token, db, platform)
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str: def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
@ -380,22 +267,12 @@ def _add_message_prefix(text: Optional[str], username: Optional[str] = None) ->
return f"{prefix}: {text}" if text is not None else f"{prefix}:" return f"{prefix}: {text}" if text is not None else f"{prefix}:"
def _encode_image(image: bytes) -> str:
encoded_image = base64.b64encode(image).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_image}"
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict: def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
return {"role": role, "content": _serialize_message_content(text, image)} return {"role": role, "content": serialize_message_content(text, image)}
def _serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> list[dict]: def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict:
content = [] return _remove_none_recursive(message.model_dump(by_alias=True))
if text is not None:
content.append({"type": "text", "text": text})
if image is not None:
content.append({"type": "image_url", "detail": "high", "image_url": {"url": _encode_image(image)}})
return content
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]: def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:
@ -413,21 +290,3 @@ def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, an
] ]
else: else:
return data return data
def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict:
return _remove_none_recursive(message.model_dump(by_alias=True))
def _compress_image(image: bytes, max_side: Optional[int] = None) -> bytes:
img = Image.open(BytesIO(image)).convert("RGB")
if img.width > max_side or img.height > max_side:
scale = min(max_side / img.width, max_side / img.height)
new_width = int(img.width * scale)
new_height = int(img.height * scale)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
output = BytesIO()
img.save(output, format='JPEG', quality=87, optimize=True)
return output.getvalue()

66
ai/tool.py Normal file
View file

@ -0,0 +1,66 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from openrouter.components import ChatMessageContentItemTypedDict, ToolDefinitionJSONTypedDict
class Tool(ABC):
"""Интерфейс функции"""
@property
@abstractmethod
def name(self) -> str:
"""Имя функции (snake_case)"""
pass
@property
@abstractmethod
def description(self) -> str:
"""Текстовое описание функции"""
pass
@property
@abstractmethod
def parameters(self) -> Dict[str, Any]:
"""Описание параметров функции"""
pass
def to_dict(self) -> ToolDefinitionJSONTypedDict:
"""JSON-представление инструмента для OpenRouter"""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters
}
}
@abstractmethod
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
"""Вызов функции.
:param args: Параметры из JSON
:param artifacts: Словарь для хранения артефактов
:return: Содержимое JSON-ответа на вызов функции
"""
pass
@dataclass
class ToolSet:
"""Набор логически объединенных функций"""
functions: List[Tool]
"""Список функций, входящих в набор"""
system_prompt: str
"""Дополнение к системному запросу, описывающее, как пользоваться функциями"""
def get_function_by_name(self, name: str) -> Optional[Tool]:
"""Поиск инструмента по имени"""
return next((t for t in self.functions if t.name == name), None)
def get_all_tools_description(self) -> List[ToolDefinitionJSONTypedDict]:
"""Получить JSON-описание всех инструментов"""
return [tool.to_dict() for tool in self.functions]

7
ai/tools/__init__.py Normal file
View file

@ -0,0 +1,7 @@
from ai.tools.image_generation import ImageGenerationToolSet
from ai.tools.web_search import TavilySearchToolSet
__all__ = [
"ImageGenerationToolSet",
"TavilySearchToolSet"
]

View file

@ -0,0 +1,18 @@
from ai.tool import ToolSet
from .generate_image import GenerateImageTool
from .generate_image_anime import GenerateImageAnimeTool
class ImageGenerationToolSet(ToolSet):
def __init__(self, fal_token: str, replicate_token: str):
functions = [
GenerateImageTool(fal_token),
GenerateImageAnimeTool(replicate_token)
]
with open("ai/tools/image_generation/prompt.md", "r") as f:
system_prompt = f.read()
super().__init__(functions=functions, system_prompt=system_prompt)
__all__ = ["GenerateImageTool", "GenerateImageAnimeTool", "ImageGenerationToolSet"]

View file

@ -0,0 +1,79 @@
from fal_client import AsyncClient as FalClient
from openrouter.components import ChatMessageContentItemTypedDict
from typing import Any, Dict, List
from ai.tool import Tool
from ai.utils import *
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
class GenerateImageTool(Tool):
def __init__(self, fal_token: str):
self._client = FalClient(key=fal_token)
@property
def name(self) -> str:
return "generate_image"
@property
def description(self) -> str:
return "Генерация изображения по описанию"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Подробное описание сцены на английском языке БЕЗ технических параметров "
"(соотношение сторон, разрешение)"
},
"aspect_ratio": {
"type": "string",
"enum": ["1:1", "4:3", "3:4", "16:9", "9:16"],
"description": "Соотношение сторон"
}
},
"required": ["prompt"]
}
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
prompt = args.get("prompt", "")
aspect_ratio = args.get("aspect_ratio", None)
aspect_ratio_size_map = {
"1:1": "square",
"4:3": "landscape_4_3",
"3:4": "portrait_4_3",
"16:9": "landscape_16_9",
"9:16": "portrait_16_9",
"9:20": "portrait_16_9"
}
image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3")
print(f"Генерация изображения {image_size}: {prompt}")
arguments = {
"prompt": prompt,
"image_size": image_size,
"enable_safety_checker": False
}
try:
result = await self._client.run(FAL_MODEL, arguments=arguments)
if "images" not in result:
raise RuntimeError("Неожиданный ответ от сервера.")
image_url = result["images"][0]["url"]
from utils import download_file
artifacts["generated_image_hires"] = await download_file(image_url)
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)
return serialize_message_content(
text="Изображение сгенерировано и будет показано пользователю.",
image=None
)
except Exception as e:
print(f"Ошибка генерации изображения: {e}")
return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")

View file

@ -0,0 +1,85 @@
from openrouter.components import ChatMessageContentItemTypedDict
from replicate import Client as ReplicateClient
from typing import Any, Dict, List
from ai.tool import Tool
from ai.utils import *
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b"
class GenerateImageAnimeTool(Tool):
def __init__(self, replicate_token: str):
self._client = ReplicateClient(api_token=replicate_token)
@property
def name(self) -> str:
return "generate_image_anime"
@property
def description(self) -> str:
return "Генерация изображения в стиле аниме по описанию"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Положительный запрос"
},
"negative_prompt": {
"type": "string",
"description": "Отрицательный запрос"
},
"aspect_ratio": {
"type": "string",
"enum": ["1:1", "4:3", "3:4", "16:9", "9:16", "9:20"],
"description": "Соотношение сторон"
}
},
"required": ["prompt", "negative_prompt"]
}
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
prompt = args.get("prompt", "")
negative_prompt = args.get("negative_prompt", "")
aspect_ratio = args.get("aspect_ratio", None)
aspect_ratio_resolution_map = {
"1:1": (1280, 1280),
"4:3": (1280, 1024),
"3:4": (1024, 1280),
"16:9": (1280, 720),
"9:16": (720, 1280),
"9:20": (720, 1600)
}
width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
arguments = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"add_recommended_tags": False,
"width": width,
"height": height,
"guidance_scale": 4.5,
"num_inference_steps": 20,
"hires_enable": True,
"hires_num_inference_steps": 30,
"disable_safety_checker": True
}
try:
outputs = await self._client.async_run(REPLICATE_MODEL, input=arguments)
artifacts["generated_image_hires"] = await outputs[0].aread()
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)
return serialize_message_content(
text="Изображение сгенерировано и будет показано пользователю.",
image=None
)
except Exception as e:
print(f"Ошибка генерации изображения: {e}")
return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")

View file

@ -1,5 +1,3 @@
# Доступные инструменты
## Генерация изображений ## Генерация изображений
Если пользователь просит "нарисовать" или "показать" что-то, сгенерируй изображение путем вызова одной из функций. Если пользователь просит "нарисовать" или "показать" что-то, сгенерируй изображение путем вызова одной из функций.
При вызове функции не нужно добавлять сообщение - оно будет отброшено. При вызове функции не нужно добавлять сообщение - оно будет отброшено.
@ -28,8 +26,3 @@
2. Положительный запрос должен начинаться с `masterpiece, best quality, amazing quality, 4k, very aesthetic, high resolution, ultra-detailed, absurdres, newest, scenery`, а заканчиваться `depth of field, volumetric lighting`. 2. Положительный запрос должен начинаться с `masterpiece, best quality, amazing quality, 4k, very aesthetic, high resolution, ultra-detailed, absurdres, newest, scenery`, а заканчиваться `depth of field, volumetric lighting`.
3. Отрицательный запрос должен заканчиваться `modern, recent, old, oldest, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, long body, lowres, bad anatomy, bad hands, missing fingers, extra digits, fewer digits, cropped, very displeasing, (worst quality, bad quality:1.2), bad anatomy, sketch, jpeg artifacts, signature, watermark, username, signature, simple background, conjoined, bad ai-generated`. 3. Отрицательный запрос должен заканчиваться `modern, recent, old, oldest, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, long body, lowres, bad anatomy, bad hands, missing fingers, extra digits, fewer digits, cropped, very displeasing, (worst quality, bad quality:1.2), bad anatomy, sketch, jpeg artifacts, signature, watermark, username, signature, simple background, conjoined, bad ai-generated`.
4. Ты можешь добавлять тегам веса, например: `1girl, (long hair:1.2), pink hair`. 4. Ты можешь добавлять тегам веса, например: `1girl, (long hair:1.2), pink hair`.
## Веб-поиск
Для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных используй функцию `tavily_search`.
- Вызывай функцию поиска, когда нужна актуальная информация из интернета.
- После получения результатов дай пользователю краткую сводку найденной информации.

View file

@ -0,0 +1,14 @@
from ai.tool import ToolSet
from .tavily_search import TavilySearchTool
class TavilySearchToolSet(ToolSet):
def __init__(self, tavily_token: str):
functions = [TavilySearchTool(tavily_token)]
with open("ai/tools/web_search/prompt.md", "r") as f:
system_prompt = f.read()
super().__init__(functions=functions, system_prompt=system_prompt)
__all__ = ["TavilySearchTool", "TavilySearchToolSet"]

View file

@ -0,0 +1,4 @@
## Веб-поиск
Для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных используй функцию `tavily_search`.
- Вызывай функцию поиска, когда нужна актуальная информация из интернета.
- После получения результатов дай пользователю краткую сводку найденной информации.

View file

@ -0,0 +1,56 @@
from tavily import TavilyClient
from typing import Any, Dict, List
from openrouter.components import ChatMessageContentItemTypedDict
from ai.tool import Tool
from ai.utils import *
class TavilySearchTool(Tool):
def __init__(self, tavily_token: str):
self._client = TavilyClient(api_key=tavily_token)
@property
def name(self) -> str:
return "tavily_search"
@property
def description(self) -> str:
return "Поиск информации в интернете"
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Запрос для поиска (на русском или английском языке)"
}
},
"required": ["query"]
}
async def execute(self, args: Dict[str, Any], _artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
query = args.get("query", "")
print(f"Веб-поиск: {query}")
try:
results = self._client.search(query=query, max_results=5)
if not results or "results" not in results:
return serialize_message_content(text="Не удалось получить результаты поиска.")
answer_parts = []
for i, result in enumerate(results["results"], 1):
title = result.get("title", "Без названия")
url = result.get("url", "")
content = result.get("content", "")
answer_parts.append(f"{i}. {title}\n {url}\n {content}\n")
answer = "\n".join(answer_parts)
return serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}")
except Exception as e:
print(f"Ошибка веб-поиска: {e}")
return serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}")

39
ai/utils.py Normal file
View file

@ -0,0 +1,39 @@
from base64 import b64encode
from io import BytesIO
from PIL import Image
from typing import Dict, List, Optional
def serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> List[Dict]:
content = []
if text is not None:
content.append({"type": "text", "text": text})
if image is not None:
content.append({"type": "image_url", "detail": "high", "image_url": {"url": encode_image(image)}})
return content
def encode_image(image: bytes) -> str:
encoded_image = b64encode(image).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_image}"
def compress_image(image: bytes, max_side: Optional[int] = None) -> bytes:
img = Image.open(BytesIO(image)).convert("RGB")
if img.width > max_side or img.height > max_side:
scale = min(max_side / img.width, max_side / img.height)
new_width = int(img.width * scale)
new_height = int(img.height * scale)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
output = BytesIO()
img.save(output, format='JPEG', quality=87, optimize=True)
return output.getvalue()
__all__ = [
"serialize_message_content",
"compress_image",
"encode_image"
]

View file

@ -161,8 +161,9 @@ class BasicDatabase:
self.cursor.execute(query, values) self.cursor.execute(query, values)
def context_set_last_message_id(self, bot_id: int, chat_id: int, message_id: int): def context_set_last_message_id(self, bot_id: int, chat_id: int, message_id: int):
self.cursor.execute("UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL", self.cursor.execute(
message_id, bot_id, chat_id) "UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL",
message_id, bot_id, chat_id)
def _context_trim(self, bot_id: int, chat_id: int, max_messages: int): def _context_trim(self, bot_id: int, chat_id: int, max_messages: int):
current_count = self.context_get_count(bot_id, chat_id) current_count = self.context_get_count(bot_id, chat_id)

View file

@ -1,7 +1,8 @@
MESSAGE_CHAT_NOT_ACTIVE = 'Извините, но я пока не работаю в этом чате.' MESSAGE_CHAT_NOT_ACTIVE = 'Извините, но я пока не работаю в этом чате.'
MESSAGE_PERMISSION_DENIED = 'Извините, но о таком меня может попросить только администратор чата.' MESSAGE_PERMISSION_DENIED = 'Извините, но о таком меня может попросить только администратор чата.'
MESSAGE_NEED_REPLY = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение.' MESSAGE_NEED_REPLY = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение.'
MESSAGE_NEED_REPLY_OR_FORWARD = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение или с пересылкой текстовых сообщений.' MESSAGE_NEED_REPLY_OR_FORWARD = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение ' \
'или с пересылкой текстовых сообщений.'
MESSAGE_UNSUPPORTED_CONTENT_TYPE = 'Извините, но я понимаю только текст и изображения.' MESSAGE_UNSUPPORTED_CONTENT_TYPE = 'Извините, но я понимаю только текст и изображения.'
MESSAGE_DEFAULT_RULES = 'Правила не установлены. Просто ведите себя хорошо.' MESSAGE_DEFAULT_RULES = 'Правила не установлены. Просто ведите себя хорошо.'
MESSAGE_DEFAULT_CHECK_RULES = 'Правила чата не установлены. Проверка невозможна.' MESSAGE_DEFAULT_CHECK_RULES = 'Правила чата не установлены. Проверка невозможна.'

View file

@ -1,67 +0,0 @@
[
{
"type": "function",
"function": {
"name": "generate_image",
"description": "Генерация изображения по описанию",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Подробное описание сцены на английском языке БЕЗ технических параметров (соотношение сторон, разрешение)"
},
"aspect_ratio": {
"type": "string",
"enum": ["1:1", "4:3", "3:4", "16:9", "9:16"],
"description": "Соотношение сторон"
}
},
"required": ["prompt"]
}
}
},
{
"type": "function",
"function": {
"name": "generate_image_anime",
"description": "Генерация изображения в стиле аниме по описанию",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Положительный запрос"
},
"negative_prompt": {
"type": "string",
"description": "Отрицательный запрос"
},
"aspect_ratio": {
"type": "string",
"enum": ["1:1", "4:3", "3:4", "16:9", "9:16", "9:20"],
"description": "Соотношение сторон"
}
},
"required": ["prompt", "negative_prompt"]
}
}
},
{
"type": "function",
"function": {
"name": "tavily_search",
"description": "Веб-поиск по теме запроса. Используй для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Запрос для поиска (на русском или английском языке)"
}
},
"required": ["query"]
}
}
}
]

View file

@ -4,7 +4,7 @@ import json
from aiogram import Bot, Dispatcher from aiogram import Bot, Dispatcher
from ai_agent import create_ai_agent from ai import create_ai_agent
import tg.tg_database as database import tg.tg_database as database

View file

@ -2,7 +2,7 @@ from aiogram import Bot, Router, F
from aiogram.types import Message from aiogram.types import Message
from aiogram.utils.formatting import Bold from aiogram.utils.formatting import Bold
import ai_agent import ai
import utils import utils
from messages import * from messages import *
@ -106,7 +106,7 @@ async def clear_context_handler(message: Message, bot: Bot):
await message.answer(MESSAGE_PERMISSION_DENIED) await message.answer(MESSAGE_PERMISSION_DENIED)
return return
ai_agent.agent.clear_chat_context(bot.id, chat_id) ai.agent.clear_chat_context(bot.id, chat_id)
await message.answer("Контекст очищен.") await message.answer("Контекст очищен.")

View file

@ -4,7 +4,7 @@ from aiogram import Router, F, Bot
from aiogram.types import Message from aiogram.types import Message
from aiogram.enums.content_type import ContentType from aiogram.enums.content_type import ContentType
import ai_agent import ai
import utils import utils
from messages import * from messages import *
@ -51,7 +51,7 @@ async def any_message_handler(message: Message, bot: Bot):
bot_user = await bot.me() bot_user = await bot.me()
ai_fwd_messages: list[ai_agent.Message] = [] ai_fwd_messages: list[ai.Message] = []
try: try:
message_text = get_message_text(message) message_text = get_message_text(message)
@ -64,7 +64,7 @@ async def any_message_handler(message: Message, bot: Bot):
ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)] ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)]
elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id: elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id:
# Ответ на сообщение бота # Ответ на сообщение бота
last_id = ai_agent.agent.get_last_assistant_message_id(bot.id, chat_id) last_id = ai.agent.get_last_assistant_message_id(bot.id, chat_id)
if message.reply_to_message.message_id != last_id: if message.reply_to_message.message_id != last_id:
# Оригинального сообщения нет в контексте, или оно не последнее -> переслать его # Оригинального сообщения нет в контексте, или оно не последнее -> переслать его
ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)] ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)]
@ -77,10 +77,10 @@ async def any_message_handler(message: Message, bot: Bot):
ai_message = await create_ai_message(message, bot) ai_message = await create_ai_message(message, bot)
ai_message.text = message_text ai_message.text = message_text
answer: ai_agent.Message answer: ai.agent.Message
success: bool success: bool
answer, success = await utils.run_with_progress( answer, success = await utils.run_with_progress(
partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages), partial(ai.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
partial(message.bot.send_chat_action, chat_id, 'typing'), partial(message.bot.send_chat_action, chat_id, 'typing'),
interval=4) interval=4)
@ -91,4 +91,4 @@ async def any_message_handler(message: Message, bot: Bot):
else: else:
answer_id = (await message.reply(answer.text)).message_id answer_id = (await message.reply(answer.text)).message_id
if success: if success:
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id) ai.agent.set_last_response_id(bot.id, chat_id, answer_id)

View file

@ -1,11 +1,11 @@
from functools import partial from functools import partial
from aiogram import Router, F, Bot from aiogram import Router, F, Bot
from aiogram.enums import ChatType, ContentType from aiogram.enums import ChatType
from aiogram.filters import Command, CommandObject, CommandStart from aiogram.filters import Command, CommandObject, CommandStart
from aiogram.types import Message from aiogram.types import Message
import ai_agent import ai
import utils import utils
from messages import * from messages import *
@ -38,7 +38,7 @@ async def reset_context_handler(message: Message, bot: Bot):
chat_id = message.chat.id chat_id = message.chat.id
database.DB.create_chat_if_not_exists(bot.id, chat_id) database.DB.create_chat_if_not_exists(bot.id, chat_id)
ai_agent.agent.clear_chat_context(bot.id, chat_id) ai.agent.clear_chat_context(bot.id, chat_id)
await message.answer("Контекст очищен.") await message.answer("Контекст очищен.")
@ -52,10 +52,10 @@ async def any_message_handler(message: Message, bot: Bot):
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE) await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
return return
answer: ai_agent.Message answer: ai.Message
success: bool success: bool
answer, success = await utils.run_with_progress( answer, success = await utils.run_with_progress(
partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, ai_message), partial(ai.agent.get_private_chat_reply, bot.id, chat_id, ai_message),
partial(message.bot.send_chat_action, chat_id, 'typing'), partial(message.bot.send_chat_action, chat_id, 'typing'),
interval=4) interval=4)
@ -66,4 +66,4 @@ async def any_message_handler(message: Message, bot: Bot):
else: else:
answer_id = (await message.answer(answer.text)).message_id answer_id = (await message.answer(answer.text)).message_id
if success: if success:
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id) ai.agent.set_last_response_id(bot.id, chat_id, answer_id)

View file

@ -6,7 +6,7 @@ from aiogram.enums import ContentType
from aiogram.types import Message from aiogram.types import Message
from aiogram.utils.formatting import Bold, Italic from aiogram.utils.formatting import Bold, Italic
import ai_agent import ai
import utils import utils
from messages import * from messages import *
@ -214,18 +214,18 @@ async def check_rules_violation_handler(message: Message, bot: Bot):
prompt += chat_rules + '\n\n' prompt += chat_rules + '\n\n'
prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):' prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):'
ai_message = ai_agent.Message(user_name=await get_user_name_for_ai(message.from_user), ai_message = ai.Message(user_name=await get_user_name_for_ai(message.from_user),
text=prompt, message_id=message.message_id) text=prompt, message_id=message.message_id)
ai_fwd_messages = [ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user), ai_fwd_messages = [ai.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user),
text=message.reply_to_message.text)] text=message.reply_to_message.text)]
answer: ai_agent.Message answer: ai.Message
success: bool success: bool
answer, success = await utils.run_with_progress( answer, success = await utils.run_with_progress(
partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages), partial(ai.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
partial(bot.send_chat_action, chat_id, 'typing'), partial(bot.send_chat_action, chat_id, 'typing'),
interval=4) interval=4)
answer_id = (await message.answer(answer.text)).message_id answer_id = (await message.answer(answer.text)).message_id
if success: if success:
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id) ai.agent.set_last_response_id(bot.id, chat_id, answer_id)

View file

@ -39,8 +39,9 @@ class TgDatabase(database.BasicDatabase):
warnings TINYINT NOT NULL DEFAULT 0, warnings TINYINT NOT NULL DEFAULT 0,
about VARCHAR(1000), about VARCHAR(1000),
PRIMARY KEY (bot_id, chat_id, user_id), PRIMARY KEY (bot_id, chat_id, user_id),
CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
""") ON UPDATE CASCADE ON DELETE CASCADE
)""")
self.cursor.execute(""" self.cursor.execute("""
CREATE TABLE IF NOT EXISTS contexts ( CREATE TABLE IF NOT EXISTS contexts (
@ -52,8 +53,9 @@ class TgDatabase(database.BasicDatabase):
text VARCHAR(4000), text VARCHAR(4000),
image MEDIUMBLOB, image MEDIUMBLOB,
PRIMARY KEY (id), PRIMARY KEY (id),
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
""") ON UPDATE CASCADE ON DELETE CASCADE
)""")
self.conn.commit() self.conn.commit()

View file

@ -6,7 +6,7 @@ from aiogram import Bot
from aiogram.enums import ContentType from aiogram.enums import ContentType
from aiogram.types import User, PhotoSize, Message, BufferedInputFile from aiogram.types import User, PhotoSize, Message, BufferedInputFile
import ai_agent import ai
import utils import utils
@ -36,8 +36,8 @@ def get_message_text(message: Message) -> Optional[str]:
return None return None
async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message: async def create_ai_message(message: Message, bot: Bot) -> ai.Message:
ai_message = ai_agent.Message() ai_message = ai.Message()
ai_message.message_id = message.message_id ai_message.message_id = message.message_id
ai_message.user_name = await get_user_name_for_ai(message.from_user) ai_message.user_name = await get_user_name_for_ai(message.from_user)
if message.content_type == ContentType.TEXT: if message.content_type == ContentType.TEXT:
@ -64,3 +64,14 @@ def wrap_document(document: bytes, name_prefix: str, extension: str) -> Buffered
def trim_caption(caption: str) -> str: def trim_caption(caption: str) -> str:
return caption[:1024] return caption[:1024]
__all__ = [
"create_ai_message",
"get_message_text",
"get_user_name_for_ai",
"trim_caption",
"wrap_photo",
"wrap_document",
"wrap_document"
]

View file

@ -3,7 +3,7 @@ import json
from vkbottle.bot import Bot, run_multibot from vkbottle.bot import Bot, run_multibot
from ai_agent import create_ai_agent from ai import create_ai_agent
import vk.vk_database as database import vk.vk_database as database

View file

@ -3,7 +3,7 @@ from vkbottle.bot import Message
from vkbottle.framework.labeler import BotLabeler from vkbottle.framework.labeler import BotLabeler
from vkbottle_types.codegen.objects import MessagesGetConversationMembers from vkbottle_types.codegen.objects import MessagesGetConversationMembers
import ai_agent import ai
import utils import utils
from messages import * from messages import *
@ -170,7 +170,7 @@ async def clear_context_handler(message: Message):
await message.answer(MESSAGE_PERMISSION_DENIED) await message.answer(MESSAGE_PERMISSION_DENIED)
return return
ai_agent.agent.clear_chat_context(bot_id, chat_id) ai.agent.clear_chat_context(bot_id, chat_id)
await message.answer("Контекст очищен.") await message.answer("Контекст очищен.")

View file

@ -6,7 +6,7 @@ from vkbottle.bot import Message
from vkbottle.framework.labeler import BotLabeler from vkbottle.framework.labeler import BotLabeler
from vkbottle_types.codegen.objects import GroupsGroup from vkbottle_types.codegen.objects import GroupsGroup
import ai_agent import ai
import utils import utils
from messages import * from messages import *
@ -59,7 +59,7 @@ async def any_message_handler(message: Message):
message_text = message_text.replace(bot_username_mention, bot_user.name) message_text = message_text.replace(bot_username_mention, bot_user.name)
bot_mentioned = True bot_mentioned = True
ai_fwd_messages: list[ai_agent.Message] = [] ai_fwd_messages: list[ai.agent.Message] = []
try: try:
if bot_mentioned: if bot_mentioned:
@ -73,7 +73,7 @@ async def any_message_handler(message: Message):
ai_fwd_messages.append(await create_ai_message(fwd_message)) ai_fwd_messages.append(await create_ai_message(fwd_message))
elif message.reply_message and message.reply_message.from_id == -bot_user.id: elif message.reply_message and message.reply_message.from_id == -bot_user.id:
# Ответ на сообщение бота # Ответ на сообщение бота
last_id = ai_agent.agent.get_last_assistant_message_id(bot_id, chat_id) last_id = ai.agent.get_last_assistant_message_id(bot_id, chat_id)
if message.reply_message.message_id != last_id: if message.reply_message.message_id != last_id:
# Оригинального сообщения нет в контексте, или оно не последнее -> переслать его # Оригинального сообщения нет в контексте, или оно не последнее -> переслать его
ai_fwd_messages = [await create_ai_message(message.reply_message)] ai_fwd_messages = [await create_ai_message(message.reply_message)]
@ -86,10 +86,10 @@ async def any_message_handler(message: Message):
ai_message = await create_ai_message(message) ai_message = await create_ai_message(message)
ai_message.text = message_text ai_message.text = message_text
answer: ai_agent.Message answer: ai.agent.Message
success: bool success: bool
answer, success = await utils.run_with_progress( answer, success = await utils.run_with_progress(
partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages), partial(ai.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
interval=4) interval=4)
@ -100,4 +100,4 @@ async def any_message_handler(message: Message):
answer_id = (await message.reply(answer.text)).conversation_message_id answer_id = (await message.reply(answer.text)).conversation_message_id
if success: if success:
ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id) ai.agent.set_last_response_id(bot_id, chat_id, answer_id)

View file

@ -4,7 +4,7 @@ from vkbottle.bot import Message
from vkbottle.dispatch.rules.base import RegexRule from vkbottle.dispatch.rules.base import RegexRule
from vkbottle.framework.labeler import BotLabeler from vkbottle.framework.labeler import BotLabeler
import ai_agent import ai
import utils import utils
from messages import * from messages import *
@ -39,7 +39,7 @@ async def reset_context_handler(message: Message):
chat_id = message.peer_id chat_id = message.peer_id
database.DB.create_chat_if_not_exists(bot_id, chat_id) database.DB.create_chat_if_not_exists(bot_id, chat_id)
ai_agent.agent.clear_chat_context(bot_id, chat_id) ai.agent.clear_chat_context(bot_id, chat_id)
await message.answer("Контекст очищен.") await message.answer("Контекст очищен.")
@ -54,10 +54,10 @@ async def any_message_handler(message: Message):
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE) await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
return return
answer: ai_agent.Message answer: ai.Message
success: bool success: bool
answer, success = await utils.run_with_progress( answer, success = await utils.run_with_progress(
partial(ai_agent.agent.get_private_chat_reply, bot_id, chat_id, ai_message), partial(ai.agent.get_private_chat_reply, bot_id, chat_id, ai_message),
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
interval=4) interval=4)
@ -68,4 +68,4 @@ async def any_message_handler(message: Message):
answer_id = (await message.answer(answer.text)).message_id answer_id = (await message.answer(answer.text)).message_id
if success: if success:
ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id) ai.agent.set_last_response_id(bot_id, chat_id, answer_id)

View file

@ -1,11 +1,11 @@
from functools import partial from functools import partial
from typing import List, Any from typing import List, Any
from vkbottle import bold, italic from vkbottle import bold, italic, API
from vkbottle.bot import Message from vkbottle.bot import Message
from vkbottle.framework.labeler import BotLabeler from vkbottle.framework.labeler import BotLabeler
import ai_agent import ai
import utils import utils
from messages import * from messages import *
@ -246,31 +246,31 @@ async def check_rules_violation_handler(message: Message):
prompt += chat_rules + '\n\n' prompt += chat_rules + '\n\n'
prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):' prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):'
ai_message = ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.from_id), ai_message = ai.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.from_id),
text=prompt, message_id=message.message_id) text=prompt, message_id=message.message_id)
ai_fwd_messages: list[ai_agent.Message] = [] ai_fwd_messages: list[ai.Message] = []
if message.reply_message is not None and len(message.reply_message.text) > 0: if message.reply_message is not None and len(message.reply_message.text) > 0:
ai_fwd_messages.append( ai_fwd_messages.append(
ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id), ai.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id),
text=message.reply_message.text)) text=message.reply_message.text))
else: else:
for fwd_message in message.fwd_messages: for fwd_message in message.fwd_messages:
if len(fwd_message.text) > 0: if len(fwd_message.text) > 0:
ai_fwd_messages.append( ai_fwd_messages.append(
ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id), ai.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id),
text=fwd_message.text)) text=fwd_message.text))
if len(ai_fwd_messages) == 0: if len(ai_fwd_messages) == 0:
await message.answer(MESSAGE_NEED_REPLY_OR_FORWARD) await message.answer(MESSAGE_NEED_REPLY_OR_FORWARD)
return return
answer: ai_agent.Message answer: ai.Message
success: bool success: bool
answer, success = await utils.run_with_progress( answer, success = await utils.run_with_progress(
partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages), partial(ai.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
interval=4) interval=4)
answer_id = (await message.answer(answer.text)).message_id answer_id = (await message.answer(answer.text)).message_id
if success: if success:
ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id) ai.agent.set_last_response_id(bot_id, chat_id, answer_id)

View file

@ -7,7 +7,7 @@ from vkbottle.bot import Message
from vkbottle_types.codegen.objects import PhotosPhotoSizes from vkbottle_types.codegen.objects import PhotosPhotoSizes
from vkbottle_types.objects import MessagesMessageAttachmentType from vkbottle_types.objects import MessagesMessageAttachmentType
import ai_agent import ai
import utils import utils
@ -18,6 +18,7 @@ class MyAPI(API):
def get_bot_id(api: API) -> int: def get_bot_id(api: API) -> int:
# noinspection PyTypeChecker
my_api: MyAPI = api my_api: MyAPI = api
return my_api.bot_id return my_api.bot_id
@ -48,8 +49,8 @@ async def download_photo(photos: List[PhotosPhotoSizes]) -> bytes:
raise RuntimeError(f"Failed to download photo. Status code: {response.status}") raise RuntimeError(f"Failed to download photo. Status code: {response.status}")
async def create_ai_message(message: Message) -> ai_agent.Message: async def create_ai_message(message: Message) -> ai.Message:
ai_message = ai_agent.Message() ai_message = ai.Message()
ai_message.message_id = message.conversation_message_id ai_message.message_id = message.conversation_message_id
ai_message.user_name = await get_user_name_for_ai(message.ctx_api, message.from_id) ai_message.user_name = await get_user_name_for_ai(message.ctx_api, message.from_id)
if len(message.text) > 0: if len(message.text) > 0:
@ -67,3 +68,13 @@ async def create_ai_message(message: Message) -> ai_agent.Message:
async def upload_photo(image: bytes, chat_id: int, api: API) -> str: async def upload_photo(image: bytes, chat_id: int, api: API) -> str:
return await PhotoMessageUploader(api).upload(file_source=image, peer_id=chat_id) return await PhotoMessageUploader(api).upload(file_source=image, peer_id=chat_id)
__all__ = [
"MyAPI",
"get_bot_id",
"get_user_name_for_ai",
"download_photo",
"create_ai_message",
"upload_photo"
]

View file

@ -42,8 +42,9 @@ class VkDatabase(database.BasicDatabase):
happy_birthday TINYINT NOT NULL DEFAULT 1, happy_birthday TINYINT NOT NULL DEFAULT 1,
about VARCHAR(1000), about VARCHAR(1000),
PRIMARY KEY (bot_id, chat_id, user_id), PRIMARY KEY (bot_id, chat_id, user_id),
CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
""") ON UPDATE CASCADE ON DELETE CASCADE
)""")
self.cursor.execute(""" self.cursor.execute("""
CREATE TABLE IF NOT EXISTS contexts ( CREATE TABLE IF NOT EXISTS contexts (
@ -55,8 +56,9 @@ class VkDatabase(database.BasicDatabase):
text VARCHAR(4000), text VARCHAR(4000),
image MEDIUMBLOB, image MEDIUMBLOB,
PRIMARY KEY (id), PRIMARY KEY (id),
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
""") ON UPDATE CASCADE ON DELETE CASCADE
)""")
self.conn.commit() self.conn.commit()