Генерация обычных изображений через Replicate.

This commit is contained in:
Kirill Kirilenko 2026-04-08 21:04:58 +03:00
parent 4b265b5405
commit 2e67865b1d
7 changed files with 21 additions and 37 deletions

View file

@ -9,11 +9,11 @@ agent_instance: ai.agent.AiAgent
def create_ai_agent(openrouter_token: str, openrouter_model: str, def create_ai_agent(openrouter_token: str, openrouter_model: str,
fal_token: str, replicate_token: str, tavily_token: str, replicate_token: str, tavily_token: str,
db: BasicDatabase, platform: str): db: BasicDatabase, platform: str):
global agent_instance global agent_instance
agent_instance = ai.agent.AiAgent(openrouter_token, openrouter_model, agent_instance = ai.agent.AiAgent(openrouter_token, openrouter_model,
fal_token, replicate_token, tavily_token, replicate_token, tavily_token,
db, platform) db, platform)

View file

@ -37,7 +37,7 @@ class Message:
class AiAgent: class AiAgent:
def __init__(self, def __init__(self,
openrouter_token: str, openrouter_model: str, openrouter_token: str, openrouter_model: str,
fal_token: str, replicate_token: str, tavily_token: str, replicate_token: str, tavily_token: str,
db: BasicDatabase, db: BasicDatabase,
platform: str): platform: str):
retry_config = RetryConfig(strategy="backoff", retry_config = RetryConfig(strategy="backoff",
@ -58,7 +58,7 @@ class AiAgent:
# Создание наборов инструментов # Создание наборов инструментов
self.toolsets: list[ai.tool.ToolSet] = [] self.toolsets: list[ai.tool.ToolSet] = []
self.toolsets.append( self.toolsets.append(
ImageGenerationToolSet(fal_token=fal_token, replicate_token=replicate_token) ImageGenerationToolSet(replicate_token=replicate_token)
) )
self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token)) self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token))

View file

@ -5,9 +5,9 @@ from .generate_image_anime import GenerateImageAnimeTool
class ImageGenerationToolSet(ToolSet): class ImageGenerationToolSet(ToolSet):
def __init__(self, fal_token: str, replicate_token: str): def __init__(self, replicate_token: str):
functions = [ functions = [
GenerateImageTool(fal_token), GenerateImageTool(replicate_token),
GenerateImageAnimeTool(replicate_token) GenerateImageAnimeTool(replicate_token)
] ]
with open("ai/tools/image_generation/prompt.md", "r") as f: with open("ai/tools/image_generation/prompt.md", "r") as f:

View file

@ -1,17 +1,17 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from openrouter.components import ChatToolMessageContentTypedDict from openrouter.components import ChatToolMessageContentTypedDict
from fal_client import AsyncClient as FalClient from replicate import Client as ReplicateClient
from ai.tool import Tool from ai.tool import Tool
from ai.utils import * from ai.utils import *
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image" REPLICATE_MODEL = "bytedance/seedream-4.5"
class GenerateImageTool(Tool): class GenerateImageTool(Tool):
def __init__(self, fal_token: str): def __init__(self, replicate_token: str):
self._client = FalClient(key=fal_token) self._client = ReplicateClient(api_token=replicate_token)
@property @property
def name(self) -> str: def name(self) -> str:
@ -43,34 +43,19 @@ class GenerateImageTool(Tool):
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]: async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
prompt = args.get("prompt", "") prompt = args.get("prompt", "")
aspect_ratio = args.get("aspect_ratio", "4:3") aspect_ratio = args.get("aspect_ratio", "4:3")
print(f"Генерация изображения {aspect_ratio}: {prompt}")
aspect_ratio_size_map = {
"1:1": "square",
"4:3": "landscape_4_3",
"3:4": "portrait_4_3",
"16:9": "landscape_16_9",
"9:16": "portrait_16_9",
"9:20": "portrait_16_9"
}
image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3")
print(f"Генерация изображения {image_size}: {prompt}")
arguments = { arguments = {
"prompt": prompt, "prompt": prompt,
"image_size": image_size, "aspect_ratio": aspect_ratio,
"enable_safety_checker": False "disable_safety_checker": True
} }
try: try:
result = await self._client.run(FAL_MODEL, arguments=arguments) outputs: Any = await self._client.async_run(REPLICATE_MODEL, input=arguments)
if "images" not in result: artifacts["generated_image_hires"] = await outputs[0].aread()
raise RuntimeError("Неожиданный ответ от сервера.")
image_url = result["images"][0]["url"]
from utils import download_file
artifacts["generated_image_hires"] = await download_file(image_url)
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280) artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)
return serialize_message_content( return serialize_message_content(
text="Изображение сгенерировано и будет показано пользователю.", text="Изображение сгенерировано и будет показано пользователю.",
image=None image=None

View file

@ -5,7 +5,6 @@ vkbottle-types~=5.199.99.18
pyodbc~=5.3.0 pyodbc~=5.3.0
openrouter==0.8.1 openrouter==0.8.1
replicate~=1.0.7 replicate~=1.0.7
fal_client~=0.13.2
tavily~=1.1.0 tavily~=1.1.0
pillow~=12.2.0 pillow~=12.2.0
pymorphy3~=2.0.6 pymorphy3~=2.0.6

View file

@ -25,8 +25,8 @@ async def main() -> None:
database.create_database(config['db_connection_string']) database.create_database(config['db_connection_string'])
create_ai_agent(config['openrouter_token'], config['openrouter_model'], create_ai_agent(config['openrouter_token'], config['openrouter_model'],
config['fal_token'], config['replicate_token'], config['replicate_token'], config['tavily_token'],
config['tavily_token'], database.DB, 'tg') database.DB, 'tg')
bots: list[Bot] = [] bots: list[Bot] = []
for item in database.DB.get_bots(): for item in database.DB.get_bots():

View file

@ -25,8 +25,8 @@ if __name__ == '__main__':
database.create_database(config['db_connection_string']) database.create_database(config['db_connection_string'])
create_ai_agent(config['openrouter_token'], config['openrouter_model'], create_ai_agent(config['openrouter_token'], config['openrouter_model'],
config['fal_token'], config['replicate_token'], config['replicate_token'], config['tavily_token'],
config['tavily_token'], database.DB, 'vk') database.DB, 'vk')
bot = Bot(labeler=handlers.labeler) bot = Bot(labeler=handlers.labeler)