init
This commit is contained in:
commit
f3cd48e67f
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*.safetensors
|
||||
*.log
|
22
Dockerfile
Normal file
22
Dockerfile
Normal file
@ -0,0 +1,22 @@
|
||||
FROM chilloutai/stable-diffusion:1.0.0
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin"
|
||||
|
||||
WORKDIR /
|
||||
|
||||
ADD model.safetensors /
|
||||
|
||||
RUN pip install -U xformers
|
||||
RUN pip install runpod
|
||||
|
||||
ADD koreanDollLikeness_v15.safetensors /workspace/stable-diffusion-webui/models/Lora/koreanDollLikeness_v15.safetensors
|
||||
ADD japaneseDollLikeness_v10.safetensors /workspace/stable-diffusion-webui/models/Lora/japaneseDollLikeness_v10.safetensors
|
||||
ADD taiwanDollLikeness_v10.safetensors /workspace/stable-diffusion-webui/models/Lora/taiwanDollLikeness_v10.safetensors
|
||||
|
||||
ADD handler.py .
|
||||
ADD start.sh /start.sh
|
||||
RUN chmod +x /start.sh
|
||||
|
||||
CMD [ "/start.sh" ]
|
1
build.sh
Executable file
1
build.sh
Executable file
@ -0,0 +1 @@
|
||||
DOCKER_BUILDKIT=1 docker build -t chilloutai/auto-api:1.0.0 .
|
3
download.sh
Executable file
3
download.sh
Executable file
@ -0,0 +1,3 @@
|
||||
wget -O koreanDollLikeness_v15.safetensors https://civitai.com/api/download/models/14014
|
||||
wget -O japaneseDollLikeness_v10.safetensors https://civitai.com/api/download/models/12050
|
||||
wget -O taiwanDollLikeness_v10.safetensors https://civitai.com/api/download/models/9070
|
39
handler.py
Normal file
39
handler.py
Normal file
@ -0,0 +1,39 @@
|
||||
import runpod
|
||||
import subprocess
|
||||
import requests
|
||||
import time
|
||||
|
||||
def check_api_availability(host):
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(host)
|
||||
return
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"API is not available, retrying in 5s... ({e})")
|
||||
except Exception as e:
|
||||
print('something went wrong')
|
||||
time.sleep(5)
|
||||
|
||||
check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img")
|
||||
|
||||
print('run handler')
|
||||
|
||||
def handler(event):
|
||||
'''
|
||||
This is the handler function that will be called by the serverless.
|
||||
'''
|
||||
print('got event')
|
||||
print(event)
|
||||
|
||||
response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"])
|
||||
|
||||
json = response.json()
|
||||
# do the things
|
||||
|
||||
print(json)
|
||||
|
||||
# return the output that you want to be returned like pre-signed URLs to output artifacts
|
||||
return json
|
||||
|
||||
|
||||
runpod.serverless.start({"handler": handler})
|
47
runpod_api_test.py
Normal file
47
runpod_api_test.py
Normal file
@ -0,0 +1,47 @@
|
||||
import time
|
||||
|
||||
import requests
|
||||
import base64
|
||||
|
||||
runpod_key = ''
|
||||
api_name = ''
|
||||
|
||||
prompt = '<lora:koreanDollLikeness_v15:0.66>, best quality, ultra high res, (photorealistic:1.4), 1girl, solo focus, ((blue long dress)), elbow dress, black thighhighs, frills, ribbons, studio background, (Kpop idol), (aegyo sal:1), (platinum blonde hair:1), ((puffy eyes)), looking at viewer, facing front, smiling, laughing'
|
||||
negative_prompt = 'paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples'
|
||||
|
||||
|
||||
def generation():
|
||||
res = requests.post(f'https://api.runpod.ai/v1/{api_name}/run', headers={
|
||||
'Content-Type': 'application/json',
|
||||
"Authorization": f"Bearer {runpod_key}"
|
||||
}, json={
|
||||
"input": {"prompt": prompt, "steps": 28, "negative_prompt": negative_prompt, "width": 512, "height": 768, "sampler_index": "DPM++ SDE Karras",
|
||||
"batch_size": 1, "seed": -1},
|
||||
})
|
||||
|
||||
task_id = res.json()['id']
|
||||
|
||||
while True:
|
||||
res = requests.get(f'https://api.runpod.ai/v1/{api_name}/status/{task_id}', headers={
|
||||
'Content-Type': 'application/json',
|
||||
"Authorization": f"Bearer {runpod_key}"
|
||||
})
|
||||
|
||||
status = res.json()['status']
|
||||
print(status)
|
||||
if status == 'COMPLETED':
|
||||
for imgstring in res.json()['output']['images']:
|
||||
with open(f"test_{time.time()}.png", "wb") as fh:
|
||||
imgdata = base64.b64decode(imgstring)
|
||||
fh.write(imgdata)
|
||||
print(res.json())
|
||||
break
|
||||
|
||||
if status == 'FAILED':
|
||||
print(res.json())
|
||||
break
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generation()
|
51
sd-auto-v1/Dockerfile
Normal file
51
sd-auto-v1/Dockerfile
Normal file
@ -0,0 +1,51 @@
|
||||
FROM runpod/stable-diffusion-models:1.5 as build
|
||||
|
||||
FROM ubuntu:22.04 AS runtime
|
||||
|
||||
RUN rm -rf /root/.cache/pip
|
||||
|
||||
RUN mkdir -p /root/.cache/huggingface
|
||||
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
git \
|
||||
openssh-server \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6 \
|
||||
ffmpeg \
|
||||
wget \
|
||||
curl \
|
||||
python3-pip python3 python3.10-venv \
|
||||
apt-transport-https ca-certificates && \
|
||||
update-ca-certificates
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
|
||||
|
||||
WORKDIR /workspace/stable-diffusion-webui
|
||||
ADD webui.py /workspace/stable-diffusion-webui/webui.py
|
||||
RUN python3 -m venv /workspace/stable-diffusion-webui/venv
|
||||
ENV PATH="/workspace/stable-diffusion-webui/venv/bin:$PATH"
|
||||
|
||||
RUN pip install -U jupyterlab ipywidgets jupyter-archive
|
||||
RUN jupyter nbextension enable --py widgetsnbextension
|
||||
|
||||
ADD install.py .
|
||||
RUN python -m install --skip-torch-cuda-test
|
||||
|
||||
RUN apt clean && rm -rf /var/lib/apt/lists/* && \
|
||||
echo "en_US.UTF-8 UTF-8" > /etc/locale.gen
|
||||
|
||||
ADD relauncher.py .
|
||||
ADD webui-user.sh .
|
||||
ADD start.sh /start.sh
|
||||
RUN chmod a+x /start.sh
|
||||
|
||||
SHELL ["/bin/bash", "--login", "-c"]
|
||||
CMD [ "/start.sh" ]
|
1
sd-auto-v1/build.sh
Executable file
1
sd-auto-v1/build.sh
Executable file
@ -0,0 +1 @@
|
||||
DOCKER_BUILDKIT=1 docker build -t chilloutai/stable-diffusion:1.0.0 .
|
3
sd-auto-v1/install.py
Normal file
3
sd-auto-v1/install.py
Normal file
@ -0,0 +1,3 @@
|
||||
from launch import prepare_environment
|
||||
|
||||
prepare_environment()
|
12
sd-auto-v1/relauncher.py
Normal file
12
sd-auto-v1/relauncher.py
Normal file
@ -0,0 +1,12 @@
|
||||
import os, time
|
||||
|
||||
n = 0
|
||||
while True:
|
||||
print('Relauncher: Launching...')
|
||||
if n > 0:
|
||||
print(f'\tRelaunch count: {n}')
|
||||
launch_string = "/workspace/stable-diffusion-webui/webui.sh -f"
|
||||
os.system(launch_string)
|
||||
print('Relauncher: Process is ending. Relaunching in 2s...')
|
||||
n += 1
|
||||
time.sleep(2)
|
32
sd-auto-v1/start.sh
Normal file
32
sd-auto-v1/start.sh
Normal file
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
echo "Container Started"
|
||||
export PYTHONUNBUFFERED=1
|
||||
source /workspace/stable-diffusion-webui/venv/bin/activate
|
||||
cd /workspace/stable-diffusion-webui
|
||||
python relauncher.py &
|
||||
|
||||
if [[ $PUBLIC_KEY ]]
|
||||
then
|
||||
mkdir -p ~/.ssh
|
||||
chmod 700 ~/.ssh
|
||||
cd ~/.ssh
|
||||
echo $PUBLIC_KEY >> authorized_keys
|
||||
chmod 700 -R ~/.ssh
|
||||
cd /
|
||||
service ssh start
|
||||
echo "SSH Service Started"
|
||||
fi
|
||||
|
||||
if [[ $JUPYTER_PASSWORD ]]
|
||||
then
|
||||
ln -sf /examples /workspace
|
||||
ln -sf /root/welcome.ipynb /workspace
|
||||
|
||||
cd /
|
||||
jupyter lab --allow-root --no-browser --port=8888 --ip=* \
|
||||
--ServerApp.terminado_settings='{"shell_command":["/bin/bash"]}' \
|
||||
--ServerApp.token=$JUPYTER_PASSWORD --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace
|
||||
echo "Jupyter Lab Started"
|
||||
fi
|
||||
|
||||
sleep infinity
|
47
sd-auto-v1/webui-user.sh
Normal file
47
sd-auto-v1/webui-user.sh
Normal file
@ -0,0 +1,47 @@
|
||||
# #!/bin/bash
|
||||
#########################################################
|
||||
# Uncomment and change the variables below to your need:#
|
||||
#########################################################
|
||||
|
||||
# Install directory without trailing slash
|
||||
install_dir="/workspace"
|
||||
|
||||
# Name of the subdirectory
|
||||
#clone_dir="stable-diffusion-webui"
|
||||
|
||||
# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
|
||||
export COMMANDLINE_ARGS="--port 3000 --xformers --ckpt /workspace/v1-5-pruned-emaonly.ckpt --listen --enable-insecure-extension-access"
|
||||
|
||||
# python3 executable
|
||||
#python_cmd="python3"
|
||||
|
||||
# git executable
|
||||
#export GIT="git"
|
||||
|
||||
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
|
||||
# venv_dir="/workspace/venv"
|
||||
|
||||
# script to launch to start the app
|
||||
# export LAUNCH_SCRIPT="dummy.py"
|
||||
|
||||
# install command for torch
|
||||
#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
|
||||
|
||||
# Requirements file to use for stable-diffusion-webui
|
||||
#export REQS_FILE="requirements_versions.txt"
|
||||
|
||||
# Fixed git repos
|
||||
#export K_DIFFUSION_PACKAGE=""
|
||||
#export GFPGAN_PACKAGE=""
|
||||
|
||||
# Fixed git commits
|
||||
#export STABLE_DIFFUSION_COMMIT_HASH=""
|
||||
#export TAMING_TRANSFORMERS_COMMIT_HASH=""
|
||||
#export CODEFORMER_COMMIT_HASH=""
|
||||
#export BLIP_COMMIT_HASH=""
|
||||
|
||||
# Uncomment to enable accelerated launch
|
||||
#export ACCELERATE="True"
|
||||
|
||||
###########################################
|
||||
|
293
sd-auto-v1/webui.py
Normal file
293
sd-auto-v1/webui.py
Normal file
@ -0,0 +1,293 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import importlib
|
||||
import signal
|
||||
import re
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from packaging import version
|
||||
|
||||
import logging
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
|
||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||
|
||||
import torch
|
||||
|
||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||
torch.__long_version__ = torch.__version__
|
||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||
|
||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
|
||||
import modules.codeformer_model as codeformer
|
||||
import modules.face_restoration
|
||||
import modules.gfpgan_model as gfpgan
|
||||
import modules.img2img
|
||||
|
||||
import modules.lowvram
|
||||
import modules.paths
|
||||
import modules.scripts
|
||||
import modules.sd_hijack
|
||||
import modules.sd_models
|
||||
import modules.sd_vae
|
||||
import modules.txt2img
|
||||
import modules.script_callbacks
|
||||
import modules.textual_inversion.textual_inversion
|
||||
import modules.progress
|
||||
|
||||
import modules.ui
|
||||
from modules import modelloader
|
||||
from modules.shared import cmd_opts
|
||||
import modules.hypernetworks.hypernetwork
|
||||
|
||||
|
||||
if cmd_opts.server_name:
|
||||
server_name = cmd_opts.server_name
|
||||
else:
|
||||
server_name = "0.0.0.0" if cmd_opts.listen else None
|
||||
|
||||
|
||||
def check_versions():
|
||||
if shared.cmd_opts.skip_version_check:
|
||||
return
|
||||
|
||||
expected_torch_version = "1.13.1"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
errors.print_error_explanation(f"""
|
||||
You are running torch {torch.__version__}.
|
||||
The program is tested to work with torch {expected_torch_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||
Beware that this will cause a lot of large files to be downloaded, as well as
|
||||
there are reports of issues with training tab on the latest version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
expected_xformers_version = "0.0.16rc425"
|
||||
if shared.xformers_available:
|
||||
import xformers
|
||||
|
||||
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
||||
errors.print_error_explanation(f"""
|
||||
You are running xformers {xformers.__version__}.
|
||||
The program is tested to work with xformers {expected_xformers_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
|
||||
def initialize():
|
||||
check_versions()
|
||||
|
||||
extensions.list_extensions()
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
if cmd_opts.ui_debug_mode:
|
||||
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
||||
modules.scripts.load_scripts()
|
||||
return
|
||||
|
||||
modelloader.cleanup_models()
|
||||
modules.sd_models.setup_model()
|
||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||
|
||||
modelloader.list_builtin_upscalers()
|
||||
modules.scripts.load_scripts()
|
||||
modelloader.load_upscalers()
|
||||
|
||||
modules.sd_vae.refresh_vae_list()
|
||||
|
||||
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||
|
||||
try:
|
||||
modules.sd_models.load_model()
|
||||
except Exception as e:
|
||||
errors.display(e, "loading stable diffusion model")
|
||||
print("", file=sys.stderr)
|
||||
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
|
||||
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
ui_extra_networks.intialize()
|
||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
||||
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||
|
||||
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
|
||||
|
||||
try:
|
||||
if not os.path.exists(cmd_opts.tls_keyfile):
|
||||
print("Invalid path to TLS keyfile given")
|
||||
if not os.path.exists(cmd_opts.tls_certfile):
|
||||
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
|
||||
except TypeError:
|
||||
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
|
||||
print("TLS setup invalid, running webui without TLS")
|
||||
else:
|
||||
print("Running with TLS")
|
||||
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
def sigint_handler(sig, frame):
|
||||
print(f'Interrupted with signal {sig} in {frame}')
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
|
||||
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||
elif cmd_opts.cors_allow_origins:
|
||||
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||
elif cmd_opts.cors_allow_origins_regex:
|
||||
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||
|
||||
|
||||
def create_api(app):
|
||||
from modules.api.api import Api
|
||||
api = Api(app, queue_lock)
|
||||
return api
|
||||
|
||||
|
||||
def wait_on_server(demo=None):
|
||||
while 1:
|
||||
time.sleep(0.5)
|
||||
if shared.state.need_restart:
|
||||
shared.state.need_restart = False
|
||||
time.sleep(0.5)
|
||||
demo.close()
|
||||
time.sleep(0.5)
|
||||
break
|
||||
|
||||
|
||||
def api_only():
|
||||
initialize()
|
||||
|
||||
app = FastAPI()
|
||||
setup_cors(app)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
api = create_api(app)
|
||||
|
||||
modules.script_callbacks.app_started_callback(None, app)
|
||||
modules.script_callbacks.before_ui_callback()
|
||||
extensions.list_extensions()
|
||||
modules.scripts.reload_scripts()
|
||||
modules.script_callbacks.model_loaded_callback(shared.sd_model)
|
||||
modelloader.load_upscalers()
|
||||
modules.sd_models.list_models()
|
||||
|
||||
|
||||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||
|
||||
|
||||
def webui():
|
||||
launch_api = cmd_opts.api
|
||||
initialize()
|
||||
|
||||
while 1:
|
||||
if shared.opts.clean_temp_dir_at_start:
|
||||
ui_tempdir.cleanup_tmpdr()
|
||||
|
||||
modules.script_callbacks.before_ui_callback()
|
||||
|
||||
shared.demo = modules.ui.create_ui()
|
||||
|
||||
if cmd_opts.gradio_queue:
|
||||
shared.demo.queue(64)
|
||||
|
||||
gradio_auth_creds = []
|
||||
if cmd_opts.gradio_auth:
|
||||
gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',')
|
||||
if cmd_opts.gradio_auth_path:
|
||||
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
||||
for line in file.readlines():
|
||||
gradio_auth_creds += [x.strip() for x in line.split(',')]
|
||||
|
||||
app, local_url, share_url = shared.demo.launch(
|
||||
share=cmd_opts.share,
|
||||
server_name=server_name,
|
||||
server_port=cmd_opts.port,
|
||||
ssl_keyfile=cmd_opts.tls_keyfile,
|
||||
ssl_certfile=cmd_opts.tls_certfile,
|
||||
debug=cmd_opts.gradio_debug,
|
||||
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
|
||||
inbrowser=cmd_opts.autolaunch,
|
||||
prevent_thread_lock=True
|
||||
)
|
||||
# after initial launch, disable --autolaunch for subsequent restarts
|
||||
cmd_opts.autolaunch = False
|
||||
|
||||
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
||||
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
||||
# running web ui and do whatever the attacker wants, including installing an extension and
|
||||
# running its code. We disable this here. Suggested by RyotaK.
|
||||
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
|
||||
|
||||
setup_cors(app)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
modules.progress.setup_progress_api(app)
|
||||
|
||||
if launch_api:
|
||||
create_api(app)
|
||||
|
||||
ui_extra_networks.add_pages_to_demo(app)
|
||||
|
||||
modules.script_callbacks.app_started_callback(shared.demo, app)
|
||||
|
||||
wait_on_server(shared.demo)
|
||||
print('Restarting UI...')
|
||||
|
||||
sd_samplers.set_samplers()
|
||||
|
||||
modules.script_callbacks.script_unloaded_callback()
|
||||
extensions.list_extensions()
|
||||
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
modelloader.forbid_loaded_nonbuiltin_upscalers()
|
||||
modules.scripts.reload_scripts()
|
||||
modules.script_callbacks.model_loaded_callback(shared.sd_model)
|
||||
modelloader.load_upscalers()
|
||||
|
||||
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||
importlib.reload(module)
|
||||
|
||||
modules.sd_models.list_models()
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
ui_extra_networks.intialize()
|
||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
||||
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if cmd_opts.nowebui:
|
||||
api_only()
|
||||
else:
|
||||
webui()
|
11
start.sh
Normal file
11
start.sh
Normal file
@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
echo "Container Started"
|
||||
export PYTHONUNBUFFERED=1
|
||||
source /workspace/stable-diffusion-webui/venv/bin/activate
|
||||
cd /workspace/stable-diffusion-webui
|
||||
echo "starting api"
|
||||
python webui.py --port 3000 --nowebui --api --xformers --enable-insecure-extension-access --ckpt /model.safetensors &
|
||||
cd /
|
||||
|
||||
echo "starting worker"
|
||||
python -u handler.py
|
284
webui.py
Normal file
284
webui.py
Normal file
@ -0,0 +1,284 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import importlib
|
||||
import signal
|
||||
import re
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from packaging import version
|
||||
|
||||
import logging
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
|
||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||
|
||||
import torch
|
||||
|
||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||
|
||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
|
||||
import modules.codeformer_model as codeformer
|
||||
import modules.face_restoration
|
||||
import modules.gfpgan_model as gfpgan
|
||||
import modules.img2img
|
||||
|
||||
import modules.lowvram
|
||||
import modules.paths
|
||||
import modules.scripts
|
||||
import modules.sd_hijack
|
||||
import modules.sd_models
|
||||
import modules.sd_vae
|
||||
import modules.txt2img
|
||||
import modules.script_callbacks
|
||||
import modules.textual_inversion.textual_inversion
|
||||
import modules.progress
|
||||
|
||||
import modules.ui
|
||||
from modules import modelloader
|
||||
from modules.shared import cmd_opts
|
||||
import modules.hypernetworks.hypernetwork
|
||||
|
||||
|
||||
if cmd_opts.server_name:
|
||||
server_name = cmd_opts.server_name
|
||||
else:
|
||||
server_name = "0.0.0.0" if cmd_opts.listen else None
|
||||
|
||||
|
||||
def check_versions():
|
||||
if shared.cmd_opts.skip_version_check:
|
||||
return
|
||||
|
||||
expected_torch_version = "1.13.1"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
errors.print_error_explanation(f"""
|
||||
You are running torch {torch.__version__}.
|
||||
The program is tested to work with torch {expected_torch_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||
Beware that this will cause a lot of large files to be downloaded, as well as
|
||||
there are reports of issues with training tab on the latest version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
expected_xformers_version = "0.0.16rc425"
|
||||
if shared.xformers_available:
|
||||
import xformers
|
||||
|
||||
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
||||
errors.print_error_explanation(f"""
|
||||
You are running xformers {xformers.__version__}.
|
||||
The program is tested to work with xformers {expected_xformers_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
|
||||
def initialize():
|
||||
check_versions()
|
||||
|
||||
extensions.list_extensions()
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
if cmd_opts.ui_debug_mode:
|
||||
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
||||
modules.scripts.load_scripts()
|
||||
return
|
||||
|
||||
modelloader.cleanup_models()
|
||||
modules.sd_models.setup_model()
|
||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||
|
||||
modelloader.list_builtin_upscalers()
|
||||
modules.scripts.load_scripts()
|
||||
modelloader.load_upscalers()
|
||||
|
||||
modules.sd_vae.refresh_vae_list()
|
||||
|
||||
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||
|
||||
try:
|
||||
modules.sd_models.load_model()
|
||||
except Exception as e:
|
||||
errors.display(e, "loading stable diffusion model")
|
||||
print("", file=sys.stderr)
|
||||
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
|
||||
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
ui_extra_networks.intialize()
|
||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
||||
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||
|
||||
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
|
||||
|
||||
try:
|
||||
if not os.path.exists(cmd_opts.tls_keyfile):
|
||||
print("Invalid path to TLS keyfile given")
|
||||
if not os.path.exists(cmd_opts.tls_certfile):
|
||||
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
|
||||
except TypeError:
|
||||
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
|
||||
print("TLS setup invalid, running webui without TLS")
|
||||
else:
|
||||
print("Running with TLS")
|
||||
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
def sigint_handler(sig, frame):
|
||||
print(f'Interrupted with signal {sig} in {frame}')
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
|
||||
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||
elif cmd_opts.cors_allow_origins:
|
||||
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||
elif cmd_opts.cors_allow_origins_regex:
|
||||
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||
|
||||
|
||||
def create_api(app):
|
||||
from modules.api.api import Api
|
||||
api = Api(app, queue_lock)
|
||||
return api
|
||||
|
||||
|
||||
def wait_on_server(demo=None):
|
||||
while 1:
|
||||
time.sleep(0.5)
|
||||
if shared.state.need_restart:
|
||||
shared.state.need_restart = False
|
||||
time.sleep(0.5)
|
||||
demo.close()
|
||||
time.sleep(0.5)
|
||||
break
|
||||
|
||||
|
||||
def api_only():
|
||||
initialize()
|
||||
|
||||
app = FastAPI()
|
||||
setup_cors(app)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
api = create_api(app)
|
||||
|
||||
modules.script_callbacks.app_started_callback(None, app)
|
||||
modules.script_callbacks.before_ui_callback()
|
||||
extensions.list_extensions()
|
||||
modules.scripts.reload_scripts()
|
||||
modules.script_callbacks.model_loaded_callback(shared.sd_model)
|
||||
modelloader.load_upscalers()
|
||||
modules.sd_models.list_models()
|
||||
|
||||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||
|
||||
|
||||
def webui():
|
||||
launch_api = cmd_opts.api
|
||||
initialize()
|
||||
|
||||
while 1:
|
||||
if shared.opts.clean_temp_dir_at_start:
|
||||
ui_tempdir.cleanup_tmpdr()
|
||||
|
||||
modules.script_callbacks.before_ui_callback()
|
||||
|
||||
shared.demo = modules.ui.create_ui()
|
||||
|
||||
if cmd_opts.gradio_queue:
|
||||
shared.demo.queue(64)
|
||||
|
||||
app, local_url, share_url = shared.demo.launch(
|
||||
share=cmd_opts.share,
|
||||
server_name=server_name,
|
||||
server_port=cmd_opts.port,
|
||||
ssl_keyfile=cmd_opts.tls_keyfile,
|
||||
ssl_certfile=cmd_opts.tls_certfile,
|
||||
debug=cmd_opts.gradio_debug,
|
||||
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
|
||||
inbrowser=cmd_opts.autolaunch,
|
||||
prevent_thread_lock=True
|
||||
)
|
||||
# after initial launch, disable --autolaunch for subsequent restarts
|
||||
cmd_opts.autolaunch = False
|
||||
|
||||
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
||||
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
||||
# running web ui and do whatever the attacker wants, including installing an extension and
|
||||
# running its code. We disable this here. Suggested by RyotaK.
|
||||
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
|
||||
|
||||
setup_cors(app)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
modules.progress.setup_progress_api(app)
|
||||
|
||||
if launch_api:
|
||||
create_api(app)
|
||||
|
||||
ui_extra_networks.add_pages_to_demo(app)
|
||||
|
||||
modules.script_callbacks.app_started_callback(shared.demo, app)
|
||||
|
||||
wait_on_server(shared.demo)
|
||||
print('Restarting UI...')
|
||||
|
||||
sd_samplers.set_samplers()
|
||||
|
||||
modules.script_callbacks.script_unloaded_callback()
|
||||
extensions.list_extensions()
|
||||
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
modelloader.forbid_loaded_nonbuiltin_upscalers()
|
||||
modules.scripts.reload_scripts()
|
||||
modules.script_callbacks.model_loaded_callback(shared.sd_model)
|
||||
modelloader.load_upscalers()
|
||||
|
||||
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||
importlib.reload(module)
|
||||
|
||||
modules.sd_models.list_models()
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
ui_extra_networks.intialize()
|
||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
|
||||
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if cmd_opts.nowebui:
|
||||
api_only()
|
||||
else:
|
||||
webui()
|
Loading…
x
Reference in New Issue
Block a user