From 5730b7bfadb99805683d673002f9a8dead31770b Mon Sep 17 00:00:00 2001 From: hodanov <1031hoda@gmail.com> Date: Mon, 4 Nov 2024 08:50:08 +0900 Subject: [PATCH] Fix to use SD15Txt2Img. --- cmd/infrasctucture.py | 98 +++++++++++++++++++++++++++++++++++++++++++ cmd/sdxl_txt2img.py | 34 +++------------ 2 files changed, 103 insertions(+), 29 deletions(-) create mode 100644 cmd/infrasctucture.py diff --git a/cmd/infrasctucture.py b/cmd/infrasctucture.py new file mode 100644 index 0000000..b18e174 --- /dev/null +++ b/cmd/infrasctucture.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import modal + +if TYPE_CHECKING: + from domain import Prompts, Seed + + +class Txt2ImgInterface(ABC): + @abstractmethod + def run_inference(self, seed: Seed) -> list[bytes]: + pass + + +class SDXLTxt2Img(Txt2ImgInterface): + def __init__(self, prompts: Prompts, output_format: str, *, use_upscaler: bool) -> None: + self.__prompts = prompts + self.__output_format = output_format + self.__use_upscaler = use_upscaler + self.__run_inference = modal.Function.from_name( + "stable-diffusion-cli", + "SDXLTxt2Img.run_inference", + ) + + def run_inference(self, seed: Seed) -> list[bytes]: + return self.__run_inference.remote( + prompt=self.__prompts.prompt, + n_prompt=self.__prompts.n_prompt, + height=self.__prompts.height, + width=self.__prompts.width, + steps=self.__prompts.steps, + seed=seed.value, + use_upscaler=self.__use_upscaler, + output_format=self.__output_format, + ) + + +class SD15Txt2Img(Txt2ImgInterface): + def __init__( + self, + prompts: Prompts, + output_format: str, + *, + use_upscaler: bool, + fix_by_controlnet_tile: bool, + ) -> None: + self.__prompts = prompts + self.__output_format = output_format + self.__use_upscaler = use_upscaler + self.__fix_by_controlnet_tile = fix_by_controlnet_tile + self.__run_inference = modal.Function.from_name( + "stable-diffusion-cli", + "SD15.run_txt2img_inference", + ) + + def run_inference(self, seed: Seed) -> list[bytes]: + return self.__run_inference.remote( + prompt=self.__prompts.prompt, + n_prompt=self.__prompts.n_prompt, + height=self.__prompts.height, + width=self.__prompts.width, + batch_size=1, + steps=self.__prompts.steps, + seed=seed.value, + use_upscaler=self.__use_upscaler, + fix_by_controlnet_tile=self.__fix_by_controlnet_tile, + output_format=self.__output_format, + ) + + +def new_txt2img( + version: str, + prompts: Prompts, + output_format: str, + *, + use_upscaler: bool, + fix_by_controlnet_tile: bool, +) -> Txt2ImgInterface: + match version: + case "sd15": + return SD15Txt2Img( + prompts=prompts, + output_format=output_format, + use_upscaler=use_upscaler, + fix_by_controlnet_tile=fix_by_controlnet_tile, + ) + case "sdxl": + return SDXLTxt2Img( + prompts=prompts, + use_upscaler=use_upscaler, + output_format=output_format, + ) + case _: + msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'." + raise ValueError(msg) diff --git a/cmd/sdxl_txt2img.py b/cmd/sdxl_txt2img.py index 682e7c6..c2d65ba 100644 --- a/cmd/sdxl_txt2img.py +++ b/cmd/sdxl_txt2img.py @@ -5,33 +5,7 @@ import time import modal from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger -from infrasctucture import RunInferenceInterface, RunInferenceSDXLTxt2Img - - -def new_run_inference( - version: str, - prompts: Prompts, - output_format: str, - *, - use_upscaler: bool, -) -> RunInferenceInterface: - match version: - case "sd15": - # TODO: sd15用のクラスを実装したら置き換える - return RunInferenceSDXLTxt2Img( - prompts=prompts, - use_upscaler=use_upscaler, - output_format=output_format, - ) - case "sdxl": - return RunInferenceSDXLTxt2Img( - prompts=prompts, - use_upscaler=use_upscaler, - output_format=output_format, - ) - case _: - msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'." - raise ValueError(msg) +from infrasctucture import new_txt2img @modal.App("run-stable-diffusion-cli").local_entrypoint() @@ -45,6 +19,7 @@ def main( steps: int = 20, seed: int = -1, use_upscaler: str = "False", + fix_by_controlnet_tile: str = "True", output_format: str = "png", ) -> None: """This function is the entrypoint for the Runway CLI. @@ -65,17 +40,18 @@ def main( prompts = Prompts(prompt, n_prompt, height, width, samples, steps) sd_output_manager = StableDiffusionOutputManger(prompts, directory_path) - run_inference = new_run_inference( + txt2img = new_txt2img( version, prompts, output_format, use_upscaler=use_upscaler == "True", + fix_by_controlnet_tile=fix_by_controlnet_tile == "True", ) for sample_index in range(samples): start_time = time.time() new_seed = Seed(seed) - images = run_inference.exec(new_seed) + images = txt2img.run_inference(new_seed) for generated_image_index, image_bytes in enumerate(images): saved_path = sd_output_manager.save_image( image_bytes,