Merge pull request #61 from hodanov/feature/modify_filename

Rename a file and a class.
This commit is contained in:
hodanov 2023-12-04 10:32:52 +09:00 committed by GitHub
commit e210333d83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 20 deletions

View File

@ -4,7 +4,7 @@ import modal
import util
stub = modal.Stub("run-stable-diffusion-cli")
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference")
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "Txt2Img.run_inference")
@stub.local_entrypoint()
@ -20,6 +20,7 @@ def main(
upscaler: str = "",
use_face_enhancer: str = "False",
fix_by_controlnet_tile: str = "False",
output_format: str = "png",
):
"""
This function is the entrypoint for the Runway CLI.
@ -43,8 +44,9 @@ def main(
upscaler=upscaler,
use_face_enhancer=use_face_enhancer == "True",
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
output_format=output_format,
)
util.save_images(directory, images, seed_generated, i)
util.save_images(directory, images, seed_generated, i, output_format)
total_time = time.time() - start_time
print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")

View File

@ -1,12 +1,12 @@
from __future__ import annotations
from setup import stub
from txt2img import StableDiffusion
from stable_diffusion_1_5 import Txt2Img
@stub.function(gpu="A10G")
def main():
StableDiffusion
Txt2Img
if __name__ == "__main__":

View File

@ -18,7 +18,7 @@ from setup import (
gpu="A10G",
secrets=[Secret.from_dotenv(__file__)],
)
class StableDiffusion:
class Txt2Img:
"""
A class that wraps the Stable Diffusion pipeline and scheduler.
"""
@ -231,7 +231,6 @@ class StableDiffusion:
from basicsr.archs.rrdbnet_arch import RRDBNet
from gfpgan import GFPGANer
from realesrgan import RealESRGANer
from tqdm import tqdm
model_name = upscaler
if model_name == "RealESRGAN_x4plus":
@ -271,20 +270,18 @@ class StableDiffusion:
)
upscaled_imgs = []
with tqdm(total=len(base_images)) as progress_bar:
for img in base_images:
img = numpy.array(img)
if use_face_enhancer:
_, _, enhance_result = face_enhancer.enhance(
img,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
else:
enhance_result, _ = upsampler.enhance(img)
for img in base_images:
img = numpy.array(img)
if use_face_enhancer:
_, _, enhance_result = face_enhancer.enhance(
img,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
else:
enhance_result, _ = upsampler.enhance(img)
upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
progress_bar.update(1)
upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
return upscaled_imgs