Rename a file and a class.

This commit is contained in:
hodanov 2023-12-04 10:31:39 +09:00
parent 2d1c1ffc4c
commit 0132166fe6
3 changed files with 19 additions and 20 deletions

View File

@ -4,7 +4,7 @@ import modal
import util import util
stub = modal.Stub("run-stable-diffusion-cli") stub = modal.Stub("run-stable-diffusion-cli")
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference") stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "Txt2Img.run_inference")
@stub.local_entrypoint() @stub.local_entrypoint()
@ -20,6 +20,7 @@ def main(
upscaler: str = "", upscaler: str = "",
use_face_enhancer: str = "False", use_face_enhancer: str = "False",
fix_by_controlnet_tile: str = "False", fix_by_controlnet_tile: str = "False",
output_format: str = "png",
): ):
""" """
This function is the entrypoint for the Runway CLI. This function is the entrypoint for the Runway CLI.
@ -43,8 +44,9 @@ def main(
upscaler=upscaler, upscaler=upscaler,
use_face_enhancer=use_face_enhancer == "True", use_face_enhancer=use_face_enhancer == "True",
fix_by_controlnet_tile=fix_by_controlnet_tile == "True", fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
output_format=output_format,
) )
util.save_images(directory, images, seed_generated, i) util.save_images(directory, images, seed_generated, i, output_format)
total_time = time.time() - start_time total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")

View File

@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
from setup import stub from setup import stub
from txt2img import StableDiffusion from stable_diffusion_1_5 import Txt2Img
@stub.function(gpu="A10G") @stub.function(gpu="A10G")
def main(): def main():
StableDiffusion Txt2Img
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -18,7 +18,7 @@ from setup import (
gpu="A10G", gpu="A10G",
secrets=[Secret.from_dotenv(__file__)], secrets=[Secret.from_dotenv(__file__)],
) )
class StableDiffusion: class Txt2Img:
""" """
A class that wraps the Stable Diffusion pipeline and scheduler. A class that wraps the Stable Diffusion pipeline and scheduler.
""" """
@ -231,7 +231,6 @@ class StableDiffusion:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from gfpgan import GFPGANer from gfpgan import GFPGANer
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from tqdm import tqdm
model_name = upscaler model_name = upscaler
if model_name == "RealESRGAN_x4plus": if model_name == "RealESRGAN_x4plus":
@ -271,20 +270,18 @@ class StableDiffusion:
) )
upscaled_imgs = [] upscaled_imgs = []
with tqdm(total=len(base_images)) as progress_bar: for img in base_images:
for img in base_images: img = numpy.array(img)
img = numpy.array(img) if use_face_enhancer:
if use_face_enhancer: _, _, enhance_result = face_enhancer.enhance(
_, _, enhance_result = face_enhancer.enhance( img,
img, has_aligned=False,
has_aligned=False, only_center_face=False,
only_center_face=False, paste_back=True,
paste_back=True, )
) else:
else: enhance_result, _ = upsampler.enhance(img)
enhance_result, _ = upsampler.enhance(img)
upscaled_imgs.append(PIL.Image.fromarray(enhance_result)) upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
progress_bar.update(1)
return upscaled_imgs return upscaled_imgs