feat(builder): move away from docker file to modal commands
This commit is contained in:
		
							parent
							
								
									7ab4edb069
								
							
						
					
					
						commit
						c339cc4234
					
				@ -15,6 +15,9 @@ import signal
 | 
			
		||||
import logging
 | 
			
		||||
from fastapi.logger import logger as fastapi_logger
 | 
			
		||||
import requests
 | 
			
		||||
from urllib.parse import parse_qs
 | 
			
		||||
from starlette.middleware.base import BaseHTTPMiddleware
 | 
			
		||||
from starlette.types import ASGIApp, Scope, Receive, Send
 | 
			
		||||
 | 
			
		||||
from concurrent.futures import ThreadPoolExecutor
 | 
			
		||||
 | 
			
		||||
@ -41,6 +44,34 @@ global_timeout = 60 * 4
 | 
			
		||||
machine_id_websocket_dict = {}
 | 
			
		||||
machine_id_status = {}
 | 
			
		||||
 | 
			
		||||
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
 | 
			
		||||
 | 
			
		||||
class FlyReplayMiddleware(BaseHTTPMiddleware):
 | 
			
		||||
    """
 | 
			
		||||
    If the wrong instance was picked by the fly.io load balancer we use the fly-replay header
 | 
			
		||||
    to repeat the request again on the right instance.
 | 
			
		||||
 | 
			
		||||
    This only works if the right instance is provided as a query_string parameter.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, app: ASGIApp) -> None:
 | 
			
		||||
        self.app = app
 | 
			
		||||
 | 
			
		||||
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
 | 
			
		||||
        query_string = scope.get('query_string', b'').decode()
 | 
			
		||||
        query_params = parse_qs(query_string)
 | 
			
		||||
        target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
 | 
			
		||||
        async def send_wrapper(message):
 | 
			
		||||
            if target_instance != fly_instance_id:
 | 
			
		||||
                if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']:
 | 
			
		||||
                    # fly.io only seems to look at the fly-replay header if websocket is accepted
 | 
			
		||||
                    message = {'type': 'websocket.accept'}
 | 
			
		||||
                if 'headers' not in message:
 | 
			
		||||
                    message['headers'] = []
 | 
			
		||||
                message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
 | 
			
		||||
            await send(message)
 | 
			
		||||
        await self.app(scope, receive, send_wrapper)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def check_inactivity():
 | 
			
		||||
    global last_activity_time
 | 
			
		||||
    while True:
 | 
			
		||||
@ -49,7 +80,8 @@ async def check_inactivity():
 | 
			
		||||
            if len(machine_id_status) == 0:
 | 
			
		||||
                # The application has been inactive for more than 60 seconds.
 | 
			
		||||
                # Scale it down to zero here.
 | 
			
		||||
                logger.info(f"No activity for {global_timeout} seconds, exiting...")
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    f"No activity for {global_timeout} seconds, exiting...")
 | 
			
		||||
                # os._exit(0)
 | 
			
		||||
                os.kill(os.getpid(), signal.SIGINT)
 | 
			
		||||
                break
 | 
			
		||||
@ -68,9 +100,10 @@ async def lifespan(app: FastAPI):
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
app = FastAPI(lifespan=lifespan)
 | 
			
		||||
 | 
			
		||||
app.add_middleware(FlyReplayMiddleware)
 | 
			
		||||
# MODAL_ORG = os.environ.get("MODAL_ORG")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/")
 | 
			
		||||
def read_root():
 | 
			
		||||
    global last_activity_time
 | 
			
		||||
@ -97,14 +130,17 @@ def read_root():
 | 
			
		||||
#     }
 | 
			
		||||
# }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GitCustomNodes(BaseModel):
 | 
			
		||||
    hash: str
 | 
			
		||||
    disabled: bool
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Snapshot(BaseModel):
 | 
			
		||||
    comfyui: str
 | 
			
		||||
    git_custom_nodes: Dict[str, GitCustomNodes]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Model(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    type: str
 | 
			
		||||
@ -115,12 +151,14 @@ class Model(BaseModel):
 | 
			
		||||
    filename: str
 | 
			
		||||
    url: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GPUType(str, Enum):
 | 
			
		||||
    T4 = "T4"
 | 
			
		||||
    A10G = "A10G"
 | 
			
		||||
    A100 = "A100"
 | 
			
		||||
    L4 = "L4"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Item(BaseModel):
 | 
			
		||||
    machine_id: str
 | 
			
		||||
    name: str
 | 
			
		||||
@ -133,7 +171,8 @@ class Item(BaseModel):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def check_gpu(cls, value):
 | 
			
		||||
        if value not in GPUType.__members__:
 | 
			
		||||
            raise ValueError(f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
 | 
			
		||||
        return GPUType(value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -143,9 +182,11 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
 | 
			
		||||
    machine_id_websocket_dict[machine_id] = websocket
 | 
			
		||||
    # Send existing logs
 | 
			
		||||
    if machine_id in machine_logs_cache:
 | 
			
		||||
        combined_logs = "\n".join(
 | 
			
		||||
            log_entry['logs'] for log_entry in machine_logs_cache[machine_id])
 | 
			
		||||
        await websocket.send_text(json.dumps({"event": "LOGS", "data": {
 | 
			
		||||
            "machine_id": machine_id,
 | 
			
		||||
            "logs": json.dumps(machine_logs_cache[machine_id]) ,
 | 
			
		||||
            "logs": combined_logs,
 | 
			
		||||
            "timestamp": time.time()
 | 
			
		||||
        }}))
 | 
			
		||||
    try:
 | 
			
		||||
@ -173,6 +214,7 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
 | 
			
		||||
 | 
			
		||||
#     return {"Hello": "World"}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/create")
 | 
			
		||||
async def create_item(item: Item):
 | 
			
		||||
    global last_activity_time
 | 
			
		||||
@ -186,12 +228,13 @@ async def create_item(item: Item):
 | 
			
		||||
    # future = executor.submit(build_logic, item)
 | 
			
		||||
    task = asyncio.create_task(build_logic(item))
 | 
			
		||||
 | 
			
		||||
    return JSONResponse(status_code=200, content={"message": "Build Queued"})
 | 
			
		||||
    return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Initialize the logs cache
 | 
			
		||||
machine_logs_cache = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def build_logic(item: Item):
 | 
			
		||||
    # Deploy to modal
 | 
			
		||||
    folder_path = f"/app/builds/{item.machine_id}"
 | 
			
		||||
@ -242,7 +285,9 @@ async def build_logic(item: Item):
 | 
			
		||||
 | 
			
		||||
    machine_logs = machine_logs_cache[item.machine_id]
 | 
			
		||||
 | 
			
		||||
    async def read_stream(stream, isStderr):
 | 
			
		||||
    url_queue = asyncio.Queue()
 | 
			
		||||
 | 
			
		||||
    async def read_stream(stream, isStderr, url_queue: asyncio.Queue):
 | 
			
		||||
        while True:
 | 
			
		||||
            line = await stream.readline()
 | 
			
		||||
            if line:
 | 
			
		||||
@ -265,11 +310,11 @@ async def build_logic(item: Item):
 | 
			
		||||
                            "timestamp": time.time()
 | 
			
		||||
                        }}))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                    if "Created comfyui_app =>" in l or (l.startswith("https://") and l.endswith(".modal.run")):
 | 
			
		||||
                        if "Created comfyui_app =>" in l:
 | 
			
		||||
                    if "Created comfyui_api =>" in l or (l.startswith("https://") and l.endswith(".modal.run")):
 | 
			
		||||
                        if "Created comfyui_api =>" in l:
 | 
			
		||||
                            url = l.split("=>")[1].strip()
 | 
			
		||||
                        else:
 | 
			
		||||
                        # making sure it is a url
 | 
			
		||||
                        elif "comfyui_api" in l:
 | 
			
		||||
                            # Some case it only prints the url on a blank line
 | 
			
		||||
                            url = l
 | 
			
		||||
 | 
			
		||||
@ -279,6 +324,8 @@ async def build_logic(item: Item):
 | 
			
		||||
                                "timestamp": time.time()
 | 
			
		||||
                            })
 | 
			
		||||
 | 
			
		||||
                            await url_queue.put(url)
 | 
			
		||||
 | 
			
		||||
                            if item.machine_id in machine_id_websocket_dict:
 | 
			
		||||
                                await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": {
 | 
			
		||||
                                    "machine_id": item.machine_id,
 | 
			
		||||
@ -309,8 +356,12 @@ async def build_logic(item: Item):
 | 
			
		||||
            else:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
    stdout_task = asyncio.create_task(read_stream(process.stdout, False))
 | 
			
		||||
    stderr_task = asyncio.create_task(read_stream(process.stderr, True))
 | 
			
		||||
    stdout_task = asyncio.create_task(
 | 
			
		||||
        read_stream(process.stdout, False, url_queue))
 | 
			
		||||
    stderr_task = asyncio.create_task(
 | 
			
		||||
        read_stream(process.stderr, True, url_queue))
 | 
			
		||||
 | 
			
		||||
    url = await url_queue.get()
 | 
			
		||||
 | 
			
		||||
    await asyncio.wait([stdout_task, stderr_task])
 | 
			
		||||
 | 
			
		||||
@ -334,7 +385,8 @@ async def build_logic(item: Item):
 | 
			
		||||
            "logs": "Unable to build the app image.",
 | 
			
		||||
            "timestamp": time.time()
 | 
			
		||||
        })
 | 
			
		||||
        requests.post(item.callback_url, json={"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
 | 
			
		||||
        requests.post(item.callback_url, json={
 | 
			
		||||
                      "machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
 | 
			
		||||
 | 
			
		||||
        if item.machine_id in machine_logs_cache:
 | 
			
		||||
            del machine_logs_cache[item.machine_id]
 | 
			
		||||
@ -349,7 +401,8 @@ async def build_logic(item: Item):
 | 
			
		||||
            "logs": "App image built, but url is None, unable to parse the url.",
 | 
			
		||||
            "timestamp": time.time()
 | 
			
		||||
        })
 | 
			
		||||
        requests.post(item.callback_url, json={"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
 | 
			
		||||
        requests.post(item.callback_url, json={
 | 
			
		||||
                      "machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
 | 
			
		||||
 | 
			
		||||
        if item.machine_id in machine_logs_cache:
 | 
			
		||||
            del machine_logs_cache[item.machine_id]
 | 
			
		||||
@ -359,17 +412,20 @@ async def build_logic(item: Item):
 | 
			
		||||
    # example https://bennykok--my-app-comfyui-app.modal.run/
 | 
			
		||||
    # my_url = f"https://{MODAL_ORG}--{item.container_id}-{app_suffix}.modal.run"
 | 
			
		||||
 | 
			
		||||
    requests.post(item.callback_url, json={"machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)})
 | 
			
		||||
    requests.post(item.callback_url, json={
 | 
			
		||||
                  "machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)})
 | 
			
		||||
    if item.machine_id in machine_logs_cache:
 | 
			
		||||
        del machine_logs_cache[item.machine_id]
 | 
			
		||||
 | 
			
		||||
    logger.info("done")
 | 
			
		||||
    logger.info(url)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def start_loop(loop):
 | 
			
		||||
    asyncio.set_event_loop(loop)
 | 
			
		||||
    loop.run_forever()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_in_new_thread(coroutine):
 | 
			
		||||
    new_loop = asyncio.new_event_loop()
 | 
			
		||||
    t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True)
 | 
			
		||||
@ -377,6 +433,7 @@ def run_in_new_thread(coroutine):
 | 
			
		||||
    asyncio.run_coroutine_threadsafe(coroutine, new_loop)
 | 
			
		||||
    return t
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    import uvicorn
 | 
			
		||||
    # , log_level="debug"
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
from config import config
 | 
			
		||||
import modal
 | 
			
		||||
from modal import Image, Mount, web_endpoint, Stub, asgi_app
 | 
			
		||||
import json
 | 
			
		||||
@ -12,7 +13,6 @@ from fastapi.responses import HTMLResponse
 | 
			
		||||
import os
 | 
			
		||||
current_directory = os.path.dirname(os.path.realpath(__file__))
 | 
			
		||||
 | 
			
		||||
from config import config
 | 
			
		||||
deploy_test = config["deploy_test"] == "True"
 | 
			
		||||
# MODAL_IMAGE_ID = os.environ.get('MODAL_IMAGE_ID', None)
 | 
			
		||||
 | 
			
		||||
@ -30,8 +30,41 @@ print("deploy_test ", deploy_test)
 | 
			
		||||
stub = Stub(name=config["name"])
 | 
			
		||||
 | 
			
		||||
if not deploy_test:
 | 
			
		||||
    dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
 | 
			
		||||
    # dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
 | 
			
		||||
    # dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
 | 
			
		||||
 | 
			
		||||
    dockerfile_image = (
 | 
			
		||||
        modal.Image.debian_slim()
 | 
			
		||||
        .apt_install("git", "wget")
 | 
			
		||||
        .run_commands(
 | 
			
		||||
            # Basic comfyui setup
 | 
			
		||||
            "git clone https://github.com/comfyanonymous/ComfyUI.git /comfyui",
 | 
			
		||||
            "cd /comfyui && pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121",
 | 
			
		||||
 | 
			
		||||
            # Install comfyui manager
 | 
			
		||||
            "cd /comfyui/custom_nodes && git clone --depth 1 https://github.com/ltdrdata/ComfyUI-Manager.git",
 | 
			
		||||
            "cd /comfyui/custom_nodes/ComfyUI-Manager && pip install -r requirements.txt",
 | 
			
		||||
            "cd /comfyui/custom_nodes/ComfyUI-Manager && mkdir startup-scripts",
 | 
			
		||||
 | 
			
		||||
            # Install comfy deploy
 | 
			
		||||
            "cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
 | 
			
		||||
        )
 | 
			
		||||
        .copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
 | 
			
		||||
        .copy_local_file(f"{current_directory}/data/snapshot.json", "/comfyui/custom_nodes/ComfyUI-Manager/startup-scripts/restore-snapshot.json")
 | 
			
		||||
 | 
			
		||||
        .copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
 | 
			
		||||
        .run_commands("chmod +x /start.sh")
 | 
			
		||||
 | 
			
		||||
        .copy_local_file(f"{current_directory}/data/install_deps.py", "/")
 | 
			
		||||
        .copy_local_file(f"{current_directory}/data/models.json", "/")
 | 
			
		||||
        .copy_local_file(f"{current_directory}/data/deps.json", "/")
 | 
			
		||||
 | 
			
		||||
        .run_commands("python install_deps.py")
 | 
			
		||||
 | 
			
		||||
        .pip_install(
 | 
			
		||||
            "git+https://github.com/modal-labs/asgiproxy.git", "httpx", "tqdm"
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
# Time to wait between API check attempts in milliseconds
 | 
			
		||||
COMFY_API_AVAILABLE_INTERVAL_MS = 50
 | 
			
		||||
@ -44,6 +77,7 @@ COMFY_POLLING_MAX_RETRIES = 500
 | 
			
		||||
# Host where ComfyUI is running
 | 
			
		||||
COMFY_HOST = "127.0.0.1:8188"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_server(url, retries=50, delay=500):
 | 
			
		||||
    import requests
 | 
			
		||||
    import time
 | 
			
		||||
@ -71,7 +105,6 @@ def check_server(url, retries=50, delay=500):
 | 
			
		||||
            # If an exception occurs, the server may not be ready
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # print(f"runpod-worker-comfy - trying")
 | 
			
		||||
 | 
			
		||||
        # Wait for the specified delay before retrying
 | 
			
		||||
@ -82,29 +115,37 @@ def check_server(url, retries=50, delay=500):
 | 
			
		||||
    )
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_status(prompt_id):
 | 
			
		||||
    req = urllib.request.Request(f"http://{COMFY_HOST}/comfyui-deploy/check-status?prompt_id={prompt_id}")
 | 
			
		||||
    req = urllib.request.Request(
 | 
			
		||||
        f"http://{COMFY_HOST}/comfyui-deploy/check-status?prompt_id={prompt_id}")
 | 
			
		||||
    return json.loads(urllib.request.urlopen(req).read())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Input(BaseModel):
 | 
			
		||||
    prompt_id: str
 | 
			
		||||
    workflow_api: dict
 | 
			
		||||
    status_endpoint: str
 | 
			
		||||
    file_upload_endpoint: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def queue_workflow_comfy_deploy(data: Input):
 | 
			
		||||
    data_str = data.json()
 | 
			
		||||
    data_bytes = data_str.encode('utf-8')
 | 
			
		||||
    req = urllib.request.Request(f"http://{COMFY_HOST}/comfyui-deploy/run", data=data_bytes)
 | 
			
		||||
    req = urllib.request.Request(
 | 
			
		||||
        f"http://{COMFY_HOST}/comfyui-deploy/run", data=data_bytes)
 | 
			
		||||
    return json.loads(urllib.request.urlopen(req).read())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RequestInput(BaseModel):
 | 
			
		||||
    input: Input
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
image = Image.debian_slim()
 | 
			
		||||
 | 
			
		||||
target_image = image if deploy_test else dockerfile_image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@stub.function(image=target_image, gpu=config["gpu"])
 | 
			
		||||
def run(input: Input):
 | 
			
		||||
    import subprocess
 | 
			
		||||
@ -112,8 +153,9 @@ def run(input: Input):
 | 
			
		||||
    # Make sure that the ComfyUI API is available
 | 
			
		||||
    print(f"comfy-modal - check server")
 | 
			
		||||
 | 
			
		||||
    command = ["python3", "/comfyui/main.py", "--disable-auto-launch", "--disable-metadata"]
 | 
			
		||||
    server_process = subprocess.Popen(command)
 | 
			
		||||
    command = ["python", "main.py",
 | 
			
		||||
               "--disable-auto-launch", "--disable-metadata"]
 | 
			
		||||
    server_process = subprocess.Popen(command, cwd="/comfyui")
 | 
			
		||||
 | 
			
		||||
    check_server(
 | 
			
		||||
        f"http://{COMFY_HOST}",
 | 
			
		||||
@ -128,7 +170,8 @@ def run(input: Input):
 | 
			
		||||
    # Queue the workflow
 | 
			
		||||
    try:
 | 
			
		||||
        # job_input is the json input
 | 
			
		||||
        queued_workflow = queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow)
 | 
			
		||||
        queued_workflow = queue_workflow_comfy_deploy(
 | 
			
		||||
            job_input)  # queue_workflow(workflow)
 | 
			
		||||
        prompt_id = queued_workflow["prompt_id"]
 | 
			
		||||
        print(f"comfy-modal - queued workflow with ID {prompt_id}")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
@ -170,11 +213,12 @@ def run(input: Input):
 | 
			
		||||
    # Get the generated image and return it as URL in an AWS bucket or as base64
 | 
			
		||||
    # images_result = process_output_images(history[prompt_id].get("outputs"), job["id"])
 | 
			
		||||
    # result = {**images_result, "refresh_worker": REFRESH_WORKER}
 | 
			
		||||
    result = { "status": status }
 | 
			
		||||
    result = {"status": status}
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
    print("Running remotely on Modal!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@web_app.post("/run")
 | 
			
		||||
async def bar(request_input: RequestInput):
 | 
			
		||||
    # print(request_input)
 | 
			
		||||
@ -182,7 +226,73 @@ async def bar(request_input: RequestInput):
 | 
			
		||||
        return run.remote(request_input.input)
 | 
			
		||||
    # pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@stub.function(image=image)
 | 
			
		||||
@asgi_app()
 | 
			
		||||
def comfyui_app():
 | 
			
		||||
def comfyui_api():
 | 
			
		||||
    return web_app
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
HOST = "127.0.0.1"
 | 
			
		||||
PORT = "8188"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def spawn_comfyui_in_background():
 | 
			
		||||
    import socket
 | 
			
		||||
    import subprocess
 | 
			
		||||
 | 
			
		||||
    process = subprocess.Popen(
 | 
			
		||||
        [
 | 
			
		||||
            "python",
 | 
			
		||||
            "main.py",
 | 
			
		||||
            "--dont-print-server",
 | 
			
		||||
            "--port",
 | 
			
		||||
            PORT,
 | 
			
		||||
        ],
 | 
			
		||||
        cwd="/comfyui",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Poll until webserver accepts connections before running inputs.
 | 
			
		||||
    while True:
 | 
			
		||||
        try:
 | 
			
		||||
            socket.create_connection((HOST, int(PORT)), timeout=1).close()
 | 
			
		||||
            print("ComfyUI webserver ready!")
 | 
			
		||||
            break
 | 
			
		||||
        except (socket.timeout, ConnectionRefusedError):
 | 
			
		||||
            # Check if launcher webserving process has exited.
 | 
			
		||||
            # If so, a connection can never be made.
 | 
			
		||||
            retcode = process.poll()
 | 
			
		||||
            if retcode is not None:
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                    f"comfyui main.py exited unexpectedly with code {retcode}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@stub.function(
 | 
			
		||||
    image=target_image,
 | 
			
		||||
    gpu=config["gpu"],
 | 
			
		||||
    # Allows 100 concurrent requests per container.
 | 
			
		||||
    allow_concurrent_inputs=100,
 | 
			
		||||
    # Restrict to 1 container because we want to our ComfyUI session state
 | 
			
		||||
    # to be on a single container.
 | 
			
		||||
    concurrency_limit=1,
 | 
			
		||||
    timeout=10 * 60,
 | 
			
		||||
)
 | 
			
		||||
@asgi_app()
 | 
			
		||||
def comfyui_app():
 | 
			
		||||
    from asgiproxy.config import BaseURLProxyConfigMixin, ProxyConfig
 | 
			
		||||
    from asgiproxy.context import ProxyContext
 | 
			
		||||
    from asgiproxy.simple_proxy import make_simple_proxy_app
 | 
			
		||||
 | 
			
		||||
    spawn_comfyui_in_background()
 | 
			
		||||
 | 
			
		||||
    config = type(
 | 
			
		||||
        "Config",
 | 
			
		||||
        (BaseURLProxyConfigMixin, ProxyConfig),
 | 
			
		||||
        {
 | 
			
		||||
            "upstream_base_url": f"http://{HOST}:{PORT}",
 | 
			
		||||
            "rewrite_host_header": f"{HOST}:{PORT}",
 | 
			
		||||
        },
 | 
			
		||||
    )()
 | 
			
		||||
 | 
			
		||||
    return make_simple_proxy_app(ProxyContext(config))
 | 
			
		||||
 | 
			
		||||
@ -3,9 +3,9 @@ import requests
 | 
			
		||||
import time
 | 
			
		||||
import subprocess
 | 
			
		||||
 | 
			
		||||
command = ["python3", "/comfyui/main.py", "--disable-auto-launch", "--disable-metadata", "--cpu"]
 | 
			
		||||
command = ["python", "main.py", "--disable-auto-launch", "--disable-metadata", "--cpu"]
 | 
			
		||||
# Start the server
 | 
			
		||||
server_process = subprocess.Popen(command)
 | 
			
		||||
server_process = subprocess.Popen(command, cwd="/comfyui")
 | 
			
		||||
 | 
			
		||||
def check_server(url, retries=50, delay=500):
 | 
			
		||||
    for i in range(retries):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user