diff --git a/custom_routes.py b/custom_routes.py index f14563e..6575ebf 100644 --- a/custom_routes.py +++ b/custom_routes.py @@ -336,6 +336,7 @@ def apply_inputs_to_workflow(workflow_api: Any, inputs: Any, sid: str = None): def send_prompt(sid: str, inputs: StreamingPrompt): # workflow_api = inputs.workflow_api workflow_api = copy.deepcopy(inputs.workflow_api) + workflow = copy.deepcopy(inputs.workflow) # Random seed apply_random_seed_to_workflow(workflow_api) @@ -351,7 +352,8 @@ def send_prompt(sid: str, inputs: StreamingPrompt): prompt = { "prompt": workflow_api, "client_id": sid, #"comfy_deploy_instance", #api.client_id - "prompt_id": prompt_id + "prompt_id": prompt_id, + "extra_data": {"extra_pnginfo": {"workflow": workflow}}, } try: @@ -418,6 +420,7 @@ async def comfy_deploy_run(request): # The prompt id generated from comfy deploy, can be None prompt_id = data.get("prompt_id") inputs = data.get("inputs") + workflow = data.get("workflow") # Now it handles directly in here apply_random_seed_to_workflow(workflow_api) @@ -427,6 +430,7 @@ async def comfy_deploy_run(request): "prompt": workflow_api, "client_id": "comfy_deploy_instance" if client_id is None else client_id, "prompt_id": prompt_id, + "extra_data": {"extra_pnginfo": {"workflow": workflow}} } prompt_metadata[prompt_id] = SimplePrompt( @@ -477,6 +481,7 @@ async def stream_prompt(data, token): # The prompt id generated from comfy deploy, can be None prompt_id = data.get("prompt_id") inputs = data.get("inputs") + workflow = data.get("workflow") # Now it handles directly in here apply_random_seed_to_workflow(workflow_api) @@ -485,7 +490,8 @@ async def stream_prompt(data, token): prompt = { "prompt": workflow_api, "client_id": "comfy_deploy_instance", #api.client_id - "prompt_id": prompt_id + "prompt_id": prompt_id, + "extra_data": {"extra_pnginfo": {"workflow": workflow}}, } prompt_metadata[prompt_id] = SimplePrompt( @@ -819,6 +825,7 @@ async def websocket_handler(request): inputs={}, status_endpoint=status_endpoint, file_upload_endpoint=request.rel_url.query.get('file_upload_endpoint', None), + workflow=workflow["workflow"], ) await update_realtime_run_status(realtime_id, status_endpoint, Status.RUNNING) diff --git a/globals.py b/globals.py index 3ee2658..ea705ee 100644 --- a/globals.py +++ b/globals.py @@ -24,6 +24,7 @@ class StreamingPrompt(BaseModel): running_prompt_ids: set[str] = set() status_endpoint: Optional[str] file_upload_endpoint: Optional[str] + workflow: Any class SimplePrompt(BaseModel): status_endpoint: Optional[str]