commit f3cd48e67f38461a006c451a78cfa7915029abaf Author: hl Date: Fri Feb 24 22:40:53 2023 -0800 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b3882dc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.safetensors +*.log \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e9dbf8c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,22 @@ +FROM chilloutai/stable-diffusion:1.0.0 + +SHELL ["/bin/bash", "-c"] + +ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin" + +WORKDIR / + +ADD model.safetensors / + +RUN pip install -U xformers +RUN pip install runpod + +ADD koreanDollLikeness_v15.safetensors /workspace/stable-diffusion-webui/models/Lora/koreanDollLikeness_v15.safetensors +ADD japaneseDollLikeness_v10.safetensors /workspace/stable-diffusion-webui/models/Lora/japaneseDollLikeness_v10.safetensors +ADD taiwanDollLikeness_v10.safetensors /workspace/stable-diffusion-webui/models/Lora/taiwanDollLikeness_v10.safetensors + +ADD handler.py . +ADD start.sh /start.sh +RUN chmod +x /start.sh + +CMD [ "/start.sh" ] diff --git a/README.md b/README.md new file mode 100644 index 0000000..ad84106 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +### How to Run Serverless in Runpod \ No newline at end of file diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..b70adcb --- /dev/null +++ b/build.sh @@ -0,0 +1 @@ +DOCKER_BUILDKIT=1 docker build -t chilloutai/auto-api:1.0.0 . \ No newline at end of file diff --git a/download.sh b/download.sh new file mode 100755 index 0000000..ebaad48 --- /dev/null +++ b/download.sh @@ -0,0 +1,3 @@ +wget -O koreanDollLikeness_v15.safetensors https://civitai.com/api/download/models/14014 +wget -O japaneseDollLikeness_v10.safetensors https://civitai.com/api/download/models/12050 +wget -O taiwanDollLikeness_v10.safetensors https://civitai.com/api/download/models/9070 \ No newline at end of file diff --git a/handler.py b/handler.py new file mode 100644 index 0000000..49292bc --- /dev/null +++ b/handler.py @@ -0,0 +1,39 @@ +import runpod +import subprocess +import requests +import time + +def check_api_availability(host): + while True: + try: + response = requests.get(host) + return + except requests.exceptions.RequestException as e: + print(f"API is not available, retrying in 5s... ({e})") + except Exception as e: + print('something went wrong') + time.sleep(5) + +check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img") + +print('run handler') + +def handler(event): + ''' + This is the handler function that will be called by the serverless. + ''' + print('got event') + print(event) + + response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"]) + + json = response.json() + # do the things + + print(json) + + # return the output that you want to be returned like pre-signed URLs to output artifacts + return json + + +runpod.serverless.start({"handler": handler}) \ No newline at end of file diff --git a/runpod_api_test.py b/runpod_api_test.py new file mode 100644 index 0000000..b6b9dc3 --- /dev/null +++ b/runpod_api_test.py @@ -0,0 +1,47 @@ +import time + +import requests +import base64 + +runpod_key = '' +api_name = '' + +prompt = ', best quality, ultra high res, (photorealistic:1.4), 1girl, solo focus, ((blue long dress)), elbow dress, black thighhighs, frills, ribbons, studio background, (Kpop idol), (aegyo sal:1), (platinum blonde hair:1), ((puffy eyes)), looking at viewer, facing front, smiling, laughing' +negative_prompt = 'paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples' + + +def generation(): + res = requests.post(f'https://api.runpod.ai/v1/{api_name}/run', headers={ + 'Content-Type': 'application/json', + "Authorization": f"Bearer {runpod_key}" + }, json={ + "input": {"prompt": prompt, "steps": 28, "negative_prompt": negative_prompt, "width": 512, "height": 768, "sampler_index": "DPM++ SDE Karras", + "batch_size": 1, "seed": -1}, + }) + + task_id = res.json()['id'] + + while True: + res = requests.get(f'https://api.runpod.ai/v1/{api_name}/status/{task_id}', headers={ + 'Content-Type': 'application/json', + "Authorization": f"Bearer {runpod_key}" + }) + + status = res.json()['status'] + print(status) + if status == 'COMPLETED': + for imgstring in res.json()['output']['images']: + with open(f"test_{time.time()}.png", "wb") as fh: + imgdata = base64.b64decode(imgstring) + fh.write(imgdata) + print(res.json()) + break + + if status == 'FAILED': + print(res.json()) + break + time.sleep(10) + + +if __name__ == '__main__': + generation() diff --git a/sd-auto-v1/Dockerfile b/sd-auto-v1/Dockerfile new file mode 100644 index 0000000..63aca7b --- /dev/null +++ b/sd-auto-v1/Dockerfile @@ -0,0 +1,51 @@ +FROM runpod/stable-diffusion-models:1.5 as build + +FROM ubuntu:22.04 AS runtime + +RUN rm -rf /root/.cache/pip + +RUN mkdir -p /root/.cache/huggingface + +ENV DEBIAN_FRONTEND noninteractive + +RUN apt update && \ +apt install -y --no-install-recommends \ +software-properties-common \ +git \ +openssh-server \ +libglib2.0-0 \ +libsm6 \ +libxrender1 \ +libxext6 \ +ffmpeg \ +wget \ +curl \ +python3-pip python3 python3.10-venv \ +apt-transport-https ca-certificates && \ +update-ca-certificates + +WORKDIR /workspace + +RUN git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git + +WORKDIR /workspace/stable-diffusion-webui +ADD webui.py /workspace/stable-diffusion-webui/webui.py +RUN python3 -m venv /workspace/stable-diffusion-webui/venv +ENV PATH="/workspace/stable-diffusion-webui/venv/bin:$PATH" + +RUN pip install -U jupyterlab ipywidgets jupyter-archive +RUN jupyter nbextension enable --py widgetsnbextension + +ADD install.py . +RUN python -m install --skip-torch-cuda-test + +RUN apt clean && rm -rf /var/lib/apt/lists/* && \ + echo "en_US.UTF-8 UTF-8" > /etc/locale.gen + +ADD relauncher.py . +ADD webui-user.sh . +ADD start.sh /start.sh +RUN chmod a+x /start.sh + +SHELL ["/bin/bash", "--login", "-c"] +CMD [ "/start.sh" ] diff --git a/sd-auto-v1/build.sh b/sd-auto-v1/build.sh new file mode 100755 index 0000000..26b8ae1 --- /dev/null +++ b/sd-auto-v1/build.sh @@ -0,0 +1 @@ +DOCKER_BUILDKIT=1 docker build -t chilloutai/stable-diffusion:1.0.0 . \ No newline at end of file diff --git a/sd-auto-v1/install.py b/sd-auto-v1/install.py new file mode 100644 index 0000000..604d693 --- /dev/null +++ b/sd-auto-v1/install.py @@ -0,0 +1,3 @@ +from launch import prepare_environment + +prepare_environment() diff --git a/sd-auto-v1/relauncher.py b/sd-auto-v1/relauncher.py new file mode 100644 index 0000000..6a1dc95 --- /dev/null +++ b/sd-auto-v1/relauncher.py @@ -0,0 +1,12 @@ +import os, time + +n = 0 +while True: + print('Relauncher: Launching...') + if n > 0: + print(f'\tRelaunch count: {n}') + launch_string = "/workspace/stable-diffusion-webui/webui.sh -f" + os.system(launch_string) + print('Relauncher: Process is ending. Relaunching in 2s...') + n += 1 + time.sleep(2) diff --git a/sd-auto-v1/start.sh b/sd-auto-v1/start.sh new file mode 100644 index 0000000..cade55d --- /dev/null +++ b/sd-auto-v1/start.sh @@ -0,0 +1,32 @@ +#!/bin/bash +echo "Container Started" +export PYTHONUNBUFFERED=1 +source /workspace/stable-diffusion-webui/venv/bin/activate +cd /workspace/stable-diffusion-webui +python relauncher.py & + +if [[ $PUBLIC_KEY ]] +then + mkdir -p ~/.ssh + chmod 700 ~/.ssh + cd ~/.ssh + echo $PUBLIC_KEY >> authorized_keys + chmod 700 -R ~/.ssh + cd / + service ssh start + echo "SSH Service Started" +fi + +if [[ $JUPYTER_PASSWORD ]] +then + ln -sf /examples /workspace + ln -sf /root/welcome.ipynb /workspace + + cd / + jupyter lab --allow-root --no-browser --port=8888 --ip=* \ + --ServerApp.terminado_settings='{"shell_command":["/bin/bash"]}' \ + --ServerApp.token=$JUPYTER_PASSWORD --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace + echo "Jupyter Lab Started" +fi + +sleep infinity diff --git a/sd-auto-v1/webui-user.sh b/sd-auto-v1/webui-user.sh new file mode 100644 index 0000000..d696a06 --- /dev/null +++ b/sd-auto-v1/webui-user.sh @@ -0,0 +1,47 @@ +# #!/bin/bash +######################################################### +# Uncomment and change the variables below to your need:# +######################################################### + +# Install directory without trailing slash +install_dir="/workspace" + +# Name of the subdirectory +#clone_dir="stable-diffusion-webui" + +# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention" +export COMMANDLINE_ARGS="--port 3000 --xformers --ckpt /workspace/v1-5-pruned-emaonly.ckpt --listen --enable-insecure-extension-access" + +# python3 executable +#python_cmd="python3" + +# git executable +#export GIT="git" + +# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) +# venv_dir="/workspace/venv" + +# script to launch to start the app +# export LAUNCH_SCRIPT="dummy.py" + +# install command for torch +#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" + +# Requirements file to use for stable-diffusion-webui +#export REQS_FILE="requirements_versions.txt" + +# Fixed git repos +#export K_DIFFUSION_PACKAGE="" +#export GFPGAN_PACKAGE="" + +# Fixed git commits +#export STABLE_DIFFUSION_COMMIT_HASH="" +#export TAMING_TRANSFORMERS_COMMIT_HASH="" +#export CODEFORMER_COMMIT_HASH="" +#export BLIP_COMMIT_HASH="" + +# Uncomment to enable accelerated launch +#export ACCELERATE="True" + +########################################### + diff --git a/sd-auto-v1/webui.py b/sd-auto-v1/webui.py new file mode 100644 index 0000000..c4d5f95 --- /dev/null +++ b/sd-auto-v1/webui.py @@ -0,0 +1,293 @@ +import os +import sys +import time +import importlib +import signal +import re +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +from packaging import version + +import logging +logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + +from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion +from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call + +import torch + +# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors +if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__long_version__ = torch.__version__ + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) + +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks +import modules.codeformer_model as codeformer +import modules.face_restoration +import modules.gfpgan_model as gfpgan +import modules.img2img + +import modules.lowvram +import modules.paths +import modules.scripts +import modules.sd_hijack +import modules.sd_models +import modules.sd_vae +import modules.txt2img +import modules.script_callbacks +import modules.textual_inversion.textual_inversion +import modules.progress + +import modules.ui +from modules import modelloader +from modules.shared import cmd_opts +import modules.hypernetworks.hypernetwork + + +if cmd_opts.server_name: + server_name = cmd_opts.server_name +else: + server_name = "0.0.0.0" if cmd_opts.listen else None + + +def check_versions(): + if shared.cmd_opts.skip_version_check: + return + + expected_torch_version = "1.13.1" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + errors.print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + expected_xformers_version = "0.0.16rc425" + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + errors.print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + +def initialize(): + check_versions() + + extensions.list_extensions() + localization.list_localizations(cmd_opts.localizations_dir) + + if cmd_opts.ui_debug_mode: + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + modules.scripts.load_scripts() + return + + modelloader.cleanup_models() + modules.sd_models.setup_model() + codeformer.setup_model(cmd_opts.codeformer_models_path) + gfpgan.setup_model(cmd_opts.gfpgan_models_path) + + modelloader.list_builtin_upscalers() + modules.scripts.load_scripts() + modelloader.load_upscalers() + + modules.sd_vae.refresh_vae_list() + + modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + + shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title + + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + os._exit(0) + + signal.signal(signal.SIGINT, sigint_handler) + + +def setup_cors(app): + if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + elif cmd_opts.cors_allow_origins: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + elif cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + + +def create_api(app): + from modules.api.api import Api + api = Api(app, queue_lock) + return api + + +def wait_on_server(demo=None): + while 1: + time.sleep(0.5) + if shared.state.need_restart: + shared.state.need_restart = False + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + + +def api_only(): + initialize() + + app = FastAPI() + setup_cors(app) + app.add_middleware(GZipMiddleware, minimum_size=1000) + api = create_api(app) + + modules.script_callbacks.app_started_callback(None, app) + modules.script_callbacks.before_ui_callback() + extensions.list_extensions() + modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) + modelloader.load_upscalers() + modules.sd_models.list_models() + + + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) + + +def webui(): + launch_api = cmd_opts.api + initialize() + + while 1: + if shared.opts.clean_temp_dir_at_start: + ui_tempdir.cleanup_tmpdr() + + modules.script_callbacks.before_ui_callback() + + shared.demo = modules.ui.create_ui() + + if cmd_opts.gradio_queue: + shared.demo.queue(64) + + gradio_auth_creds = [] + if cmd_opts.gradio_auth: + gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') + if cmd_opts.gradio_auth_path: + with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(',')] + + app, local_url, share_url = shared.demo.launch( + share=cmd_opts.share, + server_name=server_name, + server_port=cmd_opts.port, + ssl_keyfile=cmd_opts.tls_keyfile, + ssl_certfile=cmd_opts.tls_certfile, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + # after initial launch, disable --autolaunch for subsequent restarts + cmd_opts.autolaunch = False + + # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for + # an attacker to trick the user into opening a malicious HTML page, which makes a request to the + # running web ui and do whatever the attacker wants, including installing an extension and + # running its code. We disable this here. Suggested by RyotaK. + app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] + + setup_cors(app) + + app.add_middleware(GZipMiddleware, minimum_size=1000) + + modules.progress.setup_progress_api(app) + + if launch_api: + create_api(app) + + ui_extra_networks.add_pages_to_demo(app) + + modules.script_callbacks.app_started_callback(shared.demo, app) + + wait_on_server(shared.demo) + print('Restarting UI...') + + sd_samplers.set_samplers() + + modules.script_callbacks.script_unloaded_callback() + extensions.list_extensions() + + localization.list_localizations(cmd_opts.localizations_dir) + + modelloader.forbid_loaded_nonbuiltin_upscalers() + modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) + modelloader.load_upscalers() + + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + + modules.sd_models.list_models() + + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + + +if __name__ == "__main__": + if cmd_opts.nowebui: + api_only() + else: + webui() diff --git a/start.sh b/start.sh new file mode 100644 index 0000000..8092eec --- /dev/null +++ b/start.sh @@ -0,0 +1,11 @@ +#!/bin/bash +echo "Container Started" +export PYTHONUNBUFFERED=1 +source /workspace/stable-diffusion-webui/venv/bin/activate +cd /workspace/stable-diffusion-webui +echo "starting api" +python webui.py --port 3000 --nowebui --api --xformers --enable-insecure-extension-access --ckpt /model.safetensors & +cd / + +echo "starting worker" +python -u handler.py \ No newline at end of file diff --git a/webui.py b/webui.py new file mode 100644 index 0000000..68812ab --- /dev/null +++ b/webui.py @@ -0,0 +1,284 @@ +import os +import sys +import time +import importlib +import signal +import re +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +from packaging import version + +import logging +logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + +from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion +from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call + +import torch + +# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors +if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) + +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks +import modules.codeformer_model as codeformer +import modules.face_restoration +import modules.gfpgan_model as gfpgan +import modules.img2img + +import modules.lowvram +import modules.paths +import modules.scripts +import modules.sd_hijack +import modules.sd_models +import modules.sd_vae +import modules.txt2img +import modules.script_callbacks +import modules.textual_inversion.textual_inversion +import modules.progress + +import modules.ui +from modules import modelloader +from modules.shared import cmd_opts +import modules.hypernetworks.hypernetwork + + +if cmd_opts.server_name: + server_name = cmd_opts.server_name +else: + server_name = "0.0.0.0" if cmd_opts.listen else None + + +def check_versions(): + if shared.cmd_opts.skip_version_check: + return + + expected_torch_version = "1.13.1" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + errors.print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + expected_xformers_version = "0.0.16rc425" + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + errors.print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + +def initialize(): + check_versions() + + extensions.list_extensions() + localization.list_localizations(cmd_opts.localizations_dir) + + if cmd_opts.ui_debug_mode: + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + modules.scripts.load_scripts() + return + + modelloader.cleanup_models() + modules.sd_models.setup_model() + codeformer.setup_model(cmd_opts.codeformer_models_path) + gfpgan.setup_model(cmd_opts.gfpgan_models_path) + shared.face_restorers.append(modules.face_restoration.FaceRestoration()) + + modelloader.list_builtin_upscalers() + modules.scripts.load_scripts() + modelloader.load_upscalers() + + modules.sd_vae.refresh_vae_list() + + modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + + shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title + + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + os._exit(0) + + signal.signal(signal.SIGINT, sigint_handler) + + +def setup_cors(app): + if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + elif cmd_opts.cors_allow_origins: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + elif cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + + +def create_api(app): + from modules.api.api import Api + api = Api(app, queue_lock) + return api + + +def wait_on_server(demo=None): + while 1: + time.sleep(0.5) + if shared.state.need_restart: + shared.state.need_restart = False + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + + +def api_only(): + initialize() + + app = FastAPI() + setup_cors(app) + app.add_middleware(GZipMiddleware, minimum_size=1000) + api = create_api(app) + + modules.script_callbacks.app_started_callback(None, app) + modules.script_callbacks.before_ui_callback() + extensions.list_extensions() + modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) + modelloader.load_upscalers() + modules.sd_models.list_models() + + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) + + +def webui(): + launch_api = cmd_opts.api + initialize() + + while 1: + if shared.opts.clean_temp_dir_at_start: + ui_tempdir.cleanup_tmpdr() + + modules.script_callbacks.before_ui_callback() + + shared.demo = modules.ui.create_ui() + + if cmd_opts.gradio_queue: + shared.demo.queue(64) + + app, local_url, share_url = shared.demo.launch( + share=cmd_opts.share, + server_name=server_name, + server_port=cmd_opts.port, + ssl_keyfile=cmd_opts.tls_keyfile, + ssl_certfile=cmd_opts.tls_certfile, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + # after initial launch, disable --autolaunch for subsequent restarts + cmd_opts.autolaunch = False + + # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for + # an attacker to trick the user into opening a malicious HTML page, which makes a request to the + # running web ui and do whatever the attacker wants, including installing an extension and + # running its code. We disable this here. Suggested by RyotaK. + app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] + + setup_cors(app) + + app.add_middleware(GZipMiddleware, minimum_size=1000) + + modules.progress.setup_progress_api(app) + + if launch_api: + create_api(app) + + ui_extra_networks.add_pages_to_demo(app) + + modules.script_callbacks.app_started_callback(shared.demo, app) + + wait_on_server(shared.demo) + print('Restarting UI...') + + sd_samplers.set_samplers() + + modules.script_callbacks.script_unloaded_callback() + extensions.list_extensions() + + localization.list_localizations(cmd_opts.localizations_dir) + + modelloader.forbid_loaded_nonbuiltin_upscalers() + modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) + modelloader.load_upscalers() + + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + + modules.sd_models.list_models() + + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + + +if __name__ == "__main__": + if cmd_opts.nowebui: + api_only() + else: + webui() \ No newline at end of file