From 643e0e2ea6af1c9b3a3c0d4cac04d6ddb1180515 Mon Sep 17 00:00:00 2001 From: hodanov <1031hoda@gmail.com> Date: Mon, 26 Jun 2023 21:56:58 +0900 Subject: [PATCH] Refactor sd_cli.py --- sd_cli.py | 169 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 121 insertions(+), 48 deletions(-) diff --git a/sd_cli.py b/sd_cli.py index a33aff9..ef328da 100644 --- a/sd_cli.py +++ b/sd_cli.py @@ -6,6 +6,7 @@ import time from urllib.request import Request, urlopen from modal import Image, Mount, Secret, Stub, method +from modal.cls import ClsMixin BASE_CACHE_PATH = "/vol/cache" BASE_CACHE_PATH_LORA = "/vol/cache/lora" @@ -88,52 +89,70 @@ stub.image = stub_image @stub.cls(gpu="A10G", secrets=[Secret.from_dotenv(__file__)]) -class StableDiffusion: +class StableDiffusion(ClsMixin): """ A class that wraps the Stable Diffusion pipeline and scheduler. """ - def __enter__(self): + def __init__( + self, + prompt: str, + n_prompt: str, + height: int = 512, + width: int = 512, + samples: int = 1, + batch_size: int = 1, + steps: int = 30, + ): import diffusers import torch - use_vae = os.environ["USE_VAE"] == "true" + self.prompt = prompt + self.n_prompt = n_prompt + self.height = height + self.width = width + self.samples = samples + self.batch_size = batch_size + self.steps = steps + self.use_vae = os.environ["USE_VAE"] == "true" self.upscaler = os.environ["UPSCALER"] self.use_face_enhancer = os.environ["USE_FACE_ENHANCER"] == "true" + self.use_hires_fix = os.environ["USE_HIRES_FIX"] == "true" - cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"]) - if os.path.exists(cache_path): - print(f"The directory '{cache_path}' exists.") + self.cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"]) + if os.path.exists(self.cache_path): + print(f"The directory '{self.cache_path}' exists.") else: - print(f"The directory '{cache_path}' does not exist. Download models...") + print(f"The directory '{self.cache_path}' does not exist. Download models...") download_models() + self.max_embeddings_multiples = self.count_token(p=prompt, n=n_prompt) torch.backends.cuda.matmul.allow_tf32 = True self.pipe = diffusers.StableDiffusionPipeline.from_pretrained( - cache_path, + self.cache_path, custom_pipeline="lpw_stable_diffusion", torch_dtype=torch.float16, ) # TODO: Add support for other schedulers. - # self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( - self.pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained( - cache_path, + self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( + # self.pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained( + self.cache_path, subfolder="scheduler", ) - if use_vae: + if self.use_vae: self.pipe.vae = diffusers.AutoencoderKL.from_pretrained( - cache_path, + self.cache_path, subfolder="vae", ) self.pipe.to("cuda") if os.environ["LORA_NAMES"] != "": - names = os.getenv("LORA_NAMES").split(",") - urls = os.getenv("LORA_DOWNLOAD_URLS").split(",") + names = os.environ["LORA_NAMES"].split(",") + urls = os.environ["LORA_DOWNLOAD_URLS"].split(",") for name, url in zip(names, urls): path = os.path.join(BASE_CACHE_PATH_LORA, name) if os.path.exists(path): @@ -144,8 +163,8 @@ class StableDiffusion: self.pipe.load_lora_weights(".", weight_name=path) if os.environ["TEXTUAL_INVERSION_NAMES"] != "": - names = os.getenv("TEXTUAL_INVERSION_NAMES").split(",") - urls = os.getenv("TEXTUAL_INVERSION_DOWNLOAD_URLS").split(",") + names = os.environ["TEXTUAL_INVERSION_NAMES"].split(",") + urls = os.environ["TEXTUAL_INVERSION_DOWNLOAD_URLS"].split(",") for name, url in zip(names, urls): path = os.path.join(BASE_CACHE_PATH_TEXTUAL_INVERSION, name) if os.path.exists(path): @@ -164,7 +183,10 @@ class StableDiffusion: """ from transformers import CLIPTokenizer - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + tokenizer = CLIPTokenizer.from_pretrained( + self.cache_path, + subfolder="tokenizer", + ) token_size_p = len(tokenizer.tokenize(p)) token_size_n = len(tokenizer.tokenize(n)) token_size = token_size_p @@ -181,34 +203,50 @@ class StableDiffusion: return max_embeddings_multiples @method() - def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]: + def run_inference(self, seed: int) -> list[bytes]: """ Runs the Stable Diffusion pipeline on the given prompt and outputs images. """ import torch - generator = torch.Generator("cuda").manual_seed(inputs["seed"]) + generator = torch.Generator("cuda").manual_seed(seed) with torch.inference_mode(): with torch.autocast("cuda"): base_images = self.pipe.text2img( - [inputs["prompt"]] * int(inputs["batch_size"]), - negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]), - height=inputs["height"], - width=inputs["width"], - num_inference_steps=inputs["steps"], + self.prompt * self.batch_size, + negative_prompt=self.n_prompt * self.batch_size, + height=self.height, + width=self.width, + num_inference_steps=self.steps, guidance_scale=7.5, - max_embeddings_multiples=inputs["max_embeddings_multiples"], + max_embeddings_multiples=self.max_embeddings_multiples, generator=generator, ).images if self.upscaler != "": - uplcaled_images = self.upscale( + upscaled = self.upscale( base_images=base_images, - scale_factor=4, half_precision=False, tile=700, ) - base_images.extend(uplcaled_images) + base_images.extend(upscaled) + if self.use_hires_fix: + torch.cuda.empty_cache() + for img in upscaled: + with torch.inference_mode(): + with torch.autocast("cuda"): + hires_fixed = self.pipe.img2img( + prompt=self.prompt * self.batch_size, + negative_prompt=self.n_prompt * self.batch_size, + num_inference_steps=self.steps, + strength=0.3, + guidance_scale=7.5, + max_embeddings_multiples=self.max_embeddings_multiples, + generator=generator, + image=img, + ).images + base_images.extend(hires_fixed) + torch.cuda.empty_cache() image_output = [] for image in base_images: @@ -222,7 +260,6 @@ class StableDiffusion: def upscale( self, base_images: list[Image.Image], - scale_factor: float = 4, half_precision: bool = False, tile: int = 0, tile_pad: int = 10, @@ -281,7 +318,7 @@ class StableDiffusion: torch.cuda.empty_cache() upscaled_imgs = [] with tqdm(total=len(base_images)) as progress_bar: - for i, img in enumerate(base_images): + for img in base_images: img = numpy.array(img) if self.use_face_enhancer: _, _, enhance_result = face_enhancer.enhance( @@ -300,6 +337,38 @@ class StableDiffusion: return upscaled_imgs + # TODO: Implement this + # @method() + # def img2img( + # self, + # prompt: str, + # n_prompt: str, + # batch_size: int = 1, + # steps: int = 20, + # strength: float = 0.3, + # max_embeddings_multiples: int = 1, + # # image: Image.Image = None, + # base_images: list[Image.Image], + # ): + # import torch + + # torch.cuda.empty_cache() + # for img in base_images: + # with torch.inference_mode(): + # with torch.autocast("cuda"): + # hires_fixed = self.pipe.img2img( + # prompt=prompt * batch_size, + # negative_prompt=n_prompt * batch_size, + # num_inference_steps=steps], + # strength=strength, + # guidance_scale=7.5, + # max_embeddings_multiples=max_embeddings_multiples, + # generator=generator, + # image=img, + # ).images + # base_images.extend(hires_fixed) + # torch.cuda.empty_cache() + @stub.local_entrypoint() def entrypoint( @@ -319,7 +388,26 @@ def entrypoint( """ import util - inputs: dict[str, int | str] = { + directory = util.make_directory() + + sd = StableDiffusion.remote( + prompt=prompt, + n_prompt=n_prompt, + height=height, + width=width, + batch_size=batch_size, + steps=steps, + ) + for i in range(samples): + if seed == -1: + seed_generated = util.generate_seed() + start_time = time.time() + images = sd.run_inference(seed=seed_generated) + util.save_images(directory, images, seed_generated, i) + total_time = time.time() - start_time + print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") + + prompts: dict[str, int | str] = { "prompt": prompt, "n_prompt": n_prompt, "height": height, @@ -327,20 +415,5 @@ def entrypoint( "samples": samples, "batch_size": batch_size, "steps": steps, - "seed": seed, } - - directory = util.make_directory() - - sd = StableDiffusion() - inputs["max_embeddings_multiples"] = sd.count_token(p=prompt, n=n_prompt) - for i in range(samples): - if seed == -1: - inputs["seed"] = util.generate_seed() - start_time = time.time() - images = sd.run_inference.call(inputs) - util.save_images(directory, images, int(inputs["seed"]), i) - total_time = time.time() - start_time - print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") - - util.save_prompts(inputs) + util.save_prompts(prompts)