This commit is contained in:
zhaobenny 2023-12-28 00:19:52 -08:00
commit f5fd4dc2a4
No known key found for this signature in database
GPG Key ID: D7DAF5C97F51E001
13 changed files with 325 additions and 0 deletions

8
.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
.python-version
.venv/
__pycache__/
*.safetensors
*.pt
test.py
config.yaml

24
README.md Normal file
View File

@ -0,0 +1,24 @@
# A1111 Stable Diffusion | Modal.com Serverless Worker
Deploys AUTOMATIC1111/stable-diffusion-webui as serverless worker on Modal.com
## Usage
Quickstart:
```
git clone .... && cd serverless-img-gen
pip install modal && modal token new
./deploy.sh
```
Query the Modal endpoint like an normal A1111 endpoint.
To add more models, add your urls to "config.yaml" and run:
```
modal run a1111_modal_worker/download.py
```
(Delete unneeded models using the modal cli or the web dashboard)
## Caveats
- over 50s response time (including for simple requests)
- no extension support yet
- need to add auth to api endpoints
- very WIP

View File

View File

@ -0,0 +1,79 @@
import os
import httpx
from modal import Image, Stub, Volume
from a1111_modal_worker.utils import UserModels, get_urls
MODELS = "/models"
stub = Stub("a1111")
user_models = Volume.persisted("a1111-user-models")
@stub.function(volumes={MODELS: user_models},
image=Image.debian_slim(python_version="3.10")
.pip_install(["httpx"])
)
def download_all(models: UserModels):
download_type(models.embeddings_urls, "embeddings")
download_type(models.loras_urls, "loras")
download_type(models.checkpoints_urls, "checkpoints")
download_type(models.vae_urls, "vaes")
def download_type(urls, model_type):
directory = os.path.join(MODELS, model_type)
os.makedirs(directory, exist_ok=True)
if not urls:
return
print(f"Downloading {model_type} models...")
for url in urls:
try:
download_to_folder(url, model_type)
except Exception as e:
print(f"Failed to download from \"{url}\": {str(e)}")
user_models.commit()
print(f"Downloaded all {model_type} models...")
def download_to_folder(url, folder):
with httpx.Client() as client:
with client.stream("GET", url, follow_redirects=True, timeout=5) as r:
headers = r.headers
filename = extract_filename(url, headers)
filepath = os.path.join(MODELS, folder, filename)
if os.path.exists(filepath):
return print(f"\"{filename}\" already exists in \"{folder}\", skipping...")
r = client.get(url, follow_redirects=True)
r.raise_for_status()
with open(filepath, "wb") as f:
f.write(r.content)
print(f"Downloaded \"{url}\" to \"{folder}\" as \"{filename}\"")
def extract_filename(url, headers):
content_disposition = headers.get("Content-Disposition")
if content_disposition:
filename = content_disposition.split("filename=")[1]
elif url.endswith(".safetensors") or url.endswith(".pt"):
filename = url.split("/")[-1]
else:
raise Exception(f"\"{url}\" does not contain a valid file")
return filename.strip(";").strip("\"")
@stub.local_entrypoint()
def download_models():
urls: UserModels = get_urls()
download_all.remote(urls)

View File

@ -0,0 +1,23 @@
from fastapi import FastAPI
from modal import asgi_app
from a1111_modal_worker.setup import stub
from a1111_modal_worker.worker import A1111
web_app = FastAPI()
@web_app.get("{path:path}")
def forward_get(path: str):
return A1111.api_get.remote(path)
@web_app.post("{path:path}")
def forward_post(path: str, body: dict):
return A1111.api_post.remote(path, body)
@stub.function()
@asgi_app()
def webui():
return web_app

View File

@ -0,0 +1,73 @@
# adapted from https://modal.com/docs/examples/a1111_webui#stable-diffusion-a1111
import subprocess
import time
import webbrowser
from modal import Image, Stub, Volume, forward
from a1111_modal_worker.utils import (ALWAYS_GET_LATEST_A1111, MODAL_GPU,
START_CMD, wait_for_port)
stub = Stub("a1111")
user_models = Volume.persisted("a1111-user-models")
def initialize_webui():
subprocess.Popen("bash /webui.sh -f --no-download-sd-model", shell=True)
wait_for_port(7860)
image = (
Image.debian_slim(python_version="3.10").apt_install(
"wget",
"git",
"python3",
"python3-pip",
"python3-venv",
"libgl1",
"libglib2.0-0",
"google-perftools",
).env(
{"LD_PRELOAD": "/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4"}
).run_commands(
"pip3 install httpx",
"pip3 install pyyaml"
).run_commands(
"pip3 install xformers",
gpu=MODAL_GPU
).run_commands(
"wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh",
"chmod +x webui.sh",
force_build=ALWAYS_GET_LATEST_A1111
).run_function(
initialize_webui,
gpu=MODAL_GPU
)
.copy_local_dir(
"./overwrite/", "/stable-diffusion-webui"
)
)
@stub.function(gpu=MODAL_GPU, image=image, volumes={"/models/": user_models})
def web_instance():
with forward(7860) as tunnel:
p = subprocess.Popen(f"{START_CMD} --listen", shell=True)
wait_for_port(7860)
webbrowser.open(tunnel.url)
time.sleep(10) # pause to allow models to load
print("######################")
print("######################")
print("URL")
print("Accepting connections at", tunnel.url)
print("WARNING: None of your settings will be saved on this instance")
print("Press Ctrl+C to quit or be timed out in 1 hour")
print("######################")
print("######################")
p.wait(3600)
@stub.local_entrypoint()
def start_web_instance():
web_instance.remote()

View File

@ -0,0 +1,50 @@
import inspect
import shutil
import socket
import time
from dataclasses import dataclass, field
from typing import List
MODAL_GPU = "A10G"
START_CMD = "bash /webui.sh -f --lora-dir '/models/loras' --embeddings-dir '/models/embeddings' --ckpt-dir '/models/checkpoints/' --vae-dir '/models/vaes/' --xformers"
ALWAYS_GET_LATEST_A1111 = True
def wait_for_port(port: int):
while True:
try:
with socket.create_connection(("127.0.0.1", port), timeout=5.0):
break
except OSError:
time.sleep(0.1)
@dataclass(frozen=True)
class UserModels:
checkpoints_urls: List[str] = field(default_factory=list)
vae_urls: List[str] = field(default_factory=list)
loras_urls: List[str] = field(default_factory=list)
embeddings_urls: List[str] = field(default_factory=list)
@classmethod
def from_dict(cls, env):
return cls(**{
k: v for k, v in env.items()
if k in inspect.signature(cls).parameters
})
def get_urls():
""" should only be ran locally """
import yaml
try:
with open("./config.yaml") as f:
try:
config = yaml.safe_load(f)
return UserModels.from_dict(config)
except yaml.YAMLError as exc:
print(exc)
except FileNotFoundError:
shutil.copyfile("./config.example.yaml", "./config.yaml")
return UserModels()

View File

@ -0,0 +1,31 @@
import subprocess
import time
from modal import method
from a1111_modal_worker.setup import image, stub, user_models
from a1111_modal_worker.utils import MODAL_GPU, START_CMD, wait_for_port
@stub.cls(gpu=MODAL_GPU, image=image, volumes={"/models": user_models})
class A1111:
BASE_URL = "http://127.0.0.1:7860"
def __enter__(self):
subprocess.Popen(f"{START_CMD} --api", shell=True)
wait_for_port(7860)
time.sleep(15) # wait for model/embeddings to load
@method()
def api_get(self, path: str):
import httpx
with httpx.Client() as client:
r = client.get(self.BASE_URL + path)
return r.json()
@method()
def api_post(self, path: str, data: dict):
import httpx
with httpx.Client() as client:
r = client.post(self.BASE_URL + path, json=data)
return r.json()

15
config.example.yaml Normal file
View File

@ -0,0 +1,15 @@
# extra models and concepts to download
# note that the urls should be direct download links
CHECKPOINTS_URLS:
- 'https://civitai.com/api/download/models/119057'
VAE_URLS: []
LORAS_URLS:
- 'https://civitai.com/api/download/models/231021'
- 'https://huggingface.co/hollowstrawberry/holotard/resolve/main/loras/kiryu_coco_5_outfits.safetensors?download=true'
EMBEDDINGS_URLS:
- 'https://civitai.com/api/download/models/94057'

5
deploy.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash
source .venv/bin/activate || python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt &&
modal run a1111_modal_worker/download.py &&
modal deploy a1111_modal_worker/server.py

9
license.md Normal file
View File

@ -0,0 +1,9 @@
MIT License
Copyright (c) 2024 Benny Zhao
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

5
overwrite/README.md Normal file
View File

@ -0,0 +1,5 @@
Anything in the `overwrite` directory will be copied/overwriten and baked into the Modal image.
Make sure the directory structure of the file is same.
eg. Adding the file `test.pt` to `\models\hypernetworks` in the remote A1111 instance will need the file path `\overwrite\models\hypernetworks\test.pt`

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
modal
PyYAML
httpx