Генерация обычных изображений через Replicate.
This commit is contained in:
parent
4b265b5405
commit
2e67865b1d
7 changed files with 21 additions and 37 deletions
|
|
@ -9,11 +9,11 @@ agent_instance: ai.agent.AiAgent
|
||||||
|
|
||||||
|
|
||||||
def create_ai_agent(openrouter_token: str, openrouter_model: str,
|
def create_ai_agent(openrouter_token: str, openrouter_model: str,
|
||||||
fal_token: str, replicate_token: str, tavily_token: str,
|
replicate_token: str, tavily_token: str,
|
||||||
db: BasicDatabase, platform: str):
|
db: BasicDatabase, platform: str):
|
||||||
global agent_instance
|
global agent_instance
|
||||||
agent_instance = ai.agent.AiAgent(openrouter_token, openrouter_model,
|
agent_instance = ai.agent.AiAgent(openrouter_token, openrouter_model,
|
||||||
fal_token, replicate_token, tavily_token,
|
replicate_token, tavily_token,
|
||||||
db, platform)
|
db, platform)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ class Message:
|
||||||
class AiAgent:
|
class AiAgent:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
openrouter_token: str, openrouter_model: str,
|
openrouter_token: str, openrouter_model: str,
|
||||||
fal_token: str, replicate_token: str, tavily_token: str,
|
replicate_token: str, tavily_token: str,
|
||||||
db: BasicDatabase,
|
db: BasicDatabase,
|
||||||
platform: str):
|
platform: str):
|
||||||
retry_config = RetryConfig(strategy="backoff",
|
retry_config = RetryConfig(strategy="backoff",
|
||||||
|
|
@ -58,7 +58,7 @@ class AiAgent:
|
||||||
# Создание наборов инструментов
|
# Создание наборов инструментов
|
||||||
self.toolsets: list[ai.tool.ToolSet] = []
|
self.toolsets: list[ai.tool.ToolSet] = []
|
||||||
self.toolsets.append(
|
self.toolsets.append(
|
||||||
ImageGenerationToolSet(fal_token=fal_token, replicate_token=replicate_token)
|
ImageGenerationToolSet(replicate_token=replicate_token)
|
||||||
)
|
)
|
||||||
self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token))
|
self.toolsets.append(TavilySearchToolSet(tavily_token=tavily_token))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,9 @@ from .generate_image_anime import GenerateImageAnimeTool
|
||||||
|
|
||||||
|
|
||||||
class ImageGenerationToolSet(ToolSet):
|
class ImageGenerationToolSet(ToolSet):
|
||||||
def __init__(self, fal_token: str, replicate_token: str):
|
def __init__(self, replicate_token: str):
|
||||||
functions = [
|
functions = [
|
||||||
GenerateImageTool(fal_token),
|
GenerateImageTool(replicate_token),
|
||||||
GenerateImageAnimeTool(replicate_token)
|
GenerateImageAnimeTool(replicate_token)
|
||||||
]
|
]
|
||||||
with open("ai/tools/image_generation/prompt.md", "r") as f:
|
with open("ai/tools/image_generation/prompt.md", "r") as f:
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from openrouter.components import ChatToolMessageContentTypedDict
|
from openrouter.components import ChatToolMessageContentTypedDict
|
||||||
from fal_client import AsyncClient as FalClient
|
from replicate import Client as ReplicateClient
|
||||||
|
|
||||||
from ai.tool import Tool
|
from ai.tool import Tool
|
||||||
from ai.utils import *
|
from ai.utils import *
|
||||||
|
|
||||||
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
|
REPLICATE_MODEL = "bytedance/seedream-4.5"
|
||||||
|
|
||||||
|
|
||||||
class GenerateImageTool(Tool):
|
class GenerateImageTool(Tool):
|
||||||
def __init__(self, fal_token: str):
|
def __init__(self, replicate_token: str):
|
||||||
self._client = FalClient(key=fal_token)
|
self._client = ReplicateClient(api_token=replicate_token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
|
@ -43,34 +43,19 @@ class GenerateImageTool(Tool):
|
||||||
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
|
async def execute(self, args: Dict[str, Any], artifacts: Dict[str, Any]) -> List[ChatToolMessageContentTypedDict]:
|
||||||
prompt = args.get("prompt", "")
|
prompt = args.get("prompt", "")
|
||||||
aspect_ratio = args.get("aspect_ratio", "4:3")
|
aspect_ratio = args.get("aspect_ratio", "4:3")
|
||||||
|
print(f"Генерация изображения {aspect_ratio}: {prompt}")
|
||||||
aspect_ratio_size_map = {
|
|
||||||
"1:1": "square",
|
|
||||||
"4:3": "landscape_4_3",
|
|
||||||
"3:4": "portrait_4_3",
|
|
||||||
"16:9": "landscape_16_9",
|
|
||||||
"9:16": "portrait_16_9",
|
|
||||||
"9:20": "portrait_16_9"
|
|
||||||
}
|
|
||||||
image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3")
|
|
||||||
print(f"Генерация изображения {image_size}: {prompt}")
|
|
||||||
|
|
||||||
arguments = {
|
arguments = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"image_size": image_size,
|
"aspect_ratio": aspect_ratio,
|
||||||
"enable_safety_checker": False
|
"disable_safety_checker": True
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await self._client.run(FAL_MODEL, arguments=arguments)
|
outputs: Any = await self._client.async_run(REPLICATE_MODEL, input=arguments)
|
||||||
if "images" not in result:
|
artifacts["generated_image_hires"] = await outputs[0].aread()
|
||||||
raise RuntimeError("Неожиданный ответ от сервера.")
|
|
||||||
image_url = result["images"][0]["url"]
|
|
||||||
|
|
||||||
from utils import download_file
|
|
||||||
artifacts["generated_image_hires"] = await download_file(image_url)
|
|
||||||
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)
|
artifacts["generated_image"] = compress_image(artifacts["generated_image_hires"], 1280)
|
||||||
|
|
||||||
return serialize_message_content(
|
return serialize_message_content(
|
||||||
text="Изображение сгенерировано и будет показано пользователю.",
|
text="Изображение сгенерировано и будет показано пользователю.",
|
||||||
image=None
|
image=None
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ vkbottle-types~=5.199.99.18
|
||||||
pyodbc~=5.3.0
|
pyodbc~=5.3.0
|
||||||
openrouter==0.8.1
|
openrouter==0.8.1
|
||||||
replicate~=1.0.7
|
replicate~=1.0.7
|
||||||
fal_client~=0.13.2
|
|
||||||
tavily~=1.1.0
|
tavily~=1.1.0
|
||||||
pillow~=12.2.0
|
pillow~=12.2.0
|
||||||
pymorphy3~=2.0.6
|
pymorphy3~=2.0.6
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,8 @@ async def main() -> None:
|
||||||
database.create_database(config['db_connection_string'])
|
database.create_database(config['db_connection_string'])
|
||||||
|
|
||||||
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
||||||
config['fal_token'], config['replicate_token'],
|
config['replicate_token'], config['tavily_token'],
|
||||||
config['tavily_token'], database.DB, 'tg')
|
database.DB, 'tg')
|
||||||
|
|
||||||
bots: list[Bot] = []
|
bots: list[Bot] = []
|
||||||
for item in database.DB.get_bots():
|
for item in database.DB.get_bots():
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,8 @@ if __name__ == '__main__':
|
||||||
database.create_database(config['db_connection_string'])
|
database.create_database(config['db_connection_string'])
|
||||||
|
|
||||||
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
||||||
config['fal_token'], config['replicate_token'],
|
config['replicate_token'], config['tavily_token'],
|
||||||
config['tavily_token'], database.DB, 'vk')
|
database.DB, 'vk')
|
||||||
|
|
||||||
bot = Bot(labeler=handlers.labeler)
|
bot = Bot(labeler=handlers.labeler)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue