From c84646dcd3ea295e6971b9609fc7bbcbcf3088bd Mon Sep 17 00:00:00 2001 From: hodanov <1031hoda@gmail.com> Date: Mon, 4 Nov 2024 12:20:04 +0900 Subject: [PATCH] Modify some instance variables to private. --- app/stable_diffusion_xl.py | 52 ++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/app/stable_diffusion_xl.py b/app/stable_diffusion_xl.py index e0df6cd..c113569 100644 --- a/app/stable_diffusion_xl.py +++ b/app/stable_diffusion_xl.py @@ -18,40 +18,40 @@ class SDXLTxt2Img: """ @enter() - def _setup(self): + def setup(self) -> None: import diffusers import torch import yaml config = {} - with open("/config.yml", "r") as file: + with open("/config.yml") as file: config = yaml.safe_load(file) - self.cache_path = os.path.join(BASE_CACHE_PATH, config["model"]["name"]) - if os.path.exists(self.cache_path): - print(f"The directory '{self.cache_path}' exists.") + self.__cache_path = os.path.join(BASE_CACHE_PATH, config["model"]["name"]) + if os.path.exists(self.__cache_path): + print(f"The directory '{self.__cache_path}' exists.") else: - print(f"The directory '{self.cache_path}' does not exist.") + print(f"The directory '{self.__cache_path}' does not exist.") - self.pipe = diffusers.StableDiffusionXLPipeline.from_pretrained( - self.cache_path, + self.__pipe = diffusers.StableDiffusionXLPipeline.from_pretrained( + self.__cache_path, torch_dtype=torch.float16, use_safetensors=True, ) - self.refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained( - self.cache_path, + self.__refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained( + self.__cache_path, torch_dtype=torch.float16, use_safetensors=True, ) - def _count_token(self, p: str, n: str) -> int: + def __count_token(self, p: str, n: str) -> int: """ Count the number of tokens in the prompt and negative prompt. """ from transformers import CLIPTokenizer tokenizer = CLIPTokenizer.from_pretrained( - self.cache_path, + self.__cache_path, subfolder="tokenizer", ) token_size_p = len(tokenizer.tokenize(p)) @@ -72,49 +72,53 @@ class SDXLTxt2Img: @method() def run_inference( self, + *, prompt: str, n_prompt: str, height: int = 1024, width: int = 1024, steps: int = 30, seed: int = 1, - use_upscaler: bool = False, output_format: str = "png", + use_upscaler: bool = False, ) -> list[bytes]: """ Runs the Stable Diffusion pipeline on the given prompt and outputs images. """ - import pillow_avif # noqa + import pillow_avif # noqa: F401 import torch + max_embeddings_multiples = self.__count_token(p=prompt, n=n_prompt) generator = torch.Generator("cuda").manual_seed(seed) - self.pipe.to("cuda") - self.pipe.enable_vae_tiling() - self.pipe.enable_xformers_memory_efficient_attention() - generated_image = self.pipe( + self.__pipe.to("cuda") + self.__pipe.enable_vae_tiling() + self.__pipe.enable_xformers_memory_efficient_attention() + generated_image = self.__pipe( prompt=prompt, negative_prompt=n_prompt, guidance_scale=7, height=height, width=width, generator=generator, + max_embeddings_multiples=max_embeddings_multiples, num_inference_steps=steps, ).images[0] generated_images = [generated_image] if use_upscaler: - self.refiner.to("cuda") - self.refiner.enable_vae_tiling() - self.refiner.enable_xformers_memory_efficient_attention() - base_image = self._double_image_size(generated_image) - image = self.refiner( + self.__refiner.to("cuda") + self.__refiner.enable_vae_tiling() + self.__refiner.enable_xformers_memory_efficient_attention() + base_image = self.__double_image_size(generated_image) + image = self.__refiner( prompt=prompt, negative_prompt=n_prompt, num_inference_steps=steps, strength=0.3, guidance_scale=7.5, generator=generator, + max_embeddings_multiples=max_embeddings_multiples, image=base_image, ).images[0] generated_images.append(image) @@ -127,7 +131,7 @@ class SDXLTxt2Img: return image_output - def _double_image_size(self, image: PIL.Image.Image) -> PIL.Image.Image: + def __double_image_size(self, image: PIL.Image.Image) -> PIL.Image.Image: image = image.convert("RGB") width, height = image.size return image.resize((width * 2, height * 2), resample=PIL.Image.LANCZOS)