Fix some lint errors.
This commit is contained in:
parent
f074a1b3ca
commit
c6ce2f61ad
48
app/setup.py
48
app/setup.py
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
from modal import App, Image, Mount, Secret
|
from modal import App, Image, Mount, Secret
|
||||||
@ -13,24 +14,26 @@ BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler"
|
|||||||
|
|
||||||
class StableDiffusionCLISetupInterface(ABC):
|
class StableDiffusionCLISetupInterface(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def download_model(self):
|
def download_model(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
|
class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
|
||||||
def __init__(self, config: dict, token: str):
|
def __init__(self, config: dict, token: str) -> None:
|
||||||
if config.get("version") != "sdxl":
|
if config.get("version") != "sdxl":
|
||||||
raise ValueError("Invalid version. Must be 'sdxl'.")
|
msg = "Invalid version. Must be 'sdxl'."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if config.get("model") is None:
|
if config.get("model") is None:
|
||||||
raise ValueError("Model is required. Please provide a model in config.yml.")
|
msg = "Model is required. Please provide a model in config.yml."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self.__model_name: str = config["model"]["name"]
|
self.__model_name: str = config["model"]["name"]
|
||||||
self.__model_url: str = config["model"]["url"]
|
self.__model_url: str = config["model"]["url"]
|
||||||
self.__token: str = token
|
self.__token: str = token
|
||||||
|
|
||||||
def download_model(self) -> None:
|
def download_model(self) -> None:
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
|
||||||
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
|
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
|
||||||
pretrained_model_link_or_path=self.__model_url,
|
pretrained_model_link_or_path=self.__model_url,
|
||||||
use_auth_token=self.__token,
|
use_auth_token=self.__token,
|
||||||
@ -40,19 +43,21 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
|
|||||||
|
|
||||||
|
|
||||||
class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
|
class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
|
||||||
def __init__(self, config: dict, token: str):
|
def __init__(self, config: dict, token: str) -> None:
|
||||||
if config.get("version") != "sd15":
|
if config.get("version") != "sd15":
|
||||||
raise ValueError("Invalid version. Must be 'sd15'.")
|
msg = "Invalid version. Must be 'sd15'."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if config.get("model") is None:
|
if config.get("model") is None:
|
||||||
raise ValueError("Model is required. Please provide a model in config.yml.")
|
msg = "Model is required. Please provide a model in config.yml."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self.__model_name: str = config["model"]["name"]
|
self.__model_name: str = config["model"]["name"]
|
||||||
self.__model_url: str = config["model"]["url"]
|
self.__model_url: str = config["model"]["url"]
|
||||||
self.__token: str = token
|
self.__token: str = token
|
||||||
|
|
||||||
def download_model(self) -> None:
|
def download_model(self) -> None:
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
|
||||||
pipe = diffusers.StableDiffusionPipeline.from_single_file(
|
pipe = diffusers.StableDiffusionPipeline.from_single_file(
|
||||||
pretrained_model_link_or_path=self.__model_url,
|
pretrained_model_link_or_path=self.__model_url,
|
||||||
token=self.__token,
|
token=self.__token,
|
||||||
@ -63,13 +68,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
|
|||||||
|
|
||||||
def __download_upscaler(self) -> None:
|
def __download_upscaler(self) -> None:
|
||||||
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
|
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
|
||||||
"stabilityai/sd-x2-latent-upscaler"
|
"stabilityai/sd-x2-latent-upscaler",
|
||||||
)
|
)
|
||||||
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
|
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
class CommonSetup:
|
class CommonSetup:
|
||||||
def __init__(self, config: dict, token: str):
|
def __init__(self, config: dict, token: str) -> None:
|
||||||
self.__token: str = token
|
self.__token: str = token
|
||||||
self.__config: dict = config
|
self.__config: dict = config
|
||||||
|
|
||||||
@ -105,8 +110,8 @@ class CommonSetup:
|
|||||||
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __download_vae(self, name: str, model_url: str, token: str):
|
def __download_vae(self, name: str, model_url: str, token: str) -> None:
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
cache_path = Path(BASE_CACHE_PATH, name)
|
||||||
vae = diffusers.AutoencoderKL.from_single_file(
|
vae = diffusers.AutoencoderKL.from_single_file(
|
||||||
pretrained_model_link_or_path=model_url,
|
pretrained_model_link_or_path=model_url,
|
||||||
use_auth_token=token,
|
use_auth_token=token,
|
||||||
@ -114,8 +119,8 @@ class CommonSetup:
|
|||||||
)
|
)
|
||||||
vae.save_pretrained(cache_path, safe_serialization=True)
|
vae.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
|
||||||
def __download_controlnet(self, name: str, repo_id: str, token: str):
|
def __download_controlnet(self, name: str, repo_id: str, token: str) -> None:
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
|
cache_path = Path(BASE_CACHE_PATH, name)
|
||||||
controlnet = diffusers.ControlNetModel.from_pretrained(
|
controlnet = diffusers.ControlNetModel.from_pretrained(
|
||||||
repo_id,
|
repo_id,
|
||||||
use_auth_token=token,
|
use_auth_token=token,
|
||||||
@ -123,7 +128,7 @@ class CommonSetup:
|
|||||||
)
|
)
|
||||||
controlnet.save_pretrained(cache_path, safe_serialization=True)
|
controlnet.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
|
||||||
def __download_other_file(self, url, file_name, file_path):
|
def __download_other_file(self, url: str, file_name: str, file_path: str) -> None:
|
||||||
"""
|
"""
|
||||||
Download file from the given URL for LoRA or TextualInversion.
|
Download file from the given URL for LoRA or TextualInversion.
|
||||||
"""
|
"""
|
||||||
@ -131,20 +136,20 @@ class CommonSetup:
|
|||||||
|
|
||||||
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||||
downloaded = urlopen(req).read()
|
downloaded = urlopen(req).read()
|
||||||
dir_names = os.path.join(file_path, file_name)
|
dir_names = Path(file_path, file_name)
|
||||||
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
||||||
with open(dir_names, mode="wb") as f:
|
with open(dir_names, mode="wb") as f:
|
||||||
f.write(downloaded)
|
f.write(downloaded)
|
||||||
|
|
||||||
|
|
||||||
def build_image():
|
def build_image() -> None:
|
||||||
"""
|
"""
|
||||||
Build the Docker image.
|
Build the Docker image.
|
||||||
"""
|
"""
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
token: str = os.environ["HUGGING_FACE_TOKEN"]
|
token: str = os.environ["HUGGING_FACE_TOKEN"]
|
||||||
with open("/config.yml", "r") as file:
|
with open("/config.yml") as file:
|
||||||
config: dict = yaml.safe_load(file)
|
config: dict = yaml.safe_load(file)
|
||||||
|
|
||||||
stable_diffusion_setup: StableDiffusionCLISetupInterface
|
stable_diffusion_setup: StableDiffusionCLISetupInterface
|
||||||
@ -154,9 +159,8 @@ def build_image():
|
|||||||
case "sdxl":
|
case "sdxl":
|
||||||
stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token)
|
stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(
|
msg = f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
|
||||||
f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
|
raise ValueError(msg)
|
||||||
)
|
|
||||||
|
|
||||||
stable_diffusion_setup.download_model()
|
stable_diffusion_setup.download_model()
|
||||||
common_setup = CommonSetup(config, token)
|
common_setup = CommonSetup(config, token)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ def main(
|
|||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
use_upscaler: str = "False",
|
use_upscaler: str = "False",
|
||||||
output_format: str = "png",
|
output_format: str = "png",
|
||||||
):
|
) -> None:
|
||||||
"""This function is the entrypoint for the Runway CLI.
|
"""This function is the entrypoint for the Runway CLI.
|
||||||
The function pass the given prompt to StableDiffusion on Modal,
|
The function pass the given prompt to StableDiffusion on Modal,
|
||||||
gets back a list of images and outputs images to local.
|
gets back a list of images and outputs images to local.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user