Modify to use Textual Inversion.
This commit is contained in:
parent
d61e212d49
commit
e4d580f28f
@ -16,3 +16,7 @@ USE_VAE="false"
|
|||||||
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
|
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
|
||||||
LORA_NAMES=""
|
LORA_NAMES=""
|
||||||
LORA_DOWNLOAD_URLS=""
|
LORA_DOWNLOAD_URLS=""
|
||||||
|
|
||||||
|
# Add Textual Inversion you wan to use. Usage is the same as `LORA_NAMES` and `LORA_DOWNLOAD_URLS`.
|
||||||
|
TEXTUAL_INVERSION_NAMES=""
|
||||||
|
TEXTUAL_INVERSION_DOWNLOAD_URLS=""
|
||||||
|
|||||||
55
sd_cli.py
55
sd_cli.py
@ -11,20 +11,21 @@ import util
|
|||||||
|
|
||||||
BASE_CACHE_PATH = "/vol/cache"
|
BASE_CACHE_PATH = "/vol/cache"
|
||||||
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
|
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
|
||||||
|
BASE_CACHE_PATH_TEXTUAL_INVERSION = "/vol/cache/textual_inversion"
|
||||||
|
|
||||||
|
|
||||||
def download_loras():
|
def download_files(urls, file_names, file_path):
|
||||||
"""
|
"""
|
||||||
Download LoRA.
|
Download files.
|
||||||
"""
|
"""
|
||||||
lora_names = os.getenv("LORA_NAMES").split(",")
|
file_names = file_names.split(",")
|
||||||
lora_download_urls = os.getenv("LORA_DOWNLOAD_URLS").split(",")
|
urls = urls.split(",")
|
||||||
|
|
||||||
for name, url in zip(lora_names, lora_download_urls):
|
for file_name, url in zip(file_names, urls):
|
||||||
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||||
downloaded = urlopen(req).read()
|
downloaded = urlopen(req).read()
|
||||||
|
|
||||||
dir_names = os.path.join(BASE_CACHE_PATH_LORA, name)
|
dir_names = os.path.join(file_path, file_name)
|
||||||
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
||||||
with open(dir_names, mode="wb") as f:
|
with open(dir_names, mode="wb") as f:
|
||||||
f.write(downloaded)
|
f.write(downloaded)
|
||||||
@ -71,7 +72,18 @@ def build_image():
|
|||||||
download_models()
|
download_models()
|
||||||
|
|
||||||
if os.environ["LORA_NAMES"] != "":
|
if os.environ["LORA_NAMES"] != "":
|
||||||
download_loras()
|
download_files(
|
||||||
|
os.getenv("LORA_DOWNLOAD_URLS"),
|
||||||
|
os.getenv("LORA_NAMES"),
|
||||||
|
BASE_CACHE_PATH_LORA,
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.environ["TEXTUAL_INVERSION_NAMES"] != "":
|
||||||
|
download_files(
|
||||||
|
os.getenv("TEXTUAL_INVERSION_DOWNLOAD_URLS"),
|
||||||
|
os.getenv("TEXTUAL_INVERSION_NAMES"),
|
||||||
|
BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
stub_image = Image.from_dockerfile(
|
stub_image = Image.from_dockerfile(
|
||||||
@ -124,15 +136,28 @@ class StableDiffusion:
|
|||||||
self.pipe.to("cuda")
|
self.pipe.to("cuda")
|
||||||
|
|
||||||
if os.environ["LORA_NAMES"] != "":
|
if os.environ["LORA_NAMES"] != "":
|
||||||
lora_names = os.getenv("LORA_NAMES").split(",")
|
names = os.getenv("LORA_NAMES").split(",")
|
||||||
for lora_name in lora_names:
|
urls = os.getenv("LORA_DOWNLOAD_URLS").split(",")
|
||||||
path_to_lora = os.path.join(BASE_CACHE_PATH_LORA, lora_name)
|
for name, url in zip(names, urls):
|
||||||
if os.path.exists(path_to_lora):
|
path = os.path.join(BASE_CACHE_PATH_LORA, name)
|
||||||
print(f"The directory '{path_to_lora}' exists.")
|
if os.path.exists(path):
|
||||||
|
print(f"The directory '{path}' exists.")
|
||||||
else:
|
else:
|
||||||
print(f"The directory '{path_to_lora}' does not exist. Download loras...")
|
print(f"The directory '{path}' does not exist. Download it...")
|
||||||
download_loras()
|
download_files(url, name, BASE_CACHE_PATH_LORA)
|
||||||
self.pipe.load_lora_weights(".", weight_name=path_to_lora)
|
self.pipe.load_lora_weights(".", weight_name=path)
|
||||||
|
|
||||||
|
if os.environ["TEXTUAL_INVERSION_NAMES"] != "":
|
||||||
|
names = os.getenv("TEXTUAL_INVERSION_NAMES").split(",")
|
||||||
|
urls = os.getenv("TEXTUAL_INVERSION_DOWNLOAD_URLS").split(",")
|
||||||
|
for name, url in zip(names, urls):
|
||||||
|
path = os.path.join(BASE_CACHE_PATH_TEXTUAL_INVERSION, name)
|
||||||
|
if os.path.exists(path):
|
||||||
|
print(f"The directory '{path}' exists.")
|
||||||
|
else:
|
||||||
|
print(f"The directory '{path}' does not exist. Download it...")
|
||||||
|
download_files(url, name, BASE_CACHE_PATH_TEXTUAL_INVERSION)
|
||||||
|
self.pipe.load_textual_inversion(path)
|
||||||
|
|
||||||
self.pipe.enable_xformers_memory_efficient_attention()
|
self.pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user