Refactor sd_cli.py

This commit is contained in:
hodanov 2023-06-26 21:56:58 +09:00
parent ddb685e4f3
commit 643e0e2ea6

169
sd_cli.py
View File

@ -6,6 +6,7 @@ import time
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
from modal import Image, Mount, Secret, Stub, method from modal import Image, Mount, Secret, Stub, method
from modal.cls import ClsMixin
BASE_CACHE_PATH = "/vol/cache" BASE_CACHE_PATH = "/vol/cache"
BASE_CACHE_PATH_LORA = "/vol/cache/lora" BASE_CACHE_PATH_LORA = "/vol/cache/lora"
@ -88,52 +89,70 @@ stub.image = stub_image
@stub.cls(gpu="A10G", secrets=[Secret.from_dotenv(__file__)]) @stub.cls(gpu="A10G", secrets=[Secret.from_dotenv(__file__)])
class StableDiffusion: class StableDiffusion(ClsMixin):
""" """
A class that wraps the Stable Diffusion pipeline and scheduler. A class that wraps the Stable Diffusion pipeline and scheduler.
""" """
def __enter__(self): def __init__(
self,
prompt: str,
n_prompt: str,
height: int = 512,
width: int = 512,
samples: int = 1,
batch_size: int = 1,
steps: int = 30,
):
import diffusers import diffusers
import torch import torch
use_vae = os.environ["USE_VAE"] == "true" self.prompt = prompt
self.n_prompt = n_prompt
self.height = height
self.width = width
self.samples = samples
self.batch_size = batch_size
self.steps = steps
self.use_vae = os.environ["USE_VAE"] == "true"
self.upscaler = os.environ["UPSCALER"] self.upscaler = os.environ["UPSCALER"]
self.use_face_enhancer = os.environ["USE_FACE_ENHANCER"] == "true" self.use_face_enhancer = os.environ["USE_FACE_ENHANCER"] == "true"
self.use_hires_fix = os.environ["USE_HIRES_FIX"] == "true"
cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"]) self.cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"])
if os.path.exists(cache_path): if os.path.exists(self.cache_path):
print(f"The directory '{cache_path}' exists.") print(f"The directory '{self.cache_path}' exists.")
else: else:
print(f"The directory '{cache_path}' does not exist. Download models...") print(f"The directory '{self.cache_path}' does not exist. Download models...")
download_models() download_models()
self.max_embeddings_multiples = self.count_token(p=prompt, n=n_prompt)
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
self.pipe = diffusers.StableDiffusionPipeline.from_pretrained( self.pipe = diffusers.StableDiffusionPipeline.from_pretrained(
cache_path, self.cache_path,
custom_pipeline="lpw_stable_diffusion", custom_pipeline="lpw_stable_diffusion",
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
# TODO: Add support for other schedulers. # TODO: Add support for other schedulers.
# self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
self.pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained( # self.pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
cache_path, self.cache_path,
subfolder="scheduler", subfolder="scheduler",
) )
if use_vae: if self.use_vae:
self.pipe.vae = diffusers.AutoencoderKL.from_pretrained( self.pipe.vae = diffusers.AutoencoderKL.from_pretrained(
cache_path, self.cache_path,
subfolder="vae", subfolder="vae",
) )
self.pipe.to("cuda") self.pipe.to("cuda")
if os.environ["LORA_NAMES"] != "": if os.environ["LORA_NAMES"] != "":
names = os.getenv("LORA_NAMES").split(",") names = os.environ["LORA_NAMES"].split(",")
urls = os.getenv("LORA_DOWNLOAD_URLS").split(",") urls = os.environ["LORA_DOWNLOAD_URLS"].split(",")
for name, url in zip(names, urls): for name, url in zip(names, urls):
path = os.path.join(BASE_CACHE_PATH_LORA, name) path = os.path.join(BASE_CACHE_PATH_LORA, name)
if os.path.exists(path): if os.path.exists(path):
@ -144,8 +163,8 @@ class StableDiffusion:
self.pipe.load_lora_weights(".", weight_name=path) self.pipe.load_lora_weights(".", weight_name=path)
if os.environ["TEXTUAL_INVERSION_NAMES"] != "": if os.environ["TEXTUAL_INVERSION_NAMES"] != "":
names = os.getenv("TEXTUAL_INVERSION_NAMES").split(",") names = os.environ["TEXTUAL_INVERSION_NAMES"].split(",")
urls = os.getenv("TEXTUAL_INVERSION_DOWNLOAD_URLS").split(",") urls = os.environ["TEXTUAL_INVERSION_DOWNLOAD_URLS"].split(",")
for name, url in zip(names, urls): for name, url in zip(names, urls):
path = os.path.join(BASE_CACHE_PATH_TEXTUAL_INVERSION, name) path = os.path.join(BASE_CACHE_PATH_TEXTUAL_INVERSION, name)
if os.path.exists(path): if os.path.exists(path):
@ -164,7 +183,10 @@ class StableDiffusion:
""" """
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") tokenizer = CLIPTokenizer.from_pretrained(
self.cache_path,
subfolder="tokenizer",
)
token_size_p = len(tokenizer.tokenize(p)) token_size_p = len(tokenizer.tokenize(p))
token_size_n = len(tokenizer.tokenize(n)) token_size_n = len(tokenizer.tokenize(n))
token_size = token_size_p token_size = token_size_p
@ -181,34 +203,50 @@ class StableDiffusion:
return max_embeddings_multiples return max_embeddings_multiples
@method() @method()
def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]: def run_inference(self, seed: int) -> list[bytes]:
""" """
Runs the Stable Diffusion pipeline on the given prompt and outputs images. Runs the Stable Diffusion pipeline on the given prompt and outputs images.
""" """
import torch import torch
generator = torch.Generator("cuda").manual_seed(inputs["seed"]) generator = torch.Generator("cuda").manual_seed(seed)
with torch.inference_mode(): with torch.inference_mode():
with torch.autocast("cuda"): with torch.autocast("cuda"):
base_images = self.pipe.text2img( base_images = self.pipe.text2img(
[inputs["prompt"]] * int(inputs["batch_size"]), self.prompt * self.batch_size,
negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]), negative_prompt=self.n_prompt * self.batch_size,
height=inputs["height"], height=self.height,
width=inputs["width"], width=self.width,
num_inference_steps=inputs["steps"], num_inference_steps=self.steps,
guidance_scale=7.5, guidance_scale=7.5,
max_embeddings_multiples=inputs["max_embeddings_multiples"], max_embeddings_multiples=self.max_embeddings_multiples,
generator=generator, generator=generator,
).images ).images
if self.upscaler != "": if self.upscaler != "":
uplcaled_images = self.upscale( upscaled = self.upscale(
base_images=base_images, base_images=base_images,
scale_factor=4,
half_precision=False, half_precision=False,
tile=700, tile=700,
) )
base_images.extend(uplcaled_images) base_images.extend(upscaled)
if self.use_hires_fix:
torch.cuda.empty_cache()
for img in upscaled:
with torch.inference_mode():
with torch.autocast("cuda"):
hires_fixed = self.pipe.img2img(
prompt=self.prompt * self.batch_size,
negative_prompt=self.n_prompt * self.batch_size,
num_inference_steps=self.steps,
strength=0.3,
guidance_scale=7.5,
max_embeddings_multiples=self.max_embeddings_multiples,
generator=generator,
image=img,
).images
base_images.extend(hires_fixed)
torch.cuda.empty_cache()
image_output = [] image_output = []
for image in base_images: for image in base_images:
@ -222,7 +260,6 @@ class StableDiffusion:
def upscale( def upscale(
self, self,
base_images: list[Image.Image], base_images: list[Image.Image],
scale_factor: float = 4,
half_precision: bool = False, half_precision: bool = False,
tile: int = 0, tile: int = 0,
tile_pad: int = 10, tile_pad: int = 10,
@ -281,7 +318,7 @@ class StableDiffusion:
torch.cuda.empty_cache() torch.cuda.empty_cache()
upscaled_imgs = [] upscaled_imgs = []
with tqdm(total=len(base_images)) as progress_bar: with tqdm(total=len(base_images)) as progress_bar:
for i, img in enumerate(base_images): for img in base_images:
img = numpy.array(img) img = numpy.array(img)
if self.use_face_enhancer: if self.use_face_enhancer:
_, _, enhance_result = face_enhancer.enhance( _, _, enhance_result = face_enhancer.enhance(
@ -300,6 +337,38 @@ class StableDiffusion:
return upscaled_imgs return upscaled_imgs
# TODO: Implement this
# @method()
# def img2img(
# self,
# prompt: str,
# n_prompt: str,
# batch_size: int = 1,
# steps: int = 20,
# strength: float = 0.3,
# max_embeddings_multiples: int = 1,
# # image: Image.Image = None,
# base_images: list[Image.Image],
# ):
# import torch
# torch.cuda.empty_cache()
# for img in base_images:
# with torch.inference_mode():
# with torch.autocast("cuda"):
# hires_fixed = self.pipe.img2img(
# prompt=prompt * batch_size,
# negative_prompt=n_prompt * batch_size,
# num_inference_steps=steps],
# strength=strength,
# guidance_scale=7.5,
# max_embeddings_multiples=max_embeddings_multiples,
# generator=generator,
# image=img,
# ).images
# base_images.extend(hires_fixed)
# torch.cuda.empty_cache()
@stub.local_entrypoint() @stub.local_entrypoint()
def entrypoint( def entrypoint(
@ -319,7 +388,26 @@ def entrypoint(
""" """
import util import util
inputs: dict[str, int | str] = { directory = util.make_directory()
sd = StableDiffusion.remote(
prompt=prompt,
n_prompt=n_prompt,
height=height,
width=width,
batch_size=batch_size,
steps=steps,
)
for i in range(samples):
if seed == -1:
seed_generated = util.generate_seed()
start_time = time.time()
images = sd.run_inference(seed=seed_generated)
util.save_images(directory, images, seed_generated, i)
total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
prompts: dict[str, int | str] = {
"prompt": prompt, "prompt": prompt,
"n_prompt": n_prompt, "n_prompt": n_prompt,
"height": height, "height": height,
@ -327,20 +415,5 @@ def entrypoint(
"samples": samples, "samples": samples,
"batch_size": batch_size, "batch_size": batch_size,
"steps": steps, "steps": steps,
"seed": seed,
} }
util.save_prompts(prompts)
directory = util.make_directory()
sd = StableDiffusion()
inputs["max_embeddings_multiples"] = sd.count_token(p=prompt, n=n_prompt)
for i in range(samples):
if seed == -1:
inputs["seed"] = util.generate_seed()
start_time = time.time()
images = sd.run_inference.call(inputs)
util.save_images(directory, images, int(inputs["seed"]), i)
total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
util.save_prompts(inputs)