From 6a1c1d0ff5f5ba5ee3472284cf73b8898cf56e2e Mon Sep 17 00:00:00 2001 From: BennyKok Date: Sun, 10 Dec 2023 12:41:36 +0800 Subject: [PATCH] chore: add custom route in comfy plugin --- routes.py | 159 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 155 insertions(+), 4 deletions(-) diff --git a/routes.py b/routes.py index 942a809..c68cb0c 100644 --- a/routes.py +++ b/routes.py @@ -14,9 +14,160 @@ import time import execution import random +import uuid +import websockets +import asyncio +import atexit +import logging + +api = None +api_task = None + load_dotenv() -@server.PromptServer.instance.routes.get("/comfy-deploy/run") -async def get_web_styles(request): - filename = os.path.join(os.path.dirname(__file__), "js/tw-styles.css") - return web.FileResponse(filename) \ No newline at end of file +def post_prompt(json_data): + prompt_server = server.PromptServer.instance + json_data = prompt_server.trigger_on_prompt(json_data) + + if "number" in json_data: + number = float(json_data["number"]) + else: + number = prompt_server.number + if "front" in json_data: + if json_data["front"]: + number = -number + + prompt_server.number += 1 + + if "prompt" in json_data: + prompt = json_data["prompt"] + valid = execution.validate_prompt(prompt) + extra_data = {} + if "extra_data" in json_data: + extra_data = json_data["extra_data"] + + if "client_id" in json_data: + extra_data["client_id"] = json_data["client_id"] + if valid[0]: + prompt_id = str(uuid.uuid4()) + outputs_to_execute = valid[2] + prompt_server.prompt_queue.put( + (number, prompt_id, prompt, extra_data, outputs_to_execute) + ) + response = { + "prompt_id": prompt_id, + "number": number, + "node_errors": valid[3], + } + return response + else: + print("invalid prompt:", valid[1]) + return {"error": valid[1], "node_errors": valid[3]} + else: + return {"error": "no prompt", "node_errors": []} + +def randomSeed(num_digits=15): + range_start = 10 ** (num_digits - 1) + range_end = (10**num_digits) - 1 + return random.randint(range_start, range_end) + +@server.PromptServer.instance.routes.post("/comfy-deploy/run") +async def comfy_deploy_run(request): + prompt_server = server.PromptServer.instance + data = await request.json() + + for key in data: + if 'inputs' in data[key] and 'seed' in data[key]['inputs']: + data[key]['inputs']['seed'] = randomSeed() + + if api is None: + connect_to_websocket() + while api.client_id is None: + await asyncio.sleep(0.1) + + workflow_api = data.get("workflow_api") + # print(workflow_api) + + prompt = { + "prompt": workflow_api, + "client_id": "fake_client" #api.client_id + } + + res = post_prompt(prompt) + + # print(prompt) + + status = 200 + if "error" in res: + status = 400 + + return web.json_response(res, status=status) + +logging.basicConfig(level=logging.INFO) + +class ComfyApi: + def __init__(self): + self.websocket = None + self.client_id = None + + async def connect(self, uri): + self.websocket = await websockets.connect(uri) + + # Event listeners + await self.on_open() + await self.on_message() + await self.on_close() + + async def close(self): + await self.websocket.close() + + async def on_open(self): + print("Connection opened") + + async def on_message(self): + async for message in self.websocket: + if isinstance(message, bytes): + print("Received binary message, skipping...") + continue # skip to the next message + logging.info(f"Received message: {message}") + + try: + message_data = json.loads(message) + + msg_type = message_data["type"] + + if msg_type == "status" and message_data["data"]["sid"] is not None: + self.client_id = message_data["data"]["sid"] + logging.info(f"Received client_id: {self.client_id}") + + except json.JSONDecodeError: + logging.info(f"Failed to parse message as JSON: {message}") + + async def on_close(self): + print("Connection closed") + + async def run(self, uri): + await self.connect(uri) + +def connect_to_websocket(): + global api, api_task + api = ComfyApi() + api_task = asyncio.create_task(api.run('ws://localhost:8188/ws')) + +prompt_server = server.PromptServer.instance + +send_json = prompt_server.send_json +async def send_json_override(self, event, data, sid=None): + await self.send_json_original(event, data, sid) + print("Sending event:", sid, event, data) + +prompt_server.send_json_original = prompt_server.send_json +prompt_server.send_json = send_json_override.__get__(prompt_server, server.PromptServer) + +@atexit.register +def close_websocket(): + print("Got close_websocket") + + global api, api_task + if api_task: + api_task.cancel() \ No newline at end of file