Modify to use seed.

This commit is contained in:
hodanov 2023-06-13 00:10:48 +09:00
parent a9063a999d
commit 337bf01048
3 changed files with 25 additions and 7 deletions

View File

@ -2,8 +2,9 @@ run:
modal run sd_cli.py \ modal run sd_cli.py \
--prompt "A woman with bob hair" \ --prompt "A woman with bob hair" \
--n-prompt "" \ --n-prompt "" \
--upscaler "RealESRGAN_x4plus_anime_6B" \
--height 768 \ --height 768 \
--width 512 \ --width 512 \
--samples 5 \ --samples 5 \
--steps 50 \ --steps 50 \
--upscaler "RealESRGAN_x4plus_anime_6B" --seed 500

View File

@ -101,6 +101,7 @@ class StableDiffusion:
""" """
import torch import torch
generator = torch.Generator("cuda").manual_seed(inputs["seed"])
with torch.inference_mode(): with torch.inference_mode():
with torch.autocast("cuda"): with torch.autocast("cuda"):
base_images = self.pipe( base_images = self.pipe(
@ -111,6 +112,7 @@ class StableDiffusion:
num_inference_steps=inputs["steps"], num_inference_steps=inputs["steps"],
guidance_scale=7.5, guidance_scale=7.5,
max_embeddings_multiples=inputs["max_embeddings_multiples"], max_embeddings_multiples=inputs["max_embeddings_multiples"],
generator=generator,
).images ).images
if inputs["upscaler"] != "": if inputs["upscaler"] != "":
@ -197,12 +199,13 @@ class StableDiffusion:
def entrypoint( def entrypoint(
prompt: str, prompt: str,
n_prompt: str, n_prompt: str,
upscaler: str,
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
samples: int = 5, samples: int = 5,
batch_size: int = 1, batch_size: int = 1,
steps: int = 20, steps: int = 20,
upscaler: str = "", seed: int = -1,
): ):
""" """
This function is the entrypoint for the Runway CLI. This function is the entrypoint for the Runway CLI.
@ -210,6 +213,9 @@ def entrypoint(
gets back a list of images and outputs images to local. gets back a list of images and outputs images to local.
""" """
if seed == -1:
seed = util.generate_seed()
inputs: dict[str, int | str] = { inputs: dict[str, int | str] = {
"prompt": prompt, "prompt": prompt,
"n_prompt": n_prompt, "n_prompt": n_prompt,
@ -219,7 +225,7 @@ def entrypoint(
"batch_size": batch_size, "batch_size": batch_size,
"steps": steps, "steps": steps,
"upscaler": upscaler, "upscaler": upscaler,
# seed=-1 "seed": seed,
} }
inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt) inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt)
@ -229,7 +235,7 @@ def entrypoint(
for i in range(samples): for i in range(samples):
start_time = time.time() start_time = time.time()
images = sd.run_inference.call(inputs) images = sd.run_inference.call(inputs)
util.save_images(directory, images, i) util.save_images(directory, images, inputs, i)
total_time = time.time() - start_time total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")

17
util.py
View File

@ -1,4 +1,5 @@
""" Utility functions for the script. """ """ Utility functions for the script. """
import random
import time import time
from datetime import date from datetime import date
from pathlib import Path from pathlib import Path
@ -9,6 +10,16 @@ OUTPUT_DIRECTORY = "outputs"
DATE_TODAY = date.today().strftime("%Y-%m-%d") DATE_TODAY = date.today().strftime("%Y-%m-%d")
def generate_seed() -> int:
"""
Generate a random seed.
"""
seed = random.randint(0, 4294967295)
print(f"Generate a random seed: {seed}")
return seed
def make_directory() -> Path: def make_directory() -> Path:
""" """
Make a directory for saving outputs. Make a directory for saving outputs.
@ -16,7 +27,7 @@ def make_directory() -> Path:
directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}") directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}")
if not directory.exists(): if not directory.exists():
directory.mkdir(exist_ok=True, parents=True) directory.mkdir(exist_ok=True, parents=True)
print(f"Make directory: {directory}") print(f"Make a directory: {directory}")
return directory return directory
@ -54,13 +65,13 @@ def count_token(p: str, n: str) -> int:
return max_embeddings_multiples return max_embeddings_multiples
def save_images(directory: Path, images: list[bytes], i: int): def save_images(directory: Path, images: list[bytes], inputs: dict[str, int | str], i: int):
""" """
Save images to a file. Save images to a file.
""" """
for j, image_bytes in enumerate(images): for j, image_bytes in enumerate(images):
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = directory / f"{formatted_time}_{i}_{j}.png" output_path = directory / f"{formatted_time}_{inputs['seed']}_{i}_{j}.png"
print(f"Saving it to {output_path}") print(f"Saving it to {output_path}")
with open(output_path, "wb") as file: with open(output_path, "wb") as file:
file.write(image_bytes) file.write(image_bytes)