Rename a file and a class.
This commit is contained in:
parent
2d1c1ffc4c
commit
0132166fe6
@ -4,7 +4,7 @@ import modal
|
|||||||
import util
|
import util
|
||||||
|
|
||||||
stub = modal.Stub("run-stable-diffusion-cli")
|
stub = modal.Stub("run-stable-diffusion-cli")
|
||||||
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference")
|
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "Txt2Img.run_inference")
|
||||||
|
|
||||||
|
|
||||||
@stub.local_entrypoint()
|
@stub.local_entrypoint()
|
||||||
@ -20,6 +20,7 @@ def main(
|
|||||||
upscaler: str = "",
|
upscaler: str = "",
|
||||||
use_face_enhancer: str = "False",
|
use_face_enhancer: str = "False",
|
||||||
fix_by_controlnet_tile: str = "False",
|
fix_by_controlnet_tile: str = "False",
|
||||||
|
output_format: str = "png",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This function is the entrypoint for the Runway CLI.
|
This function is the entrypoint for the Runway CLI.
|
||||||
@ -43,8 +44,9 @@ def main(
|
|||||||
upscaler=upscaler,
|
upscaler=upscaler,
|
||||||
use_face_enhancer=use_face_enhancer == "True",
|
use_face_enhancer=use_face_enhancer == "True",
|
||||||
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
|
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
|
||||||
|
output_format=output_format,
|
||||||
)
|
)
|
||||||
util.save_images(directory, images, seed_generated, i)
|
util.save_images(directory, images, seed_generated, i, output_format)
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
|
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from setup import stub
|
from setup import stub
|
||||||
from txt2img import StableDiffusion
|
from stable_diffusion_1_5 import Txt2Img
|
||||||
|
|
||||||
|
|
||||||
@stub.function(gpu="A10G")
|
@stub.function(gpu="A10G")
|
||||||
def main():
|
def main():
|
||||||
StableDiffusion
|
Txt2Img
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from setup import (
|
|||||||
gpu="A10G",
|
gpu="A10G",
|
||||||
secrets=[Secret.from_dotenv(__file__)],
|
secrets=[Secret.from_dotenv(__file__)],
|
||||||
)
|
)
|
||||||
class StableDiffusion:
|
class Txt2Img:
|
||||||
"""
|
"""
|
||||||
A class that wraps the Stable Diffusion pipeline and scheduler.
|
A class that wraps the Stable Diffusion pipeline and scheduler.
|
||||||
"""
|
"""
|
||||||
@ -231,7 +231,6 @@ class StableDiffusion:
|
|||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
model_name = upscaler
|
model_name = upscaler
|
||||||
if model_name == "RealESRGAN_x4plus":
|
if model_name == "RealESRGAN_x4plus":
|
||||||
@ -271,20 +270,18 @@ class StableDiffusion:
|
|||||||
)
|
)
|
||||||
|
|
||||||
upscaled_imgs = []
|
upscaled_imgs = []
|
||||||
with tqdm(total=len(base_images)) as progress_bar:
|
for img in base_images:
|
||||||
for img in base_images:
|
img = numpy.array(img)
|
||||||
img = numpy.array(img)
|
if use_face_enhancer:
|
||||||
if use_face_enhancer:
|
_, _, enhance_result = face_enhancer.enhance(
|
||||||
_, _, enhance_result = face_enhancer.enhance(
|
img,
|
||||||
img,
|
has_aligned=False,
|
||||||
has_aligned=False,
|
only_center_face=False,
|
||||||
only_center_face=False,
|
paste_back=True,
|
||||||
paste_back=True,
|
)
|
||||||
)
|
else:
|
||||||
else:
|
enhance_result, _ = upsampler.enhance(img)
|
||||||
enhance_result, _ = upsampler.enhance(img)
|
|
||||||
|
|
||||||
upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
|
upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
|
||||||
progress_bar.update(1)
|
|
||||||
|
|
||||||
return upscaled_imgs
|
return upscaled_imgs
|
||||||
Loading…
x
Reference in New Issue
Block a user