Modify to use LoRA.
This commit is contained in:
parent
1b345879f7
commit
a5adb6b9fb
14
.env.example
14
.env.example
@ -1,4 +1,18 @@
|
|||||||
HUGGING_FACE_TOKEN=""
|
HUGGING_FACE_TOKEN=""
|
||||||
MODEL_REPO_ID="stabilityai/stable-diffusion-2-1"
|
MODEL_REPO_ID="stabilityai/stable-diffusion-2-1"
|
||||||
MODEL_NAME="stable-diffusion-2-1"
|
MODEL_NAME="stable-diffusion-2-1"
|
||||||
|
|
||||||
|
# Modify `USE_VAE` to `true` if you want to use VAE.
|
||||||
USE_VAE="false"
|
USE_VAE="false"
|
||||||
|
|
||||||
|
# Add LoRA if you want to use one. You can use a download link of civitai.
|
||||||
|
# ex)
|
||||||
|
# - `LORA_NAMES="hogehoge.safetensors"`
|
||||||
|
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
|
||||||
|
#
|
||||||
|
# If you have multiple LoRAs you want to use, separate by commas like the below:
|
||||||
|
# ex)
|
||||||
|
# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
|
||||||
|
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
|
||||||
|
LORA_NAMES=""
|
||||||
|
LORA_DOWNLOAD_URLS=""
|
||||||
|
|||||||
42
sd_cli.py
42
sd_cli.py
@ -3,12 +3,31 @@ from __future__ import annotations
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
from modal import Image, Mount, Secret, Stub, method
|
from modal import Image, Mount, Secret, Stub, method
|
||||||
|
|
||||||
import util
|
import util
|
||||||
|
|
||||||
BASE_CACHE_PATH = "/vol/cache"
|
BASE_CACHE_PATH = "/vol/cache"
|
||||||
|
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
|
||||||
|
|
||||||
|
|
||||||
|
def download_loras():
|
||||||
|
"""
|
||||||
|
Download LoRA.
|
||||||
|
"""
|
||||||
|
lora_names = os.getenv("LORA_NAMES").split(",")
|
||||||
|
lora_download_urls = os.getenv("LORA_DOWNLOAD_URLS").split(",")
|
||||||
|
|
||||||
|
for name, url in zip(lora_names, lora_download_urls):
|
||||||
|
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||||
|
downloaded = urlopen(req).read()
|
||||||
|
|
||||||
|
dir_names = os.path.join(BASE_CACHE_PATH_LORA, name)
|
||||||
|
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
||||||
|
with open(dir_names, mode="wb") as f:
|
||||||
|
f.write(downloaded)
|
||||||
|
|
||||||
|
|
||||||
def download_models():
|
def download_models():
|
||||||
@ -45,11 +64,21 @@ def download_models():
|
|||||||
pipe.save_pretrained(cache_path, safe_serialization=True)
|
pipe.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
|
def build_image():
|
||||||
|
"""
|
||||||
|
Build the Docker image.
|
||||||
|
"""
|
||||||
|
download_models()
|
||||||
|
|
||||||
|
if os.environ["LORA_NAMES"] != "":
|
||||||
|
download_loras()
|
||||||
|
|
||||||
|
|
||||||
stub_image = Image.from_dockerfile(
|
stub_image = Image.from_dockerfile(
|
||||||
path="./Dockerfile",
|
path="./Dockerfile",
|
||||||
context_mount=Mount.from_local_file("./requirements.txt"),
|
context_mount=Mount.from_local_file("./requirements.txt"),
|
||||||
).run_function(
|
).run_function(
|
||||||
download_models,
|
build_image,
|
||||||
secrets=[Secret.from_dotenv(__file__)],
|
secrets=[Secret.from_dotenv(__file__)],
|
||||||
)
|
)
|
||||||
stub = Stub("stable-diffusion-cli")
|
stub = Stub("stable-diffusion-cli")
|
||||||
@ -100,6 +129,17 @@ class StableDiffusion:
|
|||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
|
|
||||||
|
if os.environ["LORA_NAMES"] != "":
|
||||||
|
lora_names = os.getenv("LORA_NAMES").split(",")
|
||||||
|
for lora_name in lora_names:
|
||||||
|
path_to_lora = os.path.join(BASE_CACHE_PATH_LORA, lora_name)
|
||||||
|
if os.path.exists(path_to_lora):
|
||||||
|
print(f"The directory '{path_to_lora}' exists.")
|
||||||
|
else:
|
||||||
|
print(f"The directory '{path_to_lora}' does not exist. Download loras...")
|
||||||
|
download_loras()
|
||||||
|
self.pipe.load_lora_weights(".", weight_name=path_to_lora)
|
||||||
|
|
||||||
self.pipe.enable_xformers_memory_efficient_attention()
|
self.pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
@method()
|
@method()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user