Modify to use LoRA.
This commit is contained in:
parent
1b345879f7
commit
a5adb6b9fb
14
.env.example
14
.env.example
@ -1,4 +1,18 @@
|
||||
HUGGING_FACE_TOKEN=""
|
||||
MODEL_REPO_ID="stabilityai/stable-diffusion-2-1"
|
||||
MODEL_NAME="stable-diffusion-2-1"
|
||||
|
||||
# Modify `USE_VAE` to `true` if you want to use VAE.
|
||||
USE_VAE="false"
|
||||
|
||||
# Add LoRA if you want to use one. You can use a download link of civitai.
|
||||
# ex)
|
||||
# - `LORA_NAMES="hogehoge.safetensors"`
|
||||
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
|
||||
#
|
||||
# If you have multiple LoRAs you want to use, separate by commas like the below:
|
||||
# ex)
|
||||
# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
|
||||
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
|
||||
LORA_NAMES=""
|
||||
LORA_DOWNLOAD_URLS=""
|
||||
|
||||
42
sd_cli.py
42
sd_cli.py
@ -3,12 +3,31 @@ from __future__ import annotations
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from modal import Image, Mount, Secret, Stub, method
|
||||
|
||||
import util
|
||||
|
||||
BASE_CACHE_PATH = "/vol/cache"
|
||||
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
|
||||
|
||||
|
||||
def download_loras():
|
||||
"""
|
||||
Download LoRA.
|
||||
"""
|
||||
lora_names = os.getenv("LORA_NAMES").split(",")
|
||||
lora_download_urls = os.getenv("LORA_DOWNLOAD_URLS").split(",")
|
||||
|
||||
for name, url in zip(lora_names, lora_download_urls):
|
||||
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
downloaded = urlopen(req).read()
|
||||
|
||||
dir_names = os.path.join(BASE_CACHE_PATH_LORA, name)
|
||||
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
||||
with open(dir_names, mode="wb") as f:
|
||||
f.write(downloaded)
|
||||
|
||||
|
||||
def download_models():
|
||||
@ -45,11 +64,21 @@ def download_models():
|
||||
pipe.save_pretrained(cache_path, safe_serialization=True)
|
||||
|
||||
|
||||
def build_image():
|
||||
"""
|
||||
Build the Docker image.
|
||||
"""
|
||||
download_models()
|
||||
|
||||
if os.environ["LORA_NAMES"] != "":
|
||||
download_loras()
|
||||
|
||||
|
||||
stub_image = Image.from_dockerfile(
|
||||
path="./Dockerfile",
|
||||
context_mount=Mount.from_local_file("./requirements.txt"),
|
||||
).run_function(
|
||||
download_models,
|
||||
build_image,
|
||||
secrets=[Secret.from_dotenv(__file__)],
|
||||
)
|
||||
stub = Stub("stable-diffusion-cli")
|
||||
@ -100,6 +129,17 @@ class StableDiffusion:
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
if os.environ["LORA_NAMES"] != "":
|
||||
lora_names = os.getenv("LORA_NAMES").split(",")
|
||||
for lora_name in lora_names:
|
||||
path_to_lora = os.path.join(BASE_CACHE_PATH_LORA, lora_name)
|
||||
if os.path.exists(path_to_lora):
|
||||
print(f"The directory '{path_to_lora}' exists.")
|
||||
else:
|
||||
print(f"The directory '{path_to_lora}' does not exist. Download loras...")
|
||||
download_loras()
|
||||
self.pipe.load_lora_weights(".", weight_name=path_to_lora)
|
||||
|
||||
self.pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@method()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user