Merge pull request #155 from hodanov/feature/refactoring

Refactor setup.py.
This commit is contained in:
hodanov 2024-11-02 17:41:22 +09:00 committed by GitHub
commit b6c26f4616
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 170 additions and 102 deletions

View File

@ -97,6 +97,7 @@ Add the model used for inference. Use the Safetensors file as is. VAE, LoRA, and
```yml
# ex)
version: "sd15" # Specify 'sd15' or 'sdxl'.
model:
name: stable-diffusion-1-5
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors # Specify URL for the safetensor file.
@ -117,6 +118,15 @@ loras:
url: download_link_here # Specify the download link for the safetensor file.
```
If you want to use SDXL:
```yml
version: "sdxl"
model:
name: stable-diffusion-xl
url: https://huggingface.co/xxxx/xxxx
```
### 4. Setting prompts
Set the prompt to Makefile.
@ -151,6 +161,10 @@ The txt2img inference is executed with the following command.
```bash
make img_by_sd15_txt2img
or
make img_by_sdxl_txt2img
```
Thank you.

View File

@ -99,6 +99,7 @@ HUGGING_FACE_TOKEN="ここにHuggingFaceのトークンを記載する"
```yml
# 設定例
version: "sd15" # Specify 'sd15' or 'sdxl'.
model:
name: stable-diffusion-1-5
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors # Specify URL for the safetensor file.
@ -119,6 +120,15 @@ loras:
url: https://civitai.com/api/download/models/150907?type=Model&format=SafeTensor # ダウンロードリンクを指定
```
SDXLを使いたい場合は`version``sdxl`を指定し、urlに使いたいsdxlのモデルを指定します。
```yml
version: "sdxl"
model:
name: stable-diffusion-xl
url: https://huggingface.co/xxxx/xxxx
```
### 4. Makefileの設定プロンプトの設定
プロンプトをMakefileに設定します。
@ -164,4 +174,8 @@ make app
```bash
make img_by_sd15_txt2img
or
make img_by_sdxl_txt2img
```

View File

@ -6,6 +6,7 @@
##########
# You can use a diffusers model and VAE on hugging face.
version: "sd15" # 'sd15' or 'sdxl'.
model:
name: stable-diffusion-1-5
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
import diffusers
from modal import App, Image, Mount, Secret
@ -12,79 +11,130 @@ BASE_CACHE_PATH_CONTROLNET = "/vol/cache/controlnet"
BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler"
def download_file(url, file_name, file_path):
"""
Download files.
"""
from urllib.request import Request, urlopen
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
downloaded = urlopen(req).read()
dir_names = os.path.join(file_path, file_name)
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
with open(dir_names, mode="wb") as f:
f.write(downloaded)
class StableDiffusionCLISetupInterface(ABC):
@abstractmethod
def download_model(self):
pass
def download_upscaler():
"""
Download the stabilityai/sd-x2-latent-upscaler.
"""
model_id = "stabilityai/sd-x2-latent-upscaler"
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(model_id)
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
def __init__(self, config: dict, token: str):
if config.get("version") != "sdxl":
raise ValueError("Invalid version. Must be 'sdxl'.")
if config.get("model") is None:
raise ValueError("Model is required. Please provide a model in config.yml.")
self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"]
self.__token: str = token
def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url,
use_auth_token=self.__token,
cache_dir=cache_path,
)
pipe.save_pretrained(cache_path, safe_serialization=True)
def download_controlnet(name: str, repo_id: str, token: str):
"""
Download a controlnet.
"""
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
controlnet = diffusers.ControlNetModel.from_pretrained(
repo_id,
use_auth_token=token,
cache_dir=cache_path,
)
controlnet.save_pretrained(cache_path, safe_serialization=True)
class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
def __init__(self, config: dict, token: str):
if config.get("version") != "sd15":
raise ValueError("Invalid version. Must be 'sd15'.")
if config.get("model") is None:
raise ValueError("Model is required. Please provide a model in config.yml.")
self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"]
self.__token: str = token
def download_model(self) -> None:
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url,
token=self.__token,
cache_dir=cache_path,
)
pipe.save_pretrained(cache_path, safe_serialization=True)
self.__download_upscaler()
def __download_upscaler(self) -> None:
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler"
)
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
def download_vae(name: str, model_url: str, token: str):
"""
Download a vae.
"""
cache_path = os.path.join(BASE_CACHE_PATH, name)
vae = diffusers.AutoencoderKL.from_single_file(
pretrained_model_link_or_path=model_url,
use_auth_token=token,
cache_dir=cache_path,
)
vae.save_pretrained(cache_path, safe_serialization=True)
class CommonSetup:
def __init__(self, config: dict, token: str):
self.__token: str = token
self.__config: dict = config
def download_setup_files(self) -> None:
if self.__config.get("vae") is not None:
self.__download_vae(
name=self.__config["model"]["name"],
model_url=self.__config["vae"]["url"],
token=self.__token,
)
def download_model(name: str, model_url: str, token: str):
"""
Download a model.
"""
cache_path = os.path.join(BASE_CACHE_PATH, name)
pipe = diffusers.StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=model_url,
token=token,
cache_dir=cache_path,
)
pipe.save_pretrained(cache_path, safe_serialization=True)
if self.__config.get("controlnets") is not None:
for controlnet in self.__config["controlnets"]:
self.__download_controlnet(
name=controlnet["name"],
repo_id=controlnet["repo_id"],
token=self.__token,
)
if self.__config.get("loras") is not None:
for lora in self.__config["loras"]:
self.__download_other_file(
url=lora["url"],
file_name=lora["name"],
file_path=BASE_CACHE_PATH_LORA,
)
def download_model_sdxl(name: str, model_url: str, token: str):
"""
Download a sdxl model.
"""
cache_path = os.path.join(BASE_CACHE_PATH, name)
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=model_url,
use_auth_token=token,
cache_dir=cache_path,
)
pipe.save_pretrained(cache_path, safe_serialization=True)
if self.__config.get("textual_inversions") is not None:
for textual_inversion in self.__config["textual_inversions"]:
self.__download_other_file(
url=textual_inversion["url"],
file_name=textual_inversion["name"],
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
)
def __download_vae(self, name: str, model_url: str, token: str):
cache_path = os.path.join(BASE_CACHE_PATH, name)
vae = diffusers.AutoencoderKL.from_single_file(
pretrained_model_link_or_path=model_url,
use_auth_token=token,
cache_dir=cache_path,
)
vae.save_pretrained(cache_path, safe_serialization=True)
def __download_controlnet(self, name: str, repo_id: str, token: str):
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
controlnet = diffusers.ControlNetModel.from_pretrained(
repo_id,
use_auth_token=token,
cache_dir=cache_path,
)
controlnet.save_pretrained(cache_path, safe_serialization=True)
def __download_other_file(self, url, file_name, file_path):
"""
Download file from the given URL for LoRA or TextualInversion.
"""
from urllib.request import Request, urlopen
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
downloaded = urlopen(req).read()
dir_names = os.path.join(file_path, file_name)
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
with open(dir_names, mode="wb") as f:
f.write(downloaded)
def build_image():
@ -93,43 +143,24 @@ def build_image():
"""
import yaml
token = os.environ["HUGGING_FACE_TOKEN"]
config = {}
token: str = os.environ["HUGGING_FACE_TOKEN"]
with open("/config.yml", "r") as file:
config = yaml.safe_load(file)
config: dict = yaml.safe_load(file)
model = config.get("model")
use_xl = config.get("use_xl")
if model is not None:
if use_xl is not None and use_xl:
download_model_sdxl(name=model["name"], model_url=model["url"], token=token)
else:
download_model(name=model["name"], model_url=model["url"], token=token)
vae = config.get("vae")
if vae is not None:
download_vae(name=model["name"], model_url=vae["url"], token=token)
controlnets = config.get("controlnets")
if controlnets is not None:
for controlnet in controlnets:
download_controlnet(name=controlnet["name"], repo_id=controlnet["repo_id"], token=token)
loras = config.get("loras")
if loras is not None:
for lora in loras:
download_file(url=lora["url"], file_name=lora["name"], file_path=BASE_CACHE_PATH_LORA)
textual_inversions = config.get("textual_inversions")
if textual_inversions is not None:
for textual_inversion in textual_inversions:
download_file(
url=textual_inversion["url"],
file_name=textual_inversion["name"],
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
stable_diffusion_setup: StableDiffusionCLISetupInterface
match config.get("version"):
case "sd15":
stable_diffusion_setup = StableDiffusionCLISetupSD15(config, token)
case "sdxl":
stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token)
case _:
raise ValueError(
f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
)
download_upscaler()
stable_diffusion_setup.download_model()
common_setup = CommonSetup(config, token)
common_setup.download_setup_files()
app = App("stable-diffusion-cli")

View File

@ -4,7 +4,9 @@ import modal
import util
stub = modal.Stub("run-stable-diffusion-cli")
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "SD15.run_img2img_inference")
stub.run_inference = modal.Function.from_name(
"stable-diffusion-cli", "SD15.run_img2img_inference"
)
@stub.local_entrypoint()
@ -44,7 +46,9 @@ def main(
)
util.save_images(directory, images, seed_generated, i, output_format)
total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
print(
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
)
prompts: dict[str, int | str] = {
"prompt": prompt,

View File

@ -4,7 +4,9 @@ import modal
import util
app = modal.App("run-stable-diffusion-cli")
run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference")
run_inference = modal.Function.from_name(
"stable-diffusion-cli", "SD15.run_txt2img_inference"
)
@app.local_entrypoint()
@ -46,7 +48,9 @@ def main(
)
util.save_images(directory, images, seed_generated, i, output_format)
total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
print(
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
)
prompts: dict[str, int | str] = {
"prompt": prompt,