fix: log and async task issues with modal script
This commit is contained in:
parent
38fea1e79f
commit
d8951df35f
@ -37,7 +37,9 @@ if not deploy_test:
|
|||||||
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
|
# dockerfile_image = Image.from_dockerfile(f"{current_directory}/Dockerfile", context_mount=Mount.from_local_dir(f"{current_directory}/data", remote_path="/data"))
|
||||||
|
|
||||||
dockerfile_image = (
|
dockerfile_image = (
|
||||||
modal.Image.debian_slim()
|
modal.Image.debian_slim(
|
||||||
|
python_version="3.11",
|
||||||
|
)
|
||||||
.apt_install("git", "wget")
|
.apt_install("git", "wget")
|
||||||
.pip_install(
|
.pip_install(
|
||||||
"git+https://github.com/modal-labs/asgiproxy.git", "httpx", "tqdm"
|
"git+https://github.com/modal-labs/asgiproxy.git", "httpx", "tqdm"
|
||||||
@ -83,7 +85,7 @@ if not deploy_test:
|
|||||||
# Time to wait between API check attempts in milliseconds
|
# Time to wait between API check attempts in milliseconds
|
||||||
COMFY_API_AVAILABLE_INTERVAL_MS = 50
|
COMFY_API_AVAILABLE_INTERVAL_MS = 50
|
||||||
# Maximum number of API check attempts
|
# Maximum number of API check attempts
|
||||||
COMFY_API_AVAILABLE_MAX_RETRIES = 500
|
COMFY_API_AVAILABLE_MAX_RETRIES = 1000
|
||||||
# Time to wait between poll attempts in milliseconds
|
# Time to wait between poll attempts in milliseconds
|
||||||
COMFY_POLLING_INTERVAL_MS = 250
|
COMFY_POLLING_INTERVAL_MS = 250
|
||||||
# Maximum number of poll attempts
|
# Maximum number of poll attempts
|
||||||
@ -94,7 +96,8 @@ COMFY_HOST = "127.0.0.1:8188"
|
|||||||
|
|
||||||
async def check_server(url, retries=50, delay=500):
|
async def check_server(url, retries=50, delay=500):
|
||||||
import aiohttp
|
import aiohttp
|
||||||
for i in range(retries):
|
# for i in range(retries):
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
@ -157,30 +160,34 @@ class ComfyDeployRunner:
|
|||||||
async def read_stream(self, stream, isStderr):
|
async def read_stream(self, stream, isStderr):
|
||||||
import time
|
import time
|
||||||
while True:
|
while True:
|
||||||
line = await stream.readline()
|
try:
|
||||||
if line:
|
line = await stream.readline()
|
||||||
l = line.decode('utf-8').strip()
|
if line:
|
||||||
|
l = line.decode('utf-8').strip()
|
||||||
|
|
||||||
if l == "":
|
if l == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not isStderr:
|
if not isStderr:
|
||||||
print(l, flush=True)
|
print(l, flush=True)
|
||||||
self.machine_logs.append({
|
self.machine_logs.append({
|
||||||
"logs": l,
|
"logs": l,
|
||||||
"timestamp": time.time()
|
"timestamp": time.time()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
else:
|
||||||
|
# is error
|
||||||
|
# logger.error(l)
|
||||||
|
print(l, flush=True)
|
||||||
|
self.machine_logs.append({
|
||||||
|
"logs": l,
|
||||||
|
"timestamp": time.time()
|
||||||
|
})
|
||||||
else:
|
else:
|
||||||
# is error
|
break
|
||||||
# logger.error(l)
|
except asyncio.CancelledError:
|
||||||
print(l, flush=True)
|
# Handle the cancellation here if needed
|
||||||
self.machine_logs.append({
|
break # Break out of the loop on cancellation
|
||||||
"logs": l,
|
|
||||||
"timestamp": time.time()
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
@enter()
|
@enter()
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
@ -197,24 +204,22 @@ class ComfyDeployRunner:
|
|||||||
# env={**os.environ, "COLUMNS": "10000"}
|
# env={**os.environ, "COLUMNS": "10000"}
|
||||||
)
|
)
|
||||||
|
|
||||||
stdout_task = asyncio.create_task(
|
|
||||||
self.read_stream(self.server_process.stdout, False))
|
|
||||||
stderr_task = asyncio.create_task(
|
|
||||||
self.read_stream(self.server_process.stderr, True))
|
|
||||||
|
|
||||||
await check_server(
|
|
||||||
f"http://{COMFY_HOST}",
|
|
||||||
COMFY_API_AVAILABLE_MAX_RETRIES,
|
|
||||||
COMFY_API_AVAILABLE_INTERVAL_MS,
|
|
||||||
)
|
|
||||||
|
|
||||||
stdout_task.cancel()
|
|
||||||
stderr_task.cancel()
|
|
||||||
|
|
||||||
@exit()
|
@exit()
|
||||||
async def cleanup(self, exc_type, exc_value, traceback):
|
async def cleanup(self, exc_type, exc_value, traceback):
|
||||||
print(f"comfy-modal - cleanup", exc_type, exc_value, traceback)
|
print(f"comfy-modal - cleanup", exc_type, exc_value, traceback)
|
||||||
# self.server_process.kill()
|
# Get the current event loop
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# Check if the event loop is closed
|
||||||
|
if loop.is_closed():
|
||||||
|
print("The event loop is closed.")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.server_process.terminate()
|
||||||
|
await self.server_process.wait()
|
||||||
|
except Exception as e:
|
||||||
|
print("Issues when cleaning up", e)
|
||||||
|
print("The event loop is open.")
|
||||||
|
|
||||||
@method()
|
@method()
|
||||||
async def run(self, input: Input):
|
async def run(self, input: Input):
|
||||||
@ -228,93 +233,119 @@ class ComfyDeployRunner:
|
|||||||
stderr_task = asyncio.create_task(
|
stderr_task = asyncio.create_task(
|
||||||
self.read_stream(self.server_process.stderr, True))
|
self.read_stream(self.server_process.stderr, True))
|
||||||
|
|
||||||
class TimeoutError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def timeout_handler(signum, frame):
|
|
||||||
data = json.dumps({
|
|
||||||
"run_id": input.prompt_id,
|
|
||||||
"status": "timeout",
|
|
||||||
"time": datetime.now().isoformat()
|
|
||||||
}).encode('utf-8')
|
|
||||||
req = urllib.request.Request(input.status_endpoint, data=data, method='POST')
|
|
||||||
urllib.request.urlopen(req)
|
|
||||||
raise TimeoutError("Operation timed out")
|
|
||||||
|
|
||||||
signal.signal(signal.SIGALRM, timeout_handler)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set an alarm for some seconds in the future
|
class TimeoutError(Exception):
|
||||||
signal.alarm(run_timeout) # 5 seconds timeout
|
pass
|
||||||
|
|
||||||
|
def timeout_handler(signum, frame):
|
||||||
|
data = json.dumps({
|
||||||
|
"run_id": input.prompt_id,
|
||||||
|
"status": "timeout",
|
||||||
|
"time": datetime.now().isoformat()
|
||||||
|
}).encode('utf-8')
|
||||||
|
req = urllib.request.Request(input.status_endpoint, data=data, method='POST')
|
||||||
|
urllib.request.urlopen(req)
|
||||||
|
raise TimeoutError("Operation timed out")
|
||||||
|
|
||||||
|
signal.signal(signal.SIGALRM, timeout_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
signal.alarm(run_timeout) # 5 seconds timeout
|
||||||
|
|
||||||
|
ok = await check_server(
|
||||||
|
f"http://{COMFY_HOST}",
|
||||||
|
COMFY_API_AVAILABLE_MAX_RETRIES,
|
||||||
|
COMFY_API_AVAILABLE_INTERVAL_MS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ok:
|
||||||
|
raise Exception("ComfyUI API is not available")
|
||||||
|
# Set an alarm for some seconds in the future
|
||||||
|
|
||||||
|
data = json.dumps({
|
||||||
|
"run_id": input.prompt_id,
|
||||||
|
"status": "started",
|
||||||
|
"time": datetime.now().isoformat()
|
||||||
|
}).encode('utf-8')
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(input.status_endpoint, data=data) as response:
|
||||||
|
pass
|
||||||
|
|
||||||
|
job_input = input
|
||||||
|
|
||||||
|
try:
|
||||||
|
queued_workflow = await queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow)
|
||||||
|
prompt_id = queued_workflow["prompt_id"]
|
||||||
|
print(f"comfy-modal - queued workflow with ID {prompt_id}")
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return {"error": f"Error queuing workflow: {str(e)}"}
|
||||||
|
|
||||||
|
# Poll for completion
|
||||||
|
print(f"comfy-modal - wait until image generation is complete")
|
||||||
|
retries = 0
|
||||||
|
status = ""
|
||||||
|
try:
|
||||||
|
print("getting request")
|
||||||
|
while retries < COMFY_POLLING_MAX_RETRIES:
|
||||||
|
status_result = await check_status(prompt_id=prompt_id)
|
||||||
|
if 'status' in status_result and (status_result['status'] == 'success' or status_result['status'] == 'failed'):
|
||||||
|
status = status_result['status']
|
||||||
|
print(status)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Wait before trying again
|
||||||
|
await asyncio.sleep(COMFY_POLLING_INTERVAL_MS / 1000)
|
||||||
|
retries += 1
|
||||||
|
else:
|
||||||
|
return {"error": "Max retries reached while waiting for image generation"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Error waiting for image generation: {str(e)}"}
|
||||||
|
|
||||||
|
print(f"comfy-modal - Finished, turning off")
|
||||||
|
|
||||||
|
result = {"status": status}
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
print("Operation timed out")
|
||||||
|
return {"status": "failed"}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unexpected error occurred: {str(e)}")
|
||||||
|
data = json.dumps({
|
||||||
|
"run_id": input.prompt_id,
|
||||||
|
"status": "failed",
|
||||||
|
"time": datetime.now().isoformat()
|
||||||
|
}).encode('utf-8')
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(input.status_endpoint, data=data) as response:
|
||||||
|
print("response", response)
|
||||||
|
self.machine_logs.append({
|
||||||
|
"logs": str(e),
|
||||||
|
"timestamp": time.time()
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
print("uploading log_data")
|
||||||
data = json.dumps({
|
data = json.dumps({
|
||||||
"run_id": input.prompt_id,
|
"run_id": input.prompt_id,
|
||||||
"status": "started",
|
"time": datetime.now().isoformat(),
|
||||||
"time": datetime.now().isoformat()
|
"log_data": self.machine_logs
|
||||||
}).encode('utf-8')
|
}).encode('utf-8')
|
||||||
|
print("my logs", len(self.machine_logs))
|
||||||
|
# Clear logs
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(input.status_endpoint, data=data) as response:
|
async with session.post(input.status_endpoint, data=data) as response:
|
||||||
pass
|
print("response", response)
|
||||||
|
print("uploaded log_data")
|
||||||
job_input = input
|
# print(data)
|
||||||
|
self.machine_logs = []
|
||||||
try:
|
|
||||||
queued_workflow = await queue_workflow_comfy_deploy(job_input) # queue_workflow(workflow)
|
|
||||||
prompt_id = queued_workflow["prompt_id"]
|
|
||||||
print(f"comfy-modal - queued workflow with ID {prompt_id}")
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
print(traceback.format_exc())
|
|
||||||
return {"error": f"Error queuing workflow: {str(e)}"}
|
|
||||||
|
|
||||||
# Poll for completion
|
|
||||||
print(f"comfy-modal - wait until image generation is complete")
|
|
||||||
retries = 0
|
|
||||||
status = ""
|
|
||||||
try:
|
|
||||||
print("getting request")
|
|
||||||
while retries < COMFY_POLLING_MAX_RETRIES:
|
|
||||||
status_result = await check_status(prompt_id=prompt_id)
|
|
||||||
if 'status' in status_result and (status_result['status'] == 'success' or status_result['status'] == 'failed'):
|
|
||||||
status = status_result['status']
|
|
||||||
print(status)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Wait before trying again
|
|
||||||
await asyncio.sleep(COMFY_POLLING_INTERVAL_MS / 1000)
|
|
||||||
retries += 1
|
|
||||||
else:
|
|
||||||
return {"error": "Max retries reached while waiting for image generation"}
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": f"Error waiting for image generation: {str(e)}"}
|
|
||||||
|
|
||||||
print(f"comfy-modal - Finished, turning off")
|
|
||||||
|
|
||||||
result = {"status": status}
|
|
||||||
|
|
||||||
except TimeoutError:
|
|
||||||
print("Operation timed out")
|
|
||||||
return {"status": "failed"}
|
|
||||||
finally:
|
finally:
|
||||||
signal.alarm(0)
|
stdout_task.cancel()
|
||||||
|
stderr_task.cancel()
|
||||||
print("uploading log_data")
|
await stdout_task
|
||||||
data = json.dumps({
|
await stderr_task
|
||||||
"run_id": input.prompt_id,
|
|
||||||
"time": datetime.now().isoformat(),
|
|
||||||
"log_data": self.machine_logs
|
|
||||||
}).encode('utf-8')
|
|
||||||
print("my logs", len(self.machine_logs))
|
|
||||||
# Clear logs
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(input.status_endpoint, data=data) as response:
|
|
||||||
print("response", response)
|
|
||||||
print("uploaded log_data")
|
|
||||||
# print(data)
|
|
||||||
self.machine_logs = []
|
|
||||||
|
|
||||||
stdout_task.cancel()
|
|
||||||
stderr_task.cancel()
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user