feat(builder): move away from docker file to modal commands
This commit is contained in:
parent
7ab4edb069
commit
c339cc4234
@ -15,6 +15,9 @@ import signal
|
||||
import logging
|
||||
from fastapi.logger import logger as fastapi_logger
|
||||
import requests
|
||||
from urllib.parse import parse_qs
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp, Scope, Receive, Send
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
@ -41,6 +44,34 @@ global_timeout = 60 * 4
|
||||
machine_id_websocket_dict = {}
|
||||
machine_id_status = {}
|
||||
|
||||
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
|
||||
|
||||
class FlyReplayMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
If the wrong instance was picked by the fly.io load balancer we use the fly-replay header
|
||||
to repeat the request again on the right instance.
|
||||
|
||||
This only works if the right instance is provided as a query_string parameter.
|
||||
"""
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
query_string = scope.get('query_string', b'').decode()
|
||||
query_params = parse_qs(query_string)
|
||||
target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
|
||||
async def send_wrapper(message):
|
||||
if target_instance != fly_instance_id:
|
||||
if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']:
|
||||
# fly.io only seems to look at the fly-replay header if websocket is accepted
|
||||
message = {'type': 'websocket.accept'}
|
||||
if 'headers' not in message:
|
||||
message['headers'] = []
|
||||
message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
|
||||
await send(message)
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
|
||||
|
||||
async def check_inactivity():
|
||||
global last_activity_time
|
||||
while True:
|
||||
@ -49,7 +80,8 @@ async def check_inactivity():
|
||||
if len(machine_id_status) == 0:
|
||||
# The application has been inactive for more than 60 seconds.
|
||||
# Scale it down to zero here.
|
||||
logger.info(f"No activity for {global_timeout} seconds, exiting...")
|
||||
logger.info(
|
||||
f"No activity for {global_timeout} seconds, exiting...")
|
||||
# os._exit(0)
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
break
|
||||
@ -66,11 +98,12 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
logger.info("Cancelling")
|
||||
|
||||
#
|
||||
#
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(FlyReplayMiddleware)
|
||||
# MODAL_ORG = os.environ.get("MODAL_ORG")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
global last_activity_time
|
||||
@ -97,14 +130,17 @@ def read_root():
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
class GitCustomNodes(BaseModel):
|
||||
hash: str
|
||||
disabled: bool
|
||||
|
||||
|
||||
class Snapshot(BaseModel):
|
||||
comfyui: str
|
||||
git_custom_nodes: Dict[str, GitCustomNodes]
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
@ -115,12 +151,14 @@ class Model(BaseModel):
|
||||
filename: str
|
||||
url: str
|
||||
|
||||
|
||||
class GPUType(str, Enum):
|
||||
T4 = "T4"
|
||||
A10G = "A10G"
|
||||
A100 = "A100"
|
||||
L4 = "L4"
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
machine_id: str
|
||||
name: str
|
||||
@ -133,7 +171,8 @@ class Item(BaseModel):
|
||||
@classmethod
|
||||
def check_gpu(cls, value):
|
||||
if value not in GPUType.__members__:
|
||||
raise ValueError(f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
|
||||
raise ValueError(
|
||||
f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
|
||||
return GPUType(value)
|
||||
|
||||
|
||||
@ -143,9 +182,11 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
|
||||
machine_id_websocket_dict[machine_id] = websocket
|
||||
# Send existing logs
|
||||
if machine_id in machine_logs_cache:
|
||||
combined_logs = "\n".join(
|
||||
log_entry['logs'] for log_entry in machine_logs_cache[machine_id])
|
||||
await websocket.send_text(json.dumps({"event": "LOGS", "data": {
|
||||
"machine_id": machine_id,
|
||||
"logs": json.dumps(machine_logs_cache[machine_id]) ,
|
||||
"logs": combined_logs,
|
||||
"timestamp": time.time()
|
||||
}}))
|
||||
try:
|
||||
@ -173,6 +214,7 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
|
||||
|
||||
# return {"Hello": "World"}
|
||||
|
||||
|
||||
@app.post("/create")
|
||||
async def create_item(item: Item):
|
||||
global last_activity_time
|
||||
@ -185,13 +227,14 @@ async def create_item(item: Item):
|
||||
# Run the building logic in a separate thread
|
||||
# future = executor.submit(build_logic, item)
|
||||
task = asyncio.create_task(build_logic(item))
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "Build Queued"})
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id})
|
||||
|
||||
|
||||
# Initialize the logs cache
|
||||
machine_logs_cache = {}
|
||||
|
||||
|
||||
async def build_logic(item: Item):
|
||||
# Deploy to modal
|
||||
folder_path = f"/app/builds/{item.machine_id}"
|
||||
@ -239,16 +282,18 @@ async def build_logic(item: Item):
|
||||
|
||||
if item.machine_id not in machine_logs_cache:
|
||||
machine_logs_cache[item.machine_id] = []
|
||||
|
||||
|
||||
machine_logs = machine_logs_cache[item.machine_id]
|
||||
|
||||
async def read_stream(stream, isStderr):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if line:
|
||||
url_queue = asyncio.Queue()
|
||||
|
||||
async def read_stream(stream, isStderr, url_queue: asyncio.Queue):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if line:
|
||||
l = line.decode('utf-8').strip()
|
||||
|
||||
if l == "":
|
||||
if l == "":
|
||||
continue
|
||||
|
||||
if not isStderr:
|
||||
@ -265,12 +310,12 @@ async def build_logic(item: Item):
|
||||
"timestamp": time.time()
|
||||
}}))
|
||||
|
||||
|
||||
if "Created comfyui_app =>" in l or (l.startswith("https://") and l.endswith(".modal.run")):
|
||||
if "Created comfyui_app =>" in l:
|
||||
if "Created comfyui_api =>" in l or (l.startswith("https://") and l.endswith(".modal.run")):
|
||||
if "Created comfyui_api =>" in l:
|
||||
url = l.split("=>")[1].strip()
|
||||
else:
|
||||
# Some case it only prints the url on a blank line
|
||||
# making sure it is a url
|
||||
elif "comfyui_api" in l:
|
||||
# Some case it only prints the url on a blank line
|
||||
url = l
|
||||
|
||||
if url:
|
||||
@ -279,6 +324,8 @@ async def build_logic(item: Item):
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
await url_queue.put(url)
|
||||
|
||||
if item.machine_id in machine_id_websocket_dict:
|
||||
await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": {
|
||||
"machine_id": item.machine_id,
|
||||
@ -288,7 +335,7 @@ async def build_logic(item: Item):
|
||||
await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": {
|
||||
"status": "succuss",
|
||||
}}))
|
||||
|
||||
|
||||
else:
|
||||
# is error
|
||||
logger.error(l)
|
||||
@ -306,11 +353,15 @@ async def build_logic(item: Item):
|
||||
await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "FINISHED", "data": {
|
||||
"status": "failed",
|
||||
}}))
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
stdout_task = asyncio.create_task(read_stream(process.stdout, False))
|
||||
stderr_task = asyncio.create_task(read_stream(process.stderr, True))
|
||||
stdout_task = asyncio.create_task(
|
||||
read_stream(process.stdout, False, url_queue))
|
||||
stderr_task = asyncio.create_task(
|
||||
read_stream(process.stderr, True, url_queue))
|
||||
|
||||
url = await url_queue.get()
|
||||
|
||||
await asyncio.wait([stdout_task, stderr_task])
|
||||
|
||||
@ -334,8 +385,9 @@ async def build_logic(item: Item):
|
||||
"logs": "Unable to build the app image.",
|
||||
"timestamp": time.time()
|
||||
})
|
||||
requests.post(item.callback_url, json={"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
|
||||
|
||||
requests.post(item.callback_url, json={
|
||||
"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
|
||||
|
||||
if item.machine_id in machine_logs_cache:
|
||||
del machine_logs_cache[item.machine_id]
|
||||
|
||||
@ -349,7 +401,8 @@ async def build_logic(item: Item):
|
||||
"logs": "App image built, but url is None, unable to parse the url.",
|
||||
"timestamp": time.time()
|
||||
})
|
||||
requests.post(item.callback_url, json={"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
|
||||
requests.post(item.callback_url, json={
|
||||
"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
|
||||
|
||||
if item.machine_id in machine_logs_cache:
|
||||
del machine_logs_cache[item.machine_id]
|
||||
@ -359,17 +412,20 @@ async def build_logic(item: Item):
|
||||
# example https://bennykok--my-app-comfyui-app.modal.run/
|
||||
# my_url = f"https://{MODAL_ORG}--{item.container_id}-{app_suffix}.modal.run"
|
||||
|
||||
requests.post(item.callback_url, json={"machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)})
|
||||
requests.post(item.callback_url, json={
|
||||
"machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)})
|
||||
if item.machine_id in machine_logs_cache:
|
||||
del machine_logs_cache[item.machine_id]
|
||||
|
||||
del machine_logs_cache[item.machine_id]
|
||||
|
||||
logger.info("done")
|
||||
logger.info(url)
|
||||
|
||||
|
||||
def start_loop(loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
def run_in_new_thread(coroutine):
|
||||
new_loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True)
|
||||
@ -377,6 +433,7 @@ def run_in_new_thread(coroutine):
|
||||
asyncio.run_coroutine_threadsafe(coroutine, new_loop)
|
||||
return t
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# , log_level="debug"
|
||||
|
@ -1,3 +1,4 @@
|
||||
from config import config
|
||||
import modal
|
||||
from modal import Image, Mount, web_endpoint, Stub, asgi_app
|
||||
import json
|
||||
@ -12,7 +13,6 @@ from fastapi.responses import HTMLResponse
|
||||
import os
|
||||
current_directory = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
from config import config
|
||||
deploy_test = config["deploy_test"] == "True"
|
||||
# MODAL_IMAGE_ID = os.environ.get('MODAL_IMAGE_ID', None)
|
||||
|
||||
@ -30,8 +30,41 @@ print("deploy_test ", deploy_test)
|
||||
stub = Stub(name=config["name"])
|
||||
|
||||
if not deploy_test:
|
||||
dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
|
||||
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
|
||||
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
|
||||
|
||||
dockerfile_image = (
|
||||
modal.Image.debian_slim()
|
||||
.apt_install("git", "wget")
|
||||
.run_commands(
|
||||
# Basic comfyui setup
|
||||
"git clone https://github.com/comfyanonymous/ComfyUI.git /comfyui",
|
||||
"cd /comfyui && pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121",
|
||||
|
||||
# Install comfyui manager
|
||||
"cd /comfyui/custom_nodes && git clone --depth 1 https://github.com/ltdrdata/ComfyUI-Manager.git",
|
||||
"cd /comfyui/custom_nodes/ComfyUI-Manager && pip install -r requirements.txt",
|
||||
"cd /comfyui/custom_nodes/ComfyUI-Manager && mkdir startup-scripts",
|
||||
|
||||
# Install comfy deploy
|
||||
"cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
|
||||
)
|
||||
.copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
|
||||
.copy_local_file(f"{current_directory}/data/snapshot.json", "/comfyui/custom_nodes/ComfyUI-Manager/startup-scripts/restore-snapshot.json")
|
||||
|
||||
.copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
|
||||
.run_commands("chmod +x /start.sh")
|
||||
|
||||
.copy_local_file(f"{current_directory}/data/install_deps.py", "/")
|
||||
.copy_local_file(f"{current_directory}/data/models.json", "/")
|
||||
.copy_local_file(f"{current_directory}/data/deps.json", "/")
|
||||
|
||||
.run_commands("python install_deps.py")
|
||||
|
||||
.pip_install(
|
||||
"git+https://github.com/modal-labs/asgiproxy.git", "httpx", "tqdm"
|
||||
)
|
||||
)
|
||||
|
||||
# Time to wait between API check attempts in milliseconds
|
||||
COMFY_API_AVAILABLE_INTERVAL_MS = 50
|
||||
@ -44,6 +77,7 @@ COMFY_POLLING_MAX_RETRIES = 500
|
||||
# Host where ComfyUI is running
|
||||
COMFY_HOST = "127.0.0.1:8188"
|
||||
|
||||
|
||||
def check_server(url, retries=50, delay=500):
|
||||
import requests
|
||||
import time
|
||||
@ -71,7 +105,6 @@ def check_server(url, retries=50, delay=500):
|
||||
# If an exception occurs, the server may not be ready
|
||||
pass
|
||||
|
||||
|
||||
# print(f"runpod-worker-comfy - trying")
|
||||
|
||||
# Wait for the specified delay before retrying
|
||||
@ -82,29 +115,37 @@ def check_server(url, retries=50, delay=500):
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def check_status(prompt_id):
|
||||
req = urllib.request.Request(f"http://{COMFY_HOST}/comfyui-deploy/check-status?prompt_id={prompt_id}")
|
||||
req = urllib.request.Request(
|
||||
f"http://{COMFY_HOST}/comfyui-deploy/check-status?prompt_id={prompt_id}")
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
prompt_id: str
|
||||
workflow_api: dict
|
||||
status_endpoint: str
|
||||
file_upload_endpoint: str
|
||||
|
||||
|
||||
def queue_workflow_comfy_deploy(data: Input):
|
||||
data_str = data.json()
|
||||
data_bytes = data_str.encode('utf-8')
|
||||
req = urllib.request.Request(f"http://{COMFY_HOST}/comfyui-deploy/run", data=data_bytes)
|
||||
data_bytes = data_str.encode('utf-8')
|
||||
req = urllib.request.Request(
|
||||
f"http://{COMFY_HOST}/comfyui-deploy/run", data=data_bytes)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
|
||||
class RequestInput(BaseModel):
|
||||
input: Input
|
||||
|
||||
|
||||
image = Image.debian_slim()
|
||||
|
||||
target_image = image if deploy_test else dockerfile_image
|
||||
|
||||
|
||||
@stub.function(image=target_image, gpu=config["gpu"])
|
||||
def run(input: Input):
|
||||
import subprocess
|
||||
@ -112,8 +153,9 @@ def run(input: Input):
|
||||
# Make sure that the ComfyUI API is available
|
||||
print(f"comfy-modal - check server")
|
||||
|
||||
command = ["python3", "/comfyui/main.py", "--disable-auto-launch", "--disable-metadata"]
|
||||
server_process = subprocess.Popen(command)
|
||||
command = ["python", "main.py",
|
||||
"--disable-auto-launch", "--disable-metadata"]
|
||||
server_process = subprocess.Popen(command, cwd="/comfyui")
|
||||
|
||||
check_server(
|
||||
f"http://{COMFY_HOST}",
|
||||
@ -128,7 +170,8 @@ def run(input: Input):
|
||||
# Queue the workflow
|
||||
try:
|
||||
# job_input is the json input
|
||||
queued_workflow = queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow)
|
||||
queued_workflow = queue_workflow_comfy_deploy(
|
||||
job_input) # queue_workflow(workflow)
|
||||
prompt_id = queued_workflow["prompt_id"]
|
||||
print(f"comfy-modal - queued workflow with ID {prompt_id}")
|
||||
except Exception as e:
|
||||
@ -170,11 +213,12 @@ def run(input: Input):
|
||||
# Get the generated image and return it as URL in an AWS bucket or as base64
|
||||
# images_result = process_output_images(history[prompt_id].get("outputs"), job["id"])
|
||||
# result = {**images_result, "refresh_worker": REFRESH_WORKER}
|
||||
result = { "status": status }
|
||||
result = {"status": status}
|
||||
|
||||
return result
|
||||
print("Running remotely on Modal!")
|
||||
|
||||
|
||||
@web_app.post("/run")
|
||||
async def bar(request_input: RequestInput):
|
||||
# print(request_input)
|
||||
@ -182,7 +226,73 @@ async def bar(request_input: RequestInput):
|
||||
return run.remote(request_input.input)
|
||||
# pass
|
||||
|
||||
|
||||
@stub.function(image=image)
|
||||
@asgi_app()
|
||||
def comfyui_api():
|
||||
return web_app
|
||||
|
||||
|
||||
HOST = "127.0.0.1"
|
||||
PORT = "8188"
|
||||
|
||||
|
||||
def spawn_comfyui_in_background():
|
||||
import socket
|
||||
import subprocess
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
"main.py",
|
||||
"--dont-print-server",
|
||||
"--port",
|
||||
PORT,
|
||||
],
|
||||
cwd="/comfyui",
|
||||
)
|
||||
|
||||
# Poll until webserver accepts connections before running inputs.
|
||||
while True:
|
||||
try:
|
||||
socket.create_connection((HOST, int(PORT)), timeout=1).close()
|
||||
print("ComfyUI webserver ready!")
|
||||
break
|
||||
except (socket.timeout, ConnectionRefusedError):
|
||||
# Check if launcher webserving process has exited.
|
||||
# If so, a connection can never be made.
|
||||
retcode = process.poll()
|
||||
if retcode is not None:
|
||||
raise RuntimeError(
|
||||
f"comfyui main.py exited unexpectedly with code {retcode}"
|
||||
)
|
||||
|
||||
|
||||
@stub.function(
|
||||
image=target_image,
|
||||
gpu=config["gpu"],
|
||||
# Allows 100 concurrent requests per container.
|
||||
allow_concurrent_inputs=100,
|
||||
# Restrict to 1 container because we want to our ComfyUI session state
|
||||
# to be on a single container.
|
||||
concurrency_limit=1,
|
||||
timeout=10 * 60,
|
||||
)
|
||||
@asgi_app()
|
||||
def comfyui_app():
|
||||
return web_app
|
||||
from asgiproxy.config import BaseURLProxyConfigMixin, ProxyConfig
|
||||
from asgiproxy.context import ProxyContext
|
||||
from asgiproxy.simple_proxy import make_simple_proxy_app
|
||||
|
||||
spawn_comfyui_in_background()
|
||||
|
||||
config = type(
|
||||
"Config",
|
||||
(BaseURLProxyConfigMixin, ProxyConfig),
|
||||
{
|
||||
"upstream_base_url": f"http://{HOST}:{PORT}",
|
||||
"rewrite_host_header": f"{HOST}:{PORT}",
|
||||
},
|
||||
)()
|
||||
|
||||
return make_simple_proxy_app(ProxyContext(config))
|
||||
|
@ -3,9 +3,9 @@ import requests
|
||||
import time
|
||||
import subprocess
|
||||
|
||||
command = ["python3", "/comfyui/main.py", "--disable-auto-launch", "--disable-metadata", "--cpu"]
|
||||
command = ["python", "main.py", "--disable-auto-launch", "--disable-metadata", "--cpu"]
|
||||
# Start the server
|
||||
server_process = subprocess.Popen(command)
|
||||
server_process = subprocess.Popen(command, cwd="/comfyui")
|
||||
|
||||
def check_server(url, retries=50, delay=500):
|
||||
for i in range(retries):
|
||||
|
Loading…
x
Reference in New Issue
Block a user