Modify some instance variables to private.

This commit is contained in:
hodanov 2024-11-04 12:20:04 +09:00
parent 335b678f8f
commit c84646dcd3

View File

@ -18,40 +18,40 @@ class SDXLTxt2Img:
""" """
@enter() @enter()
def _setup(self): def setup(self) -> None:
import diffusers import diffusers
import torch import torch
import yaml import yaml
config = {} config = {}
with open("/config.yml", "r") as file: with open("/config.yml") as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
self.cache_path = os.path.join(BASE_CACHE_PATH, config["model"]["name"]) self.__cache_path = os.path.join(BASE_CACHE_PATH, config["model"]["name"])
if os.path.exists(self.cache_path): if os.path.exists(self.__cache_path):
print(f"The directory '{self.cache_path}' exists.") print(f"The directory '{self.__cache_path}' exists.")
else: else:
print(f"The directory '{self.cache_path}' does not exist.") print(f"The directory '{self.__cache_path}' does not exist.")
self.pipe = diffusers.StableDiffusionXLPipeline.from_pretrained( self.__pipe = diffusers.StableDiffusionXLPipeline.from_pretrained(
self.cache_path, self.__cache_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_safetensors=True, use_safetensors=True,
) )
self.refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained( self.__refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained(
self.cache_path, self.__cache_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_safetensors=True, use_safetensors=True,
) )
def _count_token(self, p: str, n: str) -> int: def __count_token(self, p: str, n: str) -> int:
""" """
Count the number of tokens in the prompt and negative prompt. Count the number of tokens in the prompt and negative prompt.
""" """
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
self.cache_path, self.__cache_path,
subfolder="tokenizer", subfolder="tokenizer",
) )
token_size_p = len(tokenizer.tokenize(p)) token_size_p = len(tokenizer.tokenize(p))
@ -72,49 +72,53 @@ class SDXLTxt2Img:
@method() @method()
def run_inference( def run_inference(
self, self,
*,
prompt: str, prompt: str,
n_prompt: str, n_prompt: str,
height: int = 1024, height: int = 1024,
width: int = 1024, width: int = 1024,
steps: int = 30, steps: int = 30,
seed: int = 1, seed: int = 1,
use_upscaler: bool = False,
output_format: str = "png", output_format: str = "png",
use_upscaler: bool = False,
) -> list[bytes]: ) -> list[bytes]:
""" """
Runs the Stable Diffusion pipeline on the given prompt and outputs images. Runs the Stable Diffusion pipeline on the given prompt and outputs images.
""" """
import pillow_avif # noqa import pillow_avif # noqa: F401
import torch import torch
max_embeddings_multiples = self.__count_token(p=prompt, n=n_prompt)
generator = torch.Generator("cuda").manual_seed(seed) generator = torch.Generator("cuda").manual_seed(seed)
self.pipe.to("cuda") self.__pipe.to("cuda")
self.pipe.enable_vae_tiling() self.__pipe.enable_vae_tiling()
self.pipe.enable_xformers_memory_efficient_attention() self.__pipe.enable_xformers_memory_efficient_attention()
generated_image = self.pipe( generated_image = self.__pipe(
prompt=prompt, prompt=prompt,
negative_prompt=n_prompt, negative_prompt=n_prompt,
guidance_scale=7, guidance_scale=7,
height=height, height=height,
width=width, width=width,
generator=generator, generator=generator,
max_embeddings_multiples=max_embeddings_multiples,
num_inference_steps=steps, num_inference_steps=steps,
).images[0] ).images[0]
generated_images = [generated_image] generated_images = [generated_image]
if use_upscaler: if use_upscaler:
self.refiner.to("cuda") self.__refiner.to("cuda")
self.refiner.enable_vae_tiling() self.__refiner.enable_vae_tiling()
self.refiner.enable_xformers_memory_efficient_attention() self.__refiner.enable_xformers_memory_efficient_attention()
base_image = self._double_image_size(generated_image) base_image = self.__double_image_size(generated_image)
image = self.refiner( image = self.__refiner(
prompt=prompt, prompt=prompt,
negative_prompt=n_prompt, negative_prompt=n_prompt,
num_inference_steps=steps, num_inference_steps=steps,
strength=0.3, strength=0.3,
guidance_scale=7.5, guidance_scale=7.5,
generator=generator, generator=generator,
max_embeddings_multiples=max_embeddings_multiples,
image=base_image, image=base_image,
).images[0] ).images[0]
generated_images.append(image) generated_images.append(image)
@ -127,7 +131,7 @@ class SDXLTxt2Img:
return image_output return image_output
def _double_image_size(self, image: PIL.Image.Image) -> PIL.Image.Image: def __double_image_size(self, image: PIL.Image.Image) -> PIL.Image.Image:
image = image.convert("RGB") image = image.convert("RGB")
width, height = image.size width, height = image.size
return image.resize((width * 2, height * 2), resample=PIL.Image.LANCZOS) return image.resize((width * 2, height * 2), resample=PIL.Image.LANCZOS)