Fix to use SD15Txt2Img.
This commit is contained in:
parent
df6caf8f55
commit
5730b7bfad
98
cmd/infrasctucture.py
Normal file
98
cmd/infrasctucture.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import modal
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from domain import Prompts, Seed
|
||||||
|
|
||||||
|
|
||||||
|
class Txt2ImgInterface(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def run_inference(self, seed: Seed) -> list[bytes]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLTxt2Img(Txt2ImgInterface):
|
||||||
|
def __init__(self, prompts: Prompts, output_format: str, *, use_upscaler: bool) -> None:
|
||||||
|
self.__prompts = prompts
|
||||||
|
self.__output_format = output_format
|
||||||
|
self.__use_upscaler = use_upscaler
|
||||||
|
self.__run_inference = modal.Function.from_name(
|
||||||
|
"stable-diffusion-cli",
|
||||||
|
"SDXLTxt2Img.run_inference",
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_inference(self, seed: Seed) -> list[bytes]:
|
||||||
|
return self.__run_inference.remote(
|
||||||
|
prompt=self.__prompts.prompt,
|
||||||
|
n_prompt=self.__prompts.n_prompt,
|
||||||
|
height=self.__prompts.height,
|
||||||
|
width=self.__prompts.width,
|
||||||
|
steps=self.__prompts.steps,
|
||||||
|
seed=seed.value,
|
||||||
|
use_upscaler=self.__use_upscaler,
|
||||||
|
output_format=self.__output_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SD15Txt2Img(Txt2ImgInterface):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompts: Prompts,
|
||||||
|
output_format: str,
|
||||||
|
*,
|
||||||
|
use_upscaler: bool,
|
||||||
|
fix_by_controlnet_tile: bool,
|
||||||
|
) -> None:
|
||||||
|
self.__prompts = prompts
|
||||||
|
self.__output_format = output_format
|
||||||
|
self.__use_upscaler = use_upscaler
|
||||||
|
self.__fix_by_controlnet_tile = fix_by_controlnet_tile
|
||||||
|
self.__run_inference = modal.Function.from_name(
|
||||||
|
"stable-diffusion-cli",
|
||||||
|
"SD15.run_txt2img_inference",
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_inference(self, seed: Seed) -> list[bytes]:
|
||||||
|
return self.__run_inference.remote(
|
||||||
|
prompt=self.__prompts.prompt,
|
||||||
|
n_prompt=self.__prompts.n_prompt,
|
||||||
|
height=self.__prompts.height,
|
||||||
|
width=self.__prompts.width,
|
||||||
|
batch_size=1,
|
||||||
|
steps=self.__prompts.steps,
|
||||||
|
seed=seed.value,
|
||||||
|
use_upscaler=self.__use_upscaler,
|
||||||
|
fix_by_controlnet_tile=self.__fix_by_controlnet_tile,
|
||||||
|
output_format=self.__output_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def new_txt2img(
|
||||||
|
version: str,
|
||||||
|
prompts: Prompts,
|
||||||
|
output_format: str,
|
||||||
|
*,
|
||||||
|
use_upscaler: bool,
|
||||||
|
fix_by_controlnet_tile: bool,
|
||||||
|
) -> Txt2ImgInterface:
|
||||||
|
match version:
|
||||||
|
case "sd15":
|
||||||
|
return SD15Txt2Img(
|
||||||
|
prompts=prompts,
|
||||||
|
output_format=output_format,
|
||||||
|
use_upscaler=use_upscaler,
|
||||||
|
fix_by_controlnet_tile=fix_by_controlnet_tile,
|
||||||
|
)
|
||||||
|
case "sdxl":
|
||||||
|
return SDXLTxt2Img(
|
||||||
|
prompts=prompts,
|
||||||
|
use_upscaler=use_upscaler,
|
||||||
|
output_format=output_format,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
|
||||||
|
raise ValueError(msg)
|
||||||
@ -5,33 +5,7 @@ import time
|
|||||||
|
|
||||||
import modal
|
import modal
|
||||||
from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger
|
from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger
|
||||||
from infrasctucture import RunInferenceInterface, RunInferenceSDXLTxt2Img
|
from infrasctucture import new_txt2img
|
||||||
|
|
||||||
|
|
||||||
def new_run_inference(
|
|
||||||
version: str,
|
|
||||||
prompts: Prompts,
|
|
||||||
output_format: str,
|
|
||||||
*,
|
|
||||||
use_upscaler: bool,
|
|
||||||
) -> RunInferenceInterface:
|
|
||||||
match version:
|
|
||||||
case "sd15":
|
|
||||||
# TODO: sd15用のクラスを実装したら置き換える
|
|
||||||
return RunInferenceSDXLTxt2Img(
|
|
||||||
prompts=prompts,
|
|
||||||
use_upscaler=use_upscaler,
|
|
||||||
output_format=output_format,
|
|
||||||
)
|
|
||||||
case "sdxl":
|
|
||||||
return RunInferenceSDXLTxt2Img(
|
|
||||||
prompts=prompts,
|
|
||||||
use_upscaler=use_upscaler,
|
|
||||||
output_format=output_format,
|
|
||||||
)
|
|
||||||
case _:
|
|
||||||
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
@modal.App("run-stable-diffusion-cli").local_entrypoint()
|
@modal.App("run-stable-diffusion-cli").local_entrypoint()
|
||||||
@ -45,6 +19,7 @@ def main(
|
|||||||
steps: int = 20,
|
steps: int = 20,
|
||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
use_upscaler: str = "False",
|
use_upscaler: str = "False",
|
||||||
|
fix_by_controlnet_tile: str = "True",
|
||||||
output_format: str = "png",
|
output_format: str = "png",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function is the entrypoint for the Runway CLI.
|
"""This function is the entrypoint for the Runway CLI.
|
||||||
@ -65,17 +40,18 @@ def main(
|
|||||||
prompts = Prompts(prompt, n_prompt, height, width, samples, steps)
|
prompts = Prompts(prompt, n_prompt, height, width, samples, steps)
|
||||||
sd_output_manager = StableDiffusionOutputManger(prompts, directory_path)
|
sd_output_manager = StableDiffusionOutputManger(prompts, directory_path)
|
||||||
|
|
||||||
run_inference = new_run_inference(
|
txt2img = new_txt2img(
|
||||||
version,
|
version,
|
||||||
prompts,
|
prompts,
|
||||||
output_format,
|
output_format,
|
||||||
use_upscaler=use_upscaler == "True",
|
use_upscaler=use_upscaler == "True",
|
||||||
|
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
|
||||||
)
|
)
|
||||||
|
|
||||||
for sample_index in range(samples):
|
for sample_index in range(samples):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_seed = Seed(seed)
|
new_seed = Seed(seed)
|
||||||
images = run_inference.exec(new_seed)
|
images = txt2img.run_inference(new_seed)
|
||||||
for generated_image_index, image_bytes in enumerate(images):
|
for generated_image_index, image_bytes in enumerate(images):
|
||||||
saved_path = sd_output_manager.save_image(
|
saved_path = sd_output_manager.save_image(
|
||||||
image_bytes,
|
image_bytes,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user