diff --git a/app/setup.py b/app/setup.py index a946435..155ca82 100644 --- a/app/setup.py +++ b/app/setup.py @@ -1,5 +1,6 @@ import os from abc import ABC, abstractmethod +from pathlib import Path import diffusers from modal import App, Image, Mount, Secret @@ -13,24 +14,26 @@ BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler" class StableDiffusionCLISetupInterface(ABC): @abstractmethod - def download_model(self): + def download_model(self) -> None: pass class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface): - def __init__(self, config: dict, token: str): + def __init__(self, config: dict, token: str) -> None: if config.get("version") != "sdxl": - raise ValueError("Invalid version. Must be 'sdxl'.") + msg = "Invalid version. Must be 'sdxl'." + raise ValueError(msg) if config.get("model") is None: - raise ValueError("Model is required. Please provide a model in config.yml.") + msg = "Model is required. Please provide a model in config.yml." + raise ValueError(msg) self.__model_name: str = config["model"]["name"] self.__model_url: str = config["model"]["url"] self.__token: str = token def download_model(self) -> None: - cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) + cache_path = Path(BASE_CACHE_PATH, self.__model_name) pipe = diffusers.StableDiffusionXLPipeline.from_single_file( pretrained_model_link_or_path=self.__model_url, use_auth_token=self.__token, @@ -40,19 +43,21 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface): class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface): - def __init__(self, config: dict, token: str): + def __init__(self, config: dict, token: str) -> None: if config.get("version") != "sd15": - raise ValueError("Invalid version. Must be 'sd15'.") + msg = "Invalid version. Must be 'sd15'." + raise ValueError(msg) if config.get("model") is None: - raise ValueError("Model is required. Please provide a model in config.yml.") + msg = "Model is required. Please provide a model in config.yml." + raise ValueError(msg) self.__model_name: str = config["model"]["name"] self.__model_url: str = config["model"]["url"] self.__token: str = token def download_model(self) -> None: - cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) + cache_path = Path(BASE_CACHE_PATH, self.__model_name) pipe = diffusers.StableDiffusionPipeline.from_single_file( pretrained_model_link_or_path=self.__model_url, token=self.__token, @@ -63,13 +68,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface): def __download_upscaler(self) -> None: upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained( - "stabilityai/sd-x2-latent-upscaler" + "stabilityai/sd-x2-latent-upscaler", ) upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True) class CommonSetup: - def __init__(self, config: dict, token: str): + def __init__(self, config: dict, token: str) -> None: self.__token: str = token self.__config: dict = config @@ -105,8 +110,8 @@ class CommonSetup: file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION, ) - def __download_vae(self, name: str, model_url: str, token: str): - cache_path = os.path.join(BASE_CACHE_PATH, name) + def __download_vae(self, name: str, model_url: str, token: str) -> None: + cache_path = Path(BASE_CACHE_PATH, name) vae = diffusers.AutoencoderKL.from_single_file( pretrained_model_link_or_path=model_url, use_auth_token=token, @@ -114,8 +119,8 @@ class CommonSetup: ) vae.save_pretrained(cache_path, safe_serialization=True) - def __download_controlnet(self, name: str, repo_id: str, token: str): - cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name) + def __download_controlnet(self, name: str, repo_id: str, token: str) -> None: + cache_path = Path(BASE_CACHE_PATH, name) controlnet = diffusers.ControlNetModel.from_pretrained( repo_id, use_auth_token=token, @@ -123,7 +128,7 @@ class CommonSetup: ) controlnet.save_pretrained(cache_path, safe_serialization=True) - def __download_other_file(self, url, file_name, file_path): + def __download_other_file(self, url: str, file_name: str, file_path: str) -> None: """ Download file from the given URL for LoRA or TextualInversion. """ @@ -131,20 +136,20 @@ class CommonSetup: req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) downloaded = urlopen(req).read() - dir_names = os.path.join(file_path, file_name) + dir_names = Path(file_path, file_name) os.makedirs(os.path.dirname(dir_names), exist_ok=True) with open(dir_names, mode="wb") as f: f.write(downloaded) -def build_image(): +def build_image() -> None: """ Build the Docker image. """ import yaml token: str = os.environ["HUGGING_FACE_TOKEN"] - with open("/config.yml", "r") as file: + with open("/config.yml") as file: config: dict = yaml.safe_load(file) stable_diffusion_setup: StableDiffusionCLISetupInterface @@ -154,9 +159,8 @@ def build_image(): case "sdxl": stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token) case _: - raise ValueError( - f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'." - ) + msg = f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'." + raise ValueError(msg) stable_diffusion_setup.download_model() common_setup = CommonSetup(config, token) diff --git a/cmd/domain.py b/cmd/domain.py new file mode 100644 index 0000000..ec6c9ee --- /dev/null +++ b/cmd/domain.py @@ -0,0 +1,123 @@ +"""Utility functions for the script.""" + +from __future__ import annotations + +import secrets +import time +from datetime import date +from pathlib import Path + + +class Seed: + def __init__(self, seed: int) -> None: + if seed != -1: + self.__value = seed + return + + self.__value = self.__generate_seed() + + def __generate_seed(self) -> int: + max_limit_value = 4294967295 + return secrets.randbelow(max_limit_value) + + @property + def value(self) -> int: + return self.__value + + +class Prompts: + def __init__( + self, + prompt: str, + n_prompt: str, + height: int, + width: int, + samples: int, + steps: int, + ) -> None: + if prompt == "": + msg = "prompt should not be empty." + raise ValueError(msg) + + if n_prompt == "": + msg = "n_prompt should not be empty." + raise ValueError(msg) + + if height <= 0: + msg = "height should be positive." + raise ValueError(msg) + + if width <= 0: + msg = "width should be positive." + raise ValueError(msg) + + if samples <= 0: + msg = "samples should be positive." + raise ValueError(msg) + + if steps <= 0: + msg = "steps should be positive." + raise ValueError(msg) + + self.__dict: dict[str, int | str] = { + "prompt": prompt, + "n_prompt": n_prompt, + "height": height, + "width": width, + "samples": samples, + "steps": steps, + } + + @property + def dict(self) -> dict[str, int | str]: + return self.__dict + + +class OutputDirectory: + def __init__(self) -> None: + self.__output_directory_name = "outputs" + self.__date_today = date.today().strftime("%Y-%m-%d") + self.__make_path() + + def __make_path(self) -> None: + self.__path = Path(f"{self.__output_directory_name}/{self.__date_today}") + + def make_directory(self) -> Path: + """Make a directory for saving outputs.""" + if not self.__path.exists(): + self.__path.mkdir(exist_ok=True, parents=True) + + return self.__path + + +class StableDiffusionOutputManger: + def __init__(self, prompts: Prompts, output_directory: Path) -> None: + self.__prompts = prompts + self.__output_directory = output_directory + + def save_prompts(self) -> str: + """Save prompts to a file.""" + prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) + output_path = f"{self.__output_directory}/prompts_{prompts_filename}.txt" + with Path(output_path).open("wb") as file: + for name, value in self.__prompts.dict.items(): + file.write(f"{name} = {value!r}\n".encode()) + + return output_path + + def save_image( + self, + image: bytes, + seed: int, + i: int, + j: int, + output_format: str = "png", + ) -> str: + """Save image to a file.""" + formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) + filename = f"{formatted_time}_{seed}_{i}_{j}.{output_format}" + output_path = f"{self.__output_directory}/{filename}" + with Path(output_path).open("wb") as file: + file.write(image) + + return output_path diff --git a/cmd/sd15_txt2img.py b/cmd/sd15_txt2img.py index 621c1b9..dcd38c7 100644 --- a/cmd/sd15_txt2img.py +++ b/cmd/sd15_txt2img.py @@ -1,11 +1,13 @@ +import logging import time +import domain import modal -import util app = modal.App("run-stable-diffusion-cli") run_inference = modal.Function.from_name( - "stable-diffusion-cli", "SD15.run_txt2img_inference" + "stable-diffusion-cli", + "SD15.run_txt2img_inference", ) @@ -16,49 +18,55 @@ def main( height: int = 512, width: int = 512, samples: int = 5, - batch_size: int = 1, steps: int = 20, seed: int = -1, use_upscaler: str = "", fix_by_controlnet_tile: str = "False", output_format: str = "png", -): +) -> None: + """main() is the entrypoint for the Runway CLI. + This pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local. """ - This function is the entrypoint for the Runway CLI. - The function pass the given prompt to StableDiffusion on Modal, - gets back a list of images and outputs images to local. - """ - directory = util.make_directory() - seed_generated = seed - for i in range(samples): - if seed == -1: - seed_generated = util.generate_seed() + logging.basicConfig( + level=logging.INFO, + format="[%(levelname)s] %(asctime)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + logger = logging.getLogger("run-stable-diffusion-cli") + + output_directory = domain.OutputDirectory() + directory_path = output_directory.make_directory() + logger.info("Made a directory: %s", directory_path) + + prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps) + sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path) + for sample_index in range(samples): + new_seed = domain.Seed(seed) start_time = time.time() images = run_inference.remote( prompt=prompt, n_prompt=n_prompt, height=height, width=width, - batch_size=batch_size, + batch_size=1, steps=steps, - seed=seed_generated, + seed=new_seed.value, use_upscaler=use_upscaler == "True", fix_by_controlnet_tile=fix_by_controlnet_tile == "True", output_format=output_format, ) - util.save_images(directory, images, seed_generated, i, output_format) - total_time = time.time() - start_time - print( - f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)." - ) + for generated_image_index, image_bytes in enumerate(images): + saved_path = sd_output_manager.save_image( + image_bytes, + new_seed.value, + sample_index, + generated_image_index, + output_format, + ) + logger.info("Saved image to the: %s", saved_path) - prompts: dict[str, int | str] = { - "prompt": prompt, - "n_prompt": n_prompt, - "height": height, - "width": width, - "samples": samples, - "batch_size": batch_size, - "steps": steps, - } - util.save_prompts(prompts) + total_time = time.time() - start_time + logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images)) + + saved_prompts_path = sd_output_manager.save_prompts() + logger.info("Saved prompts: %s", saved_prompts_path) diff --git a/cmd/sdxl_txt2img.py b/cmd/sdxl_txt2img.py index 1528c75..32712fc 100644 --- a/cmd/sdxl_txt2img.py +++ b/cmd/sdxl_txt2img.py @@ -1,10 +1,14 @@ +import logging import time +import domain import modal -import util app = modal.App("run-stable-diffusion-cli") -run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference") +run_inference = modal.Function.from_name( + "stable-diffusion-cli", + "SDXLTxt2Img.run_inference", +) @app.local_entrypoint() @@ -18,17 +22,27 @@ def main( seed: int = -1, use_upscaler: str = "False", output_format: str = "png", -): - """ - This function is the entrypoint for the Runway CLI. +) -> None: + """This function is the entrypoint for the Runway CLI. The function pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local. """ - directory = util.make_directory() - seed_generated = seed - for i in range(samples): - if seed == -1: - seed_generated = util.generate_seed() + logging.basicConfig( + level=logging.INFO, + format="[%(levelname)s] %(asctime)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + logger = logging.getLogger("run-stable-diffusion-cli") + + output_directory = domain.OutputDirectory() + directory_path = output_directory.make_directory() + logger.info("Made a directory: %s", directory_path) + + prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps) + sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path) + + for sample_index in range(samples): + new_seed = domain.Seed(seed) start_time = time.time() images = run_inference.remote( prompt=prompt, @@ -36,18 +50,23 @@ def main( height=height, width=width, steps=steps, - seed=seed_generated, + seed=new_seed.value, use_upscaler=use_upscaler == "True", output_format=output_format, ) - util.save_images(directory, images, seed_generated, i, output_format) - total_time = time.time() - start_time - print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") - prompts: dict[str, int | str] = { - "prompt": prompt, - "height": height, - "width": width, - "samples": samples, - } - util.save_prompts(prompts) + for generated_image_index, image_bytes in enumerate(images): + saved_path = sd_output_manager.save_image( + image_bytes, + new_seed.value, + sample_index, + generated_image_index, + output_format, + ) + logger.info("Saved image to the: %s", saved_path) + + total_time = time.time() - start_time + logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images)) + + saved_prompts_path = sd_output_manager.save_prompts() + logger.info("Saved prompts: %s", saved_prompts_path) diff --git a/cmd/util.py b/cmd/util.py deleted file mode 100644 index 643be2e..0000000 --- a/cmd/util.py +++ /dev/null @@ -1,55 +0,0 @@ -""" Utility functions for the script. """ -import random -import time -from datetime import date -from pathlib import Path - -OUTPUT_DIRECTORY = "outputs" -DATE_TODAY = date.today().strftime("%Y-%m-%d") - - -def generate_seed() -> int: - """ - Generate a random seed. - """ - seed = random.randint(0, 4294967295) - print(f"Generate a random seed: {seed}") - - return seed - - -def make_directory() -> Path: - """ - Make a directory for saving outputs. - """ - directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}") - if not directory.exists(): - directory.mkdir(exist_ok=True, parents=True) - print(f"Make a directory: {directory}") - - return directory - - -def save_prompts(inputs: dict): - """ - Save prompts to a file. - """ - prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) - with open( - file=f"{OUTPUT_DIRECTORY}/{DATE_TODAY}/prompts_{prompts_filename}.txt", mode="w", encoding="utf-8" - ) as file: - for name, value in inputs.items(): - file.write(f"{name} = {repr(value)}\n") - print(f"Save prompts: {prompts_filename}.txt") - - -def save_images(directory: Path, images: list[bytes], seed: int, i: int, output_format: str = "png"): - """ - Save images to a file. - """ - for j, image_bytes in enumerate(images): - formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) - output_path = directory / f"{formatted_time}_{seed}_{i}_{j}.{output_format}" - print(f"Saving it to {output_path}") - with open(output_path, "wb") as file: - file.write(image_bytes)