Modify to use LoRA.

This commit is contained in:
hodanov 2023-06-17 17:41:05 +09:00
parent 1b345879f7
commit a5adb6b9fb
2 changed files with 55 additions and 1 deletions

View File

@ -1,4 +1,18 @@
HUGGING_FACE_TOKEN="" HUGGING_FACE_TOKEN=""
MODEL_REPO_ID="stabilityai/stable-diffusion-2-1" MODEL_REPO_ID="stabilityai/stable-diffusion-2-1"
MODEL_NAME="stable-diffusion-2-1" MODEL_NAME="stable-diffusion-2-1"
# Modify `USE_VAE` to `true` if you want to use VAE.
USE_VAE="false" USE_VAE="false"
# Add LoRA if you want to use one. You can use a download link of civitai.
# ex)
# - `LORA_NAMES="hogehoge.safetensors"`
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
#
# If you have multiple LoRAs you want to use, separate by commas like the below:
# ex)
# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
LORA_NAMES=""
LORA_DOWNLOAD_URLS=""

View File

@ -3,12 +3,31 @@ from __future__ import annotations
import io import io
import os import os
import time import time
from urllib.request import Request, urlopen
from modal import Image, Mount, Secret, Stub, method from modal import Image, Mount, Secret, Stub, method
import util import util
BASE_CACHE_PATH = "/vol/cache" BASE_CACHE_PATH = "/vol/cache"
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
def download_loras():
"""
Download LoRA.
"""
lora_names = os.getenv("LORA_NAMES").split(",")
lora_download_urls = os.getenv("LORA_DOWNLOAD_URLS").split(",")
for name, url in zip(lora_names, lora_download_urls):
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
downloaded = urlopen(req).read()
dir_names = os.path.join(BASE_CACHE_PATH_LORA, name)
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
with open(dir_names, mode="wb") as f:
f.write(downloaded)
def download_models(): def download_models():
@ -45,11 +64,21 @@ def download_models():
pipe.save_pretrained(cache_path, safe_serialization=True) pipe.save_pretrained(cache_path, safe_serialization=True)
def build_image():
"""
Build the Docker image.
"""
download_models()
if os.environ["LORA_NAMES"] != "":
download_loras()
stub_image = Image.from_dockerfile( stub_image = Image.from_dockerfile(
path="./Dockerfile", path="./Dockerfile",
context_mount=Mount.from_local_file("./requirements.txt"), context_mount=Mount.from_local_file("./requirements.txt"),
).run_function( ).run_function(
download_models, build_image,
secrets=[Secret.from_dotenv(__file__)], secrets=[Secret.from_dotenv(__file__)],
) )
stub = Stub("stable-diffusion-cli") stub = Stub("stable-diffusion-cli")
@ -100,6 +129,17 @@ class StableDiffusion:
torch_dtype=torch.float16, torch_dtype=torch.float16,
).to("cuda") ).to("cuda")
if os.environ["LORA_NAMES"] != "":
lora_names = os.getenv("LORA_NAMES").split(",")
for lora_name in lora_names:
path_to_lora = os.path.join(BASE_CACHE_PATH_LORA, lora_name)
if os.path.exists(path_to_lora):
print(f"The directory '{path_to_lora}' exists.")
else:
print(f"The directory '{path_to_lora}' does not exist. Download loras...")
download_loras()
self.pipe.load_lora_weights(".", weight_name=path_to_lora)
self.pipe.enable_xformers_memory_efficient_attention() self.pipe.enable_xformers_memory_efficient_attention()
@method() @method()