Merge pull request #155 from hodanov/feature/refactoring
Refactor setup.py.
This commit is contained in:
commit
b6c26f4616
14
README.md
14
README.md
@ -97,6 +97,7 @@ Add the model used for inference. Use the Safetensors file as is. VAE, LoRA, and
|
|||||||
|
|
||||||
```yml
|
```yml
|
||||||
# ex)
|
# ex)
|
||||||
|
version: "sd15" # Specify 'sd15' or 'sdxl'.
|
||||||
model:
|
model:
|
||||||
name: stable-diffusion-1-5
|
name: stable-diffusion-1-5
|
||||||
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors # Specify URL for the safetensor file.
|
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors # Specify URL for the safetensor file.
|
||||||
@ -117,6 +118,15 @@ loras:
|
|||||||
url: download_link_here # Specify the download link for the safetensor file.
|
url: download_link_here # Specify the download link for the safetensor file.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you want to use SDXL:
|
||||||
|
|
||||||
|
```yml
|
||||||
|
version: "sdxl"
|
||||||
|
model:
|
||||||
|
name: stable-diffusion-xl
|
||||||
|
url: https://huggingface.co/xxxx/xxxx
|
||||||
|
```
|
||||||
|
|
||||||
### 4. Setting prompts
|
### 4. Setting prompts
|
||||||
|
|
||||||
Set the prompt to Makefile.
|
Set the prompt to Makefile.
|
||||||
@ -151,6 +161,10 @@ The txt2img inference is executed with the following command.
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
make img_by_sd15_txt2img
|
make img_by_sd15_txt2img
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
make img_by_sdxl_txt2img
|
||||||
```
|
```
|
||||||
|
|
||||||
Thank you.
|
Thank you.
|
||||||
|
|||||||
14
README_ja.md
14
README_ja.md
@ -99,6 +99,7 @@ HUGGING_FACE_TOKEN="ここにHuggingFaceのトークンを記載する"
|
|||||||
|
|
||||||
```yml
|
```yml
|
||||||
# 設定例
|
# 設定例
|
||||||
|
version: "sd15" # Specify 'sd15' or 'sdxl'.
|
||||||
model:
|
model:
|
||||||
name: stable-diffusion-1-5
|
name: stable-diffusion-1-5
|
||||||
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors # Specify URL for the safetensor file.
|
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors # Specify URL for the safetensor file.
|
||||||
@ -119,6 +120,15 @@ loras:
|
|||||||
url: https://civitai.com/api/download/models/150907?type=Model&format=SafeTensor # ダウンロードリンクを指定
|
url: https://civitai.com/api/download/models/150907?type=Model&format=SafeTensor # ダウンロードリンクを指定
|
||||||
```
|
```
|
||||||
|
|
||||||
|
SDXLを使いたい場合は`version`に`sdxl`を指定し、urlに使いたいsdxlのモデルを指定します。
|
||||||
|
|
||||||
|
```yml
|
||||||
|
version: "sdxl"
|
||||||
|
model:
|
||||||
|
name: stable-diffusion-xl
|
||||||
|
url: https://huggingface.co/xxxx/xxxx
|
||||||
|
```
|
||||||
|
|
||||||
### 4. Makefileの設定(プロンプトの設定)
|
### 4. Makefileの設定(プロンプトの設定)
|
||||||
|
|
||||||
プロンプトをMakefileに設定します。
|
プロンプトをMakefileに設定します。
|
||||||
@ -164,4 +174,8 @@ make app
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
make img_by_sd15_txt2img
|
make img_by_sd15_txt2img
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
make img_by_sdxl_txt2img
|
||||||
```
|
```
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
##########
|
##########
|
||||||
# You can use a diffusers model and VAE on hugging face.
|
# You can use a diffusers model and VAE on hugging face.
|
||||||
|
version: "sd15" # 'sd15' or 'sdxl'.
|
||||||
model:
|
model:
|
||||||
name: stable-diffusion-1-5
|
name: stable-diffusion-1-5
|
||||||
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors
|
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors
|
||||||
|
|||||||
227
app/setup.py
227
app/setup.py
@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
from modal import App, Image, Mount, Secret
|
from modal import App, Image, Mount, Secret
|
||||||
@ -12,79 +11,130 @@ BASE_CACHE_PATH_CONTROLNET = "/vol/cache/controlnet"
|
|||||||
BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler"
|
BASE_CACHE_PATH_UPSCALER = "/vol/cache/upscaler"
|
||||||
|
|
||||||
|
|
||||||
def download_file(url, file_name, file_path):
|
class StableDiffusionCLISetupInterface(ABC):
|
||||||
"""
|
@abstractmethod
|
||||||
Download files.
|
def download_model(self):
|
||||||
"""
|
pass
|
||||||
from urllib.request import Request, urlopen
|
|
||||||
|
|
||||||
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
|
||||||
downloaded = urlopen(req).read()
|
|
||||||
dir_names = os.path.join(file_path, file_name)
|
|
||||||
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
|
||||||
with open(dir_names, mode="wb") as f:
|
|
||||||
f.write(downloaded)
|
|
||||||
|
|
||||||
|
|
||||||
def download_upscaler():
|
class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
|
||||||
"""
|
def __init__(self, config: dict, token: str):
|
||||||
Download the stabilityai/sd-x2-latent-upscaler.
|
if config.get("version") != "sdxl":
|
||||||
"""
|
raise ValueError("Invalid version. Must be 'sdxl'.")
|
||||||
model_id = "stabilityai/sd-x2-latent-upscaler"
|
|
||||||
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(model_id)
|
if config.get("model") is None:
|
||||||
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
|
raise ValueError("Model is required. Please provide a model in config.yml.")
|
||||||
|
|
||||||
|
self.__model_name: str = config["model"]["name"]
|
||||||
|
self.__model_url: str = config["model"]["url"]
|
||||||
|
self.__token: str = token
|
||||||
|
|
||||||
|
def download_model(self) -> None:
|
||||||
|
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
||||||
|
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
|
||||||
|
pretrained_model_link_or_path=self.__model_url,
|
||||||
|
use_auth_token=self.__token,
|
||||||
|
cache_dir=cache_path,
|
||||||
|
)
|
||||||
|
pipe.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
def download_controlnet(name: str, repo_id: str, token: str):
|
class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
|
||||||
"""
|
def __init__(self, config: dict, token: str):
|
||||||
Download a controlnet.
|
if config.get("version") != "sd15":
|
||||||
"""
|
raise ValueError("Invalid version. Must be 'sd15'.")
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
|
|
||||||
controlnet = diffusers.ControlNetModel.from_pretrained(
|
if config.get("model") is None:
|
||||||
repo_id,
|
raise ValueError("Model is required. Please provide a model in config.yml.")
|
||||||
use_auth_token=token,
|
|
||||||
cache_dir=cache_path,
|
self.__model_name: str = config["model"]["name"]
|
||||||
)
|
self.__model_url: str = config["model"]["url"]
|
||||||
controlnet.save_pretrained(cache_path, safe_serialization=True)
|
self.__token: str = token
|
||||||
|
|
||||||
|
def download_model(self) -> None:
|
||||||
|
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
||||||
|
pipe = diffusers.StableDiffusionPipeline.from_single_file(
|
||||||
|
pretrained_model_link_or_path=self.__model_url,
|
||||||
|
token=self.__token,
|
||||||
|
cache_dir=cache_path,
|
||||||
|
)
|
||||||
|
pipe.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
self.__download_upscaler()
|
||||||
|
|
||||||
|
def __download_upscaler(self) -> None:
|
||||||
|
upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained(
|
||||||
|
"stabilityai/sd-x2-latent-upscaler"
|
||||||
|
)
|
||||||
|
upscaler.save_pretrained(BASE_CACHE_PATH_UPSCALER, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
def download_vae(name: str, model_url: str, token: str):
|
class CommonSetup:
|
||||||
"""
|
def __init__(self, config: dict, token: str):
|
||||||
Download a vae.
|
self.__token: str = token
|
||||||
"""
|
self.__config: dict = config
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
|
||||||
vae = diffusers.AutoencoderKL.from_single_file(
|
|
||||||
pretrained_model_link_or_path=model_url,
|
|
||||||
use_auth_token=token,
|
|
||||||
cache_dir=cache_path,
|
|
||||||
)
|
|
||||||
vae.save_pretrained(cache_path, safe_serialization=True)
|
|
||||||
|
|
||||||
|
def download_setup_files(self) -> None:
|
||||||
|
if self.__config.get("vae") is not None:
|
||||||
|
self.__download_vae(
|
||||||
|
name=self.__config["model"]["name"],
|
||||||
|
model_url=self.__config["vae"]["url"],
|
||||||
|
token=self.__token,
|
||||||
|
)
|
||||||
|
|
||||||
def download_model(name: str, model_url: str, token: str):
|
if self.__config.get("controlnets") is not None:
|
||||||
"""
|
for controlnet in self.__config["controlnets"]:
|
||||||
Download a model.
|
self.__download_controlnet(
|
||||||
"""
|
name=controlnet["name"],
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
repo_id=controlnet["repo_id"],
|
||||||
pipe = diffusers.StableDiffusionPipeline.from_single_file(
|
token=self.__token,
|
||||||
pretrained_model_link_or_path=model_url,
|
)
|
||||||
token=token,
|
|
||||||
cache_dir=cache_path,
|
|
||||||
)
|
|
||||||
pipe.save_pretrained(cache_path, safe_serialization=True)
|
|
||||||
|
|
||||||
|
if self.__config.get("loras") is not None:
|
||||||
|
for lora in self.__config["loras"]:
|
||||||
|
self.__download_other_file(
|
||||||
|
url=lora["url"],
|
||||||
|
file_name=lora["name"],
|
||||||
|
file_path=BASE_CACHE_PATH_LORA,
|
||||||
|
)
|
||||||
|
|
||||||
def download_model_sdxl(name: str, model_url: str, token: str):
|
if self.__config.get("textual_inversions") is not None:
|
||||||
"""
|
for textual_inversion in self.__config["textual_inversions"]:
|
||||||
Download a sdxl model.
|
self.__download_other_file(
|
||||||
"""
|
url=textual_inversion["url"],
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
file_name=textual_inversion["name"],
|
||||||
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
|
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
||||||
pretrained_model_link_or_path=model_url,
|
)
|
||||||
use_auth_token=token,
|
|
||||||
cache_dir=cache_path,
|
def __download_vae(self, name: str, model_url: str, token: str):
|
||||||
)
|
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
||||||
pipe.save_pretrained(cache_path, safe_serialization=True)
|
vae = diffusers.AutoencoderKL.from_single_file(
|
||||||
|
pretrained_model_link_or_path=model_url,
|
||||||
|
use_auth_token=token,
|
||||||
|
cache_dir=cache_path,
|
||||||
|
)
|
||||||
|
vae.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
|
||||||
|
def __download_controlnet(self, name: str, repo_id: str, token: str):
|
||||||
|
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
|
||||||
|
controlnet = diffusers.ControlNetModel.from_pretrained(
|
||||||
|
repo_id,
|
||||||
|
use_auth_token=token,
|
||||||
|
cache_dir=cache_path,
|
||||||
|
)
|
||||||
|
controlnet.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
|
||||||
|
def __download_other_file(self, url, file_name, file_path):
|
||||||
|
"""
|
||||||
|
Download file from the given URL for LoRA or TextualInversion.
|
||||||
|
"""
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||||
|
downloaded = urlopen(req).read()
|
||||||
|
dir_names = os.path.join(file_path, file_name)
|
||||||
|
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
||||||
|
with open(dir_names, mode="wb") as f:
|
||||||
|
f.write(downloaded)
|
||||||
|
|
||||||
|
|
||||||
def build_image():
|
def build_image():
|
||||||
@ -93,43 +143,24 @@ def build_image():
|
|||||||
"""
|
"""
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
token = os.environ["HUGGING_FACE_TOKEN"]
|
token: str = os.environ["HUGGING_FACE_TOKEN"]
|
||||||
config = {}
|
|
||||||
with open("/config.yml", "r") as file:
|
with open("/config.yml", "r") as file:
|
||||||
config = yaml.safe_load(file)
|
config: dict = yaml.safe_load(file)
|
||||||
|
|
||||||
model = config.get("model")
|
stable_diffusion_setup: StableDiffusionCLISetupInterface
|
||||||
use_xl = config.get("use_xl")
|
match config.get("version"):
|
||||||
if model is not None:
|
case "sd15":
|
||||||
if use_xl is not None and use_xl:
|
stable_diffusion_setup = StableDiffusionCLISetupSD15(config, token)
|
||||||
download_model_sdxl(name=model["name"], model_url=model["url"], token=token)
|
case "sdxl":
|
||||||
else:
|
stable_diffusion_setup = StableDiffusionCLISetupSDXL(config, token)
|
||||||
download_model(name=model["name"], model_url=model["url"], token=token)
|
case _:
|
||||||
|
raise ValueError(
|
||||||
vae = config.get("vae")
|
f"Invalid version: {config.get('version')}. Must be 'sd15' or 'sdxl'."
|
||||||
if vae is not None:
|
|
||||||
download_vae(name=model["name"], model_url=vae["url"], token=token)
|
|
||||||
|
|
||||||
controlnets = config.get("controlnets")
|
|
||||||
if controlnets is not None:
|
|
||||||
for controlnet in controlnets:
|
|
||||||
download_controlnet(name=controlnet["name"], repo_id=controlnet["repo_id"], token=token)
|
|
||||||
|
|
||||||
loras = config.get("loras")
|
|
||||||
if loras is not None:
|
|
||||||
for lora in loras:
|
|
||||||
download_file(url=lora["url"], file_name=lora["name"], file_path=BASE_CACHE_PATH_LORA)
|
|
||||||
|
|
||||||
textual_inversions = config.get("textual_inversions")
|
|
||||||
if textual_inversions is not None:
|
|
||||||
for textual_inversion in textual_inversions:
|
|
||||||
download_file(
|
|
||||||
url=textual_inversion["url"],
|
|
||||||
file_name=textual_inversion["name"],
|
|
||||||
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
download_upscaler()
|
stable_diffusion_setup.download_model()
|
||||||
|
common_setup = CommonSetup(config, token)
|
||||||
|
common_setup.download_setup_files()
|
||||||
|
|
||||||
|
|
||||||
app = App("stable-diffusion-cli")
|
app = App("stable-diffusion-cli")
|
||||||
|
|||||||
@ -4,7 +4,9 @@ import modal
|
|||||||
import util
|
import util
|
||||||
|
|
||||||
stub = modal.Stub("run-stable-diffusion-cli")
|
stub = modal.Stub("run-stable-diffusion-cli")
|
||||||
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "SD15.run_img2img_inference")
|
stub.run_inference = modal.Function.from_name(
|
||||||
|
"stable-diffusion-cli", "SD15.run_img2img_inference"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@stub.local_entrypoint()
|
@stub.local_entrypoint()
|
||||||
@ -44,7 +46,9 @@ def main(
|
|||||||
)
|
)
|
||||||
util.save_images(directory, images, seed_generated, i, output_format)
|
util.save_images(directory, images, seed_generated, i, output_format)
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
|
print(
|
||||||
|
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
|
||||||
|
)
|
||||||
|
|
||||||
prompts: dict[str, int | str] = {
|
prompts: dict[str, int | str] = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
|||||||
@ -4,7 +4,9 @@ import modal
|
|||||||
import util
|
import util
|
||||||
|
|
||||||
app = modal.App("run-stable-diffusion-cli")
|
app = modal.App("run-stable-diffusion-cli")
|
||||||
run_inference = modal.Function.from_name("stable-diffusion-cli", "SDXLTxt2Img.run_inference")
|
run_inference = modal.Function.from_name(
|
||||||
|
"stable-diffusion-cli", "SD15.run_txt2img_inference"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
@ -46,7 +48,9 @@ def main(
|
|||||||
)
|
)
|
||||||
util.save_images(directory, images, seed_generated, i, output_format)
|
util.save_images(directory, images, seed_generated, i, output_format)
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
|
print(
|
||||||
|
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
|
||||||
|
)
|
||||||
|
|
||||||
prompts: dict[str, int | str] = {
|
prompts: dict[str, int | str] = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user