Refactor some codes. Add sd_x2_latent_upscaler.

This commit is contained in:
hodanov 2023-06-05 09:49:14 +09:00
parent 653403d29a
commit 3830bded6a
3 changed files with 115 additions and 60 deletions

View File

@ -1,3 +1,3 @@
HUGGINGFACE_TOKEN="" HUGGING_FACE_TOKEN=""
MODEL_REPO_ID="stabilityai/stable-diffusion-2-1" MODEL_REPO_ID="stabilityai/stable-diffusion-2-1"
MODEL_NAME="stable-diffusion-2-1" MODEL_NAME="stable-diffusion-2-1"

119
sd_cli.py
View File

@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
import io import io
import os import os
import time import time
from datetime import date
from pathlib import Path
from modal import Image, Secret, Stub, method, Mount
stub = Stub("stable-diffusion-cli") from modal import Image, Mount, Secret, Stub, method
import util
BASE_CACHE_PATH = "/vol/cache" BASE_CACHE_PATH = "/vol/cache"
@ -18,7 +18,7 @@ def download_models():
""" """
import diffusers import diffusers
hugging_face_token = os.environ["HUGGINGFACE_TOKEN"] hugging_face_token = os.environ["HUGGING_FACE_TOKEN"]
model_repo_id = os.environ["MODEL_REPO_ID"] model_repo_id = os.environ["MODEL_REPO_ID"]
cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"]) cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"])
@ -45,6 +45,7 @@ stub_image = Image.from_dockerfile(
download_models, download_models,
secrets=[Secret.from_dotenv(__file__)], secrets=[Secret.from_dotenv(__file__)],
) )
stub = Stub("stable-diffusion-cli")
stub.image = stub_image stub.image = stub_image
@ -79,17 +80,19 @@ class StableDiffusion:
).to("cuda") ).to("cuda")
self.pipe.enable_xformers_memory_efficient_attention() self.pipe.enable_xformers_memory_efficient_attention()
self.upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16
).to("cuda")
self.upscaler.enable_xformers_memory_efficient_attention()
# model_id = "stabilityai/stable-diffusion-x4-upscaler"
# self.upscaler = diffusers.StableDiffusionUpscalePipeline.from_pretrained(
# , revision="fp16", torch_dtype=torch.float16
# ).to("cuda")
# self.upscaler.enable_xformers_memory_efficient_attention()
@method() @method()
def run_inference( def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]:
self,
prompt: str,
n_prompt: str,
steps: int = 30,
batch_size: int = 1,
height: int = 512,
width: int = 512,
max_embeddings_multiples: int = 1,
) -> list[bytes]:
""" """
Runs the Stable Diffusion pipeline on the given prompt and outputs images. Runs the Stable Diffusion pipeline on the given prompt and outputs images.
""" """
@ -98,13 +101,13 @@ class StableDiffusion:
with torch.inference_mode(): with torch.inference_mode():
with torch.autocast("cuda"): with torch.autocast("cuda"):
images = self.pipe( images = self.pipe(
[prompt] * batch_size, [inputs["prompt"]] * int(inputs["batch_size"]),
negative_prompt=[n_prompt] * batch_size, negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]),
height=height, height=inputs["height"],
width=width, width=inputs["width"],
num_inference_steps=steps, num_inference_steps=inputs["steps"],
guidance_scale=7.5, guidance_scale=7.5,
max_embeddings_multiples=max_embeddings_multiples, max_embeddings_multiples=inputs["max_embeddings_multiples"],
).images ).images
image_output = [] image_output = []
@ -112,6 +115,19 @@ class StableDiffusion:
with io.BytesIO() as buf: with io.BytesIO() as buf:
image.save(buf, format="PNG") image.save(buf, format="PNG")
image_output.append(buf.getvalue()) image_output.append(buf.getvalue())
if inputs["upscaler"] != "":
upscaled_images = self.upscaler(
prompt=inputs["prompt"],
image=images,
num_inference_steps=inputs["steps"],
guidance_scale=0,
).images
for image in upscaled_images:
with io.BytesIO() as buf:
image.save(buf, format="PNG")
image_output.append(buf.getvalue())
return image_output return image_output
@ -119,60 +135,45 @@ class StableDiffusion:
def entrypoint( def entrypoint(
prompt: str, prompt: str,
n_prompt: str, n_prompt: str,
samples: int = 5,
steps: int = 30,
batch_size: int = 1,
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
samples: int = 5,
batch_size: int = 1,
steps: int = 20,
upscaler: str = "",
): ):
""" """
This function is the entrypoint for the Runway CLI. This function is the entrypoint for the Runway CLI.
The function pass the given prompt to StableDiffusion on Modal, The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local. gets back a list of images and outputs images to local.
The function is called with the following arguments:
- prompt: the prompt to run inference on
- n_prompt: the negative prompt to run inference on
- samples: the number of samples to generate
- steps: the number of steps to run inference for
- batch_size: the batch size to use
- height: the height of the output image
- width: the width of the output image
""" """
print(f"steps => {steps}, sapmles => {samples}, batch_size => {batch_size}")
max_embeddings_multiples = 1 inputs: dict[str, int | str] = {
token_count = len(prompt.split()) "prompt": prompt,
if token_count > 77: "n_prompt": n_prompt,
max_embeddings_multiples = token_count // 77 + 1 "height": height,
"width": width,
"samples": samples,
"batch_size": batch_size,
"steps": steps,
"upscaler": upscaler, # sd_x2_latent_upscaler, sd_x4_upscaler
# seed=-1
}
print( inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt)
f"token_count => {token_count}, max_embeddings_multiples => {max_embeddings_multiples}" directory = util.make_directory()
) util.save_prompts(inputs)
directory = Path(f"./outputs/{date.today().strftime('%Y-%m-%d')}") sd = StableDiffusion()
if not directory.exists():
directory.mkdir(exist_ok=True, parents=True)
stable_diffusion = StableDiffusion()
for i in range(samples): for i in range(samples):
start_time = time.time() start_time = time.time()
images = stable_diffusion.run_inference.call( images = sd.run_inference.call(inputs)
prompt,
n_prompt,
steps,
batch_size,
height,
width,
max_embeddings_multiples,
)
total_time = time.time() - start_time
print(
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
)
for j, image_bytes in enumerate(images): for j, image_bytes in enumerate(images):
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = directory / f"{formatted_time}_{i}_{j}.png" output_path = directory / f"{formatted_time}_{i}_{j}.png"
print(f"Saving it to {output_path}") print(f"Saving it to {output_path}")
with open(output_path, "wb") as file: with open(output_path, "wb") as file:
file.write(image_bytes) file.write(image_bytes)
total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")

54
util.py Normal file
View File

@ -0,0 +1,54 @@
""" Utility functions for the script. """
import time
from datetime import date
from pathlib import Path
from PIL import Image
OUTPUT_DIRECTORY = "outputs"
DATE_TODAY = date.today().strftime("%Y-%m-%d")
def make_directory() -> Path:
"""
Make a directory for saving outputs.
"""
directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}")
if not directory.exists():
directory.mkdir(exist_ok=True, parents=True)
print(f"Make directory: {directory}")
return directory
def save_prompts(inputs: dict):
"""
Save prompts to a file.
"""
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
with open(
file=f"{OUTPUT_DIRECTORY}/{DATE_TODAY}/prompts_{prompts_filename}.txt", mode="w", encoding="utf-8"
) as file:
for name, value in inputs.items():
file.write(f"{name} = {repr(value)}\n")
print(f"Save prompts: {prompts_filename}.txt")
def count_token(p: str, n: str) -> int:
"""
Count the number of tokens in the prompt and negative prompt.
"""
token_count_p = len(p.split())
token_count_n = len(n.split())
if token_count_p >= token_count_n:
token_count = token_count_p
else:
token_count = token_count_n
max_embeddings_multiples = 1
if token_count > 77:
max_embeddings_multiples = token_count // 77 + 1
print(f"token_count: {token_count}, max_embeddings_multiples: {max_embeddings_multiples}")
return max_embeddings_multiples