Merge pull request #155 from hodanov/feature/refactoring
Refactor setup.py.
This commit is contained in:
commit
b6c26f4616
14
README.md
14
README.md
@ -97,6 +97,7 @@ Add the model used for inference. Use the Safetensors file as is. VAE, LoRA, and
|
||||
|
||||
```yml
|
||||
# ex)
|
||||
version: "sd15" # Specify 'sd15' or 'sdxl'.
|
||||
model:
|
||||
name: stable-diffusion-1-5
|
||||
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors # Specify URL for the safetensor file.
|
||||
@ -117,6 +118,15 @@ loras:
|
||||
url: download_link_here # Specify the download link for the safetensor file.
|
||||
```
|
||||
|
||||
If you want to use SDXL:
|
||||
|
||||
```yml
|
||||
version: "sdxl"
|
||||
model:
|
||||
name: stable-diffusion-xl
|
||||
url: https://huggingface.co/xxxx/xxxx
|
||||
```
|
||||
|
||||
### 4. Setting prompts
|
||||
|
||||
Set the prompt to Makefile.
|
||||
@ -151,6 +161,10 @@ The txt2img inference is executed with the following command.
|
||||
|
||||
```bash
|
||||
make img_by_sd15_txt2img
|
||||
|
||||
or
|
||||
|
||||
make img_by_sdxl_txt2img
|
||||
```
|
||||
|
||||
Thank you.
|
||||
|
||||
14
README_ja.md
14
README_ja.md
@ -99,6 +99,7 @@ HUGGING_FACE_TOKEN="ここにHuggingFaceのトークンを記載する"
|
||||
|
||||
```yml
|
||||
# 設定例
|
||||
version: "sd15" # Specify 'sd15' or 'sdxl'.
|
||||
model:
|
||||
name: stable-diffusion-1-5
|
||||
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors # Specify URL for the safetensor file.
|
||||
@ -119,6 +120,15 @@ loras:
|
||||
url: https://civitai.com/api/download/models/150907?type=Model&format=SafeTensor # ダウンロードリンクを指定
|
||||
```
|
||||
|
||||
SDXLを使いたい場合は`version`に`sdxl`を指定し、urlに使いたいsdxlのモデルを指定します。
|
||||
|
||||
```yml
|
||||
version: "sdxl"
|
||||
model:
|
||||
name: stable-diffusion-xl
|
||||
url: https://huggingface.co/xxxx/xxxx
|
||||
```
|
||||
|
||||
### 4. Makefileの設定(プロンプトの設定)
|
||||
|
||||
プロンプトをMakefileに設定します。
|
||||
@ -164,4 +174,8 @@ make app
|
||||
|
||||
```bash
|
||||
make img_by_sd15_txt2img
|
||||
|
||||
or
|
||||
|
||||
make img_by_sdxl_txt2img
|
||||
```
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
|
||||
##########
|
||||
# You can use a diffusers model and VAE on hugging face.
|
||||
version: "sd15" # 'sd15' or 'sdxl'.
|
||||
model:
|
||||
name: stable-diffusion-1-5
|
||||
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors
|
||||
|
||||
227
app/setup.py
227
app/setup.py
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import diffusers
|
||||
from modal import App, Image, Mount, Secret
|
||||
@ -12,79 +11,130 @@ BASE_CACHE_PATH_CONTROLNET = "/vol/cache/controlnet"
|
||||
BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler"
|
||||
|
||||
|
||||
def download_file(url, file_name, file_path):
|
||||
"""
|
||||
Download files.
|
||||
"""
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
downloaded = urlopen(req).read()
|
||||
dir_names = os.path.join(file_path, file_name)
|
||||
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
||||
with open(dir_names, mode="wb") as f:
|
||||
f.write(downloaded)
|
||||
class StableDiffusionCLISetupInterface(ABC):
|
||||
@abstractmethod
|
||||
def download_model(self):
|
||||
pass
|
||||
|
||||
|
||||
def download_upscaler():
|
||||
"""
|
||||
Download the stabilityai/sd-x2-latent-upscaler.
|
||||
"""
|
||||
model_id = "stabilityai/sd-x2-latent-upscaler"
|
||||
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(model_id)
|
||||
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
|
||||
class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
|
||||
def __init__(self, config: dict, token: str):
|
||||
if config.get("version") != "sdxl":
|
||||
raise ValueError("Invalid version. Must be 'sdxl'.")
|
||||
|
||||
if config.get("model") is None:
|
||||
raise ValueError("Model is required. Please provide a model in config.yml.")
|
||||
|
||||
self.__model_name: str = config["model"]["name"]
|
||||
self.__model_url: str = config["model"]["url"]
|
||||
self.__token: str = token
|
||||
|
||||
def download_model(self) -> None:
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
||||
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
|
||||
pretrained_model_link_or_path=self.__model_url,
|
||||
use_auth_token=self.__token,
|
||||
cache_dir=cache_path,
|
||||
)
|
||||
pipe.save_pretrained(cache_path, safe_serialization=True)
|
||||
|
||||
|
||||
def download_controlnet(name: str, repo_id: str, token: str):
|
||||
"""
|
||||
Download a controlnet.
|
||||
"""
|
||||
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
|
||||
controlnet = diffusers.ControlNetModel.from_pretrained(
|
||||
repo_id,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_path,
|
||||
)
|
||||
controlnet.save_pretrained(cache_path, safe_serialization=True)
|
||||
class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
|
||||
def __init__(self, config: dict, token: str):
|
||||
if config.get("version") != "sd15":
|
||||
raise ValueError("Invalid version. Must be 'sd15'.")
|
||||
|
||||
if config.get("model") is None:
|
||||
raise ValueError("Model is required. Please provide a model in config.yml.")
|
||||
|
||||
self.__model_name: str = config["model"]["name"]
|
||||
self.__model_url: str = config["model"]["url"]
|
||||
self.__token: str = token
|
||||
|
||||
def download_model(self) -> None:
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
||||
pipe = diffusers.StableDiffusionPipeline.from_single_file(
|
||||
pretrained_model_link_or_path=self.__model_url,
|
||||
token=self.__token,
|
||||
cache_dir=cache_path,
|
||||
)
|
||||
pipe.save_pretrained(cache_path, safe_serialization=True)
|
||||
self.__download_upscaler()
|
||||
|
||||
def __download_upscaler(self) -> None:
|
||||
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
|
||||
"stabilityai/sd-x2-latent-upscaler"
|
||||
)
|
||||
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
|
||||
|
||||
|
||||
def download_vae(name: str, model_url: str, token: str):
|
||||
"""
|
||||
Download a vae.
|
||||
"""
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
||||
vae = diffusers.AutoencoderKL.from_single_file(
|
||||
pretrained_model_link_or_path=model_url,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_path,
|
||||
)
|
||||
vae.save_pretrained(cache_path, safe_serialization=True)
|
||||
class CommonSetup:
|
||||
def __init__(self, config: dict, token: str):
|
||||
self.__token: str = token
|
||||
self.__config: dict = config
|
||||
|
||||
def download_setup_files(self) -> None:
|
||||
if self.__config.get("vae") is not None:
|
||||
self.__download_vae(
|
||||
name=self.__config["model"]["name"],
|
||||
model_url=self.__config["vae"]["url"],
|
||||
token=self.__token,
|
||||
)
|
||||
|
||||
def download_model(name: str, model_url: str, token: str):
|
||||
"""
|
||||
Download a model.
|
||||
"""
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
||||
pipe = diffusers.StableDiffusionPipeline.from_single_file(
|
||||
pretrained_model_link_or_path=model_url,
|
||||
token=token,
|
||||
cache_dir=cache_path,
|
||||
)
|
||||
pipe.save_pretrained(cache_path, safe_serialization=True)
|
||||
if self.__config.get("controlnets") is not None:
|
||||
for controlnet in self.__config["controlnets"]:
|
||||
self.__download_controlnet(
|
||||
name=controlnet["name"],
|
||||
repo_id=controlnet["repo_id"],
|
||||
token=self.__token,
|
||||
)
|
||||
|
||||
if self.__config.get("loras") is not None:
|
||||
for lora in self.__config["loras"]:
|
||||
self.__download_other_file(
|
||||
url=lora["url"],
|
||||
file_name=lora["name"],
|
||||
file_path=BASE_CACHE_PATH_LORA,
|
||||
)
|
||||
|
||||
def download_model_sdxl(name: str, model_url: str, token: str):
|
||||
"""
|
||||
Download a sdxl model.
|
||||
"""
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
||||
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
|
||||
pretrained_model_link_or_path=model_url,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_path,
|
||||
)
|
||||
pipe.save_pretrained(cache_path, safe_serialization=True)
|
||||
if self.__config.get("textual_inversions") is not None:
|
||||
for textual_inversion in self.__config["textual_inversions"]:
|
||||
self.__download_other_file(
|
||||
url=textual_inversion["url"],
|
||||
file_name=textual_inversion["name"],
|
||||
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
||||
)
|
||||
|
||||
def __download_vae(self, name: str, model_url: str, token: str):
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
||||
vae = diffusers.AutoencoderKL.from_single_file(
|
||||
pretrained_model_link_or_path=model_url,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_path,
|
||||
)
|
||||
vae.save_pretrained(cache_path, safe_serialization=True)
|
||||
|
||||
def __download_controlnet(self, name: str, repo_id: str, token: str):
|
||||
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
|
||||
controlnet = diffusers.ControlNetModel.from_pretrained(
|
||||
repo_id,
|
||||
use_auth_token=token,
|
||||
cache_dir=cache_path,
|
||||
)
|
||||
controlnet.save_pretrained(cache_path, safe_serialization=True)
|
||||
|
||||
def __download_other_file(self, url, file_name, file_path):
|
||||
"""
|
||||
Download file from the given URL for LoRA or TextualInversion.
|
||||
"""
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
downloaded = urlopen(req).read()
|
||||
dir_names = os.path.join(file_path, file_name)
|
||||
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
||||
with open(dir_names, mode="wb") as f:
|
||||
f.write(downloaded)
|
||||
|
||||
|
||||
def build_image():
|
||||
@ -93,43 +143,24 @@ def build_image():
|
||||
"""
|
||||
import yaml
|
||||
|
||||
token = os.environ["HUGGING_FACE_TOKEN"]
|
||||
config = {}
|
||||
token: str = os.environ["HUGGING_FACE_TOKEN"]
|
||||
with open("/config.yml", "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
config: dict = yaml.safe_load(file)
|
||||
|
||||
model = config.get("model")
|
||||
use_xl = config.get("use_xl")
|
||||
if model is not None:
|
||||
if use_xl is not None and use_xl:
|
||||
download_model_sdxl(name=model["name"], model_url=model["url"], token=token)
|
||||
else:
|
||||
download_model(name=model["name"], model_url=model["url"], token=token)
|
||||
|
||||
vae = config.get("vae")
|
||||
if vae is not None:
|
||||
download_vae(name=model["name"], model_url=vae["url"], token=token)
|
||||
|
||||
controlnets = config.get("controlnets")
|
||||
if controlnets is not None:
|
||||
for controlnet in controlnets:
|
||||
download_controlnet(name=controlnet["name"], repo_id=controlnet["repo_id"], token=token)
|
||||
|
||||
loras = config.get("loras")
|
||||
if loras is not None:
|
||||
for lora in loras:
|
||||
download_file(url=lora["url"], file_name=lora["name"], file_path=BASE_CACHE_PATH_LORA)
|
||||
|
||||
textual_inversions = config.get("textual_inversions")
|
||||
if textual_inversions is not None:
|
||||
for textual_inversion in textual_inversions:
|
||||
download_file(
|
||||
url=textual_inversion["url"],
|
||||
file_name=textual_inversion["name"],
|
||||
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
||||
stable_diffusion_setup: StableDiffusionCLISetupInterface
|
||||
match config.get("version"):
|
||||
case "sd15":
|
||||
stable_diffusion_setup = StableDiffusionCLISetupSD15(config, token)
|
||||
case "sdxl":
|
||||
stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
|
||||
)
|
||||
|
||||
download_upscaler()
|
||||
stable_diffusion_setup.download_model()
|
||||
common_setup = CommonSetup(config, token)
|
||||
common_setup.download_setup_files()
|
||||
|
||||
|
||||
app = App("stable-diffusion-cli")
|
||||
|
||||
@ -4,7 +4,9 @@ import modal
|
||||
import util
|
||||
|
||||
stub = modal.Stub("run-stable-diffusion-cli")
|
||||
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "SD15.run_img2img_inference")
|
||||
stub.run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli", "SD15.run_img2img_inference"
|
||||
)
|
||||
|
||||
|
||||
@stub.local_entrypoint()
|
||||
@ -44,7 +46,9 @@ def main(
|
||||
)
|
||||
util.save_images(directory, images, seed_generated, i, output_format)
|
||||
total_time = time.time() - start_time
|
||||
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
|
||||
print(
|
||||
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
|
||||
)
|
||||
|
||||
prompts: dict[str, int | str] = {
|
||||
"prompt": prompt,
|
||||
|
||||
@ -4,7 +4,9 @@ import modal
|
||||
import util
|
||||
|
||||
app = modal.App("run-stable-diffusion-cli")
|
||||
run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference")
|
||||
run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli", "SD15.run_txt2img_inference"
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
@ -46,7 +48,9 @@ def main(
|
||||
)
|
||||
util.save_images(directory, images, seed_generated, i, output_format)
|
||||
total_time = time.time() - start_time
|
||||
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
|
||||
print(
|
||||
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
|
||||
)
|
||||
|
||||
prompts: dict[str, int | str] = {
|
||||
"prompt": prompt,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user