diff --git a/.env.example b/.env.example index e19e5d7..a2c9dfb 100644 --- a/.env.example +++ b/.env.example @@ -16,3 +16,7 @@ USE_VAE="false" # - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"` LORA_NAMES="" LORA_DOWNLOAD_URLS="" + +# Add Textual Inversion you wan to use. Usage is the same as `LORA_NAMES` and `LORA_DOWNLOAD_URLS`. +TEXTUAL_INVERSION_NAMES="" +TEXTUAL_INVERSION_DOWNLOAD_URLS="" diff --git a/sd_cli.py b/sd_cli.py index 3240a3d..fe02b82 100644 --- a/sd_cli.py +++ b/sd_cli.py @@ -11,20 +11,21 @@ import util BASE_CACHE_PATH = "/vol/cache" BASE_CACHE_PATH_LORA = "/vol/cache/lora" +BASE_CACHE_PATH_TEXTUAL_INVERSION = "/vol/cache/textual_inversion" -def download_loras(): +def download_files(urls, file_names, file_path): """ - Download LoRA. + Download files. """ - lora_names = os.getenv("LORA_NAMES").split(",") - lora_download_urls = os.getenv("LORA_DOWNLOAD_URLS").split(",") + file_names = file_names.split(",") + urls = urls.split(",") - for name, url in zip(lora_names, lora_download_urls): + for file_name, url in zip(file_names, urls): req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) downloaded = urlopen(req).read() - dir_names = os.path.join(BASE_CACHE_PATH_LORA, name) + dir_names = os.path.join(file_path, file_name) os.makedirs(os.path.dirname(dir_names), exist_ok=True) with open(dir_names, mode="wb") as f: f.write(downloaded) @@ -71,7 +72,18 @@ def build_image(): download_models() if os.environ["LORA_NAMES"] != "": - download_loras() + download_files( + os.getenv("LORA_DOWNLOAD_URLS"), + os.getenv("LORA_NAMES"), + BASE_CACHE_PATH_LORA, + ) + + if os.environ["TEXTUAL_INVERSION_NAMES"] != "": + download_files( + os.getenv("TEXTUAL_INVERSION_DOWNLOAD_URLS"), + os.getenv("TEXTUAL_INVERSION_NAMES"), + BASE_CACHE_PATH_TEXTUAL_INVERSION, + ) stub_image = Image.from_dockerfile( @@ -124,15 +136,28 @@ class StableDiffusion: self.pipe.to("cuda") if os.environ["LORA_NAMES"] != "": - lora_names = os.getenv("LORA_NAMES").split(",") - for lora_name in lora_names: - path_to_lora = os.path.join(BASE_CACHE_PATH_LORA, lora_name) - if os.path.exists(path_to_lora): - print(f"The directory '{path_to_lora}' exists.") + names = os.getenv("LORA_NAMES").split(",") + urls = os.getenv("LORA_DOWNLOAD_URLS").split(",") + for name, url in zip(names, urls): + path = os.path.join(BASE_CACHE_PATH_LORA, name) + if os.path.exists(path): + print(f"The directory '{path}' exists.") else: - print(f"The directory '{path_to_lora}' does not exist. Download loras...") - download_loras() - self.pipe.load_lora_weights(".", weight_name=path_to_lora) + print(f"The directory '{path}' does not exist. Download it...") + download_files(url, name, BASE_CACHE_PATH_LORA) + self.pipe.load_lora_weights(".", weight_name=path) + + if os.environ["TEXTUAL_INVERSION_NAMES"] != "": + names = os.getenv("TEXTUAL_INVERSION_NAMES").split(",") + urls = os.getenv("TEXTUAL_INVERSION_DOWNLOAD_URLS").split(",") + for name, url in zip(names, urls): + path = os.path.join(BASE_CACHE_PATH_TEXTUAL_INVERSION, name) + if os.path.exists(path): + print(f"The directory '{path}' exists.") + else: + print(f"The directory '{path}' does not exist. Download it...") + download_files(url, name, BASE_CACHE_PATH_TEXTUAL_INVERSION) + self.pipe.load_textual_inversion(path) self.pipe.enable_xformers_memory_efficient_attention()