Merge pull request #156 from hodanov/feature/refactoring
Fix some lint errors. Refactor cmd.
This commit is contained in:
commit
d7b143ce5c
48
app/setup.py
48
app/setup.py
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
import diffusers
|
||||
from modal import App, Image, Mount, Secret
|
||||
@ -13,24 +14,26 @@ BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler"
|
||||
|
||||
class StableDiffusionCLISetupInterface(ABC):
|
||||
@abstractmethod
|
||||
def download_model(self):
|
||||
def download_model(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
|
||||
def __init__(self, config: dict, token: str):
|
||||
def __init__(self, config: dict, token: str) -> None:
|
||||
if config.get("version") != "sdxl":
|
||||
raise ValueError("Invalid version. Must be 'sdxl'.")
|
||||
msg = "Invalid version. Must be 'sdxl'."
|
||||
raise ValueError(msg)
|
||||
|
||||
if config.get("model") is None:
|
||||
raise ValueError("Model is required. Please provide a model in config.yml.")
|
||||
msg = "Model is required. Please provide a model in config.yml."
|
||||
raise ValueError(msg)
|
||||
|
||||
self.__model_name: str = config["model"]["name"]
|
||||
self.__model_url: str = config["model"]["url"]
|
||||
self.__token: str = token
|
||||
|
||||
def download_model(self) -> None:
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
||||
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
|
||||
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
|
||||
pretrained_model_link_or_path=self.__model_url,
|
||||
use_auth_token=self.__token,
|
||||
@ -40,19 +43,21 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
|
||||
|
||||
|
||||
class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
|
||||
def __init__(self, config: dict, token: str):
|
||||
def __init__(self, config: dict, token: str) -> None:
|
||||
if config.get("version") != "sd15":
|
||||
raise ValueError("Invalid version. Must be 'sd15'.")
|
||||
msg = "Invalid version. Must be 'sd15'."
|
||||
raise ValueError(msg)
|
||||
|
||||
if config.get("model") is None:
|
||||
raise ValueError("Model is required. Please provide a model in config.yml.")
|
||||
msg = "Model is required. Please provide a model in config.yml."
|
||||
raise ValueError(msg)
|
||||
|
||||
self.__model_name: str = config["model"]["name"]
|
||||
self.__model_url: str = config["model"]["url"]
|
||||
self.__token: str = token
|
||||
|
||||
def download_model(self) -> None:
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
||||
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
|
||||
pipe = diffusers.StableDiffusionPipeline.from_single_file(
|
||||
pretrained_model_link_or_path=self.__model_url,
|
||||
token=self.__token,
|
||||
@ -63,13 +68,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
|
||||
|
||||
def __download_upscaler(self) -> None:
|
||||
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
|
||||
"stabilityai/sd-x2-latent-upscaler"
|
||||
"stabilityai/sd-x2-latent-upscaler",
|
||||
)
|
||||
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
|
||||
|
||||
|
||||
class CommonSetup:
|
||||
def __init__(self, config: dict, token: str):
|
||||
def __init__(self, config: dict, token: str) -> None:
|
||||
self.__token: str = token
|
||||
self.__config: dict = config
|
||||
|
||||
@ -105,8 +110,8 @@ class CommonSetup:
|
||||
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
||||
)
|
||||
|
||||
def __download_vae(self, name: str, model_url: str, token: str):
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
||||
def __download_vae(self, name: str, model_url: str, token: str) -> None:
|
||||
cache_path = Path(BASE_CACHE_PATH, name)
|
||||
vae = diffusers.AutoencoderKL.from_single_file(
|
||||
pretrained_model_link_or_path=model_url,
|
||||
use_auth_token=token,
|
||||
@ -114,8 +119,8 @@ class CommonSetup:
|
||||
)
|
||||
vae.save_pretrained(cache_path, safe_serialization=True)
|
||||
|
||||
def __download_controlnet(self, name: str, repo_id: str, token: str):
|
||||
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
|
||||
def __download_controlnet(self, name: str, repo_id: str, token: str) -> None:
|
||||
cache_path = Path(BASE_CACHE_PATH, name)
|
||||
controlnet = diffusers.ControlNetModel.from_pretrained(
|
||||
repo_id,
|
||||
use_auth_token=token,
|
||||
@ -123,7 +128,7 @@ class CommonSetup:
|
||||
)
|
||||
controlnet.save_pretrained(cache_path, safe_serialization=True)
|
||||
|
||||
def __download_other_file(self, url, file_name, file_path):
|
||||
def __download_other_file(self, url: str, file_name: str, file_path: str) -> None:
|
||||
"""
|
||||
Download file from the given URL for LoRA or TextualInversion.
|
||||
"""
|
||||
@ -131,20 +136,20 @@ class CommonSetup:
|
||||
|
||||
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
downloaded = urlopen(req).read()
|
||||
dir_names = os.path.join(file_path, file_name)
|
||||
dir_names = Path(file_path, file_name)
|
||||
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
||||
with open(dir_names, mode="wb") as f:
|
||||
f.write(downloaded)
|
||||
|
||||
|
||||
def build_image():
|
||||
def build_image() -> None:
|
||||
"""
|
||||
Build the Docker image.
|
||||
"""
|
||||
import yaml
|
||||
|
||||
token: str = os.environ["HUGGING_FACE_TOKEN"]
|
||||
with open("/config.yml", "r") as file:
|
||||
with open("/config.yml") as file:
|
||||
config: dict = yaml.safe_load(file)
|
||||
|
||||
stable_diffusion_setup: StableDiffusionCLISetupInterface
|
||||
@ -154,9 +159,8 @@ def build_image():
|
||||
case "sdxl":
|
||||
stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
|
||||
)
|
||||
msg = f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
|
||||
raise ValueError(msg)
|
||||
|
||||
stable_diffusion_setup.download_model()
|
||||
common_setup = CommonSetup(config, token)
|
||||
|
||||
123
cmd/domain.py
Normal file
123
cmd/domain.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""Utility functions for the script."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import time
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Seed:
|
||||
def __init__(self, seed: int) -> None:
|
||||
if seed != -1:
|
||||
self.__value = seed
|
||||
return
|
||||
|
||||
self.__value = self.__generate_seed()
|
||||
|
||||
def __generate_seed(self) -> int:
|
||||
max_limit_value = 4294967295
|
||||
return secrets.randbelow(max_limit_value)
|
||||
|
||||
@property
|
||||
def value(self) -> int:
|
||||
return self.__value
|
||||
|
||||
|
||||
class Prompts:
|
||||
def __init__(
|
||||
self,
|
||||
prompt: str,
|
||||
n_prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
samples: int,
|
||||
steps: int,
|
||||
) -> None:
|
||||
if prompt == "":
|
||||
msg = "prompt should not be empty."
|
||||
raise ValueError(msg)
|
||||
|
||||
if n_prompt == "":
|
||||
msg = "n_prompt should not be empty."
|
||||
raise ValueError(msg)
|
||||
|
||||
if height <= 0:
|
||||
msg = "height should be positive."
|
||||
raise ValueError(msg)
|
||||
|
||||
if width <= 0:
|
||||
msg = "width should be positive."
|
||||
raise ValueError(msg)
|
||||
|
||||
if samples <= 0:
|
||||
msg = "samples should be positive."
|
||||
raise ValueError(msg)
|
||||
|
||||
if steps <= 0:
|
||||
msg = "steps should be positive."
|
||||
raise ValueError(msg)
|
||||
|
||||
self.__dict: dict[str, int | str] = {
|
||||
"prompt": prompt,
|
||||
"n_prompt": n_prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"samples": samples,
|
||||
"steps": steps,
|
||||
}
|
||||
|
||||
@property
|
||||
def dict(self) -> dict[str, int | str]:
|
||||
return self.__dict
|
||||
|
||||
|
||||
class OutputDirectory:
|
||||
def __init__(self) -> None:
|
||||
self.__output_directory_name = "outputs"
|
||||
self.__date_today = date.today().strftime("%Y-%m-%d")
|
||||
self.__make_path()
|
||||
|
||||
def __make_path(self) -> None:
|
||||
self.__path = Path(f"{self.__output_directory_name}/{self.__date_today}")
|
||||
|
||||
def make_directory(self) -> Path:
|
||||
"""Make a directory for saving outputs."""
|
||||
if not self.__path.exists():
|
||||
self.__path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return self.__path
|
||||
|
||||
|
||||
class StableDiffusionOutputManger:
|
||||
def __init__(self, prompts: Prompts, output_directory: Path) -> None:
|
||||
self.__prompts = prompts
|
||||
self.__output_directory = output_directory
|
||||
|
||||
def save_prompts(self) -> str:
|
||||
"""Save prompts to a file."""
|
||||
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
|
||||
output_path = f"{self.__output_directory}/prompts_{prompts_filename}.txt"
|
||||
with Path(output_path).open("wb") as file:
|
||||
for name, value in self.__prompts.dict.items():
|
||||
file.write(f"{name} = {value!r}\n".encode())
|
||||
|
||||
return output_path
|
||||
|
||||
def save_image(
|
||||
self,
|
||||
image: bytes,
|
||||
seed: int,
|
||||
i: int,
|
||||
j: int,
|
||||
output_format: str = "png",
|
||||
) -> str:
|
||||
"""Save image to a file."""
|
||||
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
|
||||
filename = f"{formatted_time}_{seed}_{i}_{j}.{output_format}"
|
||||
output_path = f"{self.__output_directory}/{filename}"
|
||||
with Path(output_path).open("wb") as file:
|
||||
file.write(image)
|
||||
|
||||
return output_path
|
||||
@ -1,11 +1,13 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import domain
|
||||
import modal
|
||||
import util
|
||||
|
||||
app = modal.App("run-stable-diffusion-cli")
|
||||
run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli", "SD15.run_txt2img_inference"
|
||||
"stable-diffusion-cli",
|
||||
"SD15.run_txt2img_inference",
|
||||
)
|
||||
|
||||
|
||||
@ -16,49 +18,55 @@ def main(
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
samples: int = 5,
|
||||
batch_size: int = 1,
|
||||
steps: int = 20,
|
||||
seed: int = -1,
|
||||
use_upscaler: str = "",
|
||||
fix_by_controlnet_tile: str = "False",
|
||||
output_format: str = "png",
|
||||
):
|
||||
) -> None:
|
||||
"""main() is the entrypoint for the Runway CLI.
|
||||
This pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local.
|
||||
"""
|
||||
This function is the entrypoint for the Runway CLI.
|
||||
The function pass the given prompt to StableDiffusion on Modal,
|
||||
gets back a list of images and outputs images to local.
|
||||
"""
|
||||
directory = util.make_directory()
|
||||
seed_generated = seed
|
||||
for i in range(samples):
|
||||
if seed == -1:
|
||||
seed_generated = util.generate_seed()
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(levelname)s] %(asctime)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("run-stable-diffusion-cli")
|
||||
|
||||
output_directory = domain.OutputDirectory()
|
||||
directory_path = output_directory.make_directory()
|
||||
logger.info("Made a directory: %s", directory_path)
|
||||
|
||||
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
|
||||
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
|
||||
for sample_index in range(samples):
|
||||
new_seed = domain.Seed(seed)
|
||||
start_time = time.time()
|
||||
images = run_inference.remote(
|
||||
prompt=prompt,
|
||||
n_prompt=n_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
batch_size=batch_size,
|
||||
batch_size=1,
|
||||
steps=steps,
|
||||
seed=seed_generated,
|
||||
seed=new_seed.value,
|
||||
use_upscaler=use_upscaler == "True",
|
||||
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
|
||||
output_format=output_format,
|
||||
)
|
||||
util.save_images(directory, images, seed_generated, i, output_format)
|
||||
total_time = time.time() - start_time
|
||||
print(
|
||||
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
|
||||
for generated_image_index, image_bytes in enumerate(images):
|
||||
saved_path = sd_output_manager.save_image(
|
||||
image_bytes,
|
||||
new_seed.value,
|
||||
sample_index,
|
||||
generated_image_index,
|
||||
output_format,
|
||||
)
|
||||
logger.info("Saved image to the: %s", saved_path)
|
||||
|
||||
prompts: dict[str, int | str] = {
|
||||
"prompt": prompt,
|
||||
"n_prompt": n_prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"samples": samples,
|
||||
"batch_size": batch_size,
|
||||
"steps": steps,
|
||||
}
|
||||
util.save_prompts(prompts)
|
||||
total_time = time.time() - start_time
|
||||
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
|
||||
|
||||
saved_prompts_path = sd_output_manager.save_prompts()
|
||||
logger.info("Saved prompts: %s", saved_prompts_path)
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import domain
|
||||
import modal
|
||||
import util
|
||||
|
||||
app = modal.App("run-stable-diffusion-cli")
|
||||
run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference")
|
||||
run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli",
|
||||
"SDXLTxt2Img.run_inference",
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
@ -18,17 +22,27 @@ def main(
|
||||
seed: int = -1,
|
||||
use_upscaler: str = "False",
|
||||
output_format: str = "png",
|
||||
):
|
||||
"""
|
||||
This function is the entrypoint for the Runway CLI.
|
||||
) -> None:
|
||||
"""This function is the entrypoint for the Runway CLI.
|
||||
The function pass the given prompt to StableDiffusion on Modal,
|
||||
gets back a list of images and outputs images to local.
|
||||
"""
|
||||
directory = util.make_directory()
|
||||
seed_generated = seed
|
||||
for i in range(samples):
|
||||
if seed == -1:
|
||||
seed_generated = util.generate_seed()
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(levelname)s] %(asctime)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("run-stable-diffusion-cli")
|
||||
|
||||
output_directory = domain.OutputDirectory()
|
||||
directory_path = output_directory.make_directory()
|
||||
logger.info("Made a directory: %s", directory_path)
|
||||
|
||||
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
|
||||
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
|
||||
|
||||
for sample_index in range(samples):
|
||||
new_seed = domain.Seed(seed)
|
||||
start_time = time.time()
|
||||
images = run_inference.remote(
|
||||
prompt=prompt,
|
||||
@ -36,18 +50,23 @@ def main(
|
||||
height=height,
|
||||
width=width,
|
||||
steps=steps,
|
||||
seed=seed_generated,
|
||||
seed=new_seed.value,
|
||||
use_upscaler=use_upscaler == "True",
|
||||
output_format=output_format,
|
||||
)
|
||||
util.save_images(directory, images, seed_generated, i, output_format)
|
||||
total_time = time.time() - start_time
|
||||
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
|
||||
|
||||
prompts: dict[str, int | str] = {
|
||||
"prompt": prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"samples": samples,
|
||||
}
|
||||
util.save_prompts(prompts)
|
||||
for generated_image_index, image_bytes in enumerate(images):
|
||||
saved_path = sd_output_manager.save_image(
|
||||
image_bytes,
|
||||
new_seed.value,
|
||||
sample_index,
|
||||
generated_image_index,
|
||||
output_format,
|
||||
)
|
||||
logger.info("Saved image to the: %s", saved_path)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
|
||||
|
||||
saved_prompts_path = sd_output_manager.save_prompts()
|
||||
logger.info("Saved prompts: %s", saved_prompts_path)
|
||||
|
||||
55
cmd/util.py
55
cmd/util.py
@ -1,55 +0,0 @@
|
||||
""" Utility functions for the script. """
|
||||
import random
|
||||
import time
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
OUTPUT_DIRECTORY = "outputs"
|
||||
DATE_TODAY = date.today().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def generate_seed() -> int:
|
||||
"""
|
||||
Generate a random seed.
|
||||
"""
|
||||
seed = random.randint(0, 4294967295)
|
||||
print(f"Generate a random seed: {seed}")
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
def make_directory() -> Path:
|
||||
"""
|
||||
Make a directory for saving outputs.
|
||||
"""
|
||||
directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}")
|
||||
if not directory.exists():
|
||||
directory.mkdir(exist_ok=True, parents=True)
|
||||
print(f"Make a directory: {directory}")
|
||||
|
||||
return directory
|
||||
|
||||
|
||||
def save_prompts(inputs: dict):
|
||||
"""
|
||||
Save prompts to a file.
|
||||
"""
|
||||
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
|
||||
with open(
|
||||
file=f"{OUTPUT_DIRECTORY}/{DATE_TODAY}/prompts_{prompts_filename}.txt", mode="w", encoding="utf-8"
|
||||
) as file:
|
||||
for name, value in inputs.items():
|
||||
file.write(f"{name} = {repr(value)}\n")
|
||||
print(f"Save prompts: {prompts_filename}.txt")
|
||||
|
||||
|
||||
def save_images(directory: Path, images: list[bytes], seed: int, i: int, output_format: str = "png"):
|
||||
"""
|
||||
Save images to a file.
|
||||
"""
|
||||
for j, image_bytes in enumerate(images):
|
||||
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
|
||||
output_path = directory / f"{formatted_time}_{seed}_{i}_{j}.{output_format}"
|
||||
print(f"Saving it to {output_path}")
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(image_bytes)
|
||||
Loading…
x
Reference in New Issue
Block a user