Fix to switch Stable Diffusion depending on versions.
This commit is contained in:
parent
a4a161ad83
commit
df6caf8f55
@ -1,18 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import domain
|
||||
import modal
|
||||
|
||||
app = modal.App("run-stable-diffusion-cli")
|
||||
run_inference = modal.Function.from_name(
|
||||
"stable-diffusion-cli",
|
||||
"SDXLTxt2Img.run_inference",
|
||||
)
|
||||
from domain import OutputDirectory, Prompts, Seed, StableDiffusionOutputManger
|
||||
from infrasctucture import RunInferenceInterface, RunInferenceSDXLTxt2Img
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def new_run_inference(
|
||||
version: str,
|
||||
prompts: Prompts,
|
||||
output_format: str,
|
||||
*,
|
||||
use_upscaler: bool,
|
||||
) -> RunInferenceInterface:
|
||||
match version:
|
||||
case "sd15":
|
||||
# TODO: sd15用のクラスを実装したら置き換える
|
||||
return RunInferenceSDXLTxt2Img(
|
||||
prompts=prompts,
|
||||
use_upscaler=use_upscaler,
|
||||
output_format=output_format,
|
||||
)
|
||||
case "sdxl":
|
||||
return RunInferenceSDXLTxt2Img(
|
||||
prompts=prompts,
|
||||
use_upscaler=use_upscaler,
|
||||
output_format=output_format,
|
||||
)
|
||||
case _:
|
||||
msg = f"Invalid version: {version}. Must be 'sd15' or 'sdxl'."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@modal.App("run-stable-diffusion-cli").local_entrypoint()
|
||||
def main(
|
||||
version: str,
|
||||
prompt: str,
|
||||
n_prompt: str,
|
||||
height: int = 1024,
|
||||
@ -34,27 +58,24 @@ def main(
|
||||
)
|
||||
logger = logging.getLogger("run-stable-diffusion-cli")
|
||||
|
||||
output_directory = domain.OutputDirectory()
|
||||
output_directory = OutputDirectory()
|
||||
directory_path = output_directory.make_directory()
|
||||
logger.info("Made a directory: %s", directory_path)
|
||||
|
||||
prompts = domain.Prompts(prompt, n_prompt, height, width, samples, steps)
|
||||
sd_output_manager = domain.StableDiffusionOutputManger(prompts, directory_path)
|
||||
prompts = Prompts(prompt, n_prompt, height, width, samples, steps)
|
||||
sd_output_manager = StableDiffusionOutputManger(prompts, directory_path)
|
||||
|
||||
for sample_index in range(samples):
|
||||
new_seed = domain.Seed(seed)
|
||||
start_time = time.time()
|
||||
images = run_inference.remote(
|
||||
prompt=prompt,
|
||||
n_prompt=n_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
steps=steps,
|
||||
seed=new_seed.value,
|
||||
run_inference = new_run_inference(
|
||||
version,
|
||||
prompts,
|
||||
output_format,
|
||||
use_upscaler=use_upscaler == "True",
|
||||
output_format=output_format,
|
||||
)
|
||||
|
||||
for sample_index in range(samples):
|
||||
start_time = time.time()
|
||||
new_seed = Seed(seed)
|
||||
images = run_inference.exec(new_seed)
|
||||
for generated_image_index, image_bytes in enumerate(images):
|
||||
saved_path = sd_output_manager.save_image(
|
||||
image_bytes,
|
||||
@ -64,7 +85,6 @@ def main(
|
||||
output_format,
|
||||
)
|
||||
logger.info("Saved image to the: %s", saved_path)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info("Sample %s, took %ss (%ss / image).", sample_index, total_time, (total_time) / len(images))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user