diff --git a/cmd/sd15_txt2img.py b/cmd/sd15_txt2img.py index 621c1b9..dcd38c7 100644 --- a/cmd/sd15_txt2img.py +++ b/cmd/sd15_txt2img.py @@ -1,11 +1,13 @@ +import logging import time +import domain import modal -import util app = modal.App("run-stable-diffusion-cli") run_inference = modal.Function.from_name( - "stable-diffusion-cli", "SD15.run_txt2img_inference" + "stable-diffusion-cli", + "SD15.run_txt2img_inference", ) @@ -16,49 +18,55 @@ def main( height: int = 512, width: int = 512, samples: int = 5, - batch_size: int = 1, steps: int = 20, seed: int = -1, use_upscaler: str = "", fix_by_controlnet_tile: str = "False", output_format: str = "png", -): +) -> None: + """main() is the entrypoint for the Runway CLI. + This pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local. """ - This function is the entrypoint for the Runway CLI. - The function pass the given prompt to StableDiffusion on Modal, - gets back a list of images and outputs images to local. - """ - directory = util.make_directory() - seed_generated = seed - for i in range(samples): - if seed == -1: - seed_generated = util.generate_seed() + logging.basicConfig( + level=logging.INFO, + format="[%(levelname)s] %(asctime)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + logger = logging.getLogger("run-stable-diffusion-cli") + + output_directory = domain.OutputDirectory() + directory_path = output_directory.make_directory() + logger.info("Made a directory: %s", directory_path) + + prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps) + sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path) + for sample_index in range(samples): + new_seed = domain.Seed(seed) start_time = time.time() images = run_inference.remote( prompt=prompt, n_prompt=n_prompt, height=height, width=width, - batch_size=batch_size, + batch_size=1, steps=steps, - seed=seed_generated, + seed=new_seed.value, use_upscaler=use_upscaler == "True", fix_by_controlnet_tile=fix_by_controlnet_tile == "True", output_format=output_format, ) - util.save_images(directory, images, seed_generated, i, output_format) - total_time = time.time() - start_time - print( - f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)." - ) + for generated_image_index, image_bytes in enumerate(images): + saved_path = sd_output_manager.save_image( + image_bytes, + new_seed.value, + sample_index, + generated_image_index, + output_format, + ) + logger.info("Saved image to the: %s", saved_path) - prompts: dict[str, int | str] = { - "prompt": prompt, - "n_prompt": n_prompt, - "height": height, - "width": width, - "samples": samples, - "batch_size": batch_size, - "steps": steps, - } - util.save_prompts(prompts) + total_time = time.time() - start_time + logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images)) + + saved_prompts_path = sd_output_manager.save_prompts() + logger.info("Saved prompts: %s", saved_prompts_path) diff --git a/cmd/sdxl_txt2img.py b/cmd/sdxl_txt2img.py index 1528c75..da086f9 100644 --- a/cmd/sdxl_txt2img.py +++ b/cmd/sdxl_txt2img.py @@ -1,10 +1,14 @@ +import logging import time +import domain import modal -import util app = modal.App("run-stable-diffusion-cli") -run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference") +run_inference = modal.Function.from_name( + "stable-diffusion-cli", + "SDXLTxt2Img.run_inference", +) @app.local_entrypoint() @@ -19,16 +23,26 @@ def main( use_upscaler: str = "False", output_format: str = "png", ): - """ - This function is the entrypoint for the Runway CLI. + """This function is the entrypoint for the Runway CLI. The function pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local. """ - directory = util.make_directory() - seed_generated = seed - for i in range(samples): - if seed == -1: - seed_generated = util.generate_seed() + logging.basicConfig( + level=logging.INFO, + format="[%(levelname)s] %(asctime)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + logger = logging.getLogger("run-stable-diffusion-cli") + + output_directory = domain.OutputDirectory() + directory_path = output_directory.make_directory() + logger.info("Made a directory: %s", directory_path) + + prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps) + sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path) + + for sample_index in range(samples): + new_seed = domain.Seed(seed) start_time = time.time() images = run_inference.remote( prompt=prompt, @@ -36,18 +50,23 @@ def main( height=height, width=width, steps=steps, - seed=seed_generated, + seed=new_seed.value, use_upscaler=use_upscaler == "True", output_format=output_format, ) - util.save_images(directory, images, seed_generated, i, output_format) - total_time = time.time() - start_time - print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") - prompts: dict[str, int | str] = { - "prompt": prompt, - "height": height, - "width": width, - "samples": samples, - } - util.save_prompts(prompts) + for generated_image_index, image_bytes in enumerate(images): + saved_path = sd_output_manager.save_image( + image_bytes, + new_seed.value, + sample_index, + generated_image_index, + output_format, + ) + logger.info("Saved image to the: %s", saved_path) + + total_time = time.time() - start_time + logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images)) + + saved_prompts_path = sd_output_manager.save_prompts() + logger.info("Saved prompts: %s", saved_prompts_path)