From 337bf01048e7374b1c04f82adfaa1b4e296d84af Mon Sep 17 00:00:00 2001 From: hodanov <1031hoda@gmail.com> Date: Tue, 13 Jun 2023 00:10:48 +0900 Subject: [PATCH] Modify to use seed. --- Makefile | 3 ++- sd_cli.py | 12 +++++++++--- util.py | 17 ++++++++++++++--- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 16c1533..a1c5ec5 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,9 @@ run: modal run sd_cli.py \ --prompt "A woman with bob hair" \ --n-prompt "" \ + --upscaler "RealESRGAN_x4plus_anime_6B" \ --height 768 \ --width 512 \ --samples 5 \ --steps 50 \ - --upscaler "RealESRGAN_x4plus_anime_6B" + --seed 500 diff --git a/sd_cli.py b/sd_cli.py index edb70e8..005af60 100644 --- a/sd_cli.py +++ b/sd_cli.py @@ -101,6 +101,7 @@ class StableDiffusion: """ import torch + generator = torch.Generator("cuda").manual_seed(inputs["seed"]) with torch.inference_mode(): with torch.autocast("cuda"): base_images = self.pipe( @@ -111,6 +112,7 @@ class StableDiffusion: num_inference_steps=inputs["steps"], guidance_scale=7.5, max_embeddings_multiples=inputs["max_embeddings_multiples"], + generator=generator, ).images if inputs["upscaler"] != "": @@ -197,12 +199,13 @@ class StableDiffusion: def entrypoint( prompt: str, n_prompt: str, + upscaler: str, height: int = 512, width: int = 512, samples: int = 5, batch_size: int = 1, steps: int = 20, - upscaler: str = "", + seed: int = -1, ): """ This function is the entrypoint for the Runway CLI. @@ -210,6 +213,9 @@ def entrypoint( gets back a list of images and outputs images to local. """ + if seed == -1: + seed = util.generate_seed() + inputs: dict[str, int | str] = { "prompt": prompt, "n_prompt": n_prompt, @@ -219,7 +225,7 @@ def entrypoint( "batch_size": batch_size, "steps": steps, "upscaler": upscaler, - # seed=-1 + "seed": seed, } inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt) @@ -229,7 +235,7 @@ def entrypoint( for i in range(samples): start_time = time.time() images = sd.run_inference.call(inputs) - util.save_images(directory, images, i) + util.save_images(directory, images, inputs, i) total_time = time.time() - start_time print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") diff --git a/util.py b/util.py index affb2a0..40aa05f 100644 --- a/util.py +++ b/util.py @@ -1,4 +1,5 @@ """ Utility functions for the script. """ +import random import time from datetime import date from pathlib import Path @@ -9,6 +10,16 @@ OUTPUT_DIRECTORY = "outputs" DATE_TODAY = date.today().strftime("%Y-%m-%d") +def generate_seed() -> int: + """ + Generate a random seed. + """ + seed = random.randint(0, 4294967295) + print(f"Generate a random seed: {seed}") + + return seed + + def make_directory() -> Path: """ Make a directory for saving outputs. @@ -16,7 +27,7 @@ def make_directory() -> Path: directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}") if not directory.exists(): directory.mkdir(exist_ok=True, parents=True) - print(f"Make directory: {directory}") + print(f"Make a directory: {directory}") return directory @@ -54,13 +65,13 @@ def count_token(p: str, n: str) -> int: return max_embeddings_multiples -def save_images(directory: Path, images: list[bytes], i: int): +def save_images(directory: Path, images: list[bytes], inputs: dict[str, int | str], i: int): """ Save images to a file. """ for j, image_bytes in enumerate(images): formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) - output_path = directory / f"{formatted_time}_{i}_{j}.png" + output_path = directory / f"{formatted_time}_{inputs['seed']}_{i}_{j}.png" print(f"Saving it to {output_path}") with open(output_path, "wb") as file: file.write(image_bytes)