Merge pull request #156 from hodanov/feature/refactoring

Fix some lint errors. Refactor cmd.
This commit is contained in:
hodanov 2024-11-03 19:54:00 +09:00 committed by GitHub
commit d7b143ce5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 227 additions and 128 deletions

View File

@ -1,5 +1,6 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
import diffusers import diffusers
from modal import App, Image, Mount, Secret from modal import App, Image, Mount, Secret
@ -13,24 +14,26 @@ BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler"
class StableDiffusionCLISetupInterface(ABC): class StableDiffusionCLISetupInterface(ABC):
@abstractmethod @abstractmethod
def download_model(self): def download_model(self) -> None:
pass pass
class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface): class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
def __init__(self, config: dict, token: str): def __init__(self, config: dict, token: str) -> None:
if config.get("version") != "sdxl": if config.get("version") != "sdxl":
raise ValueError("Invalid version. Must be 'sdxl'.") msg = "Invalid version. Must be 'sdxl'."
raise ValueError(msg)
if config.get("model") is None: if config.get("model") is None:
raise ValueError("Model is required. Please provide a model in config.yml.") msg = "Model is required. Please provide a model in config.yml."
raise ValueError(msg)
self.__model_name: str = config["model"]["name"] self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"] self.__model_url: str = config["model"]["url"]
self.__token: str = token self.__token: str = token
def download_model(self) -> None: def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) cache_path = Path(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionXLPipeline.from_single_file( pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url, pretrained_model_link_or_path=self.__model_url,
use_auth_token=self.__token, use_auth_token=self.__token,
@ -40,19 +43,21 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface): class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
def __init__(self, config: dict, token: str): def __init__(self, config: dict, token: str) -> None:
if config.get("version") != "sd15": if config.get("version") != "sd15":
raise ValueError("Invalid version. Must be 'sd15'.") msg = "Invalid version. Must be 'sd15'."
raise ValueError(msg)
if config.get("model") is None: if config.get("model") is None:
raise ValueError("Model is required. Please provide a model in config.yml.") msg = "Model is required. Please provide a model in config.yml."
raise ValueError(msg)
self.__model_name: str = config["model"]["name"] self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"] self.__model_url: str = config["model"]["url"]
self.__token: str = token self.__token: str = token
def download_model(self) -> None: def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) cache_path = Path(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionPipeline.from_single_file( pipe = diffusers.StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url, pretrained_model_link_or_path=self.__model_url,
token=self.__token, token=self.__token,
@ -63,13 +68,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
def __download_upscaler(self) -> None: def __download_upscaler(self) -> None:
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained( upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler" "stabilityai/sd-x2-latent-upscaler",
) )
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True) upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
class CommonSetup: class CommonSetup:
def __init__(self, config: dict, token: str): def __init__(self, config: dict, token: str) -> None:
self.__token: str = token self.__token: str = token
self.__config: dict = config self.__config: dict = config
@ -105,8 +110,8 @@ class CommonSetup:
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION, file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
) )
def __download_vae(self, name: str, model_url: str, token: str): def __download_vae(self, name: str, model_url: str, token: str) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, name) cache_path = Path(BASE_CACHE_PATH, name)
vae = diffusers.AutoencoderKL.from_single_file( vae = diffusers.AutoencoderKL.from_single_file(
pretrained_model_link_or_path=model_url, pretrained_model_link_or_path=model_url,
use_auth_token=token, use_auth_token=token,
@ -114,8 +119,8 @@ class CommonSetup:
) )
vae.save_pretrained(cache_path, safe_serialization=True) vae.save_pretrained(cache_path, safe_serialization=True)
def __download_controlnet(self, name: str, repo_id: str, token: str): def __download_controlnet(self, name: str, repo_id: str, token: str) -> None:
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name) cache_path = Path(BASE_CACHE_PATH, name)
controlnet = diffusers.ControlNetModel.from_pretrained( controlnet = diffusers.ControlNetModel.from_pretrained(
repo_id, repo_id,
use_auth_token=token, use_auth_token=token,
@ -123,7 +128,7 @@ class CommonSetup:
) )
controlnet.save_pretrained(cache_path, safe_serialization=True) controlnet.save_pretrained(cache_path, safe_serialization=True)
def __download_other_file(self, url, file_name, file_path): def __download_other_file(self, url: str, file_name: str, file_path: str) -> None:
""" """
Download file from the given URL for LoRA or TextualInversion. Download file from the given URL for LoRA or TextualInversion.
""" """
@ -131,20 +136,20 @@ class CommonSetup:
req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
downloaded = urlopen(req).read() downloaded = urlopen(req).read()
dir_names = os.path.join(file_path, file_name) dir_names = Path(file_path, file_name)
os.makedirs(os.path.dirname(dir_names), exist_ok=True) os.makedirs(os.path.dirname(dir_names), exist_ok=True)
with open(dir_names, mode="wb") as f: with open(dir_names, mode="wb") as f:
f.write(downloaded) f.write(downloaded)
def build_image(): def build_image() -> None:
""" """
Build the Docker image. Build the Docker image.
""" """
import yaml import yaml
token: str = os.environ["HUGGING_FACE_TOKEN"] token: str = os.environ["HUGGING_FACE_TOKEN"]
with open("/config.yml", "r") as file: with open("/config.yml") as file:
config: dict = yaml.safe_load(file) config: dict = yaml.safe_load(file)
stable_diffusion_setup: StableDiffusionCLISetupInterface stable_diffusion_setup: StableDiffusionCLISetupInterface
@ -154,9 +159,8 @@ def build_image():
case "sdxl": case "sdxl":
stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token) stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token)
case _: case _:
raise ValueError( msg = f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'." raise ValueError(msg)
)
stable_diffusion_setup.download_model() stable_diffusion_setup.download_model()
common_setup = CommonSetup(config, token) common_setup = CommonSetup(config, token)

123
cmd/domain.py Normal file
View File

@ -0,0 +1,123 @@
"""Utility functions for the script."""
from __future__ import annotations
import secrets
import time
from datetime import date
from pathlib import Path
class Seed:
def __init__(self, seed: int) -> None:
if seed != -1:
self.__value = seed
return
self.__value = self.__generate_seed()
def __generate_seed(self) -> int:
max_limit_value = 4294967295
return secrets.randbelow(max_limit_value)
@property
def value(self) -> int:
return self.__value
class Prompts:
def __init__(
self,
prompt: str,
n_prompt: str,
height: int,
width: int,
samples: int,
steps: int,
) -> None:
if prompt == "":
msg = "prompt should not be empty."
raise ValueError(msg)
if n_prompt == "":
msg = "n_prompt should not be empty."
raise ValueError(msg)
if height <= 0:
msg = "height should be positive."
raise ValueError(msg)
if width <= 0:
msg = "width should be positive."
raise ValueError(msg)
if samples <= 0:
msg = "samples should be positive."
raise ValueError(msg)
if steps <= 0:
msg = "steps should be positive."
raise ValueError(msg)
self.__dict: dict[str, int | str] = {
"prompt": prompt,
"n_prompt": n_prompt,
"height": height,
"width": width,
"samples": samples,
"steps": steps,
}
@property
def dict(self) -> dict[str, int | str]:
return self.__dict
class OutputDirectory:
def __init__(self) -> None:
self.__output_directory_name = "outputs"
self.__date_today = date.today().strftime("%Y-%m-%d")
self.__make_path()
def __make_path(self) -> None:
self.__path = Path(f"{self.__output_directory_name}/{self.__date_today}")
def make_directory(self) -> Path:
"""Make a directory for saving outputs."""
if not self.__path.exists():
self.__path.mkdir(exist_ok=True, parents=True)
return self.__path
class StableDiffusionOutputManger:
def __init__(self, prompts: Prompts, output_directory: Path) -> None:
self.__prompts = prompts
self.__output_directory = output_directory
def save_prompts(self) -> str:
"""Save prompts to a file."""
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = f"{self.__output_directory}/prompts_{prompts_filename}.txt"
with Path(output_path).open("wb") as file:
for name, value in self.__prompts.dict.items():
file.write(f"{name} = {value!r}\n".encode())
return output_path
def save_image(
self,
image: bytes,
seed: int,
i: int,
j: int,
output_format: str = "png",
) -> str:
"""Save image to a file."""
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
filename = f"{formatted_time}_{seed}_{i}_{j}.{output_format}"
output_path = f"{self.__output_directory}/{filename}"
with Path(output_path).open("wb") as file:
file.write(image)
return output_path

View File

@ -1,11 +1,13 @@
import logging
import time import time
import domain
import modal import modal
import util
app = modal.App("run-stable-diffusion-cli") app = modal.App("run-stable-diffusion-cli")
run_inference = modal.Function.from_name( run_inference = modal.Function.from_name(
"stable-diffusion-cli", "SD15.run_txt2img_inference" "stable-diffusion-cli",
"SD15.run_txt2img_inference",
) )
@ -16,49 +18,55 @@ def main(
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
samples: int = 5, samples: int = 5,
batch_size: int = 1,
steps: int = 20, steps: int = 20,
seed: int = -1, seed: int = -1,
use_upscaler: str = "", use_upscaler: str = "",
fix_by_controlnet_tile: str = "False", fix_by_controlnet_tile: str = "False",
output_format: str = "png", output_format: str = "png",
): ) -> None:
"""main() is the entrypoint for the Runway CLI.
This pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local.
""" """
This function is the entrypoint for the Runway CLI. logging.basicConfig(
The function pass the given prompt to StableDiffusion on Modal, level=logging.INFO,
gets back a list of images and outputs images to local. format="[%(levelname)s] %(asctime)s - %(message)s",
""" datefmt="%Y-%m-%d %H:%M:%S",
directory = util.make_directory() )
seed_generated = seed logger = logging.getLogger("run-stable-diffusion-cli")
for i in range(samples):
if seed == -1: output_directory = domain.OutputDirectory()
seed_generated = util.generate_seed() directory_path = output_directory.make_directory()
logger.info("Made a directory: %s", directory_path)
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
for sample_index in range(samples):
new_seed = domain.Seed(seed)
start_time = time.time() start_time = time.time()
images = run_inference.remote( images = run_inference.remote(
prompt=prompt, prompt=prompt,
n_prompt=n_prompt, n_prompt=n_prompt,
height=height, height=height,
width=width, width=width,
batch_size=batch_size, batch_size=1,
steps=steps, steps=steps,
seed=seed_generated, seed=new_seed.value,
use_upscaler=use_upscaler == "True", use_upscaler=use_upscaler == "True",
fix_by_controlnet_tile=fix_by_controlnet_tile == "True", fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
output_format=output_format, output_format=output_format,
) )
util.save_images(directory, images, seed_generated, i, output_format) for generated_image_index, image_bytes in enumerate(images):
total_time = time.time() - start_time saved_path = sd_output_manager.save_image(
print( image_bytes,
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)." new_seed.value,
) sample_index,
generated_image_index,
output_format,
)
logger.info("Saved image to the: %s", saved_path)
prompts: dict[str, int | str] = { total_time = time.time() - start_time
"prompt": prompt, logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
"n_prompt": n_prompt,
"height": height, saved_prompts_path = sd_output_manager.save_prompts()
"width": width, logger.info("Saved prompts: %s", saved_prompts_path)
"samples": samples,
"batch_size": batch_size,
"steps": steps,
}
util.save_prompts(prompts)

View File

@ -1,10 +1,14 @@
import logging
import time import time
import domain
import modal import modal
import util
app = modal.App("run-stable-diffusion-cli") app = modal.App("run-stable-diffusion-cli")
run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference") run_inference = modal.Function.from_name(
"stable-diffusion-cli",
"SDXLTxt2Img.run_inference",
)
@app.local_entrypoint() @app.local_entrypoint()
@ -18,17 +22,27 @@ def main(
seed: int = -1, seed: int = -1,
use_upscaler: str = "False", use_upscaler: str = "False",
output_format: str = "png", output_format: str = "png",
): ) -> None:
""" """This function is the entrypoint for the Runway CLI.
This function is the entrypoint for the Runway CLI.
The function pass the given prompt to StableDiffusion on Modal, The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local. gets back a list of images and outputs images to local.
""" """
directory = util.make_directory() logging.basicConfig(
seed_generated = seed level=logging.INFO,
for i in range(samples): format="[%(levelname)s] %(asctime)s - %(message)s",
if seed == -1: datefmt="%Y-%m-%d %H:%M:%S",
seed_generated = util.generate_seed() )
logger = logging.getLogger("run-stable-diffusion-cli")
output_directory = domain.OutputDirectory()
directory_path = output_directory.make_directory()
logger.info("Made a directory: %s", directory_path)
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
for sample_index in range(samples):
new_seed = domain.Seed(seed)
start_time = time.time() start_time = time.time()
images = run_inference.remote( images = run_inference.remote(
prompt=prompt, prompt=prompt,
@ -36,18 +50,23 @@ def main(
height=height, height=height,
width=width, width=width,
steps=steps, steps=steps,
seed=seed_generated, seed=new_seed.value,
use_upscaler=use_upscaler == "True", use_upscaler=use_upscaler == "True",
output_format=output_format, output_format=output_format,
) )
util.save_images(directory, images, seed_generated, i, output_format)
total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
prompts: dict[str, int | str] = { for generated_image_index, image_bytes in enumerate(images):
"prompt": prompt, saved_path = sd_output_manager.save_image(
"height": height, image_bytes,
"width": width, new_seed.value,
"samples": samples, sample_index,
} generated_image_index,
util.save_prompts(prompts) output_format,
)
logger.info("Saved image to the: %s", saved_path)
total_time = time.time() - start_time
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
saved_prompts_path = sd_output_manager.save_prompts()
logger.info("Saved prompts: %s", saved_prompts_path)

View File

@ -1,55 +0,0 @@
""" Utility functions for the script. """
import random
import time
from datetime import date
from pathlib import Path
OUTPUT_DIRECTORY = "outputs"
DATE_TODAY = date.today().strftime("%Y-%m-%d")
def generate_seed() -> int:
"""
Generate a random seed.
"""
seed = random.randint(0, 4294967295)
print(f"Generate a random seed: {seed}")
return seed
def make_directory() -> Path:
"""
Make a directory for saving outputs.
"""
directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}")
if not directory.exists():
directory.mkdir(exist_ok=True, parents=True)
print(f"Make a directory: {directory}")
return directory
def save_prompts(inputs: dict):
"""
Save prompts to a file.
"""
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
with open(
file=f"{OUTPUT_DIRECTORY}/{DATE_TODAY}/prompts_{prompts_filename}.txt", mode="w", encoding="utf-8"
) as file:
for name, value in inputs.items():
file.write(f"{name} = {repr(value)}\n")
print(f"Save prompts: {prompts_filename}.txt")
def save_images(directory: Path, images: list[bytes], seed: int, i: int, output_format: str = "png"):
"""
Save images to a file.
"""
for j, image_bytes in enumerate(images):
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = directory / f"{formatted_time}_{seed}_{i}_{j}.{output_format}"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as file:
file.write(image_bytes)