Merge pull request #158 from hodanov/feature/refactoring
Refactor cmd files. Fix some bugs.
This commit is contained in:
commit
8845d26052
8
Makefile
8
Makefile
@ -1,10 +1,11 @@
|
||||
.PHONY: all app clean
|
||||
.PHONY: app
|
||||
|
||||
app:
|
||||
cd ./app && modal deploy __main__.py
|
||||
|
||||
img_by_sd15_txt2img:
|
||||
cd ./cmd && modal run sd15_txt2img.py \
|
||||
cd ./cmd && modal run txt2img_handler.py::main \
|
||||
--version "sd15" \
|
||||
--prompt "a photograph of an astronaut riding a horse" \
|
||||
--n-prompt "" \
|
||||
--height 512 \
|
||||
@ -27,7 +28,8 @@ img_by_sd15_img2img:
|
||||
--base-image-url "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
|
||||
|
||||
img_by_sdxl_txt2img:
|
||||
cd ./cmd && modal run sdxl_txt2img.py \
|
||||
cd ./cmd && modal run txt2img_handler.py::main \
|
||||
--version "sdxl" \
|
||||
--prompt "A dog is running on the grass" \
|
||||
--n-prompt "" \
|
||||
--height 1024 \
|
||||
|
||||
45
README.md
45
README.md
@ -7,7 +7,10 @@ This is a Diffusers-based script for running Stable Diffusion on [Modal](https:/
|
||||
## Features
|
||||
|
||||
1. Image generation using txt2img or img2img.
|
||||

|
||||

|
||||
Available versions:
|
||||
- SDXL
|
||||
- 1.5
|
||||
|
||||
2. Upscaling
|
||||
|
||||
@ -58,10 +61,8 @@ Images are generated and output to the `outputs/` directory.
|
||||
├── README.md
|
||||
├── cmd/ # A directory with scripts to run inference.
|
||||
│ ├── outputs/ # Images are outputted this directory.
|
||||
│ ├── sd15_img2img.py # A script to run sd15_img2img inference.
|
||||
│ ├── sd15_txt2img.py # A script to run sd15_txt2img inference.
|
||||
│ ├── sdxl_txt2img.py # A script to run sdxl_txt2img inference.
|
||||
│ └── util.py
|
||||
...
|
||||
│ └── txt2img_handler.py # A script to run txt2img inference.
|
||||
└── app/ # A directory with config files.
|
||||
├── __main__.py # A main script to run inference.
|
||||
├── Dockerfile # To build a base image.
|
||||
@ -133,20 +134,30 @@ Set the prompt to Makefile.
|
||||
|
||||
```makefile
|
||||
# ex)
|
||||
run:
|
||||
cd ./cmd && modal run txt2img.py \
|
||||
--prompt "hogehoge" \
|
||||
--n-prompt "mogumogu" \
|
||||
--height 768 \
|
||||
--width 512 \
|
||||
--samples 1 \
|
||||
--steps 30 \
|
||||
--seed 12321 |
|
||||
--use-upscaler "True" \
|
||||
--fix-by-controlnet-tile "True" \
|
||||
--output-fomart "avif"
|
||||
img_by_sdxl_txt2img:
|
||||
cd ./cmd && modal run txt2img_handler.py::main \
|
||||
--version "sdxl" \
|
||||
--prompt "A dog is running on the grass" \
|
||||
--n-prompt "" \
|
||||
--height 1024 \
|
||||
--width 1024 \
|
||||
--samples 1 \
|
||||
--steps 30 \
|
||||
--use-upscaler "True" \
|
||||
--output-format "avif"
|
||||
```
|
||||
|
||||
- prompt: Specifies the prompt.
|
||||
- n-prompt: Specifies a negative prompt.
|
||||
- height: Specifies the height of the image.
|
||||
- width: Specifies the width of the image.
|
||||
- samples: Specifies the number of images to generate.
|
||||
- steps: Specifies the number of steps.
|
||||
- seed: Specifies the seed.
|
||||
- use-upscaler: Enables the upscaler to increase the image resolution.
|
||||
- fix-by-controlnet-tile: Specifies whether to use ControlNet 1.1 Tile. If enabled, it will repair broken images and generate high-resolution images. Only sd15 is supported.
|
||||
- output-format: Specifies the output format. Only avif and png are supported.
|
||||
|
||||
### 5. Deploy an application
|
||||
|
||||
Execute the below command. An application will be deployed on Modal.
|
||||
|
||||
39
README_ja.md
39
README_ja.md
@ -5,8 +5,10 @@
|
||||
## このスクリプトでできること
|
||||
|
||||
1. txt2imgまたはimt2imgによる画像生成ができます。
|
||||
|
||||

|
||||

|
||||
利用可能なバージョン:
|
||||
- SDXL
|
||||
- 1.5
|
||||
|
||||
2. アップスケーラーとControlNet Tileを利用した高解像度な画像を生成することができます。
|
||||
|
||||
@ -58,10 +60,8 @@ modal token new
|
||||
├── README.md
|
||||
├── cmd/ # A directory with scripts to run inference.
|
||||
│ ├── outputs/ # Images are outputted this directory.
|
||||
│ ├── sd15_img2img.py # A script to run sd15_img2img inference.
|
||||
│ ├── sd15_txt2img.py # A script to run sd15_txt2img inference.
|
||||
│ ├── sdxl_txt2img.py # A script to run sdxl_txt2img inference.
|
||||
│ └── util.py
|
||||
...
|
||||
│ └── txt2img_handler.py # A script to run txt2img inference.
|
||||
└── app/ # A directory with config files.
|
||||
├── __main__.py # A main script to run inference.
|
||||
├── Dockerfile # To build a base image.
|
||||
@ -135,18 +135,17 @@ model:
|
||||
|
||||
```makefile
|
||||
# 設定例
|
||||
run:
|
||||
cd ./cmd && modal run txt2img.py \
|
||||
--prompt "hogehoge" \
|
||||
--n-prompt "mogumogu" \
|
||||
--height 768 \
|
||||
--width 512 \
|
||||
--samples 1 \
|
||||
--steps 30 \
|
||||
--seed 12321 |
|
||||
--use-upscaler "True" \
|
||||
--fix-by-controlnet-tile "True" \
|
||||
--output-fomart "png"
|
||||
img_by_sdxl_txt2img:
|
||||
cd ./cmd && modal run txt2img_handler.py::main \
|
||||
--version "sdxl" \
|
||||
--prompt "A dog is running on the grass" \
|
||||
--n-prompt "" \
|
||||
--height 1024 \
|
||||
--width 1024 \
|
||||
--samples 1 \
|
||||
--steps 30 \
|
||||
--use-upscaler "True" \
|
||||
--output-format "avif"
|
||||
```
|
||||
|
||||
- prompt: プロンプトを指定します。
|
||||
@ -157,8 +156,8 @@ run:
|
||||
- steps: ステップ数を指定します。
|
||||
- seed: seedを指定します。
|
||||
- use-upscaler: 画像の解像度を上げるためのアップスケーラーを有効にします。
|
||||
- fix-by-controlnet-tile: ControlNet 1.1 Tileの利用有無を指定します。有効にすると、崩れた画像を修復しつつ、高解像度な画像を生成します。
|
||||
- output-format: 出力フォーマットを指定します。avifも指定可能です。
|
||||
- fix-by-controlnet-tile: ControlNet 1.1 Tileの利用有無を指定します。有効にすると、崩れた画像を修復しつつ、高解像度な画像を生成します。sd15のみ対応。
|
||||
- output-format: 出力フォーマットを指定します。avifとpngのみ対応。
|
||||
|
||||
### 5. アプリケーションをデプロイする
|
||||
|
||||
|
||||
18
app/setup.py
18
app/setup.py
@ -1,8 +1,8 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
import diffusers
|
||||
from huggingface_hub import login
|
||||
from modal import App, Image, Mount, Secret
|
||||
|
||||
BASE_CACHE_PATH = "/vol/cache"
|
||||
@ -30,10 +30,13 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
|
||||
|
||||
self.__model_name: str = config["model"]["name"]
|
||||
self.__model_url: str = config["model"]["url"]
|
||||
|
||||
if token != "":
|
||||
login(token)
|
||||
self.__token: str = token
|
||||
|
||||
def download_model(self) -> None:
|
||||
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
||||
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
|
||||
pretrained_model_link_or_path=self.__model_url,
|
||||
use_auth_token=self.__token,
|
||||
@ -54,10 +57,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
|
||||
|
||||
self.__model_name: str = config["model"]["name"]
|
||||
self.__model_url: str = config["model"]["url"]
|
||||
|
||||
if token != "":
|
||||
login(token)
|
||||
self.__token: str = token
|
||||
|
||||
def download_model(self) -> None:
|
||||
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
|
||||
pipe = diffusers.StableDiffusionPipeline.from_single_file(
|
||||
pretrained_model_link_or_path=self.__model_url,
|
||||
token=self.__token,
|
||||
@ -111,7 +117,7 @@ class CommonSetup:
|
||||
)
|
||||
|
||||
def __download_vae(self, name: str, model_url: str, token: str) -> None:
|
||||
cache_path = Path(BASE_CACHE_PATH, name)
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
||||
vae = diffusers.AutoencoderKL.from_single_file(
|
||||
pretrained_model_link_or_path=model_url,
|
||||
use_auth_token=token,
|
||||
@ -120,7 +126,7 @@ class CommonSetup:
|
||||
vae.save_pretrained(cache_path, safe_serialization=True)
|
||||
|
||||
def __download_controlnet(self, name: str, repo_id: str, token: str) -> None:
|
||||
cache_path = Path(BASE_CACHE_PATH, name)
|
||||
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
|
||||
controlnet = diffusers.ControlNetModel.from_pretrained(
|
||||
repo_id,
|
||||
use_auth_token=token,
|
||||
@ -136,7 +142,7 @@ class CommonSetup:
|
||||
|
||||
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
downloaded = urlopen(req).read()
|
||||
dir_names = Path(file_path, file_name)
|
||||
dir_names = os.path.join(file_path, file_name)
|
||||
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
|
||||
with open(dir_names, mode="wb") as f:
|
||||
f.write(downloaded)
|
||||
|
||||
@ -39,10 +39,6 @@ class Prompts:
|
||||
msg = "prompt should not be empty."
|
||||
raise ValueError(msg)
|
||||
|
||||
if n_prompt == "":
|
||||
msg = "n_prompt should not be empty."
|
||||
raise ValueError(msg)
|
||||
|
||||
if height <= 0:
|
||||
msg = "height should be positive."
|
||||
raise ValueError(msg)
|
||||
@ -59,18 +55,36 @@ class Prompts:
|
||||
msg = "steps should be positive."
|
||||
raise ValueError(msg)
|
||||
|
||||
self.__dict: dict[str, int | str] = {
|
||||
"prompt": prompt,
|
||||
"n_prompt": n_prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"samples": samples,
|
||||
"steps": steps,
|
||||
}
|
||||
self.__prompt = prompt
|
||||
self.__n_prompt = n_prompt
|
||||
self.__height = height
|
||||
self.__width = width
|
||||
self.__samples = samples
|
||||
self.__steps = steps
|
||||
|
||||
@property
|
||||
def dict(self) -> dict[str, int | str]:
|
||||
return self.__dict
|
||||
def prompt(self) -> str:
|
||||
return self.__prompt
|
||||
|
||||
@property
|
||||
def n_prompt(self) -> str:
|
||||
return self.__n_prompt
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return self.__height
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
return self.__width
|
||||
|
||||
@property
|
||||
def samples(self) -> int:
|
||||
return self.__samples
|
||||
|
||||
@property
|
||||
def steps(self) -> int:
|
||||
return self.__steps
|
||||
|
||||
|
||||
class OutputDirectory:
|
||||
@ -100,8 +114,8 @@ class StableDiffusionOutputManger:
|
||||
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
|
||||
output_path = f"{self.__output_directory}/prompts_{prompts_filename}.txt"
|
||||
with Path(output_path).open("wb") as file:
|
||||
for name, value in self.__prompts.dict.items():
|
||||
file.write(f"{name} = {value!r}\n".encode())
|
||||
for key, value in vars(self.__prompts).items():
|
||||
file.write(f"{key} = {value!r}\n".encode())
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
98
cmd/infrasctucture.py
Normal file
98
cmd/infrasctucture.py
Normal file
@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import modal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from domain import Prompts, Seed
|
||||
|
||||
|
||||
class Txt2ImgInterface(ABC):
|
||||
@abstractmethod
|
||||
def run_inference(self, seed: Seed) -> list[bytes]:
|
||||
pass
|
||||
|
||||
|
||||
class SDXLTxt2Img(Txt2ImgInterface):
|
||||
def __init__(self, prompts: Prompts, output_format: str, *, use_upscaler: bool) -> None:
|
||||
self.__prompts = prompts
|
||||
self.__output_format = output_format
|
||||
self.__use_upscaler = use_upscaler
|
||||
self.__run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli",
|
||||
"SDXLTxt2Img.run_inference",
|
||||
)
|
||||
|
||||
def run_inference(self, seed: Seed) -> list[bytes]:
|
||||
return self.__run_inference.remote(
|
||||
prompt=self.__prompts.prompt,
|
||||
n_prompt=self.__prompts.n_prompt,
|
||||
height=self.__prompts.height,
|
||||
width=self.__prompts.width,
|
||||
steps=self.__prompts.steps,
|
||||
seed=seed.value,
|
||||
use_upscaler=self.__use_upscaler,
|
||||
output_format=self.__output_format,
|
||||
)
|
||||
|
||||
|
||||
class SD15Txt2Img(Txt2ImgInterface):
|
||||
def __init__(
|
||||
self,
|
||||
prompts: Prompts,
|
||||
output_format: str,
|
||||
*,
|
||||
use_upscaler: bool,
|
||||
fix_by_controlnet_tile: bool,
|
||||
) -> None:
|
||||
self.__prompts = prompts
|
||||
self.__output_format = output_format
|
||||
self.__use_upscaler = use_upscaler
|
||||
self.__fix_by_controlnet_tile = fix_by_controlnet_tile
|
||||
self.__run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli",
|
||||
"SD15.run_txt2img_inference",
|
||||
)
|
||||
|
||||
def run_inference(self, seed: Seed) -> list[bytes]:
|
||||
return self.__run_inference.remote(
|
||||
prompt=self.__prompts.prompt,
|
||||
n_prompt=self.__prompts.n_prompt,
|
||||
height=self.__prompts.height,
|
||||
width=self.__prompts.width,
|
||||
batch_size=1,
|
||||
steps=self.__prompts.steps,
|
||||
seed=seed.value,
|
||||
use_upscaler=self.__use_upscaler,
|
||||
fix_by_controlnet_tile=self.__fix_by_controlnet_tile,
|
||||
output_format=self.__output_format,
|
||||
)
|
||||
|
||||
|
||||
def new_txt2img(
|
||||
version: str,
|
||||
prompts: Prompts,
|
||||
output_format: str,
|
||||
*,
|
||||
use_upscaler: bool,
|
||||
fix_by_controlnet_tile: bool,
|
||||
) -> Txt2ImgInterface:
|
||||
match version:
|
||||
case "sd15":
|
||||
return SD15Txt2Img(
|
||||
prompts=prompts,
|
||||
output_format=output_format,
|
||||
use_upscaler=use_upscaler,
|
||||
fix_by_controlnet_tile=fix_by_controlnet_tile,
|
||||
)
|
||||
case "sdxl":
|
||||
return SDXLTxt2Img(
|
||||
prompts=prompts,
|
||||
use_upscaler=use_upscaler,
|
||||
output_format=output_format,
|
||||
)
|
||||
case _:
|
||||
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
|
||||
raise ValueError(msg)
|
||||
@ -1,72 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import domain
|
||||
import modal
|
||||
|
||||
app = modal.App("run-stable-diffusion-cli")
|
||||
run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli",
|
||||
"SD15.run_txt2img_inference",
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main(
|
||||
prompt: str,
|
||||
n_prompt: str,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
samples: int = 5,
|
||||
steps: int = 20,
|
||||
seed: int = -1,
|
||||
use_upscaler: str = "",
|
||||
fix_by_controlnet_tile: str = "False",
|
||||
output_format: str = "png",
|
||||
) -> None:
|
||||
"""main() is the entrypoint for the Runway CLI.
|
||||
This pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local.
|
||||
"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(levelname)s] %(asctime)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("run-stable-diffusion-cli")
|
||||
|
||||
output_directory = domain.OutputDirectory()
|
||||
directory_path = output_directory.make_directory()
|
||||
logger.info("Made a directory: %s", directory_path)
|
||||
|
||||
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
|
||||
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
|
||||
for sample_index in range(samples):
|
||||
new_seed = domain.Seed(seed)
|
||||
start_time = time.time()
|
||||
images = run_inference.remote(
|
||||
prompt=prompt,
|
||||
n_prompt=n_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
batch_size=1,
|
||||
steps=steps,
|
||||
seed=new_seed.value,
|
||||
use_upscaler=use_upscaler == "True",
|
||||
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
|
||||
output_format=output_format,
|
||||
)
|
||||
for generated_image_index, image_bytes in enumerate(images):
|
||||
saved_path = sd_output_manager.save_image(
|
||||
image_bytes,
|
||||
new_seed.value,
|
||||
sample_index,
|
||||
generated_image_index,
|
||||
output_format,
|
||||
)
|
||||
logger.info("Saved image to the: %s", saved_path)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
|
||||
|
||||
saved_prompts_path = sd_output_manager.save_prompts()
|
||||
logger.info("Saved prompts: %s", saved_prompts_path)
|
||||
@ -1,18 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import domain
|
||||
import modal
|
||||
|
||||
app = modal.App("run-stable-diffusion-cli")
|
||||
run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli",
|
||||
"SDXLTxt2Img.run_inference",
|
||||
)
|
||||
from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger
|
||||
from infrasctucture import new_txt2img
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
@modal.App("run-stable-diffusion-cli").local_entrypoint()
|
||||
def main(
|
||||
version: str,
|
||||
prompt: str,
|
||||
n_prompt: str,
|
||||
height: int = 1024,
|
||||
@ -21,6 +19,7 @@ def main(
|
||||
steps: int = 20,
|
||||
seed: int = -1,
|
||||
use_upscaler: str = "False",
|
||||
fix_by_controlnet_tile: str = "True",
|
||||
output_format: str = "png",
|
||||
) -> None:
|
||||
"""This function is the entrypoint for the Runway CLI.
|
||||
@ -34,27 +33,25 @@ def main(
|
||||
)
|
||||
logger = logging.getLogger("run-stable-diffusion-cli")
|
||||
|
||||
output_directory = domain.OutputDirectory()
|
||||
output_directory = OutputDirectory()
|
||||
directory_path = output_directory.make_directory()
|
||||
logger.info("Made a directory: %s", directory_path)
|
||||
|
||||
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
|
||||
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
|
||||
prompts = Prompts(prompt, n_prompt, height, width, samples, steps)
|
||||
sd_output_manager = StableDiffusionOutputManger(prompts, directory_path)
|
||||
|
||||
txt2img = new_txt2img(
|
||||
version,
|
||||
prompts,
|
||||
output_format,
|
||||
use_upscaler=use_upscaler == "True",
|
||||
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
|
||||
)
|
||||
|
||||
for sample_index in range(samples):
|
||||
new_seed = domain.Seed(seed)
|
||||
start_time = time.time()
|
||||
images = run_inference.remote(
|
||||
prompt=prompt,
|
||||
n_prompt=n_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
steps=steps,
|
||||
seed=new_seed.value,
|
||||
use_upscaler=use_upscaler == "True",
|
||||
output_format=output_format,
|
||||
)
|
||||
|
||||
new_seed = Seed(seed)
|
||||
images = txt2img.run_inference(new_seed)
|
||||
for generated_image_index, image_bytes in enumerate(images):
|
||||
saved_path = sd_output_manager.save_image(
|
||||
image_bytes,
|
||||
@ -64,7 +61,6 @@ def main(
|
||||
output_format,
|
||||
)
|
||||
logger.info("Saved image to the: %s", saved_path)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user