diff --git a/app/requirements.txt b/app/requirements.txt index 5118aa4..eb1bcfb 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -22,4 +22,6 @@ controlnet_aux pyyaml # Use the below in 'download_from_original_stable_diffusion_ckpt'. -omegaconf==2.3.0 \ No newline at end of file +omegaconf==2.3.0 + +peft \ No newline at end of file diff --git a/app/stable_diffusion_1_5.py b/app/stable_diffusion_1_5.py index 741046a..6c61f1d 100644 --- a/app/stable_diffusion_1_5.py +++ b/app/stable_diffusion_1_5.py @@ -4,7 +4,7 @@ import io import os import PIL.Image -from modal import Secret, method +from modal import Secret, enter, method from setup import ( BASE_CACHE_PATH, BASE_CACHE_PATH_CONTROLNET, @@ -23,7 +23,8 @@ class SD15: SD15 is a class that runs inference using Stable Diffusion 1.5. """ - def __enter__(self): + @enter() + def _setup(self): import diffusers import torch import yaml @@ -69,6 +70,7 @@ class SD15: else: print(f"The directory '{path}' does not exist. Need to execute 'modal deploy' first.") self.pipe.load_lora_weights(".", weight_name=path) + self.pipe.fuse_lora() textual_inversions = config.get("textual_inversions") if textual_inversions is not None: diff --git a/app/stable_diffusion_xl.py b/app/stable_diffusion_xl.py index c4012fa..3cc5afb 100644 --- a/app/stable_diffusion_xl.py +++ b/app/stable_diffusion_xl.py @@ -4,8 +4,8 @@ import io import os import PIL.Image -from modal import Secret, method -from setup import BASE_CACHE_PATH, stub +from modal import Secret, enter, method +from setup import BASE_CACHE_PATH, BASE_CACHE_PATH_CONTROLNET, stub @stub.cls( @@ -17,7 +17,8 @@ class SDXLTxt2Img: A class that wraps the Stable Diffusion pipeline and scheduler. """ - def __enter__(self): + @enter() + def _setup(self): import diffusers import torch import yaml @@ -38,23 +39,67 @@ class SDXLTxt2Img: variant="fp16", ) - self.refiner_cache_path = self.cache_path + "-refiner" - self.refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained( - self.refiner_cache_path, - torch_dtype=torch.float16, - use_safetensors=True, - variant="fp16", + # self.refiner_cache_path = self.cache_path + "-refiner" + # self.refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained( + # self.refiner_cache_path, + # torch_dtype=torch.float16, + # use_safetensors=True, + # variant="fp16", + # ) + + # controlnets = config.get("controlnets") + # if controlnets is not None: + # for controlnet in controlnets: + # path = os.path.join(BASE_CACHE_PATH_CONTROLNET, controlnet["name"]) + # controlnet = diffusers.ControlNetModel.from_pretrained(path, torch_dtype=torch.float16) + # self.controlnet_pipe = diffusers.StableDiffusionControlNetPipeline.from_pretrained( + # self.cache_path, + # controlnet=controlnet, + # custom_pipeline="lpw_stable_diffusion", + # scheduler=self.pipe.scheduler, + # vae=self.pipe.vae, + # torch_dtype=torch.float16, + # use_safetensors=True, + # ) + + def _count_token(self, p: str, n: str) -> int: + """ + Count the number of tokens in the prompt and negative prompt. + """ + from transformers import CLIPTokenizer + + tokenizer = CLIPTokenizer.from_pretrained( + self.cache_path, + subfolder="tokenizer", ) + token_size_p = len(tokenizer.tokenize(p)) + token_size_n = len(tokenizer.tokenize(n)) + token_size = token_size_p + if token_size_p <= token_size_n: + token_size = token_size_n + + max_embeddings_multiples = 1 + max_length = tokenizer.model_max_length - 2 + if token_size > max_length: + max_embeddings_multiples = token_size // max_length + 1 + + print(f"token_size: {token_size}, max_embeddings_multiples: {max_embeddings_multiples}") + + return max_embeddings_multiples @method() def run_inference( self, prompt: str, + n_prompt: str, height: int = 1024, width: int = 1024, + batch_size: int = 1, + steps: int = 30, seed: int = 1, upscaler: str = "", use_face_enhancer: bool = False, + fix_by_controlnet_tile: bool = False, output_format: str = "png", ) -> list[bytes]: """ @@ -67,20 +112,57 @@ class SDXLTxt2Img: self.pipe.to("cuda") generated_images = self.pipe( prompt=prompt, + negative_prompt=n_prompt, height=height, width=width, generator=generator, ).images base_images = generated_images - for image in base_images: - self.refiner.to("cuda") - refined_images = self.refiner( - prompt=prompt, - image=image, - ).images - generated_images.extend(refined_images) - base_images = refined_images + # for image in base_images: + # image = self._resize_image(image=image, scale_factor=2) + # self.refiner.to("cuda") + # refined_images = self.refiner( + # prompt=prompt, + # negative_prompt=n_prompt, + # num_inference_steps=steps, + # strength=0.1, + # # guidance_scale=7.5, + # generator=generator, + # image=image, + # ).images + # generated_images.extend(refined_images) + # base_images = refined_images + """ + Fix the generated images by the control_v11f1e_sd15_tile when `fix_by_controlnet_tile` is `True`. + https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile + """ + # if fix_by_controlnet_tile: + # max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt) + # print("========================確認用========================") + # print("Step1") + # self.controlnet_pipe.to("cuda") + # self.controlnet_pipe.enable_vae_tiling() + # self.controlnet_pipe.enable_xformers_memory_efficient_attention() + # print("Step2") + # for image in base_images: + # image = self._resize_image(image=image, scale_factor=2) + # print("Step3") + # with torch.autocast("cuda"): + # print("Step4") + # fixed_by_controlnet = self.controlnet_pipe( + # prompt=prompt * batch_size, + # negative_prompt=n_prompt * batch_size, + # num_inference_steps=steps, + # strength=0.3, + # guidance_scale=7.5, + # max_embeddings_multiples=max_embeddings_multiples, + # generator=generator, + # image=image, + # ).images + # print("Step5") + # generated_images.extend(fixed_by_controlnet) + # base_images = fixed_by_controlnet if upscaler != "": upscaled = self._upscale( @@ -100,6 +182,12 @@ class SDXLTxt2Img: return image_output + def _resize_image(self, image: PIL.Image.Image, scale_factor: int) -> PIL.Image.Image: + image = image.convert("RGB") + width, height = image.size + img = image.resize((width * scale_factor, height * scale_factor), resample=PIL.Image.LANCZOS) + return img + def _upscale( self, base_images: list[PIL.Image],