This commit is contained in:
hl 2023-02-24 22:40:53 -08:00
commit f3cd48e67f
16 changed files with 849 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*.safetensors
*.log

22
Dockerfile Normal file
View File

@ -0,0 +1,22 @@
FROM chilloutai/stable-diffusion:1.0.0
SHELL ["/bin/bash", "-c"]
ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin"
WORKDIR /
ADD model.safetensors /
RUN pip install -U xformers
RUN pip install runpod
ADD koreanDollLikeness_v15.safetensors /workspace/stable-diffusion-webui/models/Lora/koreanDollLikeness_v15.safetensors
ADD japaneseDollLikeness_v10.safetensors /workspace/stable-diffusion-webui/models/Lora/japaneseDollLikeness_v10.safetensors
ADD taiwanDollLikeness_v10.safetensors /workspace/stable-diffusion-webui/models/Lora/taiwanDollLikeness_v10.safetensors
ADD handler.py .
ADD start.sh /start.sh
RUN chmod +x /start.sh
CMD [ "/start.sh" ]

1
README.md Normal file
View File

@ -0,0 +1 @@
### How to Run Serverless in Runpod

1
build.sh Executable file
View File

@ -0,0 +1 @@
DOCKER_BUILDKIT=1 docker build -t chilloutai/auto-api:1.0.0 .

3
download.sh Executable file
View File

@ -0,0 +1,3 @@
wget -O koreanDollLikeness_v15.safetensors https://civitai.com/api/download/models/14014
wget -O japaneseDollLikeness_v10.safetensors https://civitai.com/api/download/models/12050
wget -O taiwanDollLikeness_v10.safetensors https://civitai.com/api/download/models/9070

39
handler.py Normal file
View File

@ -0,0 +1,39 @@
import runpod
import subprocess
import requests
import time
def check_api_availability(host):
while True:
try:
response = requests.get(host)
return
except requests.exceptions.RequestException as e:
print(f"API is not available, retrying in 5s... ({e})")
except Exception as e:
print('something went wrong')
time.sleep(5)
check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img")
print('run handler')
def handler(event):
'''
This is the handler function that will be called by the serverless.
'''
print('got event')
print(event)
response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"])
json = response.json()
# do the things
print(json)
# return the output that you want to be returned like pre-signed URLs to output artifacts
return json
runpod.serverless.start({"handler": handler})

47
runpod_api_test.py Normal file
View File

@ -0,0 +1,47 @@
import time
import requests
import base64
runpod_key = ''
api_name = ''
prompt = '<lora:koreanDollLikeness_v15:0.66>, best quality, ultra high res, (photorealistic:1.4), 1girl, solo focus, ((blue long dress)), elbow dress, black thighhighs, frills, ribbons, studio background, (Kpop idol), (aegyo sal:1), (platinum blonde hair:1), ((puffy eyes)), looking at viewer, facing front, smiling, laughing'
negative_prompt = 'paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples'
def generation():
res = requests.post(f'https://api.runpod.ai/v1/{api_name}/run', headers={
'Content-Type': 'application/json',
"Authorization": f"Bearer {runpod_key}"
}, json={
"input": {"prompt": prompt, "steps": 28, "negative_prompt": negative_prompt, "width": 512, "height": 768, "sampler_index": "DPM++ SDE Karras",
"batch_size": 1, "seed": -1},
})
task_id = res.json()['id']
while True:
res = requests.get(f'https://api.runpod.ai/v1/{api_name}/status/{task_id}', headers={
'Content-Type': 'application/json',
"Authorization": f"Bearer {runpod_key}"
})
status = res.json()['status']
print(status)
if status == 'COMPLETED':
for imgstring in res.json()['output']['images']:
with open(f"test_{time.time()}.png", "wb") as fh:
imgdata = base64.b64decode(imgstring)
fh.write(imgdata)
print(res.json())
break
if status == 'FAILED':
print(res.json())
break
time.sleep(10)
if __name__ == '__main__':
generation()

51
sd-auto-v1/Dockerfile Normal file
View File

@ -0,0 +1,51 @@
FROM runpod/stable-diffusion-models:1.5 as build
FROM ubuntu:22.04 AS runtime
RUN rm -rf /root/.cache/pip
RUN mkdir -p /root/.cache/huggingface
ENV DEBIAN_FRONTEND noninteractive
RUN apt update && \
apt install -y --no-install-recommends \
software-properties-common \
git \
openssh-server \
libglib2.0-0 \
libsm6 \
libxrender1 \
libxext6 \
ffmpeg \
wget \
curl \
python3-pip python3 python3.10-venv \
apt-transport-https ca-certificates && \
update-ca-certificates
WORKDIR /workspace
RUN git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
WORKDIR /workspace/stable-diffusion-webui
ADD webui.py /workspace/stable-diffusion-webui/webui.py
RUN python3 -m venv /workspace/stable-diffusion-webui/venv
ENV PATH="/workspace/stable-diffusion-webui/venv/bin:$PATH"
RUN pip install -U jupyterlab ipywidgets jupyter-archive
RUN jupyter nbextension enable --py widgetsnbextension
ADD install.py .
RUN python -m install --skip-torch-cuda-test
RUN apt clean && rm -rf /var/lib/apt/lists/* && \
echo "en_US.UTF-8 UTF-8" > /etc/locale.gen
ADD relauncher.py .
ADD webui-user.sh .
ADD start.sh /start.sh
RUN chmod a+x /start.sh
SHELL ["/bin/bash", "--login", "-c"]
CMD [ "/start.sh" ]

1
sd-auto-v1/build.sh Executable file
View File

@ -0,0 +1 @@
DOCKER_BUILDKIT=1 docker build -t chilloutai/stable-diffusion:1.0.0 .

3
sd-auto-v1/install.py Normal file
View File

@ -0,0 +1,3 @@
from launch import prepare_environment
prepare_environment()

12
sd-auto-v1/relauncher.py Normal file
View File

@ -0,0 +1,12 @@
import os, time
n = 0
while True:
print('Relauncher: Launching...')
if n > 0:
print(f'\tRelaunch count: {n}')
launch_string = "/workspace/stable-diffusion-webui/webui.sh -f"
os.system(launch_string)
print('Relauncher: Process is ending. Relaunching in 2s...')
n += 1
time.sleep(2)

32
sd-auto-v1/start.sh Normal file
View File

@ -0,0 +1,32 @@
#!/bin/bash
echo "Container Started"
export PYTHONUNBUFFERED=1
source /workspace/stable-diffusion-webui/venv/bin/activate
cd /workspace/stable-diffusion-webui
python relauncher.py &
if [[ $PUBLIC_KEY ]]
then
mkdir -p ~/.ssh
chmod 700 ~/.ssh
cd ~/.ssh
echo $PUBLIC_KEY >> authorized_keys
chmod 700 -R ~/.ssh
cd /
service ssh start
echo "SSH Service Started"
fi
if [[ $JUPYTER_PASSWORD ]]
then
ln -sf /examples /workspace
ln -sf /root/welcome.ipynb /workspace
cd /
jupyter lab --allow-root --no-browser --port=8888 --ip=* \
--ServerApp.terminado_settings='{"shell_command":["/bin/bash"]}' \
--ServerApp.token=$JUPYTER_PASSWORD --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace
echo "Jupyter Lab Started"
fi
sleep infinity

47
sd-auto-v1/webui-user.sh Normal file
View File

@ -0,0 +1,47 @@
# #!/bin/bash
#########################################################
# Uncomment and change the variables below to your need:#
#########################################################
# Install directory without trailing slash
install_dir="/workspace"
# Name of the subdirectory
#clone_dir="stable-diffusion-webui"
# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
export COMMANDLINE_ARGS="--port 3000 --xformers --ckpt /workspace/v1-5-pruned-emaonly.ckpt --listen --enable-insecure-extension-access"
# python3 executable
#python_cmd="python3"
# git executable
#export GIT="git"
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
# venv_dir="/workspace/venv"
# script to launch to start the app
# export LAUNCH_SCRIPT="dummy.py"
# install command for torch
#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
# Requirements file to use for stable-diffusion-webui
#export REQS_FILE="requirements_versions.txt"
# Fixed git repos
#export K_DIFFUSION_PACKAGE=""
#export GFPGAN_PACKAGE=""
# Fixed git commits
#export STABLE_DIFFUSION_COMMIT_HASH=""
#export TAMING_TRANSFORMERS_COMMIT_HASH=""
#export CODEFORMER_COMMIT_HASH=""
#export BLIP_COMMIT_HASH=""
# Uncomment to enable accelerated launch
#export ACCELERATE="True"
###########################################

293
sd-auto-v1/webui.py Normal file
View File

@ -0,0 +1,293 @@
import os
import sys
import time
import importlib
import signal
import re
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img
import modules.lowvram
import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
import modules.progress
import modules.ui
from modules import modelloader
from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork
if cmd_opts.server_name:
server_name = cmd_opts.server_name
else:
server_name = "0.0.0.0" if cmd_opts.listen else None
def check_versions():
if shared.cmd_opts.skip_version_check:
return
expected_torch_version = "1.13.1"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f"""
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
Beware that this will cause a lot of large files to be downloaded, as well as
there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check.
""".strip())
expected_xformers_version = "0.0.16rc425"
if shared.xformers_available:
import xformers
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
errors.print_error_explanation(f"""
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
Use --skip-version-check commandline argument to disable this check.
""".strip())
def initialize():
check_versions()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
return
modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
modelloader.list_builtin_upscalers()
modules.scripts.load_scripts()
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
try:
modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.reload_hypernetworks()
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
try:
if not os.path.exists(cmd_opts.tls_keyfile):
print("Invalid path to TLS keyfile given")
if not os.path.exists(cmd_opts.tls_certfile):
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
except TypeError:
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
print("TLS setup invalid, running webui without TLS")
else:
print("Running with TLS")
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)
signal.signal(signal.SIGINT, sigint_handler)
def setup_cors(app):
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
def create_api(app):
from modules.api.api import Api
api = Api(app, queue_lock)
return api
def wait_on_server(demo=None):
while 1:
time.sleep(0.5)
if shared.state.need_restart:
shared.state.need_restart = False
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
def api_only():
initialize()
app = FastAPI()
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app)
modules.script_callbacks.app_started_callback(None, app)
modules.script_callbacks.before_ui_callback()
extensions.list_extensions()
modules.scripts.reload_scripts()
modules.script_callbacks.model_loaded_callback(shared.sd_model)
modelloader.load_upscalers()
modules.sd_models.list_models()
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui():
launch_api = cmd_opts.api
initialize()
while 1:
if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr()
modules.script_callbacks.before_ui_callback()
shared.demo = modules.ui.create_ui()
if cmd_opts.gradio_queue:
shared.demo.queue(64)
gradio_auth_creds = []
if cmd_opts.gradio_auth:
gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',')
if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',')]
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
server_port=cmd_opts.port,
ssl_keyfile=cmd_opts.tls_keyfile,
ssl_certfile=cmd_opts.tls_certfile,
debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
)
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
# running web ui and do whatever the attacker wants, including installing an extension and
# running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000)
modules.progress.setup_progress_api(app)
if launch_api:
create_api(app)
ui_extra_networks.add_pages_to_demo(app)
modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo)
print('Restarting UI...')
sd_samplers.set_samplers()
modules.script_callbacks.script_unloaded_callback()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
modules.script_callbacks.model_loaded_callback(shared.sd_model)
modelloader.load_upscalers()
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
modules.sd_models.list_models()
shared.reload_hypernetworks()
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
if __name__ == "__main__":
if cmd_opts.nowebui:
api_only()
else:
webui()

11
start.sh Normal file
View File

@ -0,0 +1,11 @@
#!/bin/bash
echo "Container Started"
export PYTHONUNBUFFERED=1
source /workspace/stable-diffusion-webui/venv/bin/activate
cd /workspace/stable-diffusion-webui
echo "starting api"
python webui.py --port 3000 --nowebui --api --xformers --enable-insecure-extension-access --ckpt /model.safetensors &
cd /
echo "starting worker"
python -u handler.py

284
webui.py Normal file
View File

@ -0,0 +1,284 @@
import os
import sys
import time
import importlib
import signal
import re
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img
import modules.lowvram
import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
import modules.progress
import modules.ui
from modules import modelloader
from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork
if cmd_opts.server_name:
server_name = cmd_opts.server_name
else:
server_name = "0.0.0.0" if cmd_opts.listen else None
def check_versions():
if shared.cmd_opts.skip_version_check:
return
expected_torch_version = "1.13.1"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f"""
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
Beware that this will cause a lot of large files to be downloaded, as well as
there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check.
""".strip())
expected_xformers_version = "0.0.16rc425"
if shared.xformers_available:
import xformers
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
errors.print_error_explanation(f"""
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
Use --skip-version-check commandline argument to disable this check.
""".strip())
def initialize():
check_versions()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
return
modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.list_builtin_upscalers()
modules.scripts.load_scripts()
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
try:
modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.reload_hypernetworks()
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
try:
if not os.path.exists(cmd_opts.tls_keyfile):
print("Invalid path to TLS keyfile given")
if not os.path.exists(cmd_opts.tls_certfile):
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
except TypeError:
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
print("TLS setup invalid, running webui without TLS")
else:
print("Running with TLS")
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)
signal.signal(signal.SIGINT, sigint_handler)
def setup_cors(app):
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
def create_api(app):
from modules.api.api import Api
api = Api(app, queue_lock)
return api
def wait_on_server(demo=None):
while 1:
time.sleep(0.5)
if shared.state.need_restart:
shared.state.need_restart = False
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
def api_only():
initialize()
app = FastAPI()
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app)
modules.script_callbacks.app_started_callback(None, app)
modules.script_callbacks.before_ui_callback()
extensions.list_extensions()
modules.scripts.reload_scripts()
modules.script_callbacks.model_loaded_callback(shared.sd_model)
modelloader.load_upscalers()
modules.sd_models.list_models()
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui():
launch_api = cmd_opts.api
initialize()
while 1:
if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr()
modules.script_callbacks.before_ui_callback()
shared.demo = modules.ui.create_ui()
if cmd_opts.gradio_queue:
shared.demo.queue(64)
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
server_port=cmd_opts.port,
ssl_keyfile=cmd_opts.tls_keyfile,
ssl_certfile=cmd_opts.tls_certfile,
debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
)
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
# running web ui and do whatever the attacker wants, including installing an extension and
# running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000)
modules.progress.setup_progress_api(app)
if launch_api:
create_api(app)
ui_extra_networks.add_pages_to_demo(app)
modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo)
print('Restarting UI...')
sd_samplers.set_samplers()
modules.script_callbacks.script_unloaded_callback()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
modules.script_callbacks.model_loaded_callback(shared.sd_model)
modelloader.load_upscalers()
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
modules.sd_models.list_models()
shared.reload_hypernetworks()
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
if __name__ == "__main__":
if cmd_opts.nowebui:
api_only()
else:
webui()