Rename a file and a class.
This commit is contained in:
parent
2d1c1ffc4c
commit
0132166fe6
@ -4,7 +4,7 @@ import modal
|
||||
import util
|
||||
|
||||
stub = modal.Stub("run-stable-diffusion-cli")
|
||||
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference")
|
||||
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "Txt2Img.run_inference")
|
||||
|
||||
|
||||
@stub.local_entrypoint()
|
||||
@ -20,6 +20,7 @@ def main(
|
||||
upscaler: str = "",
|
||||
use_face_enhancer: str = "False",
|
||||
fix_by_controlnet_tile: str = "False",
|
||||
output_format: str = "png",
|
||||
):
|
||||
"""
|
||||
This function is the entrypoint for the Runway CLI.
|
||||
@ -43,8 +44,9 @@ def main(
|
||||
upscaler=upscaler,
|
||||
use_face_enhancer=use_face_enhancer == "True",
|
||||
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
|
||||
output_format=output_format,
|
||||
)
|
||||
util.save_images(directory, images, seed_generated, i)
|
||||
util.save_images(directory, images, seed_generated, i, output_format)
|
||||
total_time = time.time() - start_time
|
||||
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from setup import stub
|
||||
from txt2img import StableDiffusion
|
||||
from stable_diffusion_1_5 import Txt2Img
|
||||
|
||||
|
||||
@stub.function(gpu="A10G")
|
||||
def main():
|
||||
StableDiffusion
|
||||
Txt2Img
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -18,7 +18,7 @@ from setup import (
|
||||
gpu="A10G",
|
||||
secrets=[Secret.from_dotenv(__file__)],
|
||||
)
|
||||
class StableDiffusion:
|
||||
class Txt2Img:
|
||||
"""
|
||||
A class that wraps the Stable Diffusion pipeline and scheduler.
|
||||
"""
|
||||
@ -231,7 +231,6 @@ class StableDiffusion:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from gfpgan import GFPGANer
|
||||
from realesrgan import RealESRGANer
|
||||
from tqdm import tqdm
|
||||
|
||||
model_name = upscaler
|
||||
if model_name == "RealESRGAN_x4plus":
|
||||
@ -271,20 +270,18 @@ class StableDiffusion:
|
||||
)
|
||||
|
||||
upscaled_imgs = []
|
||||
with tqdm(total=len(base_images)) as progress_bar:
|
||||
for img in base_images:
|
||||
img = numpy.array(img)
|
||||
if use_face_enhancer:
|
||||
_, _, enhance_result = face_enhancer.enhance(
|
||||
img,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
else:
|
||||
enhance_result, _ = upsampler.enhance(img)
|
||||
for img in base_images:
|
||||
img = numpy.array(img)
|
||||
if use_face_enhancer:
|
||||
_, _, enhance_result = face_enhancer.enhance(
|
||||
img,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
else:
|
||||
enhance_result, _ = upsampler.enhance(img)
|
||||
|
||||
upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
|
||||
progress_bar.update(1)
|
||||
upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
|
||||
|
||||
return upscaled_imgs
|
||||
Loading…
x
Reference in New Issue
Block a user