Merge pull request #61 from hodanov/feature/modify_filename
Rename a file and a class.
This commit is contained in:
		
						commit
						e210333d83
					
				@ -4,7 +4,7 @@ import modal
 | 
			
		||||
import util
 | 
			
		||||
 | 
			
		||||
stub = modal.Stub("run-stable-diffusion-cli")
 | 
			
		||||
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference")
 | 
			
		||||
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "Txt2Img.run_inference")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@stub.local_entrypoint()
 | 
			
		||||
@ -20,6 +20,7 @@ def main(
 | 
			
		||||
    upscaler: str = "",
 | 
			
		||||
    use_face_enhancer: str = "False",
 | 
			
		||||
    fix_by_controlnet_tile: str = "False",
 | 
			
		||||
    output_format: str = "png",
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    This function is the entrypoint for the Runway CLI.
 | 
			
		||||
@ -43,8 +44,9 @@ def main(
 | 
			
		||||
            upscaler=upscaler,
 | 
			
		||||
            use_face_enhancer=use_face_enhancer == "True",
 | 
			
		||||
            fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
 | 
			
		||||
            output_format=output_format,
 | 
			
		||||
        )
 | 
			
		||||
        util.save_images(directory, images, seed_generated, i)
 | 
			
		||||
        util.save_images(directory, images, seed_generated, i, output_format)
 | 
			
		||||
        total_time = time.time() - start_time
 | 
			
		||||
        print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,12 +1,12 @@
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from setup import stub
 | 
			
		||||
from txt2img import StableDiffusion
 | 
			
		||||
from stable_diffusion_1_5 import Txt2Img
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@stub.function(gpu="A10G")
 | 
			
		||||
def main():
 | 
			
		||||
    StableDiffusion
 | 
			
		||||
    Txt2Img
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ from setup import (
 | 
			
		||||
    gpu="A10G",
 | 
			
		||||
    secrets=[Secret.from_dotenv(__file__)],
 | 
			
		||||
)
 | 
			
		||||
class StableDiffusion:
 | 
			
		||||
class Txt2Img:
 | 
			
		||||
    """
 | 
			
		||||
    A class that wraps the Stable Diffusion pipeline and scheduler.
 | 
			
		||||
    """
 | 
			
		||||
@ -231,7 +231,6 @@ class StableDiffusion:
 | 
			
		||||
        from basicsr.archs.rrdbnet_arch import RRDBNet
 | 
			
		||||
        from gfpgan import GFPGANer
 | 
			
		||||
        from realesrgan import RealESRGANer
 | 
			
		||||
        from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
        model_name = upscaler
 | 
			
		||||
        if model_name == "RealESRGAN_x4plus":
 | 
			
		||||
@ -271,20 +270,18 @@ class StableDiffusion:
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        upscaled_imgs = []
 | 
			
		||||
        with tqdm(total=len(base_images)) as progress_bar:
 | 
			
		||||
            for img in base_images:
 | 
			
		||||
                img = numpy.array(img)
 | 
			
		||||
                if use_face_enhancer:
 | 
			
		||||
                    _, _, enhance_result = face_enhancer.enhance(
 | 
			
		||||
                        img,
 | 
			
		||||
                        has_aligned=False,
 | 
			
		||||
                        only_center_face=False,
 | 
			
		||||
                        paste_back=True,
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    enhance_result, _ = upsampler.enhance(img)
 | 
			
		||||
        for img in base_images:
 | 
			
		||||
            img = numpy.array(img)
 | 
			
		||||
            if use_face_enhancer:
 | 
			
		||||
                _, _, enhance_result = face_enhancer.enhance(
 | 
			
		||||
                    img,
 | 
			
		||||
                    has_aligned=False,
 | 
			
		||||
                    only_center_face=False,
 | 
			
		||||
                    paste_back=True,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                enhance_result, _ = upsampler.enhance(img)
 | 
			
		||||
 | 
			
		||||
                upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
 | 
			
		||||
                progress_bar.update(1)
 | 
			
		||||
            upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
 | 
			
		||||
 | 
			
		||||
        return upscaled_imgs
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user