Fix to use SD15Txt2Img.

This commit is contained in:
hodanov 2024-11-04 08:50:08 +09:00
parent df6caf8f55
commit 5730b7bfad
2 changed files with 103 additions and 29 deletions

98
cmd/infrasctucture.py Normal file
View File

@ -0,0 +1,98 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import modal
if TYPE_CHECKING:
from domain import Prompts, Seed
class Txt2ImgInterface(ABC):
@abstractmethod
def run_inference(self, seed: Seed) -> list[bytes]:
pass
class SDXLTxt2Img(Txt2ImgInterface):
def __init__(self, prompts: Prompts, output_format: str, *, use_upscaler: bool) -> None:
self.__prompts = prompts
self.__output_format = output_format
self.__use_upscaler = use_upscaler
self.__run_inference = modal.Function.from_name(
"stable-diffusion-cli",
"SDXLTxt2Img.run_inference",
)
def run_inference(self, seed: Seed) -> list[bytes]:
return self.__run_inference.remote(
prompt=self.__prompts.prompt,
n_prompt=self.__prompts.n_prompt,
height=self.__prompts.height,
width=self.__prompts.width,
steps=self.__prompts.steps,
seed=seed.value,
use_upscaler=self.__use_upscaler,
output_format=self.__output_format,
)
class SD15Txt2Img(Txt2ImgInterface):
def __init__(
self,
prompts: Prompts,
output_format: str,
*,
use_upscaler: bool,
fix_by_controlnet_tile: bool,
) -> None:
self.__prompts = prompts
self.__output_format = output_format
self.__use_upscaler = use_upscaler
self.__fix_by_controlnet_tile = fix_by_controlnet_tile
self.__run_inference = modal.Function.from_name(
"stable-diffusion-cli",
"SD15.run_txt2img_inference",
)
def run_inference(self, seed: Seed) -> list[bytes]:
return self.__run_inference.remote(
prompt=self.__prompts.prompt,
n_prompt=self.__prompts.n_prompt,
height=self.__prompts.height,
width=self.__prompts.width,
batch_size=1,
steps=self.__prompts.steps,
seed=seed.value,
use_upscaler=self.__use_upscaler,
fix_by_controlnet_tile=self.__fix_by_controlnet_tile,
output_format=self.__output_format,
)
def new_txt2img(
version: str,
prompts: Prompts,
output_format: str,
*,
use_upscaler: bool,
fix_by_controlnet_tile: bool,
) -> Txt2ImgInterface:
match version:
case "sd15":
return SD15Txt2Img(
prompts=prompts,
output_format=output_format,
use_upscaler=use_upscaler,
fix_by_controlnet_tile=fix_by_controlnet_tile,
)
case "sdxl":
return SDXLTxt2Img(
prompts=prompts,
use_upscaler=use_upscaler,
output_format=output_format,
)
case _:
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
raise ValueError(msg)

View File

@ -5,33 +5,7 @@ import time
import modal import modal
from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger
from infrasctucture import RunInferenceInterface, RunInferenceSDXLTxt2Img from infrasctucture import new_txt2img
def new_run_inference(
version: str,
prompts: Prompts,
output_format: str,
*,
use_upscaler: bool,
) -> RunInferenceInterface:
match version:
case "sd15":
# TODO: sd15用のクラスを実装したら置き換える
return RunInferenceSDXLTxt2Img(
prompts=prompts,
use_upscaler=use_upscaler,
output_format=output_format,
)
case "sdxl":
return RunInferenceSDXLTxt2Img(
prompts=prompts,
use_upscaler=use_upscaler,
output_format=output_format,
)
case _:
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
raise ValueError(msg)
@modal.App("run-stable-diffusion-cli").local_entrypoint() @modal.App("run-stable-diffusion-cli").local_entrypoint()
@ -45,6 +19,7 @@ def main(
steps: int = 20, steps: int = 20,
seed: int = -1, seed: int = -1,
use_upscaler: str = "False", use_upscaler: str = "False",
fix_by_controlnet_tile: str = "True",
output_format: str = "png", output_format: str = "png",
) -> None: ) -> None:
"""This function is the entrypoint for the Runway CLI. """This function is the entrypoint for the Runway CLI.
@ -65,17 +40,18 @@ def main(
prompts = Prompts(prompt, n_prompt, height, width, samples, steps) prompts = Prompts(prompt, n_prompt, height, width, samples, steps)
sd_output_manager = StableDiffusionOutputManger(prompts, directory_path) sd_output_manager = StableDiffusionOutputManger(prompts, directory_path)
run_inference = new_run_inference( txt2img = new_txt2img(
version, version,
prompts, prompts,
output_format, output_format,
use_upscaler=use_upscaler == "True", use_upscaler=use_upscaler == "True",
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
) )
for sample_index in range(samples): for sample_index in range(samples):
start_time = time.time() start_time = time.time()
new_seed = Seed(seed) new_seed = Seed(seed)
images = run_inference.exec(new_seed) images = txt2img.run_inference(new_seed)
for generated_image_index, image_bytes in enumerate(images): for generated_image_index, image_bytes in enumerate(images):
saved_path = sd_output_manager.save_image( saved_path = sd_output_manager.save_image(
image_bytes, image_bytes,