Рефакторинг.
Вся логика ИИ перенесена в модуль ai. Логика инструментов выделена в отдельные подмодули. Исправлены все проблемы, обнаруженные PyCharm.
This commit is contained in:
parent
924d728533
commit
1c58359e44
31 changed files with 534 additions and 335 deletions
18
ai/__init__.py
Normal file
18
ai/__init__.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
import ai.agent
|
||||
from database import BasicDatabase
|
||||
|
||||
Message = ai.agent.Message
|
||||
Agent = ai.agent.AiAgent
|
||||
|
||||
# Глобальный экземпляр агента
|
||||
agent: ai.agent.AiAgent
|
||||
|
||||
|
||||
def create_ai_agent(openrouter_token: str, openrouter_model: str,
|
||||
fal_token: str, replicate_token: str, tavily_token: str,
|
||||
db: BasicDatabase, platform: str):
|
||||
global agent
|
||||
agent = ai.agent.AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, tavily_token, db, platform)
|
||||
|
||||
|
||||
__all__ = ["agent", "Agent", "Message", "create_ai_agent"]
|
||||
|
|
@ -1,26 +1,22 @@
|
|||
import base64
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import List, Tuple, Optional, Union, Dict, Awaitable
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from openrouter import OpenRouter, RetryConfig
|
||||
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
|
||||
ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict
|
||||
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, \
|
||||
ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict, ToolDefinitionJSONTypedDict
|
||||
from openrouter.errors import ResponseValidationError, OpenRouterError
|
||||
from openrouter.utils import BackoffStrategy
|
||||
|
||||
from fal_client import AsyncClient as FalClient
|
||||
from replicate import Client as ReplicateClient
|
||||
from tavily import TavilyClient
|
||||
|
||||
from utils import download_file
|
||||
import ai.tool
|
||||
from database import BasicDatabase
|
||||
|
||||
from ai.utils import *
|
||||
from ai.tools import *
|
||||
|
||||
|
||||
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
|
||||
OPENROUTER_HTTP_REFERER = "https://ultracoder.org"
|
||||
|
||||
|
|
@ -28,9 +24,6 @@ GROUP_CHAT_MAX_MESSAGES = 40
|
|||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||||
MAX_OUTPUT_TOKENS = 500
|
||||
|
||||
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
|
||||
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b"
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Message:
|
||||
|
|
@ -60,14 +53,20 @@ class AiAgent:
|
|||
self.client_openrouter = OpenRouter(api_key=openrouter_token,
|
||||
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER,
|
||||
retry_config=retry_config)
|
||||
self.client_fal = FalClient(key=fal_token)
|
||||
self.replicate_client = ReplicateClient(api_token=replicate_token)
|
||||
self.tavily_client = TavilyClient(api_key=tavily_token)
|
||||
|
||||
@dataclass()
|
||||
class _ToolsArtifacts:
|
||||
generated_image: Optional[bytes] = None
|
||||
generated_image_hires: Optional[bytes] = None
|
||||
# Создание наборов инструментов
|
||||
self.toolsets: list[ai.tool.ToolSet] = []
|
||||
self.toolsets.append(
|
||||
ImageGenerationToolSet(fal_token=fal_token, replicate_token=replicate_token)
|
||||
)
|
||||
self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token))
|
||||
|
||||
# Сбор всех инструментов
|
||||
self.tools: list[ai.tool.Tool] = []
|
||||
self.tools_descriptions: list[ToolDefinitionJSONTypedDict] = []
|
||||
for toolset in self.toolsets:
|
||||
self.tools.extend(toolset.functions)
|
||||
self.tools_descriptions.extend(toolset.get_all_tools_description())
|
||||
|
||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
||||
|
|
@ -87,10 +86,9 @@ class AiAgent:
|
|||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||||
ai_response = response.content
|
||||
|
||||
tools_artifacts = AiAgent._ToolsArtifacts()
|
||||
tools_artifacts = {}
|
||||
if response.tool_calls is not None:
|
||||
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
|
||||
tool_calls=response.tool_calls, context=context)
|
||||
tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context)
|
||||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||||
ai_response = response2.content
|
||||
|
||||
|
|
@ -101,11 +99,12 @@ class AiAgent:
|
|||
role="user", text=fwd_message.text, image=fwd_message.image,
|
||||
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id,
|
||||
role="assistant", text=ai_response, image=tools_artifacts.generated_image,
|
||||
role="assistant", text=ai_response,
|
||||
image=tools_artifacts.get("generated_image"),
|
||||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
|
||||
return Message(text=ai_response, image=tools_artifacts.generated_image,
|
||||
image_hires=tools_artifacts.generated_image_hires), True
|
||||
return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
|
||||
image_hires=tools_artifacts.get("generated_image_hires")), True
|
||||
|
||||
except Exception as e:
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
|
|
@ -125,21 +124,20 @@ class AiAgent:
|
|||
context.append(_serialize_assistant_message(response))
|
||||
ai_response = response.content
|
||||
|
||||
tools_artifacts = AiAgent._ToolsArtifacts()
|
||||
tools_artifacts = {}
|
||||
if response.tool_calls is not None:
|
||||
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
|
||||
tool_calls=response.tool_calls, context=context)
|
||||
tools_artifacts = await self._process_tool_calls(tool_calls=response.tool_calls, context=context)
|
||||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||||
ai_response = response2.content
|
||||
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
||||
text=ai_response, image=tools_artifacts.generated_image,
|
||||
text=ai_response, image=tools_artifacts.get("generated_image"),
|
||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
|
||||
return Message(text=ai_response, image=tools_artifacts.generated_image,
|
||||
image_hires=tools_artifacts.generated_image_hires), True
|
||||
return Message(text=ai_response, image=tools_artifacts.get("generated_image"),
|
||||
image_hires=tools_artifacts.get("generated_image_hires")), True
|
||||
|
||||
except Exception as e:
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
|
|
@ -170,7 +168,12 @@ class AiAgent:
|
|||
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict:
|
||||
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
|
||||
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
|
||||
prompt += '\n' + self.system_prompt_tools
|
||||
|
||||
prompt += '\n# Доступные инструменты\n'
|
||||
for toolset in self.toolsets:
|
||||
prompt += '\n' + toolset.system_prompt
|
||||
|
||||
prompt += '\n' + '# Дополнительные инструкции\n'
|
||||
|
||||
bot = self.db.get_bot(bot_id)
|
||||
if bot['ai_prompt'] is not None:
|
||||
|
|
@ -187,139 +190,37 @@ class AiAgent:
|
|||
response = await self._async_chat_completion_request(
|
||||
model=self.openrouter_model,
|
||||
messages=context,
|
||||
tools=self.tools_description if allow_tools else None,
|
||||
tools=self.tools_descriptions if allow_tools else None,
|
||||
tool_choice="auto" if allow_tools else None,
|
||||
max_tokens=MAX_OUTPUT_TOKENS,
|
||||
user=f'{self.platform}_{bot_id}_{chat_id}'
|
||||
)
|
||||
return self._filter_response(response.choices[0].message)
|
||||
|
||||
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
|
||||
context: List[MessageTypedDict]) -> _ToolsArtifacts:
|
||||
artifacts = AiAgent._ToolsArtifacts()
|
||||
async def _process_tool_calls(self, tool_calls: List[ChatMessageToolCall],
|
||||
context: List[MessageTypedDict]) -> dict:
|
||||
artifacts = {}
|
||||
if tool_calls is None:
|
||||
return artifacts
|
||||
|
||||
functions_map: Dict[str,
|
||||
Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
|
||||
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
|
||||
"generate_image": self._process_tool_generate_image,
|
||||
"generate_image_anime": self._process_tool_generate_image_anime,
|
||||
"tavily_search": self._process_tool_tavily_search
|
||||
}
|
||||
tools_map = {tool.name: tool for tool in self.tools}
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
if tool_name in functions_map:
|
||||
tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts)
|
||||
|
||||
if tool_name in tools_map:
|
||||
tool = tools_map[tool_name]
|
||||
# Вызов инструмента с передачей artifacts
|
||||
tool_result = await tool.execute(tool_args, artifacts)
|
||||
context.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_result
|
||||
})
|
||||
|
||||
return artifacts
|
||||
|
||||
async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
|
||||
-> List[ChatMessageContentItemTypedDict]:
|
||||
prompt = args.get("prompt", "")
|
||||
aspect_ratio = args.get("aspect_ratio", None)
|
||||
|
||||
aspect_ratio_size_map = {
|
||||
"1:1": "square",
|
||||
"4:3": "landscape_4_3",
|
||||
"3:4": "portrait_4_3",
|
||||
"16:9": "landscape_16_9",
|
||||
"9:16": "portrait_16_9",
|
||||
"9:20": "portrait_16_9"
|
||||
}
|
||||
image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3")
|
||||
print(f"Генерация изображения {image_size}: {prompt}")
|
||||
|
||||
arguments = {
|
||||
"prompt": prompt,
|
||||
"image_size": image_size,
|
||||
"enable_safety_checker": False
|
||||
}
|
||||
|
||||
try:
|
||||
result = await self.client_fal.run(FAL_MODEL, arguments=arguments)
|
||||
if "images" not in result:
|
||||
raise RuntimeError("Неожиданный ответ от сервера.")
|
||||
image_url = result["images"][0]["url"]
|
||||
artifacts.generated_image_hires = await download_file(image_url)
|
||||
artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
|
||||
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=None)
|
||||
except Exception as e:
|
||||
print(f"Ошибка генерации изображения: {e}")
|
||||
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||
|
||||
async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int,
|
||||
args: dict, artifacts: _ToolsArtifacts) \
|
||||
-> List[ChatMessageContentItemTypedDict]:
|
||||
prompt = args.get("prompt", "")
|
||||
negative_prompt = args.get("negative_prompt", "")
|
||||
aspect_ratio = args.get("aspect_ratio", None)
|
||||
|
||||
aspect_ratio_resolution_map = {
|
||||
"1:1": (1280, 1280),
|
||||
"4:3": (1280, 1024),
|
||||
"3:4": (1024, 1280),
|
||||
"16:9": (1280, 720),
|
||||
"9:16": (720, 1280),
|
||||
"9:20": (720, 1600)
|
||||
}
|
||||
width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
|
||||
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
|
||||
|
||||
arguments = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"add_recommended_tags": False,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"guidance_scale": 4.5,
|
||||
"num_inference_steps": 20,
|
||||
"hires_enable": True,
|
||||
"hires_num_inference_steps": 30,
|
||||
"disable_safety_checker": True
|
||||
}
|
||||
|
||||
try:
|
||||
outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments)
|
||||
artifacts.generated_image_hires = await outputs[0].aread()
|
||||
artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
|
||||
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=None)
|
||||
except Exception as e:
|
||||
print(f"Ошибка генерации изображения: {e}")
|
||||
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||
|
||||
async def _process_tool_tavily_search(self, _bot_id: int, _chat_id: int, args: dict,
|
||||
_artifacts: _ToolsArtifacts) -> List[ChatMessageContentItemTypedDict]:
|
||||
query = args.get("query", "")
|
||||
print(f"Веб-поиск: {query}")
|
||||
|
||||
try:
|
||||
results = self.tavily_client.search(query=query, max_results=5)
|
||||
|
||||
if not results or "results" not in results:
|
||||
return _serialize_message_content(text="Не удалось получить результаты поиска.")
|
||||
|
||||
answer_parts = []
|
||||
for i, result in enumerate(results["results"], 1):
|
||||
title = result.get("title", "Без названия")
|
||||
url = result.get("url", "")
|
||||
content = result.get("content", "")
|
||||
answer_parts.append(f"{i}. {title}\n {url}\n {content}\n")
|
||||
|
||||
answer = "\n".join(answer_parts)
|
||||
return _serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}")
|
||||
except Exception as e:
|
||||
print(f"Ошибка веб-поиска: {e}")
|
||||
return _serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}")
|
||||
|
||||
async def _async_chat_completion_request(self, **kwargs):
|
||||
try:
|
||||
return await self.client_openrouter.chat.send_async(**kwargs)
|
||||
|
|
@ -354,24 +255,10 @@ class AiAgent:
|
|||
return response
|
||||
|
||||
def _load_prompts(self):
|
||||
with open("prompts/group_chat.md", "r") as f:
|
||||
with open("ai/prompts/group_chat.md", "r") as f:
|
||||
self.system_prompt_group_chat = f.read()
|
||||
with open("prompts/private_chat.md", "r") as f:
|
||||
with open("ai/prompts/private_chat.md", "r") as f:
|
||||
self.system_prompt_private_chat = f.read()
|
||||
with open("prompts/tools.md", "r") as f:
|
||||
self.system_prompt_tools = f.read()
|
||||
with open("prompts/tools.json", "r") as f:
|
||||
self.tools_description = json.loads(f.read())
|
||||
|
||||
|
||||
agent: AiAgent
|
||||
|
||||
|
||||
def create_ai_agent(openrouter_token: str, openrouter_model: str,
|
||||
fal_token: str, replicate_token: str, tavily_token: str,
|
||||
db: BasicDatabase, platform: str):
|
||||
global agent
|
||||
agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, tavily_token, db, platform)
|
||||
|
||||
|
||||
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
|
||||
|
|
@ -380,22 +267,12 @@ def _add_message_prefix(text: Optional[str], username: Optional[str] = None) ->
|
|||
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
|
||||
|
||||
|
||||
def _encode_image(image: bytes) -> str:
|
||||
encoded_image = base64.b64encode(image).decode('utf-8')
|
||||
return f"data:image/jpeg;base64,{encoded_image}"
|
||||
|
||||
|
||||
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
||||
return {"role": role, "content": _serialize_message_content(text, image)}
|
||||
return {"role": role, "content": serialize_message_content(text, image)}
|
||||
|
||||
|
||||
def _serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> list[dict]:
|
||||
content = []
|
||||
if text is not None:
|
||||
content.append({"type": "text", "text": text})
|
||||
if image is not None:
|
||||
content.append({"type": "image_url", "detail": "high", "image_url": {"url": _encode_image(image)}})
|
||||
return content
|
||||
def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict:
|
||||
return _remove_none_recursive(message.model_dump(by_alias=True))
|
||||
|
||||
|
||||
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:
|
||||
|
|
@ -413,21 +290,3 @@ def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, an
|
|||
]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict:
|
||||
return _remove_none_recursive(message.model_dump(by_alias=True))
|
||||
|
||||
|
||||
def _compress_image(image: bytes, max_side: Optional[int] = None) -> bytes:
|
||||
img = Image.open(BytesIO(image)).convert("RGB")
|
||||
|
||||
if img.width > max_side or img.height > max_side:
|
||||
scale = min(max_side / img.width, max_side / img.height)
|
||||
new_width = int(img.width * scale)
|
||||
new_height = int(img.height * scale)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
output = BytesIO()
|
||||
img.save(output, format='JPEG', quality=87, optimize=True)
|
||||
return output.getvalue()
|
||||
66
ai/tool.py
Normal file
66
ai/tool.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openrouter.components import ChatMessageContentItemTypedDict, ToolDefinitionJSONTypedDict
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""Интерфейс функции"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Имя функции (snake_case)"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Текстовое описание функции"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
"""Описание параметров функции"""
|
||||
pass
|
||||
|
||||
def to_dict(self) -> ToolDefinitionJSONTypedDict:
|
||||
"""JSON-представление инструмента для OpenRouter"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters
|
||||
}
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
|
||||
"""Вызов функции.
|
||||
:param args: Параметры из JSON
|
||||
:param artifacts: Словарь для хранения артефактов
|
||||
:return: Содержимое JSON-ответа на вызов функции
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSet:
|
||||
"""Набор логически объединенных функций"""
|
||||
|
||||
functions: List[Tool]
|
||||
"""Список функций, входящих в набор"""
|
||||
|
||||
system_prompt: str
|
||||
"""Дополнение к системному запросу, описывающее, как пользоваться функциями"""
|
||||
|
||||
def get_function_by_name(self, name: str) -> Optional[Tool]:
|
||||
"""Поиск инструмента по имени"""
|
||||
return next((t for t in self.functions if t.name == name), None)
|
||||
|
||||
def get_all_tools_description(self) -> List[ToolDefinitionJSONTypedDict]:
|
||||
"""Получить JSON-описание всех инструментов"""
|
||||
return [tool.to_dict() for tool in self.functions]
|
||||
7
ai/tools/__init__.py
Normal file
7
ai/tools/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from ai.tools.image_generation import ImageGenerationToolSet
|
||||
from ai.tools.web_search import TavilySearchToolSet
|
||||
|
||||
__all__ = [
|
||||
"ImageGenerationToolSet",
|
||||
"TavilySearchToolSet"
|
||||
]
|
||||
18
ai/tools/image_generation/__init__.py
Normal file
18
ai/tools/image_generation/__init__.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from ai.tool import ToolSet
|
||||
|
||||
from .generate_image import GenerateImageTool
|
||||
from .generate_image_anime import GenerateImageAnimeTool
|
||||
|
||||
|
||||
class ImageGenerationToolSet(ToolSet):
|
||||
def __init__(self, fal_token: str, replicate_token: str):
|
||||
functions = [
|
||||
GenerateImageTool(fal_token),
|
||||
GenerateImageAnimeTool(replicate_token)
|
||||
]
|
||||
with open("ai/tools/image_generation/prompt.md", "r") as f:
|
||||
system_prompt = f.read()
|
||||
super().__init__(functions=functions, system_prompt=system_prompt)
|
||||
|
||||
|
||||
__all__ = ["GenerateImageTool", "GenerateImageAnimeTool", "ImageGenerationToolSet"]
|
||||
79
ai/tools/image_generation/generate_image.py
Normal file
79
ai/tools/image_generation/generate_image.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
from fal_client import AsyncClient as FalClient
|
||||
from openrouter.components import ChatMessageContentItemTypedDict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from ai.tool import Tool
|
||||
from ai.utils import *
|
||||
|
||||
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
|
||||
|
||||
|
||||
class GenerateImageTool(Tool):
|
||||
def __init__(self, fal_token: str):
|
||||
self._client = FalClient(key=fal_token)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "generate_image"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Генерация изображения по описанию"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Подробное описание сцены на английском языке БЕЗ технических параметров "
|
||||
"(соотношение сторон, разрешение)"
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["1:1", "4:3", "3:4", "16:9", "9:16"],
|
||||
"description": "Соотношение сторон"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
|
||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
|
||||
prompt = args.get("prompt", "")
|
||||
aspect_ratio = args.get("aspect_ratio", None)
|
||||
|
||||
aspect_ratio_size_map = {
|
||||
"1:1": "square",
|
||||
"4:3": "landscape_4_3",
|
||||
"3:4": "portrait_4_3",
|
||||
"16:9": "landscape_16_9",
|
||||
"9:16": "portrait_16_9",
|
||||
"9:20": "portrait_16_9"
|
||||
}
|
||||
image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3")
|
||||
print(f"Генерация изображения {image_size}: {prompt}")
|
||||
|
||||
arguments = {
|
||||
"prompt": prompt,
|
||||
"image_size": image_size,
|
||||
"enable_safety_checker": False
|
||||
}
|
||||
|
||||
try:
|
||||
result = await self._client.run(FAL_MODEL, arguments=arguments)
|
||||
if "images" not in result:
|
||||
raise RuntimeError("Неожиданный ответ от сервера.")
|
||||
image_url = result["images"][0]["url"]
|
||||
|
||||
from utils import download_file
|
||||
artifacts["generated_image_hires"] = await download_file(image_url)
|
||||
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)
|
||||
|
||||
return serialize_message_content(
|
||||
text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=None
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Ошибка генерации изображения: {e}")
|
||||
return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||
85
ai/tools/image_generation/generate_image_anime.py
Normal file
85
ai/tools/image_generation/generate_image_anime.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
from openrouter.components import ChatMessageContentItemTypedDict
|
||||
from replicate import Client as ReplicateClient
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from ai.tool import Tool
|
||||
from ai.utils import *
|
||||
|
||||
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-17:8f702486aa2852a08564ede8c83a7f58e52c83f6698e7be0e061d79c113dc88b"
|
||||
|
||||
|
||||
class GenerateImageAnimeTool(Tool):
|
||||
def __init__(self, replicate_token: str):
|
||||
self._client = ReplicateClient(api_token=replicate_token)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "generate_image_anime"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Генерация изображения в стиле аниме по описанию"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Положительный запрос"
|
||||
},
|
||||
"negative_prompt": {
|
||||
"type": "string",
|
||||
"description": "Отрицательный запрос"
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["1:1", "4:3", "3:4", "16:9", "9:16", "9:20"],
|
||||
"description": "Соотношение сторон"
|
||||
}
|
||||
},
|
||||
"required": ["prompt", "negative_prompt"]
|
||||
}
|
||||
|
||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
|
||||
prompt = args.get("prompt", "")
|
||||
negative_prompt = args.get("negative_prompt", "")
|
||||
aspect_ratio = args.get("aspect_ratio", None)
|
||||
|
||||
aspect_ratio_resolution_map = {
|
||||
"1:1": (1280, 1280),
|
||||
"4:3": (1280, 1024),
|
||||
"3:4": (1024, 1280),
|
||||
"16:9": (1280, 720),
|
||||
"9:16": (720, 1280),
|
||||
"9:20": (720, 1600)
|
||||
}
|
||||
width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
|
||||
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
|
||||
|
||||
arguments = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"add_recommended_tags": False,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"guidance_scale": 4.5,
|
||||
"num_inference_steps": 20,
|
||||
"hires_enable": True,
|
||||
"hires_num_inference_steps": 30,
|
||||
"disable_safety_checker": True
|
||||
}
|
||||
|
||||
try:
|
||||
outputs = await self._client.async_run(REPLICATE_MODEL, input=arguments)
|
||||
artifacts["generated_image_hires"] = await outputs[0].aread()
|
||||
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)
|
||||
|
||||
return serialize_message_content(
|
||||
text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=None
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Ошибка генерации изображения: {e}")
|
||||
return serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
# Доступные инструменты
|
||||
|
||||
## Генерация изображений
|
||||
Если пользователь просит "нарисовать" или "показать" что-то, сгенерируй изображение путем вызова одной из функций.
|
||||
При вызове функции не нужно добавлять сообщение - оно будет отброшено.
|
||||
|
|
@ -28,8 +26,3 @@
|
|||
2. Положительный запрос должен начинаться с `masterpiece, best quality, amazing quality, 4k, very aesthetic, high resolution, ultra-detailed, absurdres, newest, scenery`, а заканчиваться `depth of field, volumetric lighting`.
|
||||
3. Отрицательный запрос должен заканчиваться `modern, recent, old, oldest, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, long body, lowres, bad anatomy, bad hands, missing fingers, extra digits, fewer digits, cropped, very displeasing, (worst quality, bad quality:1.2), bad anatomy, sketch, jpeg artifacts, signature, watermark, username, signature, simple background, conjoined, bad ai-generated`.
|
||||
4. Ты можешь добавлять тегам веса, например: `1girl, (long hair:1.2), pink hair`.
|
||||
|
||||
## Веб-поиск
|
||||
Для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных используй функцию `tavily_search`.
|
||||
- Вызывай функцию поиска, когда нужна актуальная информация из интернета.
|
||||
- После получения результатов дай пользователю краткую сводку найденной информации.
|
||||
14
ai/tools/web_search/__init__.py
Normal file
14
ai/tools/web_search/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from ai.tool import ToolSet
|
||||
|
||||
from .tavily_search import TavilySearchTool
|
||||
|
||||
|
||||
class TavilySearchToolSet(ToolSet):
|
||||
def __init__(self, tavily_token: str):
|
||||
functions = [TavilySearchTool(tavily_token)]
|
||||
with open("ai/tools/web_search/prompt.md", "r") as f:
|
||||
system_prompt = f.read()
|
||||
super().__init__(functions=functions, system_prompt=system_prompt)
|
||||
|
||||
|
||||
__all__ = ["TavilySearchTool", "TavilySearchToolSet"]
|
||||
4
ai/tools/web_search/prompt.md
Normal file
4
ai/tools/web_search/prompt.md
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
## Веб-поиск
|
||||
Для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных используй функцию `tavily_search`.
|
||||
- Вызывай функцию поиска, когда нужна актуальная информация из интернета.
|
||||
- После получения результатов дай пользователю краткую сводку найденной информации.
|
||||
56
ai/tools/web_search/tavily_search.py
Normal file
56
ai/tools/web_search/tavily_search.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
from tavily import TavilyClient
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from openrouter.components import ChatMessageContentItemTypedDict
|
||||
|
||||
from ai.tool import Tool
|
||||
from ai.utils import *
|
||||
|
||||
|
||||
class TavilySearchTool(Tool):
|
||||
def __init__(self, tavily_token: str):
|
||||
self._client = TavilyClient(api_key=tavily_token)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "tavily_search"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Поиск информации в интернете"
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Запрос для поиска (на русском или английском языке)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
async def execute(self, args: Dict[str, Any], _artifacts: Dict[str, Any]) -> List[ChatMessageContentItemTypedDict]:
|
||||
query = args.get("query", "")
|
||||
print(f"Веб-поиск: {query}")
|
||||
|
||||
try:
|
||||
results = self._client.search(query=query, max_results=5)
|
||||
|
||||
if not results or "results" not in results:
|
||||
return serialize_message_content(text="Не удалось получить результаты поиска.")
|
||||
|
||||
answer_parts = []
|
||||
for i, result in enumerate(results["results"], 1):
|
||||
title = result.get("title", "Без названия")
|
||||
url = result.get("url", "")
|
||||
content = result.get("content", "")
|
||||
answer_parts.append(f"{i}. {title}\n {url}\n {content}\n")
|
||||
|
||||
answer = "\n".join(answer_parts)
|
||||
return serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}")
|
||||
except Exception as e:
|
||||
print(f"Ошибка веб-поиска: {e}")
|
||||
return serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}")
|
||||
39
ai/utils.py
Normal file
39
ai/utils.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from base64 import b64encode
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
def serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> List[Dict]:
|
||||
content = []
|
||||
if text is not None:
|
||||
content.append({"type": "text", "text": text})
|
||||
if image is not None:
|
||||
content.append({"type": "image_url", "detail": "high", "image_url": {"url": encode_image(image)}})
|
||||
return content
|
||||
|
||||
|
||||
def encode_image(image: bytes) -> str:
|
||||
encoded_image = b64encode(image).decode('utf-8')
|
||||
return f"data:image/jpeg;base64,{encoded_image}"
|
||||
|
||||
|
||||
def compress_image(image: bytes, max_side: Optional[int] = None) -> bytes:
|
||||
img = Image.open(BytesIO(image)).convert("RGB")
|
||||
|
||||
if img.width > max_side or img.height > max_side:
|
||||
scale = min(max_side / img.width, max_side / img.height)
|
||||
new_width = int(img.width * scale)
|
||||
new_height = int(img.height * scale)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
output = BytesIO()
|
||||
img.save(output, format='JPEG', quality=87, optimize=True)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"serialize_message_content",
|
||||
"compress_image",
|
||||
"encode_image"
|
||||
]
|
||||
|
|
@ -161,8 +161,9 @@ class BasicDatabase:
|
|||
self.cursor.execute(query, values)
|
||||
|
||||
def context_set_last_message_id(self, bot_id: int, chat_id: int, message_id: int):
|
||||
self.cursor.execute("UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL",
|
||||
message_id, bot_id, chat_id)
|
||||
self.cursor.execute(
|
||||
"UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL",
|
||||
message_id, bot_id, chat_id)
|
||||
|
||||
def _context_trim(self, bot_id: int, chat_id: int, max_messages: int):
|
||||
current_count = self.context_get_count(bot_id, chat_id)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
MESSAGE_CHAT_NOT_ACTIVE = 'Извините, но я пока не работаю в этом чате.'
|
||||
MESSAGE_PERMISSION_DENIED = 'Извините, но о таком меня может попросить только администратор чата.'
|
||||
MESSAGE_NEED_REPLY = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение.'
|
||||
MESSAGE_NEED_REPLY_OR_FORWARD = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение или с пересылкой текстовых сообщений.'
|
||||
MESSAGE_NEED_REPLY_OR_FORWARD = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение ' \
|
||||
'или с пересылкой текстовых сообщений.'
|
||||
MESSAGE_UNSUPPORTED_CONTENT_TYPE = 'Извините, но я понимаю только текст и изображения.'
|
||||
MESSAGE_DEFAULT_RULES = 'Правила не установлены. Просто ведите себя хорошо.'
|
||||
MESSAGE_DEFAULT_CHECK_RULES = 'Правила чата не установлены. Проверка невозможна.'
|
||||
|
|
|
|||
|
|
@ -1,67 +0,0 @@
|
|||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image",
|
||||
"description": "Генерация изображения по описанию",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Подробное описание сцены на английском языке БЕЗ технических параметров (соотношение сторон, разрешение)"
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["1:1", "4:3", "3:4", "16:9", "9:16"],
|
||||
"description": "Соотношение сторон"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image_anime",
|
||||
"description": "Генерация изображения в стиле аниме по описанию",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Положительный запрос"
|
||||
},
|
||||
"negative_prompt": {
|
||||
"type": "string",
|
||||
"description": "Отрицательный запрос"
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["1:1", "4:3", "3:4", "16:9", "9:16", "9:20"],
|
||||
"description": "Соотношение сторон"
|
||||
}
|
||||
},
|
||||
"required": ["prompt", "negative_prompt"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tavily_search",
|
||||
"description": "Веб-поиск по теме запроса. Используй для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Запрос для поиска (на русском или английском языке)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -4,7 +4,7 @@ import json
|
|||
|
||||
from aiogram import Bot, Dispatcher
|
||||
|
||||
from ai_agent import create_ai_agent
|
||||
from ai import create_ai_agent
|
||||
|
||||
import tg.tg_database as database
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from aiogram import Bot, Router, F
|
|||
from aiogram.types import Message
|
||||
from aiogram.utils.formatting import Bold
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ async def clear_context_handler(message: Message, bot: Bot):
|
|||
await message.answer(MESSAGE_PERMISSION_DENIED)
|
||||
return
|
||||
|
||||
ai_agent.agent.clear_chat_context(bot.id, chat_id)
|
||||
ai.agent.clear_chat_context(bot.id, chat_id)
|
||||
await message.answer("Контекст очищен.")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from aiogram import Router, F, Bot
|
|||
from aiogram.types import Message
|
||||
from aiogram.enums.content_type import ContentType
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
|
||||
bot_user = await bot.me()
|
||||
|
||||
ai_fwd_messages: list[ai_agent.Message] = []
|
||||
ai_fwd_messages: list[ai.Message] = []
|
||||
|
||||
try:
|
||||
message_text = get_message_text(message)
|
||||
|
|
@ -64,7 +64,7 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)]
|
||||
elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id:
|
||||
# Ответ на сообщение бота
|
||||
last_id = ai_agent.agent.get_last_assistant_message_id(bot.id, chat_id)
|
||||
last_id = ai.agent.get_last_assistant_message_id(bot.id, chat_id)
|
||||
if message.reply_to_message.message_id != last_id:
|
||||
# Оригинального сообщения нет в контексте, или оно не последнее -> переслать его
|
||||
ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)]
|
||||
|
|
@ -77,10 +77,10 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
ai_message = await create_ai_message(message, bot)
|
||||
ai_message.text = message_text
|
||||
|
||||
answer: ai_agent.Message
|
||||
answer: ai.agent.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(ai.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4)
|
||||
|
||||
|
|
@ -91,4 +91,4 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
else:
|
||||
answer_id = (await message.reply(answer.text)).message_id
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||
ai.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from functools import partial
|
||||
|
||||
from aiogram import Router, F, Bot
|
||||
from aiogram.enums import ChatType, ContentType
|
||||
from aiogram.enums import ChatType
|
||||
from aiogram.filters import Command, CommandObject, CommandStart
|
||||
from aiogram.types import Message
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ async def reset_context_handler(message: Message, bot: Bot):
|
|||
chat_id = message.chat.id
|
||||
database.DB.create_chat_if_not_exists(bot.id, chat_id)
|
||||
|
||||
ai_agent.agent.clear_chat_context(bot.id, chat_id)
|
||||
ai.agent.clear_chat_context(bot.id, chat_id)
|
||||
await message.answer("Контекст очищен.")
|
||||
|
||||
|
||||
|
|
@ -52,10 +52,10 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
|
||||
answer: ai_agent.Message
|
||||
answer: ai.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, ai_message),
|
||||
partial(ai.agent.get_private_chat_reply, bot.id, chat_id, ai_message),
|
||||
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4)
|
||||
|
||||
|
|
@ -66,4 +66,4 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
else:
|
||||
answer_id = (await message.answer(answer.text)).message_id
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||
ai.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from aiogram.enums import ContentType
|
|||
from aiogram.types import Message
|
||||
from aiogram.utils.formatting import Bold, Italic
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
|
|
@ -214,18 +214,18 @@ async def check_rules_violation_handler(message: Message, bot: Bot):
|
|||
prompt += chat_rules + '\n\n'
|
||||
prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):'
|
||||
|
||||
ai_message = ai_agent.Message(user_name=await get_user_name_for_ai(message.from_user),
|
||||
text=prompt, message_id=message.message_id)
|
||||
ai_fwd_messages = [ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user),
|
||||
text=message.reply_to_message.text)]
|
||||
ai_message = ai.Message(user_name=await get_user_name_for_ai(message.from_user),
|
||||
text=prompt, message_id=message.message_id)
|
||||
ai_fwd_messages = [ai.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user),
|
||||
text=message.reply_to_message.text)]
|
||||
|
||||
answer: ai_agent.Message
|
||||
answer: ai.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(ai.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4)
|
||||
|
||||
answer_id = (await message.answer(answer.text)).message_id
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||
ai.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||
|
|
|
|||
|
|
@ -39,8 +39,9 @@ class TgDatabase(database.BasicDatabase):
|
|||
warnings TINYINT NOT NULL DEFAULT 0,
|
||||
about VARCHAR(1000),
|
||||
PRIMARY KEY (bot_id, chat_id, user_id),
|
||||
CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE)
|
||||
""")
|
||||
CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
|
||||
ON UPDATE CASCADE ON DELETE CASCADE
|
||||
)""")
|
||||
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS contexts (
|
||||
|
|
@ -52,8 +53,9 @@ class TgDatabase(database.BasicDatabase):
|
|||
text VARCHAR(4000),
|
||||
image MEDIUMBLOB,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE)
|
||||
""")
|
||||
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
|
||||
ON UPDATE CASCADE ON DELETE CASCADE
|
||||
)""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
|
|
|
|||
17
tg/utils.py
17
tg/utils.py
|
|
@ -6,7 +6,7 @@ from aiogram import Bot
|
|||
from aiogram.enums import ContentType
|
||||
from aiogram.types import User, PhotoSize, Message, BufferedInputFile
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
|
||||
|
||||
|
|
@ -36,8 +36,8 @@ def get_message_text(message: Message) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message:
|
||||
ai_message = ai_agent.Message()
|
||||
async def create_ai_message(message: Message, bot: Bot) -> ai.Message:
|
||||
ai_message = ai.Message()
|
||||
ai_message.message_id = message.message_id
|
||||
ai_message.user_name = await get_user_name_for_ai(message.from_user)
|
||||
if message.content_type == ContentType.TEXT:
|
||||
|
|
@ -64,3 +64,14 @@ def wrap_document(document: bytes, name_prefix: str, extension: str) -> Buffered
|
|||
|
||||
def trim_caption(caption: str) -> str:
|
||||
return caption[:1024]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_ai_message",
|
||||
"get_message_text",
|
||||
"get_user_name_for_ai",
|
||||
"trim_caption",
|
||||
"wrap_photo",
|
||||
"wrap_document",
|
||||
"wrap_document"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
|
||||
from vkbottle.bot import Bot, run_multibot
|
||||
|
||||
from ai_agent import create_ai_agent
|
||||
from ai import create_ai_agent
|
||||
|
||||
import vk.vk_database as database
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from vkbottle.bot import Message
|
|||
from vkbottle.framework.labeler import BotLabeler
|
||||
from vkbottle_types.codegen.objects import MessagesGetConversationMembers
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
|
|
@ -170,7 +170,7 @@ async def clear_context_handler(message: Message):
|
|||
await message.answer(MESSAGE_PERMISSION_DENIED)
|
||||
return
|
||||
|
||||
ai_agent.agent.clear_chat_context(bot_id, chat_id)
|
||||
ai.agent.clear_chat_context(bot_id, chat_id)
|
||||
await message.answer("Контекст очищен.")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from vkbottle.bot import Message
|
|||
from vkbottle.framework.labeler import BotLabeler
|
||||
from vkbottle_types.codegen.objects import GroupsGroup
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ async def any_message_handler(message: Message):
|
|||
message_text = message_text.replace(bot_username_mention, bot_user.name)
|
||||
bot_mentioned = True
|
||||
|
||||
ai_fwd_messages: list[ai_agent.Message] = []
|
||||
ai_fwd_messages: list[ai.agent.Message] = []
|
||||
|
||||
try:
|
||||
if bot_mentioned:
|
||||
|
|
@ -73,7 +73,7 @@ async def any_message_handler(message: Message):
|
|||
ai_fwd_messages.append(await create_ai_message(fwd_message))
|
||||
elif message.reply_message and message.reply_message.from_id == -bot_user.id:
|
||||
# Ответ на сообщение бота
|
||||
last_id = ai_agent.agent.get_last_assistant_message_id(bot_id, chat_id)
|
||||
last_id = ai.agent.get_last_assistant_message_id(bot_id, chat_id)
|
||||
if message.reply_message.message_id != last_id:
|
||||
# Оригинального сообщения нет в контексте, или оно не последнее -> переслать его
|
||||
ai_fwd_messages = [await create_ai_message(message.reply_message)]
|
||||
|
|
@ -86,10 +86,10 @@ async def any_message_handler(message: Message):
|
|||
ai_message = await create_ai_message(message)
|
||||
ai_message.text = message_text
|
||||
|
||||
answer: ai_agent.Message
|
||||
answer: ai.agent.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(ai.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4)
|
||||
|
||||
|
|
@ -100,4 +100,4 @@ async def any_message_handler(message: Message):
|
|||
answer_id = (await message.reply(answer.text)).conversation_message_id
|
||||
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id)
|
||||
ai.agent.set_last_response_id(bot_id, chat_id, answer_id)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from vkbottle.bot import Message
|
|||
from vkbottle.dispatch.rules.base import RegexRule
|
||||
from vkbottle.framework.labeler import BotLabeler
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ async def reset_context_handler(message: Message):
|
|||
chat_id = message.peer_id
|
||||
database.DB.create_chat_if_not_exists(bot_id, chat_id)
|
||||
|
||||
ai_agent.agent.clear_chat_context(bot_id, chat_id)
|
||||
ai.agent.clear_chat_context(bot_id, chat_id)
|
||||
await message.answer("Контекст очищен.")
|
||||
|
||||
|
||||
|
|
@ -54,10 +54,10 @@ async def any_message_handler(message: Message):
|
|||
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
|
||||
answer: ai_agent.Message
|
||||
answer: ai.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_private_chat_reply, bot_id, chat_id, ai_message),
|
||||
partial(ai.agent.get_private_chat_reply, bot_id, chat_id, ai_message),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4)
|
||||
|
||||
|
|
@ -68,4 +68,4 @@ async def any_message_handler(message: Message):
|
|||
answer_id = (await message.answer(answer.text)).message_id
|
||||
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id)
|
||||
ai.agent.set_last_response_id(bot_id, chat_id, answer_id)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from functools import partial
|
||||
from typing import List, Any
|
||||
|
||||
from vkbottle import bold, italic
|
||||
from vkbottle import bold, italic, API
|
||||
from vkbottle.bot import Message
|
||||
from vkbottle.framework.labeler import BotLabeler
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
|
|
@ -246,31 +246,31 @@ async def check_rules_violation_handler(message: Message):
|
|||
prompt += chat_rules + '\n\n'
|
||||
prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):'
|
||||
|
||||
ai_message = ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.from_id),
|
||||
text=prompt, message_id=message.message_id)
|
||||
ai_fwd_messages: list[ai_agent.Message] = []
|
||||
ai_message = ai.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.from_id),
|
||||
text=prompt, message_id=message.message_id)
|
||||
ai_fwd_messages: list[ai.Message] = []
|
||||
if message.reply_message is not None and len(message.reply_message.text) > 0:
|
||||
ai_fwd_messages.append(
|
||||
ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id),
|
||||
text=message.reply_message.text))
|
||||
ai.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id),
|
||||
text=message.reply_message.text))
|
||||
else:
|
||||
for fwd_message in message.fwd_messages:
|
||||
if len(fwd_message.text) > 0:
|
||||
ai_fwd_messages.append(
|
||||
ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id),
|
||||
text=fwd_message.text))
|
||||
ai.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id),
|
||||
text=fwd_message.text))
|
||||
|
||||
if len(ai_fwd_messages) == 0:
|
||||
await message.answer(MESSAGE_NEED_REPLY_OR_FORWARD)
|
||||
return
|
||||
|
||||
answer: ai_agent.Message
|
||||
answer: ai.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(ai.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4)
|
||||
|
||||
answer_id = (await message.answer(answer.text)).message_id
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id)
|
||||
ai.agent.set_last_response_id(bot_id, chat_id, answer_id)
|
||||
|
|
|
|||
17
vk/utils.py
17
vk/utils.py
|
|
@ -7,7 +7,7 @@ from vkbottle.bot import Message
|
|||
from vkbottle_types.codegen.objects import PhotosPhotoSizes
|
||||
from vkbottle_types.objects import MessagesMessageAttachmentType
|
||||
|
||||
import ai_agent
|
||||
import ai
|
||||
import utils
|
||||
|
||||
|
||||
|
|
@ -18,6 +18,7 @@ class MyAPI(API):
|
|||
|
||||
|
||||
def get_bot_id(api: API) -> int:
|
||||
# noinspection PyTypeChecker
|
||||
my_api: MyAPI = api
|
||||
return my_api.bot_id
|
||||
|
||||
|
|
@ -48,8 +49,8 @@ async def download_photo(photos: List[PhotosPhotoSizes]) -> bytes:
|
|||
raise RuntimeError(f"Failed to download photo. Status code: {response.status}")
|
||||
|
||||
|
||||
async def create_ai_message(message: Message) -> ai_agent.Message:
|
||||
ai_message = ai_agent.Message()
|
||||
async def create_ai_message(message: Message) -> ai.Message:
|
||||
ai_message = ai.Message()
|
||||
ai_message.message_id = message.conversation_message_id
|
||||
ai_message.user_name = await get_user_name_for_ai(message.ctx_api, message.from_id)
|
||||
if len(message.text) > 0:
|
||||
|
|
@ -67,3 +68,13 @@ async def create_ai_message(message: Message) -> ai_agent.Message:
|
|||
|
||||
async def upload_photo(image: bytes, chat_id: int, api: API) -> str:
|
||||
return await PhotoMessageUploader(api).upload(file_source=image, peer_id=chat_id)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MyAPI",
|
||||
"get_bot_id",
|
||||
"get_user_name_for_ai",
|
||||
"download_photo",
|
||||
"create_ai_message",
|
||||
"upload_photo"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -42,8 +42,9 @@ class VkDatabase(database.BasicDatabase):
|
|||
happy_birthday TINYINT NOT NULL DEFAULT 1,
|
||||
about VARCHAR(1000),
|
||||
PRIMARY KEY (bot_id, chat_id, user_id),
|
||||
CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE)
|
||||
""")
|
||||
CONSTRAINT fk_users_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
|
||||
ON UPDATE CASCADE ON DELETE CASCADE
|
||||
)""")
|
||||
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS contexts (
|
||||
|
|
@ -55,8 +56,9 @@ class VkDatabase(database.BasicDatabase):
|
|||
text VARCHAR(4000),
|
||||
image MEDIUMBLOB,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE)
|
||||
""")
|
||||
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id)
|
||||
ON UPDATE CASCADE ON DELETE CASCADE
|
||||
)""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue