Merge pull request #9 from hodanov/feature/modify_to_use_textual_inversion

Modify to use Textual Inversion.
This commit is contained in:
hodanov 2023-06-18 12:30:48 +09:00 committed by GitHub
commit b1b3863feb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 15 deletions

View File

@ -16,3 +16,7 @@ USE_VAE="false"
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
LORA_NAMES=""
LORA_DOWNLOAD_URLS=""
# Add Textual Inversion you wan to use. Usage is the same as `LORA_NAMES` and `LORA_DOWNLOAD_URLS`.
TEXTUAL_INVERSION_NAMES=""
TEXTUAL_INVERSION_DOWNLOAD_URLS=""

View File

@ -11,20 +11,21 @@ import util
BASE_CACHE_PATH = "/vol/cache"
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
BASE_CACHE_PATH_TEXTUAL_INVERSION = "/vol/cache/textual_inversion"
def download_loras():
def download_files(urls, file_names, file_path):
"""
Download LoRA.
Download files.
"""
lora_names = os.getenv("LORA_NAMES").split(",")
lora_download_urls = os.getenv("LORA_DOWNLOAD_URLS").split(",")
file_names = file_names.split(",")
urls = urls.split(",")
for name, url in zip(lora_names, lora_download_urls):
for file_name, url in zip(file_names, urls):
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
downloaded = urlopen(req).read()
dir_names = os.path.join(BASE_CACHE_PATH_LORA, name)
dir_names = os.path.join(file_path, file_name)
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
with open(dir_names, mode="wb") as f:
f.write(downloaded)
@ -71,7 +72,18 @@ def build_image():
download_models()
if os.environ["LORA_NAMES"] != "":
download_loras()
download_files(
os.getenv("LORA_DOWNLOAD_URLS"),
os.getenv("LORA_NAMES"),
BASE_CACHE_PATH_LORA,
)
if os.environ["TEXTUAL_INVERSION_NAMES"] != "":
download_files(
os.getenv("TEXTUAL_INVERSION_DOWNLOAD_URLS"),
os.getenv("TEXTUAL_INVERSION_NAMES"),
BASE_CACHE_PATH_TEXTUAL_INVERSION,
)
stub_image = Image.from_dockerfile(
@ -124,15 +136,28 @@ class StableDiffusion:
self.pipe.to("cuda")
if os.environ["LORA_NAMES"] != "":
lora_names = os.getenv("LORA_NAMES").split(",")
for lora_name in lora_names:
path_to_lora = os.path.join(BASE_CACHE_PATH_LORA, lora_name)
if os.path.exists(path_to_lora):
print(f"The directory '{path_to_lora}' exists.")
names = os.getenv("LORA_NAMES").split(",")
urls = os.getenv("LORA_DOWNLOAD_URLS").split(",")
for name, url in zip(names, urls):
path = os.path.join(BASE_CACHE_PATH_LORA, name)
if os.path.exists(path):
print(f"The directory '{path}' exists.")
else:
print(f"The directory '{path_to_lora}' does not exist. Download loras...")
download_loras()
self.pipe.load_lora_weights(".", weight_name=path_to_lora)
print(f"The directory '{path}' does not exist. Download it...")
download_files(url, name, BASE_CACHE_PATH_LORA)
self.pipe.load_lora_weights(".", weight_name=path)
if os.environ["TEXTUAL_INVERSION_NAMES"] != "":
names = os.getenv("TEXTUAL_INVERSION_NAMES").split(",")
urls = os.getenv("TEXTUAL_INVERSION_DOWNLOAD_URLS").split(",")
for name, url in zip(names, urls):
path = os.path.join(BASE_CACHE_PATH_TEXTUAL_INVERSION, name)
if os.path.exists(path):
print(f"The directory '{path}' exists.")
else:
print(f"The directory '{path}' does not exist. Download it...")
download_files(url, name, BASE_CACHE_PATH_TEXTUAL_INVERSION)
self.pipe.load_textual_inversion(path)
self.pipe.enable_xformers_memory_efficient_attention()