comfyui-deploy/routes.py

143 lines
3.8 KiB
Python

from aiohttp import web
from dotenv import load_dotenv
import os
import requests
import folder_paths
import json
import numpy as np
import server
import re
import base64
from PIL import Image
import io
import time
import execution
import random
import uuid
import asyncio
import atexit
import logging
from enum import Enum
api = None
api_task = None
prompt_metadata = {}
load_dotenv()
def post_prompt(json_data):
prompt_server = server.PromptServer.instance
json_data = prompt_server.trigger_on_prompt(json_data)
if "number" in json_data:
number = float(json_data["number"])
else:
number = prompt_server.number
if "front" in json_data:
if json_data["front"]:
number = -number
prompt_server.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = execution.validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
prompt_server.prompt_queue.put(
(number, prompt_id, prompt, extra_data, outputs_to_execute)
)
response = {
"prompt_id": prompt_id,
"number": number,
"node_errors": valid[3],
}
return response
else:
print("invalid prompt:", valid[1])
return {"error": valid[1], "node_errors": valid[3]}
else:
return {"error": "no prompt", "node_errors": []}
def randomSeed(num_digits=15):
range_start = 10 ** (num_digits - 1)
range_end = (10**num_digits) - 1
return random.randint(range_start, range_end)
@server.PromptServer.instance.routes.post("/comfy-deploy/run")
async def comfy_deploy_run(request):
print("hi")
prompt_server = server.PromptServer.instance
data = await request.json()
workflow_api = data.get("workflow_api")
for key in workflow_api:
if 'inputs' in workflow_api[key] and 'seed' in workflow_api[key]['inputs']:
workflow_api[key]['inputs']['seed'] = randomSeed()
prompt = {
"prompt": workflow_api,
"client_id": "fake_client" #api.client_id
}
res = post_prompt(prompt)
prompt_metadata[res['prompt_id']] = {
'status_endpoint': data.get('status_endpoint'),
}
status = 200
if "error" in res:
status = 400
return web.json_response(res, status=status)
logging.basicConfig(level=logging.INFO)
prompt_server = server.PromptServer.instance
send_json = prompt_server.send_json
async def send_json_override(self, event, data, sid=None):
print("INTERNAL:", event, data, sid)
prompt_id = data.get('prompt_id')
if event == 'execution_start':
update_run(prompt_id, Status.RUNNING)
# if event == 'executing':
# update_run(prompt_id, Status.RUNNING)
if event == 'executed':
update_run(prompt_id, Status.SUCCESS)
await self.send_json_original(event, data, sid)
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
def update_run(prompt_id, status: Status):
if prompt_id in prompt_metadata and ('status' not in prompt_metadata[prompt_id] or prompt_metadata[prompt_id]['status'] != status):
status_endpoint = prompt_metadata[prompt_id]['status_endpoint']
body = {
"run_id": prompt_id,
"status": status.value,
}
prompt_metadata[prompt_id]['status'] = status
requests.post(status_endpoint, json=body)
prompt_server.send_json_original = prompt_server.send_json
prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer)