From b9b8f20736378d48a2063e1be9facfa11770c75c Mon Sep 17 00:00:00 2001 From: hodanov <1031hoda@gmail.com> Date: Sat, 2 Nov 2024 17:23:41 +0900 Subject: [PATCH] Refactor setup.py. --- app/config.sample.yml | 1 + app/setup.py | 227 ++++++++++++++++++++++++------------------ cmd/sd15_img2img.py | 8 +- cmd/sd15_txt2img.py | 8 +- 4 files changed, 142 insertions(+), 102 deletions(-) diff --git a/app/config.sample.yml b/app/config.sample.yml index ae6da56..4e7f23c 100644 --- a/app/config.sample.yml +++ b/app/config.sample.yml @@ -6,6 +6,7 @@ ########## # You can use a diffusers model and VAE on hugging face. +version: "sd15" # 'sd15' or 'sdxl'. model: name: stable-diffusion-1-5 url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors diff --git a/app/setup.py b/app/setup.py index 7e36d86..a946435 100644 --- a/app/setup.py +++ b/app/setup.py @@ -1,6 +1,5 @@ -from __future__ import annotations - import os +from abc import ABC, abstractmethod import diffusers from modal import App, Image, Mount, Secret @@ -12,79 +11,130 @@ BASE_CACHE_PATH_CONTROLNET = "/vol/cache/controlnet" BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler" -def download_file(url, file_name, file_path): - """ - Download files. - """ - from urllib.request import Request, urlopen - - req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) - downloaded = urlopen(req).read() - dir_names = os.path.join(file_path, file_name) - os.makedirs(os.path.dirname(dir_names), exist_ok=True) - with open(dir_names, mode="wb") as f: - f.write(downloaded) +class StableDiffusionCLISetupInterface(ABC): + @abstractmethod + def download_model(self): + pass -def download_upscaler(): - """ - Download the stabilityai/sd-x2-latent-upscaler. - """ - model_id = "stabilityai/sd-x2-latent-upscaler" - upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(model_id) - upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True) +class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface): + def __init__(self, config: dict, token: str): + if config.get("version") != "sdxl": + raise ValueError("Invalid version. Must be 'sdxl'.") + + if config.get("model") is None: + raise ValueError("Model is required. Please provide a model in config.yml.") + + self.__model_name: str = config["model"]["name"] + self.__model_url: str = config["model"]["url"] + self.__token: str = token + + def download_model(self) -> None: + cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) + pipe = diffusers.StableDiffusionXLPipeline.from_single_file( + pretrained_model_link_or_path=self.__model_url, + use_auth_token=self.__token, + cache_dir=cache_path, + ) + pipe.save_pretrained(cache_path, safe_serialization=True) -def download_controlnet(name: str, repo_id: str, token: str): - """ - Download a controlnet. - """ - cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name) - controlnet = diffusers.ControlNetModel.from_pretrained( - repo_id, - use_auth_token=token, - cache_dir=cache_path, - ) - controlnet.save_pretrained(cache_path, safe_serialization=True) +class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface): + def __init__(self, config: dict, token: str): + if config.get("version") != "sd15": + raise ValueError("Invalid version. Must be 'sd15'.") + + if config.get("model") is None: + raise ValueError("Model is required. Please provide a model in config.yml.") + + self.__model_name: str = config["model"]["name"] + self.__model_url: str = config["model"]["url"] + self.__token: str = token + + def download_model(self) -> None: + cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) + pipe = diffusers.StableDiffusionPipeline.from_single_file( + pretrained_model_link_or_path=self.__model_url, + token=self.__token, + cache_dir=cache_path, + ) + pipe.save_pretrained(cache_path, safe_serialization=True) + self.__download_upscaler() + + def __download_upscaler(self) -> None: + upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained( + "stabilityai/sd-x2-latent-upscaler" + ) + upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True) -def download_vae(name: str, model_url: str, token: str): - """ - Download a vae. - """ - cache_path = os.path.join(BASE_CACHE_PATH, name) - vae = diffusers.AutoencoderKL.from_single_file( - pretrained_model_link_or_path=model_url, - use_auth_token=token, - cache_dir=cache_path, - ) - vae.save_pretrained(cache_path, safe_serialization=True) +class CommonSetup: + def __init__(self, config: dict, token: str): + self.__token: str = token + self.__config: dict = config + def download_setup_files(self) -> None: + if self.__config.get("vae") is not None: + self.__download_vae( + name=self.__config["model"]["name"], + model_url=self.__config["vae"]["url"], + token=self.__token, + ) -def download_model(name: str, model_url: str, token: str): - """ - Download a model. - """ - cache_path = os.path.join(BASE_CACHE_PATH, name) - pipe = diffusers.StableDiffusionPipeline.from_single_file( - pretrained_model_link_or_path=model_url, - token=token, - cache_dir=cache_path, - ) - pipe.save_pretrained(cache_path, safe_serialization=True) + if self.__config.get("controlnets") is not None: + for controlnet in self.__config["controlnets"]: + self.__download_controlnet( + name=controlnet["name"], + repo_id=controlnet["repo_id"], + token=self.__token, + ) + if self.__config.get("loras") is not None: + for lora in self.__config["loras"]: + self.__download_other_file( + url=lora["url"], + file_name=lora["name"], + file_path=BASE_CACHE_PATH_LORA, + ) -def download_model_sdxl(name: str, model_url: str, token: str): - """ - Download a sdxl model. - """ - cache_path = os.path.join(BASE_CACHE_PATH, name) - pipe = diffusers.StableDiffusionXLPipeline.from_single_file( - pretrained_model_link_or_path=model_url, - use_auth_token=token, - cache_dir=cache_path, - ) - pipe.save_pretrained(cache_path, safe_serialization=True) + if self.__config.get("textual_inversions") is not None: + for textual_inversion in self.__config["textual_inversions"]: + self.__download_other_file( + url=textual_inversion["url"], + file_name=textual_inversion["name"], + file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION, + ) + + def __download_vae(self, name: str, model_url: str, token: str): + cache_path = os.path.join(BASE_CACHE_PATH, name) + vae = diffusers.AutoencoderKL.from_single_file( + pretrained_model_link_or_path=model_url, + use_auth_token=token, + cache_dir=cache_path, + ) + vae.save_pretrained(cache_path, safe_serialization=True) + + def __download_controlnet(self, name: str, repo_id: str, token: str): + cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name) + controlnet = diffusers.ControlNetModel.from_pretrained( + repo_id, + use_auth_token=token, + cache_dir=cache_path, + ) + controlnet.save_pretrained(cache_path, safe_serialization=True) + + def __download_other_file(self, url, file_name, file_path): + """ + Download file from the given URL for LoRA or TextualInversion. + """ + from urllib.request import Request, urlopen + + req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) + downloaded = urlopen(req).read() + dir_names = os.path.join(file_path, file_name) + os.makedirs(os.path.dirname(dir_names), exist_ok=True) + with open(dir_names, mode="wb") as f: + f.write(downloaded) def build_image(): @@ -93,43 +143,24 @@ def build_image(): """ import yaml - token = os.environ["HUGGING_FACE_TOKEN"] - config = {} + token: str = os.environ["HUGGING_FACE_TOKEN"] with open("/config.yml", "r") as file: - config = yaml.safe_load(file) + config: dict = yaml.safe_load(file) - model = config.get("model") - use_xl = config.get("use_xl") - if model is not None: - if use_xl is not None and use_xl: - download_model_sdxl(name=model["name"], model_url=model["url"], token=token) - else: - download_model(name=model["name"], model_url=model["url"], token=token) - - vae = config.get("vae") - if vae is not None: - download_vae(name=model["name"], model_url=vae["url"], token=token) - - controlnets = config.get("controlnets") - if controlnets is not None: - for controlnet in controlnets: - download_controlnet(name=controlnet["name"], repo_id=controlnet["repo_id"], token=token) - - loras = config.get("loras") - if loras is not None: - for lora in loras: - download_file(url=lora["url"], file_name=lora["name"], file_path=BASE_CACHE_PATH_LORA) - - textual_inversions = config.get("textual_inversions") - if textual_inversions is not None: - for textual_inversion in textual_inversions: - download_file( - url=textual_inversion["url"], - file_name=textual_inversion["name"], - file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION, + stable_diffusion_setup: StableDiffusionCLISetupInterface + match config.get("version"): + case "sd15": + stable_diffusion_setup = StableDiffusionCLISetupSD15(config, token) + case "sdxl": + stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token) + case _: + raise ValueError( + f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'." ) - download_upscaler() + stable_diffusion_setup.download_model() + common_setup = CommonSetup(config, token) + common_setup.download_setup_files() app = App("stable-diffusion-cli") diff --git a/cmd/sd15_img2img.py b/cmd/sd15_img2img.py index f488549..dd37db4 100644 --- a/cmd/sd15_img2img.py +++ b/cmd/sd15_img2img.py @@ -4,7 +4,9 @@ import modal import util stub = modal.Stub("run-stable-diffusion-cli") -stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "SD15.run_img2img_inference") +stub.run_inference = modal.Function.from_name( + "stable-diffusion-cli", "SD15.run_img2img_inference" +) @stub.local_entrypoint() @@ -44,7 +46,9 @@ def main( ) util.save_images(directory, images, seed_generated, i, output_format) total_time = time.time() - start_time - print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") + print( + f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)." + ) prompts: dict[str, int | str] = { "prompt": prompt, diff --git a/cmd/sd15_txt2img.py b/cmd/sd15_txt2img.py index 960b023..621c1b9 100644 --- a/cmd/sd15_txt2img.py +++ b/cmd/sd15_txt2img.py @@ -4,7 +4,9 @@ import modal import util app = modal.App("run-stable-diffusion-cli") -run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference") +run_inference = modal.Function.from_name( + "stable-diffusion-cli", "SD15.run_txt2img_inference" +) @app.local_entrypoint() @@ -46,7 +48,9 @@ def main( ) util.save_images(directory, images, seed_generated, i, output_format) total_time = time.time() - start_time - print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") + print( + f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)." + ) prompts: dict[str, int | str] = { "prompt": prompt,