diff --git a/builder/modal-builder/src/template/app.py b/builder/modal-builder/src/template/app.py index 1d4145d..ab6a1fb 100644 --- a/builder/modal-builder/src/template/app.py +++ b/builder/modal-builder/src/template/app.py @@ -1,6 +1,6 @@ from config import config import modal -from modal import Image, Mount, web_endpoint, Stub, asgi_app +from modal import Image, Mount, web_endpoint, Stub, asgi_app, Volume import json import urllib.request import urllib.parse @@ -28,6 +28,8 @@ web_app = FastAPI() print(config) print("deploy_test ", deploy_test) stub = Stub(name=config["name"]) +volume = modal.Volume.persisted("model-store") +MODEL_DIR = "/comfyui/models/checkpoints/" # print(stub.app_id) if not deploy_test: @@ -72,6 +74,7 @@ if not deploy_test: .copy_local_file(f"{current_directory}/data/deps.json", "/") .run_commands("python install_deps.py") + .run_commands(f"rm -rf {MODEL_DIR}") # clear model dir so volume can mount, NOTE: could instead use the extra_model_paths ) # Time to wait between API check attempts in milliseconds @@ -154,7 +157,9 @@ image = Image.debian_slim() target_image = image if deploy_test else dockerfile_image -@stub.function(image=target_image, gpu=config["gpu"]) +@stub.function(image=target_image, gpu=config["gpu"] + , volumes={MODEL_DIR: volume} + ) def run(input: Input): import subprocess import time @@ -235,7 +240,9 @@ async def bar(request_input: RequestInput): # pass -@stub.function(image=image) +@stub.function(image=image + , volumes={MODEL_DIR: volume} + ) @asgi_app() def comfyui_api(): return web_app @@ -285,6 +292,7 @@ def spawn_comfyui_in_background(): # to be on a single container. concurrency_limit=1, timeout=10 * 60, + volumes={MODEL_DIR: volume} ) @asgi_app() def comfyui_app(): @@ -303,4 +311,4 @@ def comfyui_app(): }, )() - return make_simple_proxy_app(ProxyContext(config)) \ No newline at end of file + return make_simple_proxy_app(ProxyContext(config)) diff --git a/builder/modal-builder/src/template/data/insert_models.py b/builder/modal-builder/src/template/data/insert_models.py new file mode 100644 index 0000000..c6571d9 --- /dev/null +++ b/builder/modal-builder/src/template/data/insert_models.py @@ -0,0 +1,44 @@ +import modal +import subprocess +import requests + +stub = modal.Stub() + +# NOTE: volume name can be variable +volume = modal.Volume.persisted("model-store") +model_store_path = "/vol/models" +MODEL_ROUTE = "models" +image = ( + modal.Image.debian_slim().apt_install("wget").pip_install("requests") +) + +@stub.function(volumes={model_store_path: volume}, gpu="any", image=image, timeout=600) +def download_model(model): + # wget https://civitai.com/api/download/models/{modelVersionId} --content-disposition + model_id = model['modelVersions'][0]['id'] + download_url = f"https://civitai.com/api/download/models/{model_id}" + subprocess.run(["wget", download_url, "--content-disposition", "-P", model_store_path]) + subprocess.run(["ls", "-la", model_store_path]) + volume.commit() + +# file is raw output from Civitai API https://github.com/civitai/civitai/wiki/REST-API-Reference + +@stub.function() +def get_civitai_models(model_type: str, sort: str = "Highest Rated", page: int = 1): + """Fetch models from CivitAI API based on type.""" + try: + response = requests.get(f"https://civitai.com/api/v1/models", params={"types": model_type, "page": page, "sort": sort}) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + print(f"Error fetching models: {e}") + return None + +@stub.local_entrypoint() +def insert_model(type: str = "Checkpoint", sort = "Highest Rated", page: int = 1): + civitai_models = get_civitai_models.local(type, sort, page) + if civitai_models: + for _ in download_model.map(civitai_models['items'][1:]): + pass + else: + print("Failed to retrieve models.") diff --git a/builder/modal-builder/src/template/data/install_deps.py b/builder/modal-builder/src/template/data/install_deps.py index 3ff3ca3..6796417 100644 --- a/builder/modal-builder/src/template/data/install_deps.py +++ b/builder/modal-builder/src/template/data/install_deps.py @@ -45,13 +45,13 @@ for package in packages: response = requests.request("POST", f"{root_url}/customnode/install", json=package, headers=headers) print(response.text) -with open('models.json') as f: - models = json.load(f) - -for model in models: - response = requests.request("POST", f"{root_url}/model/install", json=model, headers=headers) - print(response.text) +# with open('models.json') as f: +# models = json.load(f) +# +# for model in models: +# response = requests.request("POST", f"{root_url}/model/install", json=model, headers=headers) +# print(response.text) # Close the server server_process.terminate() -print("Finished installing dependencies.") \ No newline at end of file +print("Finished installing dependencies.")