Fix to use SD15Txt2Img.
This commit is contained in:
parent
df6caf8f55
commit
5730b7bfad
98
cmd/infrasctucture.py
Normal file
98
cmd/infrasctucture.py
Normal file
@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import modal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from domain import Prompts, Seed
|
||||
|
||||
|
||||
class Txt2ImgInterface(ABC):
|
||||
@abstractmethod
|
||||
def run_inference(self, seed: Seed) -> list[bytes]:
|
||||
pass
|
||||
|
||||
|
||||
class SDXLTxt2Img(Txt2ImgInterface):
|
||||
def __init__(self, prompts: Prompts, output_format: str, *, use_upscaler: bool) -> None:
|
||||
self.__prompts = prompts
|
||||
self.__output_format = output_format
|
||||
self.__use_upscaler = use_upscaler
|
||||
self.__run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli",
|
||||
"SDXLTxt2Img.run_inference",
|
||||
)
|
||||
|
||||
def run_inference(self, seed: Seed) -> list[bytes]:
|
||||
return self.__run_inference.remote(
|
||||
prompt=self.__prompts.prompt,
|
||||
n_prompt=self.__prompts.n_prompt,
|
||||
height=self.__prompts.height,
|
||||
width=self.__prompts.width,
|
||||
steps=self.__prompts.steps,
|
||||
seed=seed.value,
|
||||
use_upscaler=self.__use_upscaler,
|
||||
output_format=self.__output_format,
|
||||
)
|
||||
|
||||
|
||||
class SD15Txt2Img(Txt2ImgInterface):
|
||||
def __init__(
|
||||
self,
|
||||
prompts: Prompts,
|
||||
output_format: str,
|
||||
*,
|
||||
use_upscaler: bool,
|
||||
fix_by_controlnet_tile: bool,
|
||||
) -> None:
|
||||
self.__prompts = prompts
|
||||
self.__output_format = output_format
|
||||
self.__use_upscaler = use_upscaler
|
||||
self.__fix_by_controlnet_tile = fix_by_controlnet_tile
|
||||
self.__run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli",
|
||||
"SD15.run_txt2img_inference",
|
||||
)
|
||||
|
||||
def run_inference(self, seed: Seed) -> list[bytes]:
|
||||
return self.__run_inference.remote(
|
||||
prompt=self.__prompts.prompt,
|
||||
n_prompt=self.__prompts.n_prompt,
|
||||
height=self.__prompts.height,
|
||||
width=self.__prompts.width,
|
||||
batch_size=1,
|
||||
steps=self.__prompts.steps,
|
||||
seed=seed.value,
|
||||
use_upscaler=self.__use_upscaler,
|
||||
fix_by_controlnet_tile=self.__fix_by_controlnet_tile,
|
||||
output_format=self.__output_format,
|
||||
)
|
||||
|
||||
|
||||
def new_txt2img(
|
||||
version: str,
|
||||
prompts: Prompts,
|
||||
output_format: str,
|
||||
*,
|
||||
use_upscaler: bool,
|
||||
fix_by_controlnet_tile: bool,
|
||||
) -> Txt2ImgInterface:
|
||||
match version:
|
||||
case "sd15":
|
||||
return SD15Txt2Img(
|
||||
prompts=prompts,
|
||||
output_format=output_format,
|
||||
use_upscaler=use_upscaler,
|
||||
fix_by_controlnet_tile=fix_by_controlnet_tile,
|
||||
)
|
||||
case "sdxl":
|
||||
return SDXLTxt2Img(
|
||||
prompts=prompts,
|
||||
use_upscaler=use_upscaler,
|
||||
output_format=output_format,
|
||||
)
|
||||
case _:
|
||||
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
|
||||
raise ValueError(msg)
|
||||
@ -5,33 +5,7 @@ import time
|
||||
|
||||
import modal
|
||||
from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger
|
||||
from infrasctucture import RunInferenceInterface, RunInferenceSDXLTxt2Img
|
||||
|
||||
|
||||
def new_run_inference(
|
||||
version: str,
|
||||
prompts: Prompts,
|
||||
output_format: str,
|
||||
*,
|
||||
use_upscaler: bool,
|
||||
) -> RunInferenceInterface:
|
||||
match version:
|
||||
case "sd15":
|
||||
# TODO: sd15用のクラスを実装したら置き換える
|
||||
return RunInferenceSDXLTxt2Img(
|
||||
prompts=prompts,
|
||||
use_upscaler=use_upscaler,
|
||||
output_format=output_format,
|
||||
)
|
||||
case "sdxl":
|
||||
return RunInferenceSDXLTxt2Img(
|
||||
prompts=prompts,
|
||||
use_upscaler=use_upscaler,
|
||||
output_format=output_format,
|
||||
)
|
||||
case _:
|
||||
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
|
||||
raise ValueError(msg)
|
||||
from infrasctucture import new_txt2img
|
||||
|
||||
|
||||
@modal.App("run-stable-diffusion-cli").local_entrypoint()
|
||||
@ -45,6 +19,7 @@ def main(
|
||||
steps: int = 20,
|
||||
seed: int = -1,
|
||||
use_upscaler: str = "False",
|
||||
fix_by_controlnet_tile: str = "True",
|
||||
output_format: str = "png",
|
||||
) -> None:
|
||||
"""This function is the entrypoint for the Runway CLI.
|
||||
@ -65,17 +40,18 @@ def main(
|
||||
prompts = Prompts(prompt, n_prompt, height, width, samples, steps)
|
||||
sd_output_manager = StableDiffusionOutputManger(prompts, directory_path)
|
||||
|
||||
run_inference = new_run_inference(
|
||||
txt2img = new_txt2img(
|
||||
version,
|
||||
prompts,
|
||||
output_format,
|
||||
use_upscaler=use_upscaler == "True",
|
||||
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
|
||||
)
|
||||
|
||||
for sample_index in range(samples):
|
||||
start_time = time.time()
|
||||
new_seed = Seed(seed)
|
||||
images = run_inference.exec(new_seed)
|
||||
images = txt2img.run_inference(new_seed)
|
||||
for generated_image_index, image_bytes in enumerate(images):
|
||||
saved_path = sd_output_manager.save_image(
|
||||
image_bytes,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user