Refactor cmd/sd15_txt2img.py and cmd/sdxl_txt2img.py.

This commit is contained in:
hodanov 2024-11-03 19:33:00 +09:00
parent fd2e8912a4
commit f074a1b3ca
2 changed files with 77 additions and 50 deletions

View File

@ -1,11 +1,13 @@
import logging
import time import time
import domain
import modal import modal
import util
app = modal.App("run-stable-diffusion-cli") app = modal.App("run-stable-diffusion-cli")
run_inference = modal.Function.from_name( run_inference = modal.Function.from_name(
"stable-diffusion-cli", "SD15.run_txt2img_inference" "stable-diffusion-cli",
"SD15.run_txt2img_inference",
) )
@ -16,49 +18,55 @@ def main(
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
samples: int = 5, samples: int = 5,
batch_size: int = 1,
steps: int = 20, steps: int = 20,
seed: int = -1, seed: int = -1,
use_upscaler: str = "", use_upscaler: str = "",
fix_by_controlnet_tile: str = "False", fix_by_controlnet_tile: str = "False",
output_format: str = "png", output_format: str = "png",
): ) -> None:
"""main() is the entrypoint for the Runway CLI.
This pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local.
""" """
This function is the entrypoint for the Runway CLI. logging.basicConfig(
The function pass the given prompt to StableDiffusion on Modal, level=logging.INFO,
gets back a list of images and outputs images to local. format="[%(levelname)s] %(asctime)s - %(message)s",
""" datefmt="%Y-%m-%d %H:%M:%S",
directory = util.make_directory() )
seed_generated = seed logger = logging.getLogger("run-stable-diffusion-cli")
for i in range(samples):
if seed == -1: output_directory = domain.OutputDirectory()
seed_generated = util.generate_seed() directory_path = output_directory.make_directory()
logger.info("Made a directory: %s", directory_path)
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
for sample_index in range(samples):
new_seed = domain.Seed(seed)
start_time = time.time() start_time = time.time()
images = run_inference.remote( images = run_inference.remote(
prompt=prompt, prompt=prompt,
n_prompt=n_prompt, n_prompt=n_prompt,
height=height, height=height,
width=width, width=width,
batch_size=batch_size, batch_size=1,
steps=steps, steps=steps,
seed=seed_generated, seed=new_seed.value,
use_upscaler=use_upscaler == "True", use_upscaler=use_upscaler == "True",
fix_by_controlnet_tile=fix_by_controlnet_tile == "True", fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
output_format=output_format, output_format=output_format,
) )
util.save_images(directory, images, seed_generated, i, output_format) for generated_image_index, image_bytes in enumerate(images):
total_time = time.time() - start_time saved_path = sd_output_manager.save_image(
print( image_bytes,
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)." new_seed.value,
) sample_index,
generated_image_index,
output_format,
)
logger.info("Saved image to the: %s", saved_path)
prompts: dict[str, int | str] = { total_time = time.time() - start_time
"prompt": prompt, logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
"n_prompt": n_prompt,
"height": height, saved_prompts_path = sd_output_manager.save_prompts()
"width": width, logger.info("Saved prompts: %s", saved_prompts_path)
"samples": samples,
"batch_size": batch_size,
"steps": steps,
}
util.save_prompts(prompts)

View File

@ -1,10 +1,14 @@
import logging
import time import time
import domain
import modal import modal
import util
app = modal.App("run-stable-diffusion-cli") app = modal.App("run-stable-diffusion-cli")
run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference") run_inference = modal.Function.from_name(
"stable-diffusion-cli",
"SDXLTxt2Img.run_inference",
)
@app.local_entrypoint() @app.local_entrypoint()
@ -19,16 +23,26 @@ def main(
use_upscaler: str = "False", use_upscaler: str = "False",
output_format: str = "png", output_format: str = "png",
): ):
""" """This function is the entrypoint for the Runway CLI.
This function is the entrypoint for the Runway CLI.
The function pass the given prompt to StableDiffusion on Modal, The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local. gets back a list of images and outputs images to local.
""" """
directory = util.make_directory() logging.basicConfig(
seed_generated = seed level=logging.INFO,
for i in range(samples): format="[%(levelname)s] %(asctime)s - %(message)s",
if seed == -1: datefmt="%Y-%m-%d %H:%M:%S",
seed_generated = util.generate_seed() )
logger = logging.getLogger("run-stable-diffusion-cli")
output_directory = domain.OutputDirectory()
directory_path = output_directory.make_directory()
logger.info("Made a directory: %s", directory_path)
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
for sample_index in range(samples):
new_seed = domain.Seed(seed)
start_time = time.time() start_time = time.time()
images = run_inference.remote( images = run_inference.remote(
prompt=prompt, prompt=prompt,
@ -36,18 +50,23 @@ def main(
height=height, height=height,
width=width, width=width,
steps=steps, steps=steps,
seed=seed_generated, seed=new_seed.value,
use_upscaler=use_upscaler == "True", use_upscaler=use_upscaler == "True",
output_format=output_format, output_format=output_format,
) )
util.save_images(directory, images, seed_generated, i, output_format)
total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
prompts: dict[str, int | str] = { for generated_image_index, image_bytes in enumerate(images):
"prompt": prompt, saved_path = sd_output_manager.save_image(
"height": height, image_bytes,
"width": width, new_seed.value,
"samples": samples, sample_index,
} generated_image_index,
util.save_prompts(prompts) output_format,
)
logger.info("Saved image to the: %s", saved_path)
total_time = time.time() - start_time
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
saved_prompts_path = sd_output_manager.save_prompts()
logger.info("Saved prompts: %s", saved_prompts_path)