diff --git a/app/setup.py b/app/setup.py index a946435..155ca82 100644 --- a/app/setup.py +++ b/app/setup.py @@ -1,5 +1,6 @@ import os from abc import ABC, abstractmethod +from pathlib import Path import diffusers from modal import App, Image, Mount, Secret @@ -13,24 +14,26 @@ BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler" class StableDiffusionCLISetupInterface(ABC): @abstractmethod - def download_model(self): + def download_model(self) -> None: pass class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface): - def __init__(self, config: dict, token: str): + def __init__(self, config: dict, token: str) -> None: if config.get("version") != "sdxl": - raise ValueError("Invalid version. Must be 'sdxl'.") + msg = "Invalid version. Must be 'sdxl'." + raise ValueError(msg) if config.get("model") is None: - raise ValueError("Model is required. Please provide a model in config.yml.") + msg = "Model is required. Please provide a model in config.yml." + raise ValueError(msg) self.__model_name: str = config["model"]["name"] self.__model_url: str = config["model"]["url"] self.__token: str = token def download_model(self) -> None: - cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) + cache_path = Path(BASE_CACHE_PATH, self.__model_name) pipe = diffusers.StableDiffusionXLPipeline.from_single_file( pretrained_model_link_or_path=self.__model_url, use_auth_token=self.__token, @@ -40,19 +43,21 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface): class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface): - def __init__(self, config: dict, token: str): + def __init__(self, config: dict, token: str) -> None: if config.get("version") != "sd15": - raise ValueError("Invalid version. Must be 'sd15'.") + msg = "Invalid version. Must be 'sd15'." + raise ValueError(msg) if config.get("model") is None: - raise ValueError("Model is required. Please provide a model in config.yml.") + msg = "Model is required. Please provide a model in config.yml." + raise ValueError(msg) self.__model_name: str = config["model"]["name"] self.__model_url: str = config["model"]["url"] self.__token: str = token def download_model(self) -> None: - cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) + cache_path = Path(BASE_CACHE_PATH, self.__model_name) pipe = diffusers.StableDiffusionPipeline.from_single_file( pretrained_model_link_or_path=self.__model_url, token=self.__token, @@ -63,13 +68,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface): def __download_upscaler(self) -> None: upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained( - "stabilityai/sd-x2-latent-upscaler" + "stabilityai/sd-x2-latent-upscaler", ) upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True) class CommonSetup: - def __init__(self, config: dict, token: str): + def __init__(self, config: dict, token: str) -> None: self.__token: str = token self.__config: dict = config @@ -105,8 +110,8 @@ class CommonSetup: file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION, ) - def __download_vae(self, name: str, model_url: str, token: str): - cache_path = os.path.join(BASE_CACHE_PATH, name) + def __download_vae(self, name: str, model_url: str, token: str) -> None: + cache_path = Path(BASE_CACHE_PATH, name) vae = diffusers.AutoencoderKL.from_single_file( pretrained_model_link_or_path=model_url, use_auth_token=token, @@ -114,8 +119,8 @@ class CommonSetup: ) vae.save_pretrained(cache_path, safe_serialization=True) - def __download_controlnet(self, name: str, repo_id: str, token: str): - cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name) + def __download_controlnet(self, name: str, repo_id: str, token: str) -> None: + cache_path = Path(BASE_CACHE_PATH, name) controlnet = diffusers.ControlNetModel.from_pretrained( repo_id, use_auth_token=token, @@ -123,7 +128,7 @@ class CommonSetup: ) controlnet.save_pretrained(cache_path, safe_serialization=True) - def __download_other_file(self, url, file_name, file_path): + def __download_other_file(self, url: str, file_name: str, file_path: str) -> None: """ Download file from the given URL for LoRA or TextualInversion. """ @@ -131,20 +136,20 @@ class CommonSetup: req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) downloaded = urlopen(req).read() - dir_names = os.path.join(file_path, file_name) + dir_names = Path(file_path, file_name) os.makedirs(os.path.dirname(dir_names), exist_ok=True) with open(dir_names, mode="wb") as f: f.write(downloaded) -def build_image(): +def build_image() -> None: """ Build the Docker image. """ import yaml token: str = os.environ["HUGGING_FACE_TOKEN"] - with open("/config.yml", "r") as file: + with open("/config.yml") as file: config: dict = yaml.safe_load(file) stable_diffusion_setup: StableDiffusionCLISetupInterface @@ -154,9 +159,8 @@ def build_image(): case "sdxl": stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token) case _: - raise ValueError( - f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'." - ) + msg = f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'." + raise ValueError(msg) stable_diffusion_setup.download_model() common_setup = CommonSetup(config, token) diff --git a/cmd/sdxl_txt2img.py b/cmd/sdxl_txt2img.py index da086f9..32712fc 100644 --- a/cmd/sdxl_txt2img.py +++ b/cmd/sdxl_txt2img.py @@ -22,7 +22,7 @@ def main( seed: int = -1, use_upscaler: str = "False", output_format: str = "png", -): +) -> None: """This function is the entrypoint for the Runway CLI. The function pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local.