Merge pull request #158 from hodanov/feature/refactoring

Refactor cmd files. Fix some bugs.
This commit is contained in:
hodanov 2024-11-04 11:45:32 +09:00 committed by GitHub
commit 8845d26052
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 212 additions and 158 deletions

View File

@ -1,10 +1,11 @@
.PHONY: all app clean
.PHONY: app
app:
cd ./app && modal deploy __main__.py
img_by_sd15_txt2img:
cd ./cmd && modal run sd15_txt2img.py \
cd ./cmd && modal run txt2img_handler.py::main \
--version "sd15" \
--prompt "a photograph of an astronaut riding a horse" \
--n-prompt "" \
--height 512 \
@ -27,7 +28,8 @@ img_by_sd15_img2img:
--base-image-url "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
img_by_sdxl_txt2img:
cd ./cmd && modal run sdxl_txt2img.py \
cd ./cmd && modal run txt2img_handler.py::main \
--version "sdxl" \
--prompt "A dog is running on the grass" \
--n-prompt "" \
--height 1024 \

View File

@ -7,7 +7,10 @@ This is a Diffusers-based script for running Stable Diffusion on [Modal](https:/
## Features
1. Image generation using txt2img or img2img.
![](assets/20230902_tile_imgs.png)
![example for txt2img](assets/20230902_tile_imgs.png)
Available versions:
- SDXL
- 1.5
2. Upscaling
@ -58,10 +61,8 @@ Images are generated and output to the `outputs/` directory.
├── README.md
├── cmd/ # A directory with scripts to run inference.
│   ├── outputs/ # Images are outputted this directory.
│   ├── sd15_img2img.py # A script to run sd15_img2img inference.
│   ├── sd15_txt2img.py # A script to run sd15_txt2img inference.
│   ├── sdxl_txt2img.py # A script to run sdxl_txt2img inference.
│   └── util.py
...
│   └── txt2img_handler.py # A script to run txt2img inference.
└── app/ # A directory with config files.
├── __main__.py # A main script to run inference.
├── Dockerfile # To build a base image.
@ -133,20 +134,30 @@ Set the prompt to Makefile.
```makefile
# ex)
run:
cd ./cmd && modal run txt2img.py \
--prompt "hogehoge" \
--n-prompt "mogumogu" \
--height 768 \
--width 512 \
--samples 1 \
--steps 30 \
--seed 12321 |
--use-upscaler "True" \
--fix-by-controlnet-tile "True" \
--output-fomart "avif"
img_by_sdxl_txt2img:
cd ./cmd && modal run txt2img_handler.py::main \
--version "sdxl" \
--prompt "A dog is running on the grass" \
--n-prompt "" \
--height 1024 \
--width 1024 \
--samples 1 \
--steps 30 \
--use-upscaler "True" \
--output-format "avif"
```
- prompt: Specifies the prompt.
- n-prompt: Specifies a negative prompt.
- height: Specifies the height of the image.
- width: Specifies the width of the image.
- samples: Specifies the number of images to generate.
- steps: Specifies the number of steps.
- seed: Specifies the seed.
- use-upscaler: Enables the upscaler to increase the image resolution.
- fix-by-controlnet-tile: Specifies whether to use ControlNet 1.1 Tile. If enabled, it will repair broken images and generate high-resolution images. Only sd15 is supported.
- output-format: Specifies the output format. Only avif and png are supported.
### 5. Deploy an application
Execute the below command. An application will be deployed on Modal.

View File

@ -5,8 +5,10 @@
## このスクリプトでできること
1. txt2imgまたはimt2imgによる画像生成ができます。
![txt2imgでの生成画像例](assets/20230902_tile_imgs.png)
![txt2imgでの生成画像例](assets/20230902_tile_imgs.png)
利用可能なバージョン:
- SDXL
- 1.5
2. アップスケーラーとControlNet Tileを利用した高解像度な画像を生成することができます。
@ -58,10 +60,8 @@ modal token new
├── README.md
├── cmd/ # A directory with scripts to run inference.
│   ├── outputs/ # Images are outputted this directory.
│   ├── sd15_img2img.py # A script to run sd15_img2img inference.
│   ├── sd15_txt2img.py # A script to run sd15_txt2img inference.
│   ├── sdxl_txt2img.py # A script to run sdxl_txt2img inference.
│   └── util.py
...
│   └── txt2img_handler.py # A script to run txt2img inference.
└── app/ # A directory with config files.
├── __main__.py # A main script to run inference.
├── Dockerfile # To build a base image.
@ -135,18 +135,17 @@ model:
```makefile
# 設定例
run:
cd ./cmd && modal run txt2img.py \
--prompt "hogehoge" \
--n-prompt "mogumogu" \
--height 768 \
--width 512 \
--samples 1 \
--steps 30 \
--seed 12321 |
--use-upscaler "True" \
--fix-by-controlnet-tile "True" \
--output-fomart "png"
img_by_sdxl_txt2img:
cd ./cmd && modal run txt2img_handler.py::main \
--version "sdxl" \
--prompt "A dog is running on the grass" \
--n-prompt "" \
--height 1024 \
--width 1024 \
--samples 1 \
--steps 30 \
--use-upscaler "True" \
--output-format "avif"
```
- prompt: プロンプトを指定します。
@ -157,8 +156,8 @@ run:
- steps: ステップ数を指定します。
- seed: seedを指定します。
- use-upscaler: 画像の解像度を上げるためのアップスケーラーを有効にします。
- fix-by-controlnet-tile: ControlNet 1.1 Tileの利用有無を指定します。有効にすると、崩れた画像を修復しつつ、高解像度な画像を生成します。
- output-format: 出力フォーマットを指定します。avifも指定可能です
- fix-by-controlnet-tile: ControlNet 1.1 Tileの利用有無を指定します。有効にすると、崩れた画像を修復しつつ、高解像度な画像を生成します。sd15のみ対応。
- output-format: 出力フォーマットを指定します。avifとpngのみ対応
### 5. アプリケーションをデプロイする

View File

@ -1,8 +1,8 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
import diffusers
from huggingface_hub import login
from modal import App, Image, Mount, Secret
BASE_CACHE_PATH = "/vol/cache"
@ -30,10 +30,13 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface):
self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"]
if token != "":
login(token)
self.__token: str = token
def download_model(self) -> None:
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url,
use_auth_token=self.__token,
@ -54,10 +57,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface):
self.__model_name: str = config["model"]["name"]
self.__model_url: str = config["model"]["url"]
if token != "":
login(token)
self.__token: str = token
def download_model(self) -> None:
cache_path = Path(BASE_CACHE_PATH, self.__model_name)
cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name)
pipe = diffusers.StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=self.__model_url,
token=self.__token,
@ -111,7 +117,7 @@ class CommonSetup:
)
def __download_vae(self, name: str, model_url: str, token: str) -> None:
cache_path = Path(BASE_CACHE_PATH, name)
cache_path = os.path.join(BASE_CACHE_PATH, name)
vae = diffusers.AutoencoderKL.from_single_file(
pretrained_model_link_or_path=model_url,
use_auth_token=token,
@ -120,7 +126,7 @@ class CommonSetup:
vae.save_pretrained(cache_path, safe_serialization=True)
def __download_controlnet(self, name: str, repo_id: str, token: str) -> None:
cache_path = Path(BASE_CACHE_PATH, name)
cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name)
controlnet = diffusers.ControlNetModel.from_pretrained(
repo_id,
use_auth_token=token,
@ -136,7 +142,7 @@ class CommonSetup:
req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
downloaded = urlopen(req).read()
dir_names = Path(file_path, file_name)
dir_names = os.path.join(file_path, file_name)
os.makedirs(os.path.dirname(dir_names), exist_ok=True)
with open(dir_names, mode="wb") as f:
f.write(downloaded)

View File

@ -39,10 +39,6 @@ class Prompts:
msg = "prompt should not be empty."
raise ValueError(msg)
if n_prompt == "":
msg = "n_prompt should not be empty."
raise ValueError(msg)
if height <= 0:
msg = "height should be positive."
raise ValueError(msg)
@ -59,18 +55,36 @@ class Prompts:
msg = "steps should be positive."
raise ValueError(msg)
self.__dict: dict[str, int | str] = {
"prompt": prompt,
"n_prompt": n_prompt,
"height": height,
"width": width,
"samples": samples,
"steps": steps,
}
self.__prompt = prompt
self.__n_prompt = n_prompt
self.__height = height
self.__width = width
self.__samples = samples
self.__steps = steps
@property
def dict(self) -> dict[str, int | str]:
return self.__dict
def prompt(self) -> str:
return self.__prompt
@property
def n_prompt(self) -> str:
return self.__n_prompt
@property
def height(self) -> int:
return self.__height
@property
def width(self) -> int:
return self.__width
@property
def samples(self) -> int:
return self.__samples
@property
def steps(self) -> int:
return self.__steps
class OutputDirectory:
@ -100,8 +114,8 @@ class StableDiffusionOutputManger:
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = f"{self.__output_directory}/prompts_{prompts_filename}.txt"
with Path(output_path).open("wb") as file:
for name, value in self.__prompts.dict.items():
file.write(f"{name} = {value!r}\n".encode())
for key, value in vars(self.__prompts).items():
file.write(f"{key} = {value!r}\n".encode())
return output_path

98
cmd/infrasctucture.py Normal file
View File

@ -0,0 +1,98 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import modal
if TYPE_CHECKING:
from domain import Prompts, Seed
class Txt2ImgInterface(ABC):
@abstractmethod
def run_inference(self, seed: Seed) -> list[bytes]:
pass
class SDXLTxt2Img(Txt2ImgInterface):
def __init__(self, prompts: Prompts, output_format: str, *, use_upscaler: bool) -> None:
self.__prompts = prompts
self.__output_format = output_format
self.__use_upscaler = use_upscaler
self.__run_inference = modal.Function.from_name(
"stable-diffusion-cli",
"SDXLTxt2Img.run_inference",
)
def run_inference(self, seed: Seed) -> list[bytes]:
return self.__run_inference.remote(
prompt=self.__prompts.prompt,
n_prompt=self.__prompts.n_prompt,
height=self.__prompts.height,
width=self.__prompts.width,
steps=self.__prompts.steps,
seed=seed.value,
use_upscaler=self.__use_upscaler,
output_format=self.__output_format,
)
class SD15Txt2Img(Txt2ImgInterface):
def __init__(
self,
prompts: Prompts,
output_format: str,
*,
use_upscaler: bool,
fix_by_controlnet_tile: bool,
) -> None:
self.__prompts = prompts
self.__output_format = output_format
self.__use_upscaler = use_upscaler
self.__fix_by_controlnet_tile = fix_by_controlnet_tile
self.__run_inference = modal.Function.from_name(
"stable-diffusion-cli",
"SD15.run_txt2img_inference",
)
def run_inference(self, seed: Seed) -> list[bytes]:
return self.__run_inference.remote(
prompt=self.__prompts.prompt,
n_prompt=self.__prompts.n_prompt,
height=self.__prompts.height,
width=self.__prompts.width,
batch_size=1,
steps=self.__prompts.steps,
seed=seed.value,
use_upscaler=self.__use_upscaler,
fix_by_controlnet_tile=self.__fix_by_controlnet_tile,
output_format=self.__output_format,
)
def new_txt2img(
version: str,
prompts: Prompts,
output_format: str,
*,
use_upscaler: bool,
fix_by_controlnet_tile: bool,
) -> Txt2ImgInterface:
match version:
case "sd15":
return SD15Txt2Img(
prompts=prompts,
output_format=output_format,
use_upscaler=use_upscaler,
fix_by_controlnet_tile=fix_by_controlnet_tile,
)
case "sdxl":
return SDXLTxt2Img(
prompts=prompts,
use_upscaler=use_upscaler,
output_format=output_format,
)
case _:
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
raise ValueError(msg)

View File

@ -1,72 +0,0 @@
import logging
import time
import domain
import modal
app = modal.App("run-stable-diffusion-cli")
run_inference = modal.Function.from_name(
"stable-diffusion-cli",
"SD15.run_txt2img_inference",
)
@app.local_entrypoint()
def main(
prompt: str,
n_prompt: str,
height: int = 512,
width: int = 512,
samples: int = 5,
steps: int = 20,
seed: int = -1,
use_upscaler: str = "",
fix_by_controlnet_tile: str = "False",
output_format: str = "png",
) -> None:
"""main() is the entrypoint for the Runway CLI.
This pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local.
"""
logging.basicConfig(
level=logging.INFO,
format="[%(levelname)s] %(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("run-stable-diffusion-cli")
output_directory = domain.OutputDirectory()
directory_path = output_directory.make_directory()
logger.info("Made a directory: %s", directory_path)
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
for sample_index in range(samples):
new_seed = domain.Seed(seed)
start_time = time.time()
images = run_inference.remote(
prompt=prompt,
n_prompt=n_prompt,
height=height,
width=width,
batch_size=1,
steps=steps,
seed=new_seed.value,
use_upscaler=use_upscaler == "True",
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
output_format=output_format,
)
for generated_image_index, image_bytes in enumerate(images):
saved_path = sd_output_manager.save_image(
image_bytes,
new_seed.value,
sample_index,
generated_image_index,
output_format,
)
logger.info("Saved image to the: %s", saved_path)
total_time = time.time() - start_time
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
saved_prompts_path = sd_output_manager.save_prompts()
logger.info("Saved prompts: %s", saved_prompts_path)

View File

@ -1,18 +1,16 @@
from __future__ import annotations
import logging
import time
import domain
import modal
app = modal.App("run-stable-diffusion-cli")
run_inference = modal.Function.from_name(
"stable-diffusion-cli",
"SDXLTxt2Img.run_inference",
)
from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger
from infrasctucture import new_txt2img
@app.local_entrypoint()
@modal.App("run-stable-diffusion-cli").local_entrypoint()
def main(
version: str,
prompt: str,
n_prompt: str,
height: int = 1024,
@ -21,6 +19,7 @@ def main(
steps: int = 20,
seed: int = -1,
use_upscaler: str = "False",
fix_by_controlnet_tile: str = "True",
output_format: str = "png",
) -> None:
"""This function is the entrypoint for the Runway CLI.
@ -34,27 +33,25 @@ def main(
)
logger = logging.getLogger("run-stable-diffusion-cli")
output_directory = domain.OutputDirectory()
output_directory = OutputDirectory()
directory_path = output_directory.make_directory()
logger.info("Made a directory: %s", directory_path)
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
prompts = Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = StableDiffusionOutputManger(prompts, directory_path)
txt2img = new_txt2img(
version,
prompts,
output_format,
use_upscaler=use_upscaler == "True",
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
)
for sample_index in range(samples):
new_seed = domain.Seed(seed)
start_time = time.time()
images = run_inference.remote(
prompt=prompt,
n_prompt=n_prompt,
height=height,
width=width,
steps=steps,
seed=new_seed.value,
use_upscaler=use_upscaler == "True",
output_format=output_format,
)
new_seed = Seed(seed)
images = txt2img.run_inference(new_seed)
for generated_image_index, image_bytes in enumerate(images):
saved_path = sd_output_manager.save_image(
image_bytes,
@ -64,7 +61,6 @@ def main(
output_format,
)
logger.info("Saved image to the: %s", saved_path)
total_time = time.time() - start_time
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))