feat(builder): move away from docker file to modal commands

This commit is contained in:
BennyKok 2024-01-07 21:07:48 +08:00
parent 7ab4edb069
commit c339cc4234
3 changed files with 209 additions and 42 deletions

View File

@ -15,6 +15,9 @@ import signal
import logging
from fastapi.logger import logger as fastapi_logger
import requests
from urllib.parse import parse_qs
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp, Scope, Receive, Send
from concurrent.futures import ThreadPoolExecutor
@ -41,6 +44,34 @@ global_timeout = 60 * 4
machine_id_websocket_dict = {}
machine_id_status = {}
fly_instance_id = os.environ.get('FLY_ALLOC_ID', 'local').split('-')[0]
class FlyReplayMiddleware(BaseHTTPMiddleware):
"""
If the wrong instance was picked by the fly.io load balancer we use the fly-replay header
to repeat the request again on the right instance.
This only works if the right instance is provided as a query_string parameter.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
query_string = scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
target_instance = query_params.get('fly_instance_id', [fly_instance_id])[0]
async def send_wrapper(message):
if target_instance != fly_instance_id:
if message['type'] == 'websocket.close' and 'Invalid session' in message['reason']:
# fly.io only seems to look at the fly-replay header if websocket is accepted
message = {'type': 'websocket.accept'}
if 'headers' not in message:
message['headers'] = []
message['headers'].append([b'fly-replay', f'instance={target_instance}'.encode()])
await send(message)
await self.app(scope, receive, send_wrapper)
async def check_inactivity():
global last_activity_time
while True:
@ -49,7 +80,8 @@ async def check_inactivity():
if len(machine_id_status) == 0:
# The application has been inactive for more than 60 seconds.
# Scale it down to zero here.
logger.info(f"No activity for {global_timeout} seconds, exiting...")
logger.info(
f"No activity for {global_timeout} seconds, exiting...")
# os._exit(0)
os.kill(os.getpid(), signal.SIGINT)
break
@ -68,9 +100,10 @@ async def lifespan(app: FastAPI):
#
app = FastAPI(lifespan=lifespan)
app.add_middleware(FlyReplayMiddleware)
# MODAL_ORG = os.environ.get("MODAL_ORG")
@app.get("/")
def read_root():
global last_activity_time
@ -97,14 +130,17 @@ def read_root():
# }
# }
class GitCustomNodes(BaseModel):
hash: str
disabled: bool
class Snapshot(BaseModel):
comfyui: str
git_custom_nodes: Dict[str, GitCustomNodes]
class Model(BaseModel):
name: str
type: str
@ -115,12 +151,14 @@ class Model(BaseModel):
filename: str
url: str
class GPUType(str, Enum):
T4 = "T4"
A10G = "A10G"
A100 = "A100"
L4 = "L4"
class Item(BaseModel):
machine_id: str
name: str
@ -133,7 +171,8 @@ class Item(BaseModel):
@classmethod
def check_gpu(cls, value):
if value not in GPUType.__members__:
raise ValueError(f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
raise ValueError(
f"Invalid GPU option. Choose from: {', '.join(GPUType.__members__.keys())}")
return GPUType(value)
@ -143,9 +182,11 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
machine_id_websocket_dict[machine_id] = websocket
# Send existing logs
if machine_id in machine_logs_cache:
combined_logs = "\n".join(
log_entry['logs'] for log_entry in machine_logs_cache[machine_id])
await websocket.send_text(json.dumps({"event": "LOGS", "data": {
"machine_id": machine_id,
"logs": json.dumps(machine_logs_cache[machine_id]) ,
"logs": combined_logs,
"timestamp": time.time()
}}))
try:
@ -173,6 +214,7 @@ async def websocket_endpoint(websocket: WebSocket, machine_id: str):
# return {"Hello": "World"}
@app.post("/create")
async def create_item(item: Item):
global last_activity_time
@ -186,12 +228,13 @@ async def create_item(item: Item):
# future = executor.submit(build_logic, item)
task = asyncio.create_task(build_logic(item))
return JSONResponse(status_code=200, content={"message": "Build Queued"})
return JSONResponse(status_code=200, content={"message": "Build Queued", "build_machine_instance_id": fly_instance_id})
# Initialize the logs cache
machine_logs_cache = {}
async def build_logic(item: Item):
# Deploy to modal
folder_path = f"/app/builds/{item.machine_id}"
@ -242,7 +285,9 @@ async def build_logic(item: Item):
machine_logs = machine_logs_cache[item.machine_id]
async def read_stream(stream, isStderr):
url_queue = asyncio.Queue()
async def read_stream(stream, isStderr, url_queue: asyncio.Queue):
while True:
line = await stream.readline()
if line:
@ -265,11 +310,11 @@ async def build_logic(item: Item):
"timestamp": time.time()
}}))
if "Created comfyui_app =>" in l or (l.startswith("https://") and l.endswith(".modal.run")):
if "Created comfyui_app =>" in l:
if "Created comfyui_api =>" in l or (l.startswith("https://") and l.endswith(".modal.run")):
if "Created comfyui_api =>" in l:
url = l.split("=>")[1].strip()
else:
# making sure it is a url
elif "comfyui_api" in l:
# Some case it only prints the url on a blank line
url = l
@ -279,6 +324,8 @@ async def build_logic(item: Item):
"timestamp": time.time()
})
await url_queue.put(url)
if item.machine_id in machine_id_websocket_dict:
await machine_id_websocket_dict[item.machine_id].send_text(json.dumps({"event": "LOGS", "data": {
"machine_id": item.machine_id,
@ -309,8 +356,12 @@ async def build_logic(item: Item):
else:
break
stdout_task = asyncio.create_task(read_stream(process.stdout, False))
stderr_task = asyncio.create_task(read_stream(process.stderr, True))
stdout_task = asyncio.create_task(
read_stream(process.stdout, False, url_queue))
stderr_task = asyncio.create_task(
read_stream(process.stderr, True, url_queue))
url = await url_queue.get()
await asyncio.wait([stdout_task, stderr_task])
@ -334,7 +385,8 @@ async def build_logic(item: Item):
"logs": "Unable to build the app image.",
"timestamp": time.time()
})
requests.post(item.callback_url, json={"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
requests.post(item.callback_url, json={
"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
if item.machine_id in machine_logs_cache:
del machine_logs_cache[item.machine_id]
@ -349,7 +401,8 @@ async def build_logic(item: Item):
"logs": "App image built, but url is None, unable to parse the url.",
"timestamp": time.time()
})
requests.post(item.callback_url, json={"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
requests.post(item.callback_url, json={
"machine_id": item.machine_id, "build_log": json.dumps(machine_logs)})
if item.machine_id in machine_logs_cache:
del machine_logs_cache[item.machine_id]
@ -359,17 +412,20 @@ async def build_logic(item: Item):
# example https://bennykok--my-app-comfyui-app.modal.run/
# my_url = f"https://{MODAL_ORG}--{item.container_id}-{app_suffix}.modal.run"
requests.post(item.callback_url, json={"machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)})
requests.post(item.callback_url, json={
"machine_id": item.machine_id, "endpoint": url, "build_log": json.dumps(machine_logs)})
if item.machine_id in machine_logs_cache:
del machine_logs_cache[item.machine_id]
logger.info("done")
logger.info(url)
def start_loop(loop):
asyncio.set_event_loop(loop)
loop.run_forever()
def run_in_new_thread(coroutine):
new_loop = asyncio.new_event_loop()
t = threading.Thread(target=start_loop, args=(new_loop,), daemon=True)
@ -377,6 +433,7 @@ def run_in_new_thread(coroutine):
asyncio.run_coroutine_threadsafe(coroutine, new_loop)
return t
if __name__ == "__main__":
import uvicorn
# , log_level="debug"

View File

@ -1,3 +1,4 @@
from config import config
import modal
from modal import Image, Mount, web_endpoint, Stub, asgi_app
import json
@ -12,7 +13,6 @@ from fastapi.responses import HTMLResponse
import os
current_directory = os.path.dirname(os.path.realpath(__file__))
from config import config
deploy_test = config["deploy_test"] == "True"
# MODAL_IMAGE_ID = os.environ.get('MODAL_IMAGE_ID', None)
@ -30,8 +30,41 @@ print("deploy_test ", deploy_test)
stub = Stub(name=config["name"])
if not deploy_test:
dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
dockerfile_image = (
modal.Image.debian_slim()
.apt_install("git", "wget")
.run_commands(
# Basic comfyui setup
"git clone https://github.com/comfyanonymous/ComfyUI.git /comfyui",
"cd /comfyui && pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121",
# Install comfyui manager
"cd /comfyui/custom_nodes && git clone --depth 1 https://github.com/ltdrdata/ComfyUI-Manager.git",
"cd /comfyui/custom_nodes/ComfyUI-Manager && pip install -r requirements.txt",
"cd /comfyui/custom_nodes/ComfyUI-Manager && mkdir startup-scripts",
# Install comfy deploy
"cd /comfyui/custom_nodes && git clone https://github.com/BennyKok/comfyui-deploy.git",
)
.copy_local_file(f"{current_directory}/data/extra_model_paths.yaml", "/comfyui")
.copy_local_file(f"{current_directory}/data/snapshot.json", "/comfyui/custom_nodes/ComfyUI-Manager/startup-scripts/restore-snapshot.json")
.copy_local_file(f"{current_directory}/data/start.sh", "/start.sh")
.run_commands("chmod +x /start.sh")
.copy_local_file(f"{current_directory}/data/install_deps.py", "/")
.copy_local_file(f"{current_directory}/data/models.json", "/")
.copy_local_file(f"{current_directory}/data/deps.json", "/")
.run_commands("python install_deps.py")
.pip_install(
"git+https://github.com/modal-labs/asgiproxy.git", "httpx", "tqdm"
)
)
# Time to wait between API check attempts in milliseconds
COMFY_API_AVAILABLE_INTERVAL_MS = 50
@ -44,6 +77,7 @@ COMFY_POLLING_MAX_RETRIES = 500
# Host where ComfyUI is running
COMFY_HOST = "127.0.0.1:8188"
def check_server(url, retries=50, delay=500):
import requests
import time
@ -71,7 +105,6 @@ def check_server(url, retries=50, delay=500):
# If an exception occurs, the server may not be ready
pass
# print(f"runpod-worker-comfy - trying")
# Wait for the specified delay before retrying
@ -82,29 +115,37 @@ def check_server(url, retries=50, delay=500):
)
return False
def check_status(prompt_id):
req = urllib.request.Request(f"http://{COMFY_HOST}/comfyui-deploy/check-status?prompt_id={prompt_id}")
req = urllib.request.Request(
f"http://{COMFY_HOST}/comfyui-deploy/check-status?prompt_id={prompt_id}")
return json.loads(urllib.request.urlopen(req).read())
class Input(BaseModel):
prompt_id: str
workflow_api: dict
status_endpoint: str
file_upload_endpoint: str
def queue_workflow_comfy_deploy(data: Input):
data_str = data.json()
data_bytes = data_str.encode('utf-8')
req = urllib.request.Request(f"http://{COMFY_HOST}/comfyui-deploy/run", data=data_bytes)
req = urllib.request.Request(
f"http://{COMFY_HOST}/comfyui-deploy/run", data=data_bytes)
return json.loads(urllib.request.urlopen(req).read())
class RequestInput(BaseModel):
input: Input
image = Image.debian_slim()
target_image = image if deploy_test else dockerfile_image
@stub.function(image=target_image, gpu=config["gpu"])
def run(input: Input):
import subprocess
@ -112,8 +153,9 @@ def run(input: Input):
# Make sure that the ComfyUI API is available
print(f"comfy-modal - check server")
command = ["python3", "/comfyui/main.py", "--disable-auto-launch", "--disable-metadata"]
server_process = subprocess.Popen(command)
command = ["python", "main.py",
"--disable-auto-launch", "--disable-metadata"]
server_process = subprocess.Popen(command, cwd="/comfyui")
check_server(
f"http://{COMFY_HOST}",
@ -128,7 +170,8 @@ def run(input: Input):
# Queue the workflow
try:
# job_input is the json input
queued_workflow = queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow)
queued_workflow = queue_workflow_comfy_deploy(
job_input) # queue_workflow(workflow)
prompt_id = queued_workflow["prompt_id"]
print(f"comfy-modal - queued workflow with ID {prompt_id}")
except Exception as e:
@ -170,11 +213,12 @@ def run(input: Input):
# Get the generated image and return it as URL in an AWS bucket or as base64
# images_result = process_output_images(history[prompt_id].get("outputs"), job["id"])
# result = {**images_result, "refresh_worker": REFRESH_WORKER}
result = { "status": status }
result = {"status": status}
return result
print("Running remotely on Modal!")
@web_app.post("/run")
async def bar(request_input: RequestInput):
# print(request_input)
@ -182,7 +226,73 @@ async def bar(request_input: RequestInput):
return run.remote(request_input.input)
# pass
@stub.function(image=image)
@asgi_app()
def comfyui_app():
def comfyui_api():
return web_app
HOST = "127.0.0.1"
PORT = "8188"
def spawn_comfyui_in_background():
import socket
import subprocess
process = subprocess.Popen(
[
"python",
"main.py",
"--dont-print-server",
"--port",
PORT,
],
cwd="/comfyui",
)
# Poll until webserver accepts connections before running inputs.
while True:
try:
socket.create_connection((HOST, int(PORT)), timeout=1).close()
print("ComfyUI webserver ready!")
break
except (socket.timeout, ConnectionRefusedError):
# Check if launcher webserving process has exited.
# If so, a connection can never be made.
retcode = process.poll()
if retcode is not None:
raise RuntimeError(
f"comfyui main.py exited unexpectedly with code {retcode}"
)
@stub.function(
image=target_image,
gpu=config["gpu"],
# Allows 100 concurrent requests per container.
allow_concurrent_inputs=100,
# Restrict to 1 container because we want to our ComfyUI session state
# to be on a single container.
concurrency_limit=1,
timeout=10 * 60,
)
@asgi_app()
def comfyui_app():
from asgiproxy.config import BaseURLProxyConfigMixin, ProxyConfig
from asgiproxy.context import ProxyContext
from asgiproxy.simple_proxy import make_simple_proxy_app
spawn_comfyui_in_background()
config = type(
"Config",
(BaseURLProxyConfigMixin, ProxyConfig),
{
"upstream_base_url": f"http://{HOST}:{PORT}",
"rewrite_host_header": f"{HOST}:{PORT}",
},
)()
return make_simple_proxy_app(ProxyContext(config))

View File

@ -3,9 +3,9 @@ import requests
import time
import subprocess
command = ["python3", "/comfyui/main.py", "--disable-auto-launch", "--disable-metadata", "--cpu"]
command = ["python", "main.py", "--disable-auto-launch", "--disable-metadata", "--cpu"]
# Start the server
server_process = subprocess.Popen(command)
server_process = subprocess.Popen(command, cwd="/comfyui")
def check_server(url, retries=50, delay=500):
for i in range(retries):