diff --git a/sdcli/txt2img.py b/sdcli/txt2img.py index 2a00b14..430c403 100644 --- a/sdcli/txt2img.py +++ b/sdcli/txt2img.py @@ -4,7 +4,7 @@ import modal import util stub = modal.Stub("run-stable-diffusion-cli") -stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference") +stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "Txt2Img.run_inference") @stub.local_entrypoint() @@ -20,6 +20,7 @@ def main( upscaler: str = "", use_face_enhancer: str = "False", fix_by_controlnet_tile: str = "False", + output_format: str = "png", ): """ This function is the entrypoint for the Runway CLI. @@ -43,8 +44,9 @@ def main( upscaler=upscaler, use_face_enhancer=use_face_enhancer == "True", fix_by_controlnet_tile=fix_by_controlnet_tile == "True", + output_format=output_format, ) - util.save_images(directory, images, seed_generated, i) + util.save_images(directory, images, seed_generated, i, output_format) total_time = time.time() - start_time print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") diff --git a/setup_files/__main__.py b/setup_files/__main__.py index d1ff313..5fe3ea8 100644 --- a/setup_files/__main__.py +++ b/setup_files/__main__.py @@ -1,12 +1,12 @@ from __future__ import annotations from setup import stub -from txt2img import StableDiffusion +from stable_diffusion_1_5 import Txt2Img @stub.function(gpu="A10G") def main(): - StableDiffusion + Txt2Img if __name__ == "__main__": diff --git a/setup_files/txt2img.py b/setup_files/stable_diffusion_1_5.py similarity index 93% rename from setup_files/txt2img.py rename to setup_files/stable_diffusion_1_5.py index 0a4907f..645930c 100644 --- a/setup_files/txt2img.py +++ b/setup_files/stable_diffusion_1_5.py @@ -18,7 +18,7 @@ from setup import ( gpu="A10G", secrets=[Secret.from_dotenv(__file__)], ) -class StableDiffusion: +class Txt2Img: """ A class that wraps the Stable Diffusion pipeline and scheduler. """ @@ -231,7 +231,6 @@ class StableDiffusion: from basicsr.archs.rrdbnet_arch import RRDBNet from gfpgan import GFPGANer from realesrgan import RealESRGANer - from tqdm import tqdm model_name = upscaler if model_name == "RealESRGAN_x4plus": @@ -271,20 +270,18 @@ class StableDiffusion: ) upscaled_imgs = [] - with tqdm(total=len(base_images)) as progress_bar: - for img in base_images: - img = numpy.array(img) - if use_face_enhancer: - _, _, enhance_result = face_enhancer.enhance( - img, - has_aligned=False, - only_center_face=False, - paste_back=True, - ) - else: - enhance_result, _ = upsampler.enhance(img) + for img in base_images: + img = numpy.array(img) + if use_face_enhancer: + _, _, enhance_result = face_enhancer.enhance( + img, + has_aligned=False, + only_center_face=False, + paste_back=True, + ) + else: + enhance_result, _ = upsampler.enhance(img) - upscaled_imgs.append(PIL.Image.fromarray(enhance_result)) - progress_bar.update(1) + upscaled_imgs.append(PIL.Image.fromarray(enhance_result)) return upscaled_imgs