Merge pull request #3 from hodanov/feature/add_sd_x2_latent_upscaler

Feature/add sd x2 latent upscaler
This commit is contained in:
hodanov 2023-06-05 10:18:38 +09:00 committed by GitHub
commit d16f8a5f66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 136 additions and 71 deletions

View File

@ -1,3 +1,3 @@
HUGGINGFACE_TOKEN="" HUGGING_FACE_TOKEN=""
MODEL_REPO_ID="stabilityai/stable-diffusion-2-1" MODEL_REPO_ID="stabilityai/stable-diffusion-2-1"
MODEL_NAME="stable-diffusion-2-1" MODEL_NAME="stable-diffusion-2-1"

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
.DS_Store .DS_Store
.mypy_cache/
__pycache__/ __pycache__/
outputs/ outputs/
.env .env

View File

@ -1,7 +1,9 @@
run: run:
modal run sd_cli.py \ modal run sd_cli.py \
--prompt "a woman with bob hair" \ --prompt "a woman with bob hair" \
--n-prompt "" \ --n-prompt "" \
--height 768 \ --height 768 \
--width 512 \ --width 512 \
--samples 5 --samples 5 \
--steps 20 \
--upscaler "sd_x2_latent_upscaler"

View File

@ -6,7 +6,7 @@ This is the script to execute Stable Diffusion on [Modal](https://modal.com/).
The app requires the following to run: The app requires the following to run:
- python: v3.10 > - python: > 3.10
- modal-client - modal-client
- A token for Modal. - A token for Modal.

124
sd_cli.py
View File

@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
import io import io
import os import os
import time import time
from datetime import date
from pathlib import Path
from modal import Image, Secret, Stub, method, Mount
stub = Stub("stable-diffusion-cli") from modal import Image, Mount, Secret, Stub, method
import util
BASE_CACHE_PATH = "/vol/cache" BASE_CACHE_PATH = "/vol/cache"
@ -18,7 +18,7 @@ def download_models():
""" """
import diffusers import diffusers
hugging_face_token = os.environ["HUGGINGFACE_TOKEN"] hugging_face_token = os.environ["HUGGING_FACE_TOKEN"]
model_repo_id = os.environ["MODEL_REPO_ID"] model_repo_id = os.environ["MODEL_REPO_ID"]
cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"]) cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"])
@ -45,6 +45,7 @@ stub_image = Image.from_dockerfile(
download_models, download_models,
secrets=[Secret.from_dotenv(__file__)], secrets=[Secret.from_dotenv(__file__)],
) )
stub = Stub("stable-diffusion-cli")
stub.image = stub_image stub.image = stub_image
@ -79,17 +80,19 @@ class StableDiffusion:
).to("cuda") ).to("cuda")
self.pipe.enable_xformers_memory_efficient_attention() self.pipe.enable_xformers_memory_efficient_attention()
self.upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16
).to("cuda")
self.upscaler.enable_xformers_memory_efficient_attention()
# model_id = "stabilityai/stable-diffusion-x4-upscaler"
# self.upscaler = diffusers.StableDiffusionUpscalePipeline.from_pretrained(
# , revision="fp16", torch_dtype=torch.float16
# ).to("cuda")
# self.upscaler.enable_xformers_memory_efficient_attention()
@method() @method()
def run_inference( def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]:
self,
prompt: str,
n_prompt: str,
steps: int = 30,
batch_size: int = 1,
height: int = 512,
width: int = 512,
max_embeddings_multiples: int = 1,
) -> list[bytes]:
""" """
Runs the Stable Diffusion pipeline on the given prompt and outputs images. Runs the Stable Diffusion pipeline on the given prompt and outputs images.
""" """
@ -98,13 +101,13 @@ class StableDiffusion:
with torch.inference_mode(): with torch.inference_mode():
with torch.autocast("cuda"): with torch.autocast("cuda"):
images = self.pipe( images = self.pipe(
[prompt] * batch_size, [inputs["prompt"]] * int(inputs["batch_size"]),
negative_prompt=[n_prompt] * batch_size, negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]),
height=height, height=inputs["height"],
width=width, width=inputs["width"],
num_inference_steps=steps, num_inference_steps=inputs["steps"],
guidance_scale=7.5, guidance_scale=7.5,
max_embeddings_multiples=max_embeddings_multiples, max_embeddings_multiples=inputs["max_embeddings_multiples"],
).images ).images
image_output = [] image_output = []
@ -112,6 +115,19 @@ class StableDiffusion:
with io.BytesIO() as buf: with io.BytesIO() as buf:
image.save(buf, format="PNG") image.save(buf, format="PNG")
image_output.append(buf.getvalue()) image_output.append(buf.getvalue())
if inputs["upscaler"] != "":
upscaled_images = self.upscaler(
prompt=inputs["prompt"],
image=images,
num_inference_steps=inputs["steps"],
guidance_scale=0,
).images
for image in upscaled_images:
with io.BytesIO() as buf:
image.save(buf, format="PNG")
image_output.append(buf.getvalue())
return image_output return image_output
@ -119,60 +135,40 @@ class StableDiffusion:
def entrypoint( def entrypoint(
prompt: str, prompt: str,
n_prompt: str, n_prompt: str,
samples: int = 5,
steps: int = 30,
batch_size: int = 1,
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
samples: int = 5,
batch_size: int = 1,
steps: int = 20,
upscaler: str = "",
): ):
""" """
This function is the entrypoint for the Runway CLI. This function is the entrypoint for the Runway CLI.
The function pass the given prompt to StableDiffusion on Modal, The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local. gets back a list of images and outputs images to local.
The function is called with the following arguments:
- prompt: the prompt to run inference on
- n_prompt: the negative prompt to run inference on
- samples: the number of samples to generate
- steps: the number of steps to run inference for
- batch_size: the batch size to use
- height: the height of the output image
- width: the width of the output image
""" """
print(f"steps => {steps}, sapmles => {samples}, batch_size => {batch_size}")
max_embeddings_multiples = 1 inputs: dict[str, int | str] = {
token_count = len(prompt.split()) "prompt": prompt,
if token_count > 77: "n_prompt": n_prompt,
max_embeddings_multiples = token_count // 77 + 1 "height": height,
"width": width,
"samples": samples,
"batch_size": batch_size,
"steps": steps,
"upscaler": upscaler, # sd_x2_latent_upscaler, sd_x4_upscaler
# seed=-1
}
print( inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt)
f"token_count => {token_count}, max_embeddings_multiples => {max_embeddings_multiples}" directory = util.make_directory()
)
directory = Path(f"./outputs/{date.today().strftime('%Y-%m-%d')}") sd = StableDiffusion()
if not directory.exists():
directory.mkdir(exist_ok=True, parents=True)
stable_diffusion = StableDiffusion()
for i in range(samples): for i in range(samples):
start_time = time.time() start_time = time.time()
images = stable_diffusion.run_inference.call( images = sd.run_inference.call(inputs)
prompt, util.save_images(directory, images, i)
n_prompt,
steps,
batch_size,
height,
width,
max_embeddings_multiples,
)
total_time = time.time() - start_time total_time = time.time() - start_time
print( print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
) util.save_prompts(inputs)
for j, image_bytes in enumerate(images):
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = directory / f"{formatted_time}_{i}_{j}.png"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as file:
file.write(image_bytes)

66
util.py Normal file
View File

@ -0,0 +1,66 @@
""" Utility functions for the script. """
import time
from datetime import date
from pathlib import Path
from PIL import Image
OUTPUT_DIRECTORY = "outputs"
DATE_TODAY = date.today().strftime("%Y-%m-%d")
def make_directory() -> Path:
"""
Make a directory for saving outputs.
"""
directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}")
if not directory.exists():
directory.mkdir(exist_ok=True, parents=True)
print(f"Make directory: {directory}")
return directory
def save_prompts(inputs: dict):
"""
Save prompts to a file.
"""
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
with open(
file=f"{OUTPUT_DIRECTORY}/{DATE_TODAY}/prompts_{prompts_filename}.txt", mode="w", encoding="utf-8"
) as file:
for name, value in inputs.items():
file.write(f"{name} = {repr(value)}\n")
print(f"Save prompts: {prompts_filename}.txt")
def count_token(p: str, n: str) -> int:
"""
Count the number of tokens in the prompt and negative prompt.
"""
token_count_p = len(p.split())
token_count_n = len(n.split())
if token_count_p >= token_count_n:
token_count = token_count_p
else:
token_count = token_count_n
max_embeddings_multiples = 1
if token_count > 77:
max_embeddings_multiples = token_count // 77 + 1
print(f"token_count: {token_count}, max_embeddings_multiples: {max_embeddings_multiples}")
return max_embeddings_multiples
def save_images(directory: Path, images: list[bytes], i: int):
"""
Save images to a file.
"""
for j, image_bytes in enumerate(images):
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = directory / f"{formatted_time}_{i}_{j}.png"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as file:
file.write(image_bytes)