init
This commit is contained in:
		
						commit
						f3cd48e67f
					
				
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,2 @@
 | 
			
		||||
*.safetensors
 | 
			
		||||
*.log
 | 
			
		||||
							
								
								
									
										22
									
								
								Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								Dockerfile
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
			
		||||
FROM chilloutai/stable-diffusion:1.0.0
 | 
			
		||||
 | 
			
		||||
SHELL ["/bin/bash", "-c"]
 | 
			
		||||
 | 
			
		||||
ENV PATH="${PATH}:/workspace/stable-diffusion-webui/venv/bin"
 | 
			
		||||
 | 
			
		||||
WORKDIR /
 | 
			
		||||
 | 
			
		||||
ADD model.safetensors /
 | 
			
		||||
 | 
			
		||||
RUN pip install -U xformers
 | 
			
		||||
RUN pip install runpod
 | 
			
		||||
 | 
			
		||||
ADD koreanDollLikeness_v15.safetensors /workspace/stable-diffusion-webui/models/Lora/koreanDollLikeness_v15.safetensors
 | 
			
		||||
ADD japaneseDollLikeness_v10.safetensors /workspace/stable-diffusion-webui/models/Lora/japaneseDollLikeness_v10.safetensors
 | 
			
		||||
ADD taiwanDollLikeness_v10.safetensors /workspace/stable-diffusion-webui/models/Lora/taiwanDollLikeness_v10.safetensors
 | 
			
		||||
 | 
			
		||||
ADD handler.py .
 | 
			
		||||
ADD start.sh /start.sh
 | 
			
		||||
RUN chmod +x /start.sh
 | 
			
		||||
 | 
			
		||||
CMD [ "/start.sh" ]
 | 
			
		||||
							
								
								
									
										1
									
								
								build.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										1
									
								
								build.sh
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
DOCKER_BUILDKIT=1 docker build -t chilloutai/auto-api:1.0.0 .
 | 
			
		||||
							
								
								
									
										3
									
								
								download.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										3
									
								
								download.sh
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,3 @@
 | 
			
		||||
wget -O koreanDollLikeness_v15.safetensors https://civitai.com/api/download/models/14014
 | 
			
		||||
wget -O japaneseDollLikeness_v10.safetensors https://civitai.com/api/download/models/12050
 | 
			
		||||
wget -O taiwanDollLikeness_v10.safetensors https://civitai.com/api/download/models/9070
 | 
			
		||||
							
								
								
									
										39
									
								
								handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								handler.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,39 @@
 | 
			
		||||
import runpod
 | 
			
		||||
import subprocess
 | 
			
		||||
import requests
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
def check_api_availability(host):
 | 
			
		||||
    while True:
 | 
			
		||||
        try:
 | 
			
		||||
            response = requests.get(host)
 | 
			
		||||
            return
 | 
			
		||||
        except requests.exceptions.RequestException as e:
 | 
			
		||||
            print(f"API is not available, retrying in 5s... ({e})")
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print('something went wrong')
 | 
			
		||||
        time.sleep(5)
 | 
			
		||||
 | 
			
		||||
check_api_availability("http://127.0.0.1:3000/sdapi/v1/txt2img")
 | 
			
		||||
 | 
			
		||||
print('run handler')
 | 
			
		||||
 | 
			
		||||
def handler(event):
 | 
			
		||||
    '''
 | 
			
		||||
    This is the handler function that will be called by the serverless.
 | 
			
		||||
    '''
 | 
			
		||||
    print('got event')
 | 
			
		||||
    print(event)
 | 
			
		||||
 | 
			
		||||
    response = requests.post(url=f'http://127.0.0.1:3000/sdapi/v1/txt2img', json=event["input"])
 | 
			
		||||
 | 
			
		||||
    json = response.json()
 | 
			
		||||
    # do the things
 | 
			
		||||
 | 
			
		||||
    print(json)
 | 
			
		||||
 | 
			
		||||
    # return the output that you want to be returned like pre-signed URLs to output artifacts
 | 
			
		||||
    return json
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
runpod.serverless.start({"handler": handler})
 | 
			
		||||
							
								
								
									
										47
									
								
								runpod_api_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								runpod_api_test.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,47 @@
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import base64
 | 
			
		||||
 | 
			
		||||
runpod_key = ''
 | 
			
		||||
api_name = ''
 | 
			
		||||
 | 
			
		||||
prompt = '<lora:koreanDollLikeness_v15:0.66>, best quality, ultra high res, (photorealistic:1.4), 1girl, solo focus, ((blue long dress)), elbow dress, black thighhighs, frills, ribbons, studio background, (Kpop idol), (aegyo sal:1), (platinum blonde hair:1), ((puffy eyes)), looking at viewer, facing front, smiling, laughing'
 | 
			
		||||
negative_prompt = 'paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generation():
 | 
			
		||||
    res = requests.post(f'https://api.runpod.ai/v1/{api_name}/run', headers={
 | 
			
		||||
        'Content-Type': 'application/json',
 | 
			
		||||
        "Authorization": f"Bearer {runpod_key}"
 | 
			
		||||
    }, json={
 | 
			
		||||
        "input": {"prompt": prompt, "steps": 28, "negative_prompt": negative_prompt, "width": 512, "height": 768, "sampler_index": "DPM++ SDE Karras",
 | 
			
		||||
                  "batch_size": 1, "seed": -1},
 | 
			
		||||
    })
 | 
			
		||||
 | 
			
		||||
    task_id = res.json()['id']
 | 
			
		||||
 | 
			
		||||
    while True:
 | 
			
		||||
        res = requests.get(f'https://api.runpod.ai/v1/{api_name}/status/{task_id}', headers={
 | 
			
		||||
            'Content-Type': 'application/json',
 | 
			
		||||
            "Authorization": f"Bearer {runpod_key}"
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        status = res.json()['status']
 | 
			
		||||
        print(status)
 | 
			
		||||
        if status == 'COMPLETED':
 | 
			
		||||
            for imgstring in res.json()['output']['images']:
 | 
			
		||||
                with open(f"test_{time.time()}.png", "wb") as fh:
 | 
			
		||||
                    imgdata = base64.b64decode(imgstring)
 | 
			
		||||
                    fh.write(imgdata)
 | 
			
		||||
            print(res.json())
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
        if status == 'FAILED':
 | 
			
		||||
            print(res.json())
 | 
			
		||||
            break
 | 
			
		||||
        time.sleep(10)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    generation()
 | 
			
		||||
							
								
								
									
										51
									
								
								sd-auto-v1/Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								sd-auto-v1/Dockerfile
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,51 @@
 | 
			
		||||
FROM runpod/stable-diffusion-models:1.5 as build
 | 
			
		||||
 | 
			
		||||
FROM ubuntu:22.04 AS runtime
 | 
			
		||||
 | 
			
		||||
RUN rm -rf /root/.cache/pip
 | 
			
		||||
 | 
			
		||||
RUN mkdir -p /root/.cache/huggingface
 | 
			
		||||
 | 
			
		||||
ENV DEBIAN_FRONTEND noninteractive
 | 
			
		||||
 | 
			
		||||
RUN apt update && \ 
 | 
			
		||||
apt install -y  --no-install-recommends \
 | 
			
		||||
software-properties-common \
 | 
			
		||||
git \
 | 
			
		||||
openssh-server \
 | 
			
		||||
libglib2.0-0 \
 | 
			
		||||
libsm6 \
 | 
			
		||||
libxrender1 \
 | 
			
		||||
libxext6 \
 | 
			
		||||
ffmpeg \
 | 
			
		||||
wget \
 | 
			
		||||
curl \
 | 
			
		||||
python3-pip python3 python3.10-venv \
 | 
			
		||||
apt-transport-https ca-certificates && \
 | 
			
		||||
update-ca-certificates
 | 
			
		||||
 | 
			
		||||
WORKDIR /workspace
 | 
			
		||||
 | 
			
		||||
RUN git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
 | 
			
		||||
 | 
			
		||||
WORKDIR /workspace/stable-diffusion-webui
 | 
			
		||||
ADD webui.py /workspace/stable-diffusion-webui/webui.py
 | 
			
		||||
RUN python3 -m venv /workspace/stable-diffusion-webui/venv
 | 
			
		||||
ENV PATH="/workspace/stable-diffusion-webui/venv/bin:$PATH"
 | 
			
		||||
 | 
			
		||||
RUN pip install -U jupyterlab ipywidgets jupyter-archive
 | 
			
		||||
RUN jupyter nbextension enable --py widgetsnbextension
 | 
			
		||||
 | 
			
		||||
ADD install.py .
 | 
			
		||||
RUN python -m install --skip-torch-cuda-test
 | 
			
		||||
 | 
			
		||||
RUN apt clean && rm -rf /var/lib/apt/lists/* && \
 | 
			
		||||
    echo "en_US.UTF-8 UTF-8" > /etc/locale.gen
 | 
			
		||||
 | 
			
		||||
ADD relauncher.py .
 | 
			
		||||
ADD webui-user.sh .
 | 
			
		||||
ADD start.sh /start.sh
 | 
			
		||||
RUN chmod a+x /start.sh
 | 
			
		||||
 | 
			
		||||
SHELL ["/bin/bash", "--login", "-c"]
 | 
			
		||||
CMD [ "/start.sh" ]
 | 
			
		||||
							
								
								
									
										1
									
								
								sd-auto-v1/build.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										1
									
								
								sd-auto-v1/build.sh
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
DOCKER_BUILDKIT=1 docker build -t chilloutai/stable-diffusion:1.0.0 .
 | 
			
		||||
							
								
								
									
										3
									
								
								sd-auto-v1/install.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								sd-auto-v1/install.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
			
		||||
from launch import prepare_environment
 | 
			
		||||
 | 
			
		||||
prepare_environment()
 | 
			
		||||
							
								
								
									
										12
									
								
								sd-auto-v1/relauncher.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								sd-auto-v1/relauncher.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
			
		||||
import os, time
 | 
			
		||||
 | 
			
		||||
n = 0
 | 
			
		||||
while True:
 | 
			
		||||
    print('Relauncher: Launching...')
 | 
			
		||||
    if n > 0:
 | 
			
		||||
        print(f'\tRelaunch count: {n}')
 | 
			
		||||
    launch_string = "/workspace/stable-diffusion-webui/webui.sh -f"
 | 
			
		||||
    os.system(launch_string)
 | 
			
		||||
    print('Relauncher: Process is ending. Relaunching in 2s...')
 | 
			
		||||
    n += 1
 | 
			
		||||
    time.sleep(2)
 | 
			
		||||
							
								
								
									
										32
									
								
								sd-auto-v1/start.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								sd-auto-v1/start.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,32 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
echo "Container Started"
 | 
			
		||||
export PYTHONUNBUFFERED=1
 | 
			
		||||
source /workspace/stable-diffusion-webui/venv/bin/activate
 | 
			
		||||
cd /workspace/stable-diffusion-webui
 | 
			
		||||
python relauncher.py &
 | 
			
		||||
 | 
			
		||||
if [[ $PUBLIC_KEY ]]
 | 
			
		||||
then
 | 
			
		||||
    mkdir -p ~/.ssh
 | 
			
		||||
    chmod 700 ~/.ssh
 | 
			
		||||
    cd ~/.ssh
 | 
			
		||||
    echo $PUBLIC_KEY >> authorized_keys
 | 
			
		||||
    chmod 700 -R ~/.ssh
 | 
			
		||||
    cd /
 | 
			
		||||
    service ssh start
 | 
			
		||||
    echo "SSH Service Started"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [[ $JUPYTER_PASSWORD ]]
 | 
			
		||||
then
 | 
			
		||||
    ln -sf /examples /workspace
 | 
			
		||||
    ln -sf /root/welcome.ipynb /workspace
 | 
			
		||||
 | 
			
		||||
    cd /
 | 
			
		||||
    jupyter lab --allow-root --no-browser --port=8888 --ip=* \
 | 
			
		||||
        --ServerApp.terminado_settings='{"shell_command":["/bin/bash"]}' \
 | 
			
		||||
        --ServerApp.token=$JUPYTER_PASSWORD --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace
 | 
			
		||||
    echo "Jupyter Lab Started"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
sleep infinity
 | 
			
		||||
							
								
								
									
										47
									
								
								sd-auto-v1/webui-user.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								sd-auto-v1/webui-user.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,47 @@
 | 
			
		||||
# #!/bin/bash
 | 
			
		||||
#########################################################
 | 
			
		||||
# Uncomment and change the variables below to your need:#
 | 
			
		||||
#########################################################
 | 
			
		||||
 | 
			
		||||
# Install directory without trailing slash
 | 
			
		||||
install_dir="/workspace"
 | 
			
		||||
 | 
			
		||||
# Name of the subdirectory
 | 
			
		||||
#clone_dir="stable-diffusion-webui"
 | 
			
		||||
 | 
			
		||||
# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
 | 
			
		||||
export COMMANDLINE_ARGS="--port 3000 --xformers --ckpt /workspace/v1-5-pruned-emaonly.ckpt --listen --enable-insecure-extension-access"
 | 
			
		||||
 | 
			
		||||
# python3 executable
 | 
			
		||||
#python_cmd="python3"
 | 
			
		||||
 | 
			
		||||
# git executable
 | 
			
		||||
#export GIT="git"
 | 
			
		||||
 | 
			
		||||
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
 | 
			
		||||
# venv_dir="/workspace/venv"
 | 
			
		||||
 | 
			
		||||
# script to launch to start the app
 | 
			
		||||
# export LAUNCH_SCRIPT="dummy.py"
 | 
			
		||||
 | 
			
		||||
# install command for torch
 | 
			
		||||
#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
 | 
			
		||||
 | 
			
		||||
# Requirements file to use for stable-diffusion-webui
 | 
			
		||||
#export REQS_FILE="requirements_versions.txt"
 | 
			
		||||
 | 
			
		||||
# Fixed git repos
 | 
			
		||||
#export K_DIFFUSION_PACKAGE=""
 | 
			
		||||
#export GFPGAN_PACKAGE=""
 | 
			
		||||
 | 
			
		||||
# Fixed git commits
 | 
			
		||||
#export STABLE_DIFFUSION_COMMIT_HASH=""
 | 
			
		||||
#export TAMING_TRANSFORMERS_COMMIT_HASH=""
 | 
			
		||||
#export CODEFORMER_COMMIT_HASH=""
 | 
			
		||||
#export BLIP_COMMIT_HASH=""
 | 
			
		||||
 | 
			
		||||
# Uncomment to enable accelerated launch
 | 
			
		||||
#export ACCELERATE="True"
 | 
			
		||||
 | 
			
		||||
###########################################
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										293
									
								
								sd-auto-v1/webui.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										293
									
								
								sd-auto-v1/webui.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,293 @@
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import importlib
 | 
			
		||||
import signal
 | 
			
		||||
import re
 | 
			
		||||
from fastapi import FastAPI
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from fastapi.middleware.gzip import GZipMiddleware
 | 
			
		||||
from packaging import version
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
 | 
			
		||||
 | 
			
		||||
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
 | 
			
		||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
 | 
			
		||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
 | 
			
		||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
 | 
			
		||||
    torch.__long_version__ = torch.__version__
 | 
			
		||||
    torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
 | 
			
		||||
 | 
			
		||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
 | 
			
		||||
import modules.codeformer_model as codeformer
 | 
			
		||||
import modules.face_restoration
 | 
			
		||||
import modules.gfpgan_model as gfpgan
 | 
			
		||||
import modules.img2img
 | 
			
		||||
 | 
			
		||||
import modules.lowvram
 | 
			
		||||
import modules.paths
 | 
			
		||||
import modules.scripts
 | 
			
		||||
import modules.sd_hijack
 | 
			
		||||
import modules.sd_models
 | 
			
		||||
import modules.sd_vae
 | 
			
		||||
import modules.txt2img
 | 
			
		||||
import modules.script_callbacks
 | 
			
		||||
import modules.textual_inversion.textual_inversion
 | 
			
		||||
import modules.progress
 | 
			
		||||
 | 
			
		||||
import modules.ui
 | 
			
		||||
from modules import modelloader
 | 
			
		||||
from modules.shared import cmd_opts
 | 
			
		||||
import modules.hypernetworks.hypernetwork
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if cmd_opts.server_name:
 | 
			
		||||
    server_name = cmd_opts.server_name
 | 
			
		||||
else:
 | 
			
		||||
    server_name = "0.0.0.0" if cmd_opts.listen else None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_versions():
 | 
			
		||||
    if shared.cmd_opts.skip_version_check:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    expected_torch_version = "1.13.1"
 | 
			
		||||
 | 
			
		||||
    if version.parse(torch.__version__) < version.parse(expected_torch_version):
 | 
			
		||||
        errors.print_error_explanation(f"""
 | 
			
		||||
You are running torch {torch.__version__}.
 | 
			
		||||
The program is tested to work with torch {expected_torch_version}.
 | 
			
		||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
 | 
			
		||||
Beware that this will cause a lot of large files to be downloaded, as well as
 | 
			
		||||
there are reports of issues with training tab on the latest version.
 | 
			
		||||
 | 
			
		||||
Use --skip-version-check commandline argument to disable this check.
 | 
			
		||||
        """.strip())
 | 
			
		||||
 | 
			
		||||
    expected_xformers_version = "0.0.16rc425"
 | 
			
		||||
    if shared.xformers_available:
 | 
			
		||||
        import xformers
 | 
			
		||||
 | 
			
		||||
        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
 | 
			
		||||
            errors.print_error_explanation(f"""
 | 
			
		||||
You are running xformers {xformers.__version__}.
 | 
			
		||||
The program is tested to work with xformers {expected_xformers_version}.
 | 
			
		||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
 | 
			
		||||
 | 
			
		||||
Use --skip-version-check commandline argument to disable this check.
 | 
			
		||||
            """.strip())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def initialize():
 | 
			
		||||
    check_versions()
 | 
			
		||||
 | 
			
		||||
    extensions.list_extensions()
 | 
			
		||||
    localization.list_localizations(cmd_opts.localizations_dir)
 | 
			
		||||
 | 
			
		||||
    if cmd_opts.ui_debug_mode:
 | 
			
		||||
        shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
 | 
			
		||||
        modules.scripts.load_scripts()
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    modelloader.cleanup_models()
 | 
			
		||||
    modules.sd_models.setup_model()
 | 
			
		||||
    codeformer.setup_model(cmd_opts.codeformer_models_path)
 | 
			
		||||
    gfpgan.setup_model(cmd_opts.gfpgan_models_path)
 | 
			
		||||
 | 
			
		||||
    modelloader.list_builtin_upscalers()
 | 
			
		||||
    modules.scripts.load_scripts()
 | 
			
		||||
    modelloader.load_upscalers()
 | 
			
		||||
 | 
			
		||||
    modules.sd_vae.refresh_vae_list()
 | 
			
		||||
 | 
			
		||||
    modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        modules.sd_models.load_model()
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        errors.display(e, "loading stable diffusion model")
 | 
			
		||||
        print("", file=sys.stderr)
 | 
			
		||||
        print("Stable diffusion model failed to load, exiting", file=sys.stderr)
 | 
			
		||||
        exit(1)
 | 
			
		||||
 | 
			
		||||
    shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
 | 
			
		||||
 | 
			
		||||
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
 | 
			
		||||
    shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
 | 
			
		||||
    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
 | 
			
		||||
    shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
 | 
			
		||||
 | 
			
		||||
    shared.reload_hypernetworks()
 | 
			
		||||
 | 
			
		||||
    ui_extra_networks.intialize()
 | 
			
		||||
    ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
 | 
			
		||||
    ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
 | 
			
		||||
    ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
 | 
			
		||||
 | 
			
		||||
    extra_networks.initialize()
 | 
			
		||||
    extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
 | 
			
		||||
 | 
			
		||||
    if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if not os.path.exists(cmd_opts.tls_keyfile):
 | 
			
		||||
                print("Invalid path to TLS keyfile given")
 | 
			
		||||
            if not os.path.exists(cmd_opts.tls_certfile):
 | 
			
		||||
                print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
 | 
			
		||||
            print("TLS setup invalid, running webui without TLS")
 | 
			
		||||
        else:
 | 
			
		||||
            print("Running with TLS")
 | 
			
		||||
 | 
			
		||||
    # make the program just exit at ctrl+c without waiting for anything
 | 
			
		||||
    def sigint_handler(sig, frame):
 | 
			
		||||
        print(f'Interrupted with signal {sig} in {frame}')
 | 
			
		||||
        os._exit(0)
 | 
			
		||||
 | 
			
		||||
    signal.signal(signal.SIGINT, sigint_handler)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def setup_cors(app):
 | 
			
		||||
    if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
 | 
			
		||||
        app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
 | 
			
		||||
    elif cmd_opts.cors_allow_origins:
 | 
			
		||||
        app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
 | 
			
		||||
    elif cmd_opts.cors_allow_origins_regex:
 | 
			
		||||
        app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_api(app):
 | 
			
		||||
    from modules.api.api import Api
 | 
			
		||||
    api = Api(app, queue_lock)
 | 
			
		||||
    return api
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wait_on_server(demo=None):
 | 
			
		||||
    while 1:
 | 
			
		||||
        time.sleep(0.5)
 | 
			
		||||
        if shared.state.need_restart:
 | 
			
		||||
            shared.state.need_restart = False
 | 
			
		||||
            time.sleep(0.5)
 | 
			
		||||
            demo.close()
 | 
			
		||||
            time.sleep(0.5)
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def api_only():
 | 
			
		||||
    initialize()
 | 
			
		||||
 | 
			
		||||
    app = FastAPI()
 | 
			
		||||
    setup_cors(app)
 | 
			
		||||
    app.add_middleware(GZipMiddleware, minimum_size=1000)
 | 
			
		||||
    api = create_api(app)
 | 
			
		||||
 | 
			
		||||
    modules.script_callbacks.app_started_callback(None, app)
 | 
			
		||||
    modules.script_callbacks.before_ui_callback()
 | 
			
		||||
    extensions.list_extensions()
 | 
			
		||||
    modules.scripts.reload_scripts()
 | 
			
		||||
    modules.script_callbacks.model_loaded_callback(shared.sd_model)
 | 
			
		||||
    modelloader.load_upscalers()
 | 
			
		||||
    modules.sd_models.list_models()
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def webui():
 | 
			
		||||
    launch_api = cmd_opts.api
 | 
			
		||||
    initialize()
 | 
			
		||||
 | 
			
		||||
    while 1:
 | 
			
		||||
        if shared.opts.clean_temp_dir_at_start:
 | 
			
		||||
            ui_tempdir.cleanup_tmpdr()
 | 
			
		||||
 | 
			
		||||
        modules.script_callbacks.before_ui_callback()
 | 
			
		||||
 | 
			
		||||
        shared.demo = modules.ui.create_ui()
 | 
			
		||||
 | 
			
		||||
        if cmd_opts.gradio_queue:
 | 
			
		||||
            shared.demo.queue(64)
 | 
			
		||||
 | 
			
		||||
        gradio_auth_creds = []
 | 
			
		||||
        if cmd_opts.gradio_auth:
 | 
			
		||||
            gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',')
 | 
			
		||||
        if cmd_opts.gradio_auth_path:
 | 
			
		||||
            with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
 | 
			
		||||
                for line in file.readlines():
 | 
			
		||||
                    gradio_auth_creds += [x.strip() for x in line.split(',')]
 | 
			
		||||
 | 
			
		||||
        app, local_url, share_url = shared.demo.launch(
 | 
			
		||||
            share=cmd_opts.share,
 | 
			
		||||
            server_name=server_name,
 | 
			
		||||
            server_port=cmd_opts.port,
 | 
			
		||||
            ssl_keyfile=cmd_opts.tls_keyfile,
 | 
			
		||||
            ssl_certfile=cmd_opts.tls_certfile,
 | 
			
		||||
            debug=cmd_opts.gradio_debug,
 | 
			
		||||
            auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
 | 
			
		||||
            inbrowser=cmd_opts.autolaunch,
 | 
			
		||||
            prevent_thread_lock=True
 | 
			
		||||
        )
 | 
			
		||||
        # after initial launch, disable --autolaunch for subsequent restarts
 | 
			
		||||
        cmd_opts.autolaunch = False
 | 
			
		||||
 | 
			
		||||
        # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
 | 
			
		||||
        # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
 | 
			
		||||
        # running web ui and do whatever the attacker wants, including installing an extension and
 | 
			
		||||
        # running its code. We disable this here. Suggested by RyotaK.
 | 
			
		||||
        app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
 | 
			
		||||
 | 
			
		||||
        setup_cors(app)
 | 
			
		||||
 | 
			
		||||
        app.add_middleware(GZipMiddleware, minimum_size=1000)
 | 
			
		||||
 | 
			
		||||
        modules.progress.setup_progress_api(app)
 | 
			
		||||
 | 
			
		||||
        if launch_api:
 | 
			
		||||
            create_api(app)
 | 
			
		||||
 | 
			
		||||
        ui_extra_networks.add_pages_to_demo(app)
 | 
			
		||||
 | 
			
		||||
        modules.script_callbacks.app_started_callback(shared.demo, app)
 | 
			
		||||
 | 
			
		||||
        wait_on_server(shared.demo)
 | 
			
		||||
        print('Restarting UI...')
 | 
			
		||||
 | 
			
		||||
        sd_samplers.set_samplers()
 | 
			
		||||
 | 
			
		||||
        modules.script_callbacks.script_unloaded_callback()
 | 
			
		||||
        extensions.list_extensions()
 | 
			
		||||
 | 
			
		||||
        localization.list_localizations(cmd_opts.localizations_dir)
 | 
			
		||||
 | 
			
		||||
        modelloader.forbid_loaded_nonbuiltin_upscalers()
 | 
			
		||||
        modules.scripts.reload_scripts()
 | 
			
		||||
        modules.script_callbacks.model_loaded_callback(shared.sd_model)
 | 
			
		||||
        modelloader.load_upscalers()
 | 
			
		||||
 | 
			
		||||
        for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
 | 
			
		||||
            importlib.reload(module)
 | 
			
		||||
 | 
			
		||||
        modules.sd_models.list_models()
 | 
			
		||||
 | 
			
		||||
        shared.reload_hypernetworks()
 | 
			
		||||
 | 
			
		||||
        ui_extra_networks.intialize()
 | 
			
		||||
        ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
 | 
			
		||||
        ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
 | 
			
		||||
        ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
 | 
			
		||||
 | 
			
		||||
        extra_networks.initialize()
 | 
			
		||||
        extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    if cmd_opts.nowebui:
 | 
			
		||||
        api_only()
 | 
			
		||||
    else:
 | 
			
		||||
        webui()
 | 
			
		||||
							
								
								
									
										11
									
								
								start.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								start.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
echo "Container Started"
 | 
			
		||||
export PYTHONUNBUFFERED=1
 | 
			
		||||
source /workspace/stable-diffusion-webui/venv/bin/activate
 | 
			
		||||
cd /workspace/stable-diffusion-webui
 | 
			
		||||
echo "starting api"
 | 
			
		||||
python webui.py --port 3000 --nowebui --api --xformers --enable-insecure-extension-access --ckpt /model.safetensors &
 | 
			
		||||
cd /
 | 
			
		||||
 | 
			
		||||
echo "starting worker"
 | 
			
		||||
python -u handler.py
 | 
			
		||||
							
								
								
									
										284
									
								
								webui.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										284
									
								
								webui.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,284 @@
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import importlib
 | 
			
		||||
import signal
 | 
			
		||||
import re
 | 
			
		||||
from fastapi import FastAPI
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from fastapi.middleware.gzip import GZipMiddleware
 | 
			
		||||
from packaging import version
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
 | 
			
		||||
 | 
			
		||||
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
 | 
			
		||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
 | 
			
		||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
 | 
			
		||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
 | 
			
		||||
    torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
 | 
			
		||||
 | 
			
		||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
 | 
			
		||||
import modules.codeformer_model as codeformer
 | 
			
		||||
import modules.face_restoration
 | 
			
		||||
import modules.gfpgan_model as gfpgan
 | 
			
		||||
import modules.img2img
 | 
			
		||||
 | 
			
		||||
import modules.lowvram
 | 
			
		||||
import modules.paths
 | 
			
		||||
import modules.scripts
 | 
			
		||||
import modules.sd_hijack
 | 
			
		||||
import modules.sd_models
 | 
			
		||||
import modules.sd_vae
 | 
			
		||||
import modules.txt2img
 | 
			
		||||
import modules.script_callbacks
 | 
			
		||||
import modules.textual_inversion.textual_inversion
 | 
			
		||||
import modules.progress
 | 
			
		||||
 | 
			
		||||
import modules.ui
 | 
			
		||||
from modules import modelloader
 | 
			
		||||
from modules.shared import cmd_opts
 | 
			
		||||
import modules.hypernetworks.hypernetwork
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if cmd_opts.server_name:
 | 
			
		||||
    server_name = cmd_opts.server_name
 | 
			
		||||
else:
 | 
			
		||||
    server_name = "0.0.0.0" if cmd_opts.listen else None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_versions():
 | 
			
		||||
    if shared.cmd_opts.skip_version_check:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    expected_torch_version = "1.13.1"
 | 
			
		||||
 | 
			
		||||
    if version.parse(torch.__version__) < version.parse(expected_torch_version):
 | 
			
		||||
        errors.print_error_explanation(f"""
 | 
			
		||||
You are running torch {torch.__version__}.
 | 
			
		||||
The program is tested to work with torch {expected_torch_version}.
 | 
			
		||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
 | 
			
		||||
Beware that this will cause a lot of large files to be downloaded, as well as
 | 
			
		||||
there are reports of issues with training tab on the latest version.
 | 
			
		||||
 | 
			
		||||
Use --skip-version-check commandline argument to disable this check.
 | 
			
		||||
        """.strip())
 | 
			
		||||
 | 
			
		||||
    expected_xformers_version = "0.0.16rc425"
 | 
			
		||||
    if shared.xformers_available:
 | 
			
		||||
        import xformers
 | 
			
		||||
 | 
			
		||||
        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
 | 
			
		||||
            errors.print_error_explanation(f"""
 | 
			
		||||
You are running xformers {xformers.__version__}.
 | 
			
		||||
The program is tested to work with xformers {expected_xformers_version}.
 | 
			
		||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
 | 
			
		||||
 | 
			
		||||
Use --skip-version-check commandline argument to disable this check.
 | 
			
		||||
            """.strip())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def initialize():
 | 
			
		||||
    check_versions()
 | 
			
		||||
 | 
			
		||||
    extensions.list_extensions()
 | 
			
		||||
    localization.list_localizations(cmd_opts.localizations_dir)
 | 
			
		||||
 | 
			
		||||
    if cmd_opts.ui_debug_mode:
 | 
			
		||||
        shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
 | 
			
		||||
        modules.scripts.load_scripts()
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    modelloader.cleanup_models()
 | 
			
		||||
    modules.sd_models.setup_model()
 | 
			
		||||
    codeformer.setup_model(cmd_opts.codeformer_models_path)
 | 
			
		||||
    gfpgan.setup_model(cmd_opts.gfpgan_models_path)
 | 
			
		||||
    shared.face_restorers.append(modules.face_restoration.FaceRestoration())
 | 
			
		||||
 | 
			
		||||
    modelloader.list_builtin_upscalers()
 | 
			
		||||
    modules.scripts.load_scripts()
 | 
			
		||||
    modelloader.load_upscalers()
 | 
			
		||||
 | 
			
		||||
    modules.sd_vae.refresh_vae_list()
 | 
			
		||||
 | 
			
		||||
    modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        modules.sd_models.load_model()
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        errors.display(e, "loading stable diffusion model")
 | 
			
		||||
        print("", file=sys.stderr)
 | 
			
		||||
        print("Stable diffusion model failed to load, exiting", file=sys.stderr)
 | 
			
		||||
        exit(1)
 | 
			
		||||
 | 
			
		||||
    shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
 | 
			
		||||
 | 
			
		||||
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
 | 
			
		||||
    shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
 | 
			
		||||
    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
 | 
			
		||||
    shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
 | 
			
		||||
 | 
			
		||||
    shared.reload_hypernetworks()
 | 
			
		||||
 | 
			
		||||
    ui_extra_networks.intialize()
 | 
			
		||||
    ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
 | 
			
		||||
    ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
 | 
			
		||||
    ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
 | 
			
		||||
 | 
			
		||||
    extra_networks.initialize()
 | 
			
		||||
    extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
 | 
			
		||||
 | 
			
		||||
    if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if not os.path.exists(cmd_opts.tls_keyfile):
 | 
			
		||||
                print("Invalid path to TLS keyfile given")
 | 
			
		||||
            if not os.path.exists(cmd_opts.tls_certfile):
 | 
			
		||||
                print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
 | 
			
		||||
            print("TLS setup invalid, running webui without TLS")
 | 
			
		||||
        else:
 | 
			
		||||
            print("Running with TLS")
 | 
			
		||||
 | 
			
		||||
    # make the program just exit at ctrl+c without waiting for anything
 | 
			
		||||
    def sigint_handler(sig, frame):
 | 
			
		||||
        print(f'Interrupted with signal {sig} in {frame}')
 | 
			
		||||
        os._exit(0)
 | 
			
		||||
 | 
			
		||||
    signal.signal(signal.SIGINT, sigint_handler)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def setup_cors(app):
 | 
			
		||||
    if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
 | 
			
		||||
        app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
 | 
			
		||||
    elif cmd_opts.cors_allow_origins:
 | 
			
		||||
        app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
 | 
			
		||||
    elif cmd_opts.cors_allow_origins_regex:
 | 
			
		||||
        app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_api(app):
 | 
			
		||||
    from modules.api.api import Api
 | 
			
		||||
    api = Api(app, queue_lock)
 | 
			
		||||
    return api
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wait_on_server(demo=None):
 | 
			
		||||
    while 1:
 | 
			
		||||
        time.sleep(0.5)
 | 
			
		||||
        if shared.state.need_restart:
 | 
			
		||||
            shared.state.need_restart = False
 | 
			
		||||
            time.sleep(0.5)
 | 
			
		||||
            demo.close()
 | 
			
		||||
            time.sleep(0.5)
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def api_only():
 | 
			
		||||
    initialize()
 | 
			
		||||
 | 
			
		||||
    app = FastAPI()
 | 
			
		||||
    setup_cors(app)
 | 
			
		||||
    app.add_middleware(GZipMiddleware, minimum_size=1000)
 | 
			
		||||
    api = create_api(app)
 | 
			
		||||
 | 
			
		||||
    modules.script_callbacks.app_started_callback(None, app)
 | 
			
		||||
    modules.script_callbacks.before_ui_callback()
 | 
			
		||||
    extensions.list_extensions()
 | 
			
		||||
    modules.scripts.reload_scripts()
 | 
			
		||||
    modules.script_callbacks.model_loaded_callback(shared.sd_model)
 | 
			
		||||
    modelloader.load_upscalers()
 | 
			
		||||
    modules.sd_models.list_models()
 | 
			
		||||
 | 
			
		||||
    api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def webui():
 | 
			
		||||
    launch_api = cmd_opts.api
 | 
			
		||||
    initialize()
 | 
			
		||||
 | 
			
		||||
    while 1:
 | 
			
		||||
        if shared.opts.clean_temp_dir_at_start:
 | 
			
		||||
            ui_tempdir.cleanup_tmpdr()
 | 
			
		||||
 | 
			
		||||
        modules.script_callbacks.before_ui_callback()
 | 
			
		||||
 | 
			
		||||
        shared.demo = modules.ui.create_ui()
 | 
			
		||||
 | 
			
		||||
        if cmd_opts.gradio_queue:
 | 
			
		||||
            shared.demo.queue(64)
 | 
			
		||||
 | 
			
		||||
        app, local_url, share_url = shared.demo.launch(
 | 
			
		||||
            share=cmd_opts.share,
 | 
			
		||||
            server_name=server_name,
 | 
			
		||||
            server_port=cmd_opts.port,
 | 
			
		||||
            ssl_keyfile=cmd_opts.tls_keyfile,
 | 
			
		||||
            ssl_certfile=cmd_opts.tls_certfile,
 | 
			
		||||
            debug=cmd_opts.gradio_debug,
 | 
			
		||||
            auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
 | 
			
		||||
            inbrowser=cmd_opts.autolaunch,
 | 
			
		||||
            prevent_thread_lock=True
 | 
			
		||||
        )
 | 
			
		||||
        # after initial launch, disable --autolaunch for subsequent restarts
 | 
			
		||||
        cmd_opts.autolaunch = False
 | 
			
		||||
 | 
			
		||||
        # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
 | 
			
		||||
        # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
 | 
			
		||||
        # running web ui and do whatever the attacker wants, including installing an extension and
 | 
			
		||||
        # running its code. We disable this here. Suggested by RyotaK.
 | 
			
		||||
        app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
 | 
			
		||||
 | 
			
		||||
        setup_cors(app)
 | 
			
		||||
 | 
			
		||||
        app.add_middleware(GZipMiddleware, minimum_size=1000)
 | 
			
		||||
 | 
			
		||||
        modules.progress.setup_progress_api(app)
 | 
			
		||||
 | 
			
		||||
        if launch_api:
 | 
			
		||||
            create_api(app)
 | 
			
		||||
 | 
			
		||||
        ui_extra_networks.add_pages_to_demo(app)
 | 
			
		||||
 | 
			
		||||
        modules.script_callbacks.app_started_callback(shared.demo, app)
 | 
			
		||||
 | 
			
		||||
        wait_on_server(shared.demo)
 | 
			
		||||
        print('Restarting UI...')
 | 
			
		||||
 | 
			
		||||
        sd_samplers.set_samplers()
 | 
			
		||||
 | 
			
		||||
        modules.script_callbacks.script_unloaded_callback()
 | 
			
		||||
        extensions.list_extensions()
 | 
			
		||||
 | 
			
		||||
        localization.list_localizations(cmd_opts.localizations_dir)
 | 
			
		||||
 | 
			
		||||
        modelloader.forbid_loaded_nonbuiltin_upscalers()
 | 
			
		||||
        modules.scripts.reload_scripts()
 | 
			
		||||
        modules.script_callbacks.model_loaded_callback(shared.sd_model)
 | 
			
		||||
        modelloader.load_upscalers()
 | 
			
		||||
 | 
			
		||||
        for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
 | 
			
		||||
            importlib.reload(module)
 | 
			
		||||
 | 
			
		||||
        modules.sd_models.list_models()
 | 
			
		||||
 | 
			
		||||
        shared.reload_hypernetworks()
 | 
			
		||||
 | 
			
		||||
        ui_extra_networks.intialize()
 | 
			
		||||
        ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
 | 
			
		||||
        ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
 | 
			
		||||
        ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
 | 
			
		||||
 | 
			
		||||
        extra_networks.initialize()
 | 
			
		||||
        extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    if cmd_opts.nowebui:
 | 
			
		||||
        api_only()
 | 
			
		||||
    else:
 | 
			
		||||
        webui()
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user