Fix some lint errors.

This commit is contained in:
hodanov 2024-11-03 19:52:58 +09:00
parent f074a1b3ca
commit c6ce2f61ad
2 changed files with 27 additions and 23 deletions

View File

@ -1,5 +1,6 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
import diffusers import diffusers
from modal import App, Image, Mount, Secret from modal import App, Image, Mount, Secret
@ -13,24 +14,26 @@ BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler"
class StableDiffusionCLISetupInterface(ABC): class StableDiffusionCLISetupInterface(ABC):
@abstractmethod @abstractmethod
def download_model(self): def download_model(self) -> None:
pass pass
class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface): class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
def __init__(self, config: dict, token: str): def __init__(self, config: dict, token: str) -> None:
if config.get("version") != "sdxl": if config.get("version") != "sdxl":
raise ValueError("Invalid version. Must be 'sdxl'.") msg = "Invalid version. Must be 'sdxl'."
raise ValueError(msg)
if config.get("model") is None: if config.get("model") is None:
raise ValueError("Model is required. Please provide a model in config.yml.") msg = "Model is required. Please provide a model in config.yml."
raise ValueError(msg)
self.__model_name: str = config["model"]["name"] self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"] self.__model_url: str = config["model"]["url"]
self.__token: str = token self.__token: str = token
def download_model(self) -> None: def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) cache_path = Path(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionXLPipeline.from_single_file( pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url, pretrained_model_link_or_path=self.__model_url,
use_auth_token=self.__token, use_auth_token=self.__token,
@ -40,19 +43,21 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface): class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
def __init__(self, config: dict, token: str): def __init__(self, config: dict, token: str) -> None:
if config.get("version") != "sd15": if config.get("version") != "sd15":
raise ValueError("Invalid version. Must be 'sd15'.") msg = "Invalid version. Must be 'sd15'."
raise ValueError(msg)
if config.get("model") is None: if config.get("model") is None:
raise ValueError("Model is required. Please provide a model in config.yml.") msg = "Model is required. Please provide a model in config.yml."
raise ValueError(msg)
self.__model_name: str = config["model"]["name"] self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"] self.__model_url: str = config["model"]["url"]
self.__token: str = token self.__token: str = token
def download_model(self) -> None: def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) cache_path = Path(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionPipeline.from_single_file( pipe = diffusers.StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url, pretrained_model_link_or_path=self.__model_url,
token=self.__token, token=self.__token,
@ -63,13 +68,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
def __download_upscaler(self) -> None: def __download_upscaler(self) -> None:
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained( upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler" "stabilityai/sd-x2-latent-upscaler",
) )
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True) upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
class CommonSetup: class CommonSetup:
def __init__(self, config: dict, token: str): def __init__(self, config: dict, token: str) -> None:
self.__token: str = token self.__token: str = token
self.__config: dict = config self.__config: dict = config
@ -105,8 +110,8 @@ class CommonSetup:
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION, file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
) )
def __download_vae(self, name: str, model_url: str, token: str): def __download_vae(self, name: str, model_url: str, token: str) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, name) cache_path = Path(BASE_CACHE_PATH, name)
vae = diffusers.AutoencoderKL.from_single_file( vae = diffusers.AutoencoderKL.from_single_file(
pretrained_model_link_or_path=model_url, pretrained_model_link_or_path=model_url,
use_auth_token=token, use_auth_token=token,
@ -114,8 +119,8 @@ class CommonSetup:
) )
vae.save_pretrained(cache_path, safe_serialization=True) vae.save_pretrained(cache_path, safe_serialization=True)
def __download_controlnet(self, name: str, repo_id: str, token: str): def __download_controlnet(self, name: str, repo_id: str, token: str) -> None:
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name) cache_path = Path(BASE_CACHE_PATH, name)
controlnet = diffusers.ControlNetModel.from_pretrained( controlnet = diffusers.ControlNetModel.from_pretrained(
repo_id, repo_id,
use_auth_token=token, use_auth_token=token,
@ -123,7 +128,7 @@ class CommonSetup:
) )
controlnet.save_pretrained(cache_path, safe_serialization=True) controlnet.save_pretrained(cache_path, safe_serialization=True)
def __download_other_file(self, url, file_name, file_path): def __download_other_file(self, url: str, file_name: str, file_path: str) -> None:
""" """
Download file from the given URL for LoRA or TextualInversion. Download file from the given URL for LoRA or TextualInversion.
""" """
@ -131,20 +136,20 @@ class CommonSetup:
req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
downloaded = urlopen(req).read() downloaded = urlopen(req).read()
dir_names = os.path.join(file_path, file_name) dir_names = Path(file_path, file_name)
os.makedirs(os.path.dirname(dir_names), exist_ok=True) os.makedirs(os.path.dirname(dir_names), exist_ok=True)
with open(dir_names, mode="wb") as f: with open(dir_names, mode="wb") as f:
f.write(downloaded) f.write(downloaded)
def build_image(): def build_image() -> None:
""" """
Build the Docker image. Build the Docker image.
""" """
import yaml import yaml
token: str = os.environ["HUGGING_FACE_TOKEN"] token: str = os.environ["HUGGING_FACE_TOKEN"]
with open("/config.yml", "r") as file: with open("/config.yml") as file:
config: dict = yaml.safe_load(file) config: dict = yaml.safe_load(file)
stable_diffusion_setup: StableDiffusionCLISetupInterface stable_diffusion_setup: StableDiffusionCLISetupInterface
@ -154,9 +159,8 @@ def build_image():
case "sdxl": case "sdxl":
stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token) stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token)
case _: case _:
raise ValueError( msg = f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'." raise ValueError(msg)
)
stable_diffusion_setup.download_model() stable_diffusion_setup.download_model()
common_setup = CommonSetup(config, token) common_setup = CommonSetup(config, token)

View File

@ -22,7 +22,7 @@ def main(
seed: int = -1, seed: int = -1,
use_upscaler: str = "False", use_upscaler: str = "False",
output_format: str = "png", output_format: str = "png",
): ) -> None:
"""This function is the entrypoint for the Runway CLI. """This function is the entrypoint for the Runway CLI.
The function pass the given prompt to StableDiffusion on Modal, The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local. gets back a list of images and outputs images to local.