diff --git a/Makefile b/Makefile index 8ae5f00..f9be353 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,11 @@ -.PHONY: all app clean +.PHONY: app app: cd ./app && modal deploy __main__.py img_by_sd15_txt2img: - cd ./cmd && modal run sd15_txt2img.py \ + cd ./cmd && modal run txt2img_handler.py::main \ + --version "sd15" \ --prompt "a photograph of an astronaut riding a horse" \ --n-prompt "" \ --height 512 \ @@ -27,7 +28,8 @@ img_by_sd15_img2img: --base-image-url "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" img_by_sdxl_txt2img: - cd ./cmd && modal run sdxl_txt2img.py \ + cd ./cmd && modal run txt2img_handler.py::main \ + --version "sdxl" \ --prompt "A dog is running on the grass" \ --n-prompt "" \ --height 1024 \ diff --git a/README.md b/README.md index 6d714d5..00cbda0 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,10 @@ This is a Diffusers-based script for running Stable Diffusion on [Modal](https:/ ## Features 1. Image generation using txt2img or img2img. - ![](assets/20230902_tile_imgs.png) + ![example for txt2img](assets/20230902_tile_imgs.png) + Available versions: + - SDXL + - 1.5 2. Upscaling @@ -58,10 +61,8 @@ Images are generated and output to the `outputs/` directory. ├── README.md ├── cmd/ # A directory with scripts to run inference. │   ├── outputs/ # Images are outputted this directory. -│   ├── sd15_img2img.py # A script to run sd15_img2img inference. -│   ├── sd15_txt2img.py # A script to run sd15_txt2img inference. -│   ├── sdxl_txt2img.py # A script to run sdxl_txt2img inference. -│   └── util.py +... +│   └── txt2img_handler.py # A script to run txt2img inference. └── app/ # A directory with config files. ├── __main__.py # A main script to run inference. ├── Dockerfile # To build a base image. @@ -133,20 +134,30 @@ Set the prompt to Makefile. ```makefile # ex) -run: - cd ./cmd && modal run txt2img.py \ - --prompt "hogehoge" \ - --n-prompt "mogumogu" \ - --height 768 \ - --width 512 \ - --samples 1 \ - --steps 30 \ - --seed 12321 | - --use-upscaler "True" \ - --fix-by-controlnet-tile "True" \ - --output-fomart "avif" +img_by_sdxl_txt2img: + cd ./cmd && modal run txt2img_handler.py::main \ + --version "sdxl" \ + --prompt "A dog is running on the grass" \ + --n-prompt "" \ + --height 1024 \ + --width 1024 \ + --samples 1 \ + --steps 30 \ + --use-upscaler "True" \ + --output-format "avif" ``` +- prompt: Specifies the prompt. +- n-prompt: Specifies a negative prompt. +- height: Specifies the height of the image. +- width: Specifies the width of the image. +- samples: Specifies the number of images to generate. +- steps: Specifies the number of steps. +- seed: Specifies the seed. +- use-upscaler: Enables the upscaler to increase the image resolution. +- fix-by-controlnet-tile: Specifies whether to use ControlNet 1.1 Tile. If enabled, it will repair broken images and generate high-resolution images. Only sd15 is supported. +- output-format: Specifies the output format. Only avif and png are supported. + ### 5. Deploy an application Execute the below command. An application will be deployed on Modal. diff --git a/README_ja.md b/README_ja.md index 170ba38..e28bc00 100644 --- a/README_ja.md +++ b/README_ja.md @@ -5,8 +5,10 @@ ## このスクリプトでできること 1. txt2imgまたはimt2imgによる画像生成ができます。 - -![txt2imgでの生成画像例](assets/20230902_tile_imgs.png) + ![txt2imgでの生成画像例](assets/20230902_tile_imgs.png) + 利用可能なバージョン: + - SDXL + - 1.5 2. アップスケーラーとControlNet Tileを利用した高解像度な画像を生成することができます。 @@ -58,10 +60,8 @@ modal token new ├── README.md ├── cmd/ # A directory with scripts to run inference. │   ├── outputs/ # Images are outputted this directory. -│   ├── sd15_img2img.py # A script to run sd15_img2img inference. -│   ├── sd15_txt2img.py # A script to run sd15_txt2img inference. -│   ├── sdxl_txt2img.py # A script to run sdxl_txt2img inference. -│   └── util.py +... +│   └── txt2img_handler.py # A script to run txt2img inference. └── app/ # A directory with config files. ├── __main__.py # A main script to run inference. ├── Dockerfile # To build a base image. @@ -135,18 +135,17 @@ model: ```makefile # 設定例 -run: - cd ./cmd && modal run txt2img.py \ - --prompt "hogehoge" \ - --n-prompt "mogumogu" \ - --height 768 \ - --width 512 \ - --samples 1 \ - --steps 30 \ - --seed 12321 | - --use-upscaler "True" \ - --fix-by-controlnet-tile "True" \ - --output-fomart "png" +img_by_sdxl_txt2img: + cd ./cmd && modal run txt2img_handler.py::main \ + --version "sdxl" \ + --prompt "A dog is running on the grass" \ + --n-prompt "" \ + --height 1024 \ + --width 1024 \ + --samples 1 \ + --steps 30 \ + --use-upscaler "True" \ + --output-format "avif" ``` - prompt: プロンプトを指定します。 @@ -157,8 +156,8 @@ run: - steps: ステップ数を指定します。 - seed: seedを指定します。 - use-upscaler: 画像の解像度を上げるためのアップスケーラーを有効にします。 -- fix-by-controlnet-tile: ControlNet 1.1 Tileの利用有無を指定します。有効にすると、崩れた画像を修復しつつ、高解像度な画像を生成します。 -- output-format: 出力フォーマットを指定します。avifも指定可能です。 +- fix-by-controlnet-tile: ControlNet 1.1 Tileの利用有無を指定します。有効にすると、崩れた画像を修復しつつ、高解像度な画像を生成します。sd15のみ対応。 +- output-format: 出力フォーマットを指定します。avifとpngのみ対応。 ### 5. アプリケーションをデプロイする diff --git a/app/setup.py b/app/setup.py index 155ca82..5e0939e 100644 --- a/app/setup.py +++ b/app/setup.py @@ -1,8 +1,8 @@ import os from abc import ABC, abstractmethod -from pathlib import Path import diffusers +from huggingface_hub import login from modal import App, Image, Mount, Secret BASE_CACHE_PATH = "/vol/cache" @@ -30,10 +30,13 @@ class StableDiffusionCLISetupSDXL(StableDiffusionCLISetupInterface): self.__model_name: str = config["model"]["name"] self.__model_url: str = config["model"]["url"] + + if token != "": + login(token) self.__token: str = token def download_model(self) -> None: - cache_path = Path(BASE_CACHE_PATH, self.__model_name) + cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) pipe = diffusers.StableDiffusionXLPipeline.from_single_file( pretrained_model_link_or_path=self.__model_url, use_auth_token=self.__token, @@ -54,10 +57,13 @@ class StableDiffusionCLISetupSD15(StableDiffusionCLISetupInterface): self.__model_name: str = config["model"]["name"] self.__model_url: str = config["model"]["url"] + + if token != "": + login(token) self.__token: str = token def download_model(self) -> None: - cache_path = Path(BASE_CACHE_PATH, self.__model_name) + cache_path = os.path.join(BASE_CACHE_PATH, self.__model_name) pipe = diffusers.StableDiffusionPipeline.from_single_file( pretrained_model_link_or_path=self.__model_url, token=self.__token, @@ -111,7 +117,7 @@ class CommonSetup: ) def __download_vae(self, name: str, model_url: str, token: str) -> None: - cache_path = Path(BASE_CACHE_PATH, name) + cache_path = os.path.join(BASE_CACHE_PATH, name) vae = diffusers.AutoencoderKL.from_single_file( pretrained_model_link_or_path=model_url, use_auth_token=token, @@ -120,7 +126,7 @@ class CommonSetup: vae.save_pretrained(cache_path, safe_serialization=True) def __download_controlnet(self, name: str, repo_id: str, token: str) -> None: - cache_path = Path(BASE_CACHE_PATH, name) + cache_path = os.path.join(BASE_CACHE_PATH_CONTROLNET, name) controlnet = diffusers.ControlNetModel.from_pretrained( repo_id, use_auth_token=token, @@ -136,7 +142,7 @@ class CommonSetup: req = Request(url, headers={"User-Agent": "Mozilla/5.0"}) downloaded = urlopen(req).read() - dir_names = Path(file_path, file_name) + dir_names = os.path.join(file_path, file_name) os.makedirs(os.path.dirname(dir_names), exist_ok=True) with open(dir_names, mode="wb") as f: f.write(downloaded) diff --git a/cmd/domain.py b/cmd/domain.py index ec6c9ee..6ef862b 100644 --- a/cmd/domain.py +++ b/cmd/domain.py @@ -39,10 +39,6 @@ class Prompts: msg = "prompt should not be empty." raise ValueError(msg) - if n_prompt == "": - msg = "n_prompt should not be empty." - raise ValueError(msg) - if height <= 0: msg = "height should be positive." raise ValueError(msg) @@ -59,18 +55,36 @@ class Prompts: msg = "steps should be positive." raise ValueError(msg) - self.__dict: dict[str, int | str] = { - "prompt": prompt, - "n_prompt": n_prompt, - "height": height, - "width": width, - "samples": samples, - "steps": steps, - } + self.__prompt = prompt + self.__n_prompt = n_prompt + self.__height = height + self.__width = width + self.__samples = samples + self.__steps = steps @property - def dict(self) -> dict[str, int | str]: - return self.__dict + def prompt(self) -> str: + return self.__prompt + + @property + def n_prompt(self) -> str: + return self.__n_prompt + + @property + def height(self) -> int: + return self.__height + + @property + def width(self) -> int: + return self.__width + + @property + def samples(self) -> int: + return self.__samples + + @property + def steps(self) -> int: + return self.__steps class OutputDirectory: @@ -100,8 +114,8 @@ class StableDiffusionOutputManger: prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) output_path = f"{self.__output_directory}/prompts_{prompts_filename}.txt" with Path(output_path).open("wb") as file: - for name, value in self.__prompts.dict.items(): - file.write(f"{name} = {value!r}\n".encode()) + for key, value in vars(self.__prompts).items(): + file.write(f"{key} = {value!r}\n".encode()) return output_path diff --git a/cmd/infrasctucture.py b/cmd/infrasctucture.py new file mode 100644 index 0000000..b18e174 --- /dev/null +++ b/cmd/infrasctucture.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import modal + +if TYPE_CHECKING: + from domain import Prompts, Seed + + +class Txt2ImgInterface(ABC): + @abstractmethod + def run_inference(self, seed: Seed) -> list[bytes]: + pass + + +class SDXLTxt2Img(Txt2ImgInterface): + def __init__(self, prompts: Prompts, output_format: str, *, use_upscaler: bool) -> None: + self.__prompts = prompts + self.__output_format = output_format + self.__use_upscaler = use_upscaler + self.__run_inference = modal.Function.from_name( + "stable-diffusion-cli", + "SDXLTxt2Img.run_inference", + ) + + def run_inference(self, seed: Seed) -> list[bytes]: + return self.__run_inference.remote( + prompt=self.__prompts.prompt, + n_prompt=self.__prompts.n_prompt, + height=self.__prompts.height, + width=self.__prompts.width, + steps=self.__prompts.steps, + seed=seed.value, + use_upscaler=self.__use_upscaler, + output_format=self.__output_format, + ) + + +class SD15Txt2Img(Txt2ImgInterface): + def __init__( + self, + prompts: Prompts, + output_format: str, + *, + use_upscaler: bool, + fix_by_controlnet_tile: bool, + ) -> None: + self.__prompts = prompts + self.__output_format = output_format + self.__use_upscaler = use_upscaler + self.__fix_by_controlnet_tile = fix_by_controlnet_tile + self.__run_inference = modal.Function.from_name( + "stable-diffusion-cli", + "SD15.run_txt2img_inference", + ) + + def run_inference(self, seed: Seed) -> list[bytes]: + return self.__run_inference.remote( + prompt=self.__prompts.prompt, + n_prompt=self.__prompts.n_prompt, + height=self.__prompts.height, + width=self.__prompts.width, + batch_size=1, + steps=self.__prompts.steps, + seed=seed.value, + use_upscaler=self.__use_upscaler, + fix_by_controlnet_tile=self.__fix_by_controlnet_tile, + output_format=self.__output_format, + ) + + +def new_txt2img( + version: str, + prompts: Prompts, + output_format: str, + *, + use_upscaler: bool, + fix_by_controlnet_tile: bool, +) -> Txt2ImgInterface: + match version: + case "sd15": + return SD15Txt2Img( + prompts=prompts, + output_format=output_format, + use_upscaler=use_upscaler, + fix_by_controlnet_tile=fix_by_controlnet_tile, + ) + case "sdxl": + return SDXLTxt2Img( + prompts=prompts, + use_upscaler=use_upscaler, + output_format=output_format, + ) + case _: + msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'." + raise ValueError(msg) diff --git a/cmd/sd15_txt2img.py b/cmd/sd15_txt2img.py deleted file mode 100644 index dcd38c7..0000000 --- a/cmd/sd15_txt2img.py +++ /dev/null @@ -1,72 +0,0 @@ -import logging -import time - -import domain -import modal - -app = modal.App("run-stable-diffusion-cli") -run_inference = modal.Function.from_name( - "stable-diffusion-cli", - "SD15.run_txt2img_inference", -) - - -@app.local_entrypoint() -def main( - prompt: str, - n_prompt: str, - height: int = 512, - width: int = 512, - samples: int = 5, - steps: int = 20, - seed: int = -1, - use_upscaler: str = "", - fix_by_controlnet_tile: str = "False", - output_format: str = "png", -) -> None: - """main() is the entrypoint for the Runway CLI. - This pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local. - """ - logging.basicConfig( - level=logging.INFO, - format="[%(levelname)s] %(asctime)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - logger = logging.getLogger("run-stable-diffusion-cli") - - output_directory = domain.OutputDirectory() - directory_path = output_directory.make_directory() - logger.info("Made a directory: %s", directory_path) - - prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps) - sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path) - for sample_index in range(samples): - new_seed = domain.Seed(seed) - start_time = time.time() - images = run_inference.remote( - prompt=prompt, - n_prompt=n_prompt, - height=height, - width=width, - batch_size=1, - steps=steps, - seed=new_seed.value, - use_upscaler=use_upscaler == "True", - fix_by_controlnet_tile=fix_by_controlnet_tile == "True", - output_format=output_format, - ) - for generated_image_index, image_bytes in enumerate(images): - saved_path = sd_output_manager.save_image( - image_bytes, - new_seed.value, - sample_index, - generated_image_index, - output_format, - ) - logger.info("Saved image to the: %s", saved_path) - - total_time = time.time() - start_time - logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images)) - - saved_prompts_path = sd_output_manager.save_prompts() - logger.info("Saved prompts: %s", saved_prompts_path) diff --git a/cmd/sdxl_txt2img.py b/cmd/txt2img_handler.py similarity index 67% rename from cmd/sdxl_txt2img.py rename to cmd/txt2img_handler.py index 32712fc..c2d65ba 100644 --- a/cmd/sdxl_txt2img.py +++ b/cmd/txt2img_handler.py @@ -1,18 +1,16 @@ +from __future__ import annotations + import logging import time -import domain import modal - -app = modal.App("run-stable-diffusion-cli") -run_inference = modal.Function.from_name( - "stable-diffusion-cli", - "SDXLTxt2Img.run_inference", -) +from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger +from infrasctucture import new_txt2img -@app.local_entrypoint() +@modal.App("run-stable-diffusion-cli").local_entrypoint() def main( + version: str, prompt: str, n_prompt: str, height: int = 1024, @@ -21,6 +19,7 @@ def main( steps: int = 20, seed: int = -1, use_upscaler: str = "False", + fix_by_controlnet_tile: str = "True", output_format: str = "png", ) -> None: """This function is the entrypoint for the Runway CLI. @@ -34,27 +33,25 @@ def main( ) logger = logging.getLogger("run-stable-diffusion-cli") - output_directory = domain.OutputDirectory() + output_directory = OutputDirectory() directory_path = output_directory.make_directory() logger.info("Made a directory: %s", directory_path) - prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps) - sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path) + prompts = Prompts(prompt, n_prompt, height, width, samples, steps) + sd_output_manager = StableDiffusionOutputManger(prompts, directory_path) + + txt2img = new_txt2img( + version, + prompts, + output_format, + use_upscaler=use_upscaler == "True", + fix_by_controlnet_tile=fix_by_controlnet_tile == "True", + ) for sample_index in range(samples): - new_seed = domain.Seed(seed) start_time = time.time() - images = run_inference.remote( - prompt=prompt, - n_prompt=n_prompt, - height=height, - width=width, - steps=steps, - seed=new_seed.value, - use_upscaler=use_upscaler == "True", - output_format=output_format, - ) - + new_seed = Seed(seed) + images = txt2img.run_inference(new_seed) for generated_image_index, image_bytes in enumerate(images): saved_path = sd_output_manager.save_image( image_bytes, @@ -64,7 +61,6 @@ def main( output_format, ) logger.info("Saved image to the: %s", saved_path) - total_time = time.time() - start_time logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))