Merge pull request #156 from hodanov/feature/refactoring

Fix some lint errors. Refactor cmd.
This commit is contained in:
hodanov 2024-11-03 19:54:00 +09:00 committed by GitHub
commit d7b143ce5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 227 additions and 128 deletions

View File

@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
import diffusers
from modal import App, Image, Mount, Secret
@ -13,24 +14,26 @@ BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler"
class StableDiffusionCLISetupInterface(ABC):
@abstractmethod
def download_model(self):
def download_model(self) -> None:
pass
class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
def __init__(self, config: dict, token: str):
def __init__(self, config: dict, token: str) -> None:
if config.get("version") != "sdxl":
raise ValueError("Invalid version. Must be 'sdxl'.")
msg = "Invalid version. Must be 'sdxl'."
raise ValueError(msg)
if config.get("model") is None:
raise ValueError("Model is required. Please provide a model in config.yml.")
msg = "Model is required. Please provide a model in config.yml."
raise ValueError(msg)
self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"]
self.__token: str = token
def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url,
use_auth_token=self.__token,
@ -40,19 +43,21 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
def __init__(self, config: dict, token: str):
def __init__(self, config: dict, token: str) -> None:
if config.get("version") != "sd15":
raise ValueError("Invalid version. Must be 'sd15'.")
msg = "Invalid version. Must be 'sd15'."
raise ValueError(msg)
if config.get("model") is None:
raise ValueError("Model is required. Please provide a model in config.yml.")
msg = "Model is required. Please provide a model in config.yml."
raise ValueError(msg)
self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"]
self.__token: str = token
def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url,
token=self.__token,
@ -63,13 +68,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
def __download_upscaler(self) -> None:
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler"
"stabilityai/sd-x2-latent-upscaler",
)
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
class CommonSetup:
def __init__(self, config: dict, token: str):
def __init__(self, config: dict, token: str) -> None:
self.__token: str = token
self.__config: dict = config
@ -105,8 +110,8 @@ class CommonSetup:
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
)
def __download_vae(self, name: str, model_url: str, token: str):
cache_path = os.path.join(BASE_CACHE_PATH, name)
def __download_vae(self, name: str, model_url: str, token: str) -> None:
cache_path = Path(BASE_CACHE_PATH, name)
vae = diffusers.AutoencoderKL.from_single_file(
pretrained_model_link_or_path=model_url,
use_auth_token=token,
@ -114,8 +119,8 @@ class CommonSetup:
)
vae.save_pretrained(cache_path, safe_serialization=True)
def __download_controlnet(self, name: str, repo_id: str, token: str):
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
def __download_controlnet(self, name: str, repo_id: str, token: str) -> None:
cache_path = Path(BASE_CACHE_PATH, name)
controlnet = diffusers.ControlNetModel.from_pretrained(
repo_id,
use_auth_token=token,
@ -123,7 +128,7 @@ class CommonSetup:
)
controlnet.save_pretrained(cache_path, safe_serialization=True)
def __download_other_file(self, url, file_name, file_path):
def __download_other_file(self, url: str, file_name: str, file_path: str) -> None:
"""
Download file from the given URL for LoRA or TextualInversion.
"""
@ -131,20 +136,20 @@ class CommonSetup:
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
downloaded = urlopen(req).read()
dir_names = os.path.join(file_path, file_name)
dir_names = Path(file_path, file_name)
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
with open(dir_names, mode="wb") as f:
f.write(downloaded)
def build_image():
def build_image() -> None:
"""
Build the Docker image.
"""
import yaml
token: str = os.environ["HUGGING_FACE_TOKEN"]
with open("/config.yml", "r") as file:
with open("/config.yml") as file:
config: dict = yaml.safe_load(file)
stable_diffusion_setup: StableDiffusionCLISetupInterface
@ -154,9 +159,8 @@ def build_image():
case "sdxl":
stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token)
case _:
raise ValueError(
f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
)
msg = f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
raise ValueError(msg)
stable_diffusion_setup.download_model()
common_setup = CommonSetup(config, token)

123
cmd/domain.py Normal file
View File

@ -0,0 +1,123 @@
"""Utility functions for the script."""
from __future__ import annotations
import secrets
import time
from datetime import date
from pathlib import Path
class Seed:
def __init__(self, seed: int) -> None:
if seed != -1:
self.__value = seed
return
self.__value = self.__generate_seed()
def __generate_seed(self) -> int:
max_limit_value = 4294967295
return secrets.randbelow(max_limit_value)
@property
def value(self) -> int:
return self.__value
class Prompts:
def __init__(
self,
prompt: str,
n_prompt: str,
height: int,
width: int,
samples: int,
steps: int,
) -> None:
if prompt == "":
msg = "prompt should not be empty."
raise ValueError(msg)
if n_prompt == "":
msg = "n_prompt should not be empty."
raise ValueError(msg)
if height <= 0:
msg = "height should be positive."
raise ValueError(msg)
if width <= 0:
msg = "width should be positive."
raise ValueError(msg)
if samples <= 0:
msg = "samples should be positive."
raise ValueError(msg)
if steps <= 0:
msg = "steps should be positive."
raise ValueError(msg)
self.__dict: dict[str, int | str] = {
"prompt": prompt,
"n_prompt": n_prompt,
"height": height,
"width": width,
"samples": samples,
"steps": steps,
}
@property
def dict(self) -> dict[str, int | str]:
return self.__dict
class OutputDirectory:
def __init__(self) -> None:
self.__output_directory_name = "outputs"
self.__date_today = date.today().strftime("%Y-%m-%d")
self.__make_path()
def __make_path(self) -> None:
self.__path = Path(f"{self.__output_directory_name}/{self.__date_today}")
def make_directory(self) -> Path:
"""Make a directory for saving outputs."""
if not self.__path.exists():
self.__path.mkdir(exist_ok=True, parents=True)
return self.__path
class StableDiffusionOutputManger:
def __init__(self, prompts: Prompts, output_directory: Path) -> None:
self.__prompts = prompts
self.__output_directory = output_directory
def save_prompts(self) -> str:
"""Save prompts to a file."""
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = f"{self.__output_directory}/prompts_{prompts_filename}.txt"
with Path(output_path).open("wb") as file:
for name, value in self.__prompts.dict.items():
file.write(f"{name} = {value!r}\n".encode())
return output_path
def save_image(
self,
image: bytes,
seed: int,
i: int,
j: int,
output_format: str = "png",
) -> str:
"""Save image to a file."""
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
filename = f"{formatted_time}_{seed}_{i}_{j}.{output_format}"
output_path = f"{self.__output_directory}/{filename}"
with Path(output_path).open("wb") as file:
file.write(image)
return output_path

View File

@ -1,11 +1,13 @@
import logging
import time
import domain
import modal
import util
app = modal.App("run-stable-diffusion-cli")
run_inference = modal.Function.from_name(
"stable-diffusion-cli", "SD15.run_txt2img_inference"
"stable-diffusion-cli",
"SD15.run_txt2img_inference",
)
@ -16,49 +18,55 @@ def main(
height: int = 512,
width: int = 512,
samples: int = 5,
batch_size: int = 1,
steps: int = 20,
seed: int = -1,
use_upscaler: str = "",
fix_by_controlnet_tile: str = "False",
output_format: str = "png",
):
) -> None:
"""main() is the entrypoint for the Runway CLI.
This pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local.
"""
This function is the entrypoint for the Runway CLI.
The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local.
"""
directory = util.make_directory()
seed_generated = seed
for i in range(samples):
if seed == -1:
seed_generated = util.generate_seed()
logging.basicConfig(
level=logging.INFO,
format="[%(levelname)s] %(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("run-stable-diffusion-cli")
output_directory = domain.OutputDirectory()
directory_path = output_directory.make_directory()
logger.info("Made a directory: %s", directory_path)
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
for sample_index in range(samples):
new_seed = domain.Seed(seed)
start_time = time.time()
images = run_inference.remote(
prompt=prompt,
n_prompt=n_prompt,
height=height,
width=width,
batch_size=batch_size,
batch_size=1,
steps=steps,
seed=seed_generated,
seed=new_seed.value,
use_upscaler=use_upscaler == "True",
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
output_format=output_format,
)
util.save_images(directory, images, seed_generated, i, output_format)
total_time = time.time() - start_time
print(
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
)
for generated_image_index, image_bytes in enumerate(images):
saved_path = sd_output_manager.save_image(
image_bytes,
new_seed.value,
sample_index,
generated_image_index,
output_format,
)
logger.info("Saved image to the: %s", saved_path)
prompts: dict[str, int | str] = {
"prompt": prompt,
"n_prompt": n_prompt,
"height": height,
"width": width,
"samples": samples,
"batch_size": batch_size,
"steps": steps,
}
util.save_prompts(prompts)
total_time = time.time() - start_time
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
saved_prompts_path = sd_output_manager.save_prompts()
logger.info("Saved prompts: %s", saved_prompts_path)

View File

@ -1,10 +1,14 @@
import logging
import time
import domain
import modal
import util
app = modal.App("run-stable-diffusion-cli")
run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference")
run_inference = modal.Function.from_name(
"stable-diffusion-cli",
"SDXLTxt2Img.run_inference",
)
@app.local_entrypoint()
@ -18,17 +22,27 @@ def main(
seed: int = -1,
use_upscaler: str = "False",
output_format: str = "png",
):
"""
This function is the entrypoint for the Runway CLI.
) -> None:
"""This function is the entrypoint for the Runway CLI.
The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local.
"""
directory = util.make_directory()
seed_generated = seed
for i in range(samples):
if seed == -1:
seed_generated = util.generate_seed()
logging.basicConfig(
level=logging.INFO,
format="[%(levelname)s] %(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("run-stable-diffusion-cli")
output_directory = domain.OutputDirectory()
directory_path = output_directory.make_directory()
logger.info("Made a directory: %s", directory_path)
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
for sample_index in range(samples):
new_seed = domain.Seed(seed)
start_time = time.time()
images = run_inference.remote(
prompt=prompt,
@ -36,18 +50,23 @@ def main(
height=height,
width=width,
steps=steps,
seed=seed_generated,
seed=new_seed.value,
use_upscaler=use_upscaler == "True",
output_format=output_format,
)
util.save_images(directory, images, seed_generated, i, output_format)
total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
prompts: dict[str, int | str] = {
"prompt": prompt,
"height": height,
"width": width,
"samples": samples,
}
util.save_prompts(prompts)
for generated_image_index, image_bytes in enumerate(images):
saved_path = sd_output_manager.save_image(
image_bytes,
new_seed.value,
sample_index,
generated_image_index,
output_format,
)
logger.info("Saved image to the: %s", saved_path)
total_time = time.time() - start_time
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
saved_prompts_path = sd_output_manager.save_prompts()
logger.info("Saved prompts: %s", saved_prompts_path)

View File

@ -1,55 +0,0 @@
""" Utility functions for the script. """
import random
import time
from datetime import date
from pathlib import Path
OUTPUT_DIRECTORY = "outputs"
DATE_TODAY = date.today().strftime("%Y-%m-%d")
def generate_seed() -> int:
"""
Generate a random seed.
"""
seed = random.randint(0, 4294967295)
print(f"Generate a random seed: {seed}")
return seed
def make_directory() -> Path:
"""
Make a directory for saving outputs.
"""
directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}")
if not directory.exists():
directory.mkdir(exist_ok=True, parents=True)
print(f"Make a directory: {directory}")
return directory
def save_prompts(inputs: dict):
"""
Save prompts to a file.
"""
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
with open(
file=f"{OUTPUT_DIRECTORY}/{DATE_TODAY}/prompts_{prompts_filename}.txt", mode="w", encoding="utf-8"
) as file:
for name, value in inputs.items():
file.write(f"{name} = {repr(value)}\n")
print(f"Save prompts: {prompts_filename}.txt")
def save_images(directory: Path, images: list[bytes], seed: int, i: int, output_format: str = "png"):
"""
Save images to a file.
"""
for j, image_bytes in enumerate(images):
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = directory / f"{formatted_time}_{seed}_{i}_{j}.{output_format}"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as file:
file.write(image_bytes)