diff --git a/chilloutmix-api-template/Dockerfile b/chilloutmix-api-template/Dockerfile new file mode 100644 index 0000000..ae1fae3 --- /dev/null +++ b/chilloutmix-api-template/Dockerfile @@ -0,0 +1,17 @@ +FROM chilloutai/chilloutmix-template:1.2.0 + +SHELL ["/bin/bash", "-c"] + +ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin" + +WORKDIR / + +RUN pip install -U xformers +RUN pip install runpod + +ADD webui.py /workspace/stable-diffusion-webui/webui.py +ADD handler.py . +ADD start.sh /start.sh +RUN chmod +x /start.sh + +CMD [ "/start.sh" ] diff --git a/chilloutmix-api-template/build.sh b/chilloutmix-api-template/build.sh new file mode 100755 index 0000000..d1b7a42 --- /dev/null +++ b/chilloutmix-api-template/build.sh @@ -0,0 +1 @@ +DOCKER_BUILDKIT=1 docker build -t chilloutai/auto-api:1.2.0 . \ No newline at end of file diff --git a/chilloutmix-api-template/handler.py b/chilloutmix-api-template/handler.py new file mode 100644 index 0000000..49292bc --- /dev/null +++ b/chilloutmix-api-template/handler.py @@ -0,0 +1,39 @@ +import runpod +import subprocess +import requests +import time + +def check_api_availability(host): + while True: + try: + response = requests.get(host) + return + except requests.exceptions.RequestException as e: + print(f"API is not available, retrying in 5s... ({e})") + except Exception as e: + print('something went wrong') + time.sleep(5) + +check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img") + +print('run handler') + +def handler(event): + ''' + This is the handler function that will be called by the serverless. + ''' + print('got event') + print(event) + + response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"]) + + json = response.json() + # do the things + + print(json) + + # return the output that you want to be returned like pre-signed URLs to output artifacts + return json + + +runpod.serverless.start({"handler": handler}) \ No newline at end of file diff --git a/chilloutmix-api-template/start.sh b/chilloutmix-api-template/start.sh new file mode 100644 index 0000000..f49d308 --- /dev/null +++ b/chilloutmix-api-template/start.sh @@ -0,0 +1,11 @@ +#!/bin/bash +echo "Container Started" +export PYTHONUNBUFFERED=1 +source /workspace/stable-diffusion-webui/venv/bin/activate +cd /workspace/stable-diffusion-webui +echo "starting api" +python webui.py --port 3000 --nowebui --api --xformers --enable-insecure-extension-access --ckpt ./models/Stable-diffusion/basemodel.safetensors & +cd / + +echo "starting worker" +python -u handler.py \ No newline at end of file diff --git a/chilloutmix-api-template/webui.py b/chilloutmix-api-template/webui.py new file mode 100644 index 0000000..68812ab --- /dev/null +++ b/chilloutmix-api-template/webui.py @@ -0,0 +1,284 @@ +import os +import sys +import time +import importlib +import signal +import re +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +from packaging import version + +import logging +logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + +from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion +from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call + +import torch + +# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors +if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) + +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks +import modules.codeformer_model as codeformer +import modules.face_restoration +import modules.gfpgan_model as gfpgan +import modules.img2img + +import modules.lowvram +import modules.paths +import modules.scripts +import modules.sd_hijack +import modules.sd_models +import modules.sd_vae +import modules.txt2img +import modules.script_callbacks +import modules.textual_inversion.textual_inversion +import modules.progress + +import modules.ui +from modules import modelloader +from modules.shared import cmd_opts +import modules.hypernetworks.hypernetwork + + +if cmd_opts.server_name: + server_name = cmd_opts.server_name +else: + server_name = "0.0.0.0" if cmd_opts.listen else None + + +def check_versions(): + if shared.cmd_opts.skip_version_check: + return + + expected_torch_version = "1.13.1" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + errors.print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + expected_xformers_version = "0.0.16rc425" + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + errors.print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + +def initialize(): + check_versions() + + extensions.list_extensions() + localization.list_localizations(cmd_opts.localizations_dir) + + if cmd_opts.ui_debug_mode: + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + modules.scripts.load_scripts() + return + + modelloader.cleanup_models() + modules.sd_models.setup_model() + codeformer.setup_model(cmd_opts.codeformer_models_path) + gfpgan.setup_model(cmd_opts.gfpgan_models_path) + shared.face_restorers.append(modules.face_restoration.FaceRestoration()) + + modelloader.list_builtin_upscalers() + modules.scripts.load_scripts() + modelloader.load_upscalers() + + modules.sd_vae.refresh_vae_list() + + modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + + shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title + + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + os._exit(0) + + signal.signal(signal.SIGINT, sigint_handler) + + +def setup_cors(app): + if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + elif cmd_opts.cors_allow_origins: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + elif cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + + +def create_api(app): + from modules.api.api import Api + api = Api(app, queue_lock) + return api + + +def wait_on_server(demo=None): + while 1: + time.sleep(0.5) + if shared.state.need_restart: + shared.state.need_restart = False + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + + +def api_only(): + initialize() + + app = FastAPI() + setup_cors(app) + app.add_middleware(GZipMiddleware, minimum_size=1000) + api = create_api(app) + + modules.script_callbacks.app_started_callback(None, app) + modules.script_callbacks.before_ui_callback() + extensions.list_extensions() + modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) + modelloader.load_upscalers() + modules.sd_models.list_models() + + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) + + +def webui(): + launch_api = cmd_opts.api + initialize() + + while 1: + if shared.opts.clean_temp_dir_at_start: + ui_tempdir.cleanup_tmpdr() + + modules.script_callbacks.before_ui_callback() + + shared.demo = modules.ui.create_ui() + + if cmd_opts.gradio_queue: + shared.demo.queue(64) + + app, local_url, share_url = shared.demo.launch( + share=cmd_opts.share, + server_name=server_name, + server_port=cmd_opts.port, + ssl_keyfile=cmd_opts.tls_keyfile, + ssl_certfile=cmd_opts.tls_certfile, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + # after initial launch, disable --autolaunch for subsequent restarts + cmd_opts.autolaunch = False + + # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for + # an attacker to trick the user into opening a malicious HTML page, which makes a request to the + # running web ui and do whatever the attacker wants, including installing an extension and + # running its code. We disable this here. Suggested by RyotaK. + app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] + + setup_cors(app) + + app.add_middleware(GZipMiddleware, minimum_size=1000) + + modules.progress.setup_progress_api(app) + + if launch_api: + create_api(app) + + ui_extra_networks.add_pages_to_demo(app) + + modules.script_callbacks.app_started_callback(shared.demo, app) + + wait_on_server(shared.demo) + print('Restarting UI...') + + sd_samplers.set_samplers() + + modules.script_callbacks.script_unloaded_callback() + extensions.list_extensions() + + localization.list_localizations(cmd_opts.localizations_dir) + + modelloader.forbid_loaded_nonbuiltin_upscalers() + modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) + modelloader.load_upscalers() + + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + + modules.sd_models.list_models() + + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + + +if __name__ == "__main__": + if cmd_opts.nowebui: + api_only() + else: + webui() \ No newline at end of file