fix: log and async task issues with modal script

This commit is contained in:
bennykok 2024-02-01 13:42:01 +08:00
parent 38fea1e79f
commit d8951df35f

View File

@ -37,7 +37,9 @@ if not deploy_test:
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data")) # dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
dockerfile_image = ( dockerfile_image = (
modal.Image.debian_slim() modal.Image.debian_slim(
python_version="3.11",
)
.apt_install("git", "wget") .apt_install("git", "wget")
.pip_install( .pip_install(
"git+https://github.com/modal-labs/asgiproxy.git", "httpx", "tqdm" "git+https://github.com/modal-labs/asgiproxy.git", "httpx", "tqdm"
@ -83,7 +85,7 @@ if not deploy_test:
# Time to wait between API check attempts in milliseconds # Time to wait between API check attempts in milliseconds
COMFY_API_AVAILABLE_INTERVAL_MS = 50 COMFY_API_AVAILABLE_INTERVAL_MS = 50
# Maximum number of API check attempts # Maximum number of API check attempts
COMFY_API_AVAILABLE_MAX_RETRIES = 500 COMFY_API_AVAILABLE_MAX_RETRIES = 1000
# Time to wait between poll attempts in milliseconds # Time to wait between poll attempts in milliseconds
COMFY_POLLING_INTERVAL_MS = 250 COMFY_POLLING_INTERVAL_MS = 250
# Maximum number of poll attempts # Maximum number of poll attempts
@ -94,7 +96,8 @@ COMFY_HOST = "127.0.0.1:8188"
async def check_server(url, retries=50, delay=500): async def check_server(url, retries=50, delay=500):
import aiohttp import aiohttp
for i in range(retries): # for i in range(retries):
while True:
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(url) as response: async with session.get(url) as response:
@ -157,30 +160,34 @@ class ComfyDeployRunner:
async def read_stream(self, stream, isStderr): async def read_stream(self, stream, isStderr):
import time import time
while True: while True:
line = await stream.readline() try:
if line: line = await stream.readline()
l = line.decode('utf-8').strip() if line:
l = line.decode('utf-8').strip()
if l == "": if l == "":
continue continue
if not isStderr: if not isStderr:
print(l, flush=True) print(l, flush=True)
self.machine_logs.append({ self.machine_logs.append({
"logs": l, "logs": l,
"timestamp": time.time() "timestamp": time.time()
}) })
else:
# is error
# logger.error(l)
print(l, flush=True)
self.machine_logs.append({
"logs": l,
"timestamp": time.time()
})
else: else:
# is error break
# logger.error(l) except asyncio.CancelledError:
print(l, flush=True) # Handle the cancellation here if needed
self.machine_logs.append({ break # Break out of the loop on cancellation
"logs": l,
"timestamp": time.time()
})
else:
break
@enter() @enter()
async def setup(self): async def setup(self):
@ -197,24 +204,22 @@ class ComfyDeployRunner:
# env={**os.environ, "COLUMNS": "10000"} # env={**os.environ, "COLUMNS": "10000"}
) )
stdout_task = asyncio.create_task(
self.read_stream(self.server_process.stdout, False))
stderr_task = asyncio.create_task(
self.read_stream(self.server_process.stderr, True))
await check_server(
f"http://{COMFY_HOST}",
COMFY_API_AVAILABLE_MAX_RETRIES,
COMFY_API_AVAILABLE_INTERVAL_MS,
)
stdout_task.cancel()
stderr_task.cancel()
@exit() @exit()
async def cleanup(self, exc_type, exc_value, traceback): async def cleanup(self, exc_type, exc_value, traceback):
print(f"comfy-modal - cleanup", exc_type, exc_value, traceback) print(f"comfy-modal - cleanup", exc_type, exc_value, traceback)
# self.server_process.kill() # Get the current event loop
loop = asyncio.get_event_loop()
# Check if the event loop is closed
if loop.is_closed():
print("The event loop is closed.")
else:
try:
self.server_process.terminate()
await self.server_process.wait()
except Exception as e:
print("Issues when cleaning up", e)
print("The event loop is open.")
@method() @method()
async def run(self, input: Input): async def run(self, input: Input):
@ -228,93 +233,119 @@ class ComfyDeployRunner:
stderr_task = asyncio.create_task( stderr_task = asyncio.create_task(
self.read_stream(self.server_process.stderr, True)) self.read_stream(self.server_process.stderr, True))
class TimeoutError(Exception):
pass
def timeout_handler(signum, frame):
data = json.dumps({
"run_id": input.prompt_id,
"status": "timeout",
"time": datetime.now().isoformat()
}).encode('utf-8')
req = urllib.request.Request(input.status_endpoint, data=data, method='POST')
urllib.request.urlopen(req)
raise TimeoutError("Operation timed out")
signal.signal(signal.SIGALRM, timeout_handler)
try: try:
# Set an alarm for some seconds in the future class TimeoutError(Exception):
signal.alarm(run_timeout) # 5 seconds timeout pass
def timeout_handler(signum, frame):
data = json.dumps({
"run_id": input.prompt_id,
"status": "timeout",
"time": datetime.now().isoformat()
}).encode('utf-8')
req = urllib.request.Request(input.status_endpoint, data=data, method='POST')
urllib.request.urlopen(req)
raise TimeoutError("Operation timed out")
signal.signal(signal.SIGALRM, timeout_handler)
try:
signal.alarm(run_timeout) # 5 seconds timeout
ok = await check_server(
f"http://{COMFY_HOST}",
COMFY_API_AVAILABLE_MAX_RETRIES,
COMFY_API_AVAILABLE_INTERVAL_MS,
)
if not ok:
raise Exception("ComfyUI API is not available")
# Set an alarm for some seconds in the future
data = json.dumps({
"run_id": input.prompt_id,
"status": "started",
"time": datetime.now().isoformat()
}).encode('utf-8')
async with aiohttp.ClientSession() as session:
async with session.post(input.status_endpoint, data=data) as response:
pass
job_input = input
try:
queued_workflow = await queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow)
prompt_id = queued_workflow["prompt_id"]
print(f"comfy-modal - queued workflow with ID {prompt_id}")
except Exception as e:
import traceback
print(traceback.format_exc())
return {"error": f"Error queuing workflow: {str(e)}"}
# Poll for completion
print(f"comfy-modal - wait until image generation is complete")
retries = 0
status = ""
try:
print("getting request")
while retries < COMFY_POLLING_MAX_RETRIES:
status_result = await check_status(prompt_id=prompt_id)
if 'status' in status_result and (status_result['status'] == 'success' or status_result['status'] == 'failed'):
status = status_result['status']
print(status)
break
else:
# Wait before trying again
await asyncio.sleep(COMFY_POLLING_INTERVAL_MS / 1000)
retries += 1
else:
return {"error": "Max retries reached while waiting for image generation"}
except Exception as e:
return {"error": f"Error waiting for image generation: {str(e)}"}
print(f"comfy-modal - Finished, turning off")
result = {"status": status}
except TimeoutError:
print("Operation timed out")
return {"status": "failed"}
except Exception as e:
print(f"Unexpected error occurred: {str(e)}")
data = json.dumps({
"run_id": input.prompt_id,
"status": "failed",
"time": datetime.now().isoformat()
}).encode('utf-8')
async with aiohttp.ClientSession() as session:
async with session.post(input.status_endpoint, data=data) as response:
print("response", response)
self.machine_logs.append({
"logs": str(e),
"timestamp": time.time()
})
finally:
signal.alarm(0)
print("uploading log_data")
data = json.dumps({ data = json.dumps({
"run_id": input.prompt_id, "run_id": input.prompt_id,
"status": "started", "time": datetime.now().isoformat(),
"time": datetime.now().isoformat() "log_data": self.machine_logs
}).encode('utf-8') }).encode('utf-8')
print("my logs", len(self.machine_logs))
# Clear logs
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(input.status_endpoint, data=data) as response: async with session.post(input.status_endpoint, data=data) as response:
pass print("response", response)
print("uploaded log_data")
job_input = input # print(data)
self.machine_logs = []
try:
queued_workflow = await queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow)
prompt_id = queued_workflow["prompt_id"]
print(f"comfy-modal - queued workflow with ID {prompt_id}")
except Exception as e:
import traceback
print(traceback.format_exc())
return {"error": f"Error queuing workflow: {str(e)}"}
# Poll for completion
print(f"comfy-modal - wait until image generation is complete")
retries = 0
status = ""
try:
print("getting request")
while retries < COMFY_POLLING_MAX_RETRIES:
status_result = await check_status(prompt_id=prompt_id)
if 'status' in status_result and (status_result['status'] == 'success' or status_result['status'] == 'failed'):
status = status_result['status']
print(status)
break
else:
# Wait before trying again
await asyncio.sleep(COMFY_POLLING_INTERVAL_MS / 1000)
retries += 1
else:
return {"error": "Max retries reached while waiting for image generation"}
except Exception as e:
return {"error": f"Error waiting for image generation: {str(e)}"}
print(f"comfy-modal - Finished, turning off")
result = {"status": status}
except TimeoutError:
print("Operation timed out")
return {"status": "failed"}
finally: finally:
signal.alarm(0) stdout_task.cancel()
stderr_task.cancel()
print("uploading log_data") await stdout_task
data = json.dumps({ await stderr_task
"run_id": input.prompt_id,
"time": datetime.now().isoformat(),
"log_data": self.machine_logs
}).encode('utf-8')
print("my logs", len(self.machine_logs))
# Clear logs
async with aiohttp.ClientSession() as session:
async with session.post(input.status_endpoint, data=data) as response:
print("response", response)
print("uploaded log_data")
# print(data)
self.machine_logs = []
stdout_task.cancel()
stderr_task.cancel()
return result return result