Merge pull request #159 from hodanov/feature/refactoring

Fix some lint errors. Refactor app.
This commit is contained in:
hodanov 2024-11-04 14:11:23 +09:00 committed by GitHub
commit 82c162b947
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 31 deletions

View File

@ -1,5 +1,6 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
import diffusers import diffusers
from huggingface_hub import login from huggingface_hub import login
@ -36,7 +37,7 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
self.__token: str = token self.__token: str = token
def download_model(self) -> None: def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) cache_path = Path(BASE_CACHE_PATH) / self.__model_name
pipe = diffusers.StableDiffusionXLPipeline.from_single_file( pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url, pretrained_model_link_or_path=self.__model_url,
use_auth_token=self.__token, use_auth_token=self.__token,
@ -63,7 +64,7 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
self.__token: str = token self.__token: str = token
def download_model(self) -> None: def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) cache_path = Path(BASE_CACHE_PATH) / self.__model_name
pipe = diffusers.StableDiffusionPipeline.from_single_file( pipe = diffusers.StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url, pretrained_model_link_or_path=self.__model_url,
token=self.__token, token=self.__token,
@ -117,7 +118,7 @@ class CommonSetup:
) )
def __download_vae(self, name: str, model_url: str, token: str) -> None: def __download_vae(self, name: str, model_url: str, token: str) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, name) cache_path = Path(BASE_CACHE_PATH, name)
vae = diffusers.AutoencoderKL.from_single_file( vae = diffusers.AutoencoderKL.from_single_file(
pretrained_model_link_or_path=model_url, pretrained_model_link_or_path=model_url,
use_auth_token=token, use_auth_token=token,
@ -126,7 +127,7 @@ class CommonSetup:
vae.save_pretrained(cache_path, safe_serialization=True) vae.save_pretrained(cache_path, safe_serialization=True)
def __download_controlnet(self, name: str, repo_id: str, token: str) -> None: def __download_controlnet(self, name: str, repo_id: str, token: str) -> None:
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name) cache_path = Path(BASE_CACHE_PATH_CONTROLNET) / name
controlnet = diffusers.ControlNetModel.from_pretrained( controlnet = diffusers.ControlNetModel.from_pretrained(
repo_id, repo_id,
use_auth_token=token, use_auth_token=token,
@ -142,7 +143,7 @@ class CommonSetup:
req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
downloaded = urlopen(req).read() downloaded = urlopen(req).read()
dir_names = os.path.join(file_path, file_name) dir_names = Path(file_path) / file_name
os.makedirs(os.path.dirname(dir_names), exist_ok=True) os.makedirs(os.path.dirname(dir_names), exist_ok=True)
with open(dir_names, mode="wb") as f: with open(dir_names, mode="wb") as f:
f.write(downloaded) f.write(downloaded)

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import io import io
import os from pathlib import Path
import PIL.Image import PIL.Image
from modal import Secret, enter, method from modal import Secret, enter, method
@ -18,40 +18,39 @@ class SDXLTxt2Img:
""" """
@enter() @enter()
def _setup(self): def setup(self) -> None:
import diffusers import diffusers
import torch import torch
import yaml import yaml
config = {} config = {}
with open("/config.yml", "r") as file: with Path("/config.yml").open() as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
self.cache_path = os.path.join(BASE_CACHE_PATH, config["model"]["name"]) self.__cache_path = Path(BASE_CACHE_PATH) / config["model"]["name"]
if os.path.exists(self.cache_path): if not Path.exists(self.__cache_path):
print(f"The directory '{self.cache_path}' exists.") msg = f"The directory '{self.__cache_path}' does not exist."
else: raise ValueError(msg)
print(f"The directory '{self.cache_path}' does not exist.")
self.pipe = diffusers.StableDiffusionXLPipeline.from_pretrained( self.__pipe = diffusers.StableDiffusionXLPipeline.from_pretrained(
self.cache_path, self.__cache_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_safetensors=True, use_safetensors=True,
) )
self.refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained( self.__refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained(
self.cache_path, self.__cache_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_safetensors=True, use_safetensors=True,
) )
def _count_token(self, p: str, n: str) -> int: def __count_token(self, p: str, n: str) -> int:
""" """
Count the number of tokens in the prompt and negative prompt. Count the number of tokens in the prompt and negative prompt.
""" """
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
self.cache_path, self.__cache_path,
subfolder="tokenizer", subfolder="tokenizer",
) )
token_size_p = len(tokenizer.tokenize(p)) token_size_p = len(tokenizer.tokenize(p))
@ -72,49 +71,53 @@ class SDXLTxt2Img:
@method() @method()
def run_inference( def run_inference(
self, self,
*,
prompt: str, prompt: str,
n_prompt: str, n_prompt: str,
height: int = 1024, height: int = 1024,
width: int = 1024, width: int = 1024,
steps: int = 30, steps: int = 30,
seed: int = 1, seed: int = 1,
use_upscaler: bool = False,
output_format: str = "png", output_format: str = "png",
use_upscaler: bool = False,
) -> list[bytes]: ) -> list[bytes]:
""" """
Runs the Stable Diffusion pipeline on the given prompt and outputs images. Runs the Stable Diffusion pipeline on the given prompt and outputs images.
""" """
import pillow_avif # noqa import pillow_avif # noqa: F401
import torch import torch
max_embeddings_multiples = self.__count_token(p=prompt, n=n_prompt)
generator = torch.Generator("cuda").manual_seed(seed) generator = torch.Generator("cuda").manual_seed(seed)
self.pipe.to("cuda") self.__pipe.to("cuda")
self.pipe.enable_vae_tiling() self.__pipe.enable_vae_tiling()
self.pipe.enable_xformers_memory_efficient_attention() self.__pipe.enable_xformers_memory_efficient_attention()
generated_image = self.pipe( generated_image = self.__pipe(
prompt=prompt, prompt=prompt,
negative_prompt=n_prompt, negative_prompt=n_prompt,
guidance_scale=7, guidance_scale=7,
height=height, height=height,
width=width, width=width,
generator=generator, generator=generator,
max_embeddings_multiples=max_embeddings_multiples,
num_inference_steps=steps, num_inference_steps=steps,
).images[0] ).images[0]
generated_images = [generated_image] generated_images = [generated_image]
if use_upscaler: if use_upscaler:
self.refiner.to("cuda") self.__refiner.to("cuda")
self.refiner.enable_vae_tiling() self.__refiner.enable_vae_tiling()
self.refiner.enable_xformers_memory_efficient_attention() self.__refiner.enable_xformers_memory_efficient_attention()
base_image = self._double_image_size(generated_image) base_image = self.__double_image_size(generated_image)
image = self.refiner( image = self.__refiner(
prompt=prompt, prompt=prompt,
negative_prompt=n_prompt, negative_prompt=n_prompt,
num_inference_steps=steps, num_inference_steps=steps,
strength=0.3, strength=0.3,
guidance_scale=7.5, guidance_scale=7.5,
generator=generator, generator=generator,
max_embeddings_multiples=max_embeddings_multiples,
image=base_image, image=base_image,
).images[0] ).images[0]
generated_images.append(image) generated_images.append(image)
@ -127,7 +130,7 @@ class SDXLTxt2Img:
return image_output return image_output
def _double_image_size(self, image: PIL.Image.Image) -> PIL.Image.Image: def __double_image_size(self, image: PIL.Image.Image) -> PIL.Image.Image:
image = image.convert("RGB") image = image.convert("RGB")
width, height = image.size width, height = image.size
return image.resize((width * 2, height * 2), resample=PIL.Image.LANCZOS) return image.resize((width * 2, height * 2), resample=PIL.Image.LANCZOS)