99 lines
2.9 KiB
Python
99 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING
|
|
|
|
import modal
|
|
|
|
if TYPE_CHECKING:
|
|
from domain import Prompts, Seed
|
|
|
|
|
|
class Txt2ImgInterface(ABC):
|
|
@abstractmethod
|
|
def run_inference(self, seed: Seed) -> list[bytes]:
|
|
pass
|
|
|
|
|
|
class SDXLTxt2Img(Txt2ImgInterface):
|
|
def __init__(self, prompts: Prompts, output_format: str, *, use_upscaler: bool) -> None:
|
|
self.__prompts = prompts
|
|
self.__output_format = output_format
|
|
self.__use_upscaler = use_upscaler
|
|
self.__run_inference = modal.Function.from_name(
|
|
"stable-diffusion-cli",
|
|
"SDXLTxt2Img.run_inference",
|
|
)
|
|
|
|
def run_inference(self, seed: Seed) -> list[bytes]:
|
|
return self.__run_inference.remote(
|
|
prompt=self.__prompts.prompt,
|
|
n_prompt=self.__prompts.n_prompt,
|
|
height=self.__prompts.height,
|
|
width=self.__prompts.width,
|
|
steps=self.__prompts.steps,
|
|
seed=seed.value,
|
|
use_upscaler=self.__use_upscaler,
|
|
output_format=self.__output_format,
|
|
)
|
|
|
|
|
|
class SD15Txt2Img(Txt2ImgInterface):
|
|
def __init__(
|
|
self,
|
|
prompts: Prompts,
|
|
output_format: str,
|
|
*,
|
|
use_upscaler: bool,
|
|
fix_by_controlnet_tile: bool,
|
|
) -> None:
|
|
self.__prompts = prompts
|
|
self.__output_format = output_format
|
|
self.__use_upscaler = use_upscaler
|
|
self.__fix_by_controlnet_tile = fix_by_controlnet_tile
|
|
self.__run_inference = modal.Function.from_name(
|
|
"stable-diffusion-cli",
|
|
"SD15.run_txt2img_inference",
|
|
)
|
|
|
|
def run_inference(self, seed: Seed) -> list[bytes]:
|
|
return self.__run_inference.remote(
|
|
prompt=self.__prompts.prompt,
|
|
n_prompt=self.__prompts.n_prompt,
|
|
height=self.__prompts.height,
|
|
width=self.__prompts.width,
|
|
batch_size=1,
|
|
steps=self.__prompts.steps,
|
|
seed=seed.value,
|
|
use_upscaler=self.__use_upscaler,
|
|
fix_by_controlnet_tile=self.__fix_by_controlnet_tile,
|
|
output_format=self.__output_format,
|
|
)
|
|
|
|
|
|
def new_txt2img(
|
|
version: str,
|
|
prompts: Prompts,
|
|
output_format: str,
|
|
*,
|
|
use_upscaler: bool,
|
|
fix_by_controlnet_tile: bool,
|
|
) -> Txt2ImgInterface:
|
|
match version:
|
|
case "sd15":
|
|
return SD15Txt2Img(
|
|
prompts=prompts,
|
|
output_format=output_format,
|
|
use_upscaler=use_upscaler,
|
|
fix_by_controlnet_tile=fix_by_controlnet_tile,
|
|
)
|
|
case "sdxl":
|
|
return SDXLTxt2Img(
|
|
prompts=prompts,
|
|
use_upscaler=use_upscaler,
|
|
output_format=output_format,
|
|
)
|
|
case _:
|
|
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
|
|
raise ValueError(msg)
|