Merge pull request #61 from hodanov/feature/modify_filename
Rename a file and a class.
This commit is contained in:
		
						commit
						e210333d83
					
				@ -4,7 +4,7 @@ import modal
 | 
				
			|||||||
import util
 | 
					import util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
stub = modal.Stub("run-stable-diffusion-cli")
 | 
					stub = modal.Stub("run-stable-diffusion-cli")
 | 
				
			||||||
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference")
 | 
					stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "Txt2Img.run_inference")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@stub.local_entrypoint()
 | 
					@stub.local_entrypoint()
 | 
				
			||||||
@ -20,6 +20,7 @@ def main(
 | 
				
			|||||||
    upscaler: str = "",
 | 
					    upscaler: str = "",
 | 
				
			||||||
    use_face_enhancer: str = "False",
 | 
					    use_face_enhancer: str = "False",
 | 
				
			||||||
    fix_by_controlnet_tile: str = "False",
 | 
					    fix_by_controlnet_tile: str = "False",
 | 
				
			||||||
 | 
					    output_format: str = "png",
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    This function is the entrypoint for the Runway CLI.
 | 
					    This function is the entrypoint for the Runway CLI.
 | 
				
			||||||
@ -43,8 +44,9 @@ def main(
 | 
				
			|||||||
            upscaler=upscaler,
 | 
					            upscaler=upscaler,
 | 
				
			||||||
            use_face_enhancer=use_face_enhancer == "True",
 | 
					            use_face_enhancer=use_face_enhancer == "True",
 | 
				
			||||||
            fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
 | 
					            fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
 | 
				
			||||||
 | 
					            output_format=output_format,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        util.save_images(directory, images, seed_generated, i)
 | 
					        util.save_images(directory, images, seed_generated, i, output_format)
 | 
				
			||||||
        total_time = time.time() - start_time
 | 
					        total_time = time.time() - start_time
 | 
				
			||||||
        print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
 | 
					        print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,12 +1,12 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					from __future__ import annotations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from setup import stub
 | 
					from setup import stub
 | 
				
			||||||
from txt2img import StableDiffusion
 | 
					from stable_diffusion_1_5 import Txt2Img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@stub.function(gpu="A10G")
 | 
					@stub.function(gpu="A10G")
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    StableDiffusion
 | 
					    Txt2Img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
				
			|||||||
@ -18,7 +18,7 @@ from setup import (
 | 
				
			|||||||
    gpu="A10G",
 | 
					    gpu="A10G",
 | 
				
			||||||
    secrets=[Secret.from_dotenv(__file__)],
 | 
					    secrets=[Secret.from_dotenv(__file__)],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
class StableDiffusion:
 | 
					class Txt2Img:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    A class that wraps the Stable Diffusion pipeline and scheduler.
 | 
					    A class that wraps the Stable Diffusion pipeline and scheduler.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@ -231,7 +231,6 @@ class StableDiffusion:
 | 
				
			|||||||
        from basicsr.archs.rrdbnet_arch import RRDBNet
 | 
					        from basicsr.archs.rrdbnet_arch import RRDBNet
 | 
				
			||||||
        from gfpgan import GFPGANer
 | 
					        from gfpgan import GFPGANer
 | 
				
			||||||
        from realesrgan import RealESRGANer
 | 
					        from realesrgan import RealESRGANer
 | 
				
			||||||
        from tqdm import tqdm
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model_name = upscaler
 | 
					        model_name = upscaler
 | 
				
			||||||
        if model_name == "RealESRGAN_x4plus":
 | 
					        if model_name == "RealESRGAN_x4plus":
 | 
				
			||||||
@ -271,20 +270,18 @@ class StableDiffusion:
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        upscaled_imgs = []
 | 
					        upscaled_imgs = []
 | 
				
			||||||
        with tqdm(total=len(base_images)) as progress_bar:
 | 
					        for img in base_images:
 | 
				
			||||||
            for img in base_images:
 | 
					            img = numpy.array(img)
 | 
				
			||||||
                img = numpy.array(img)
 | 
					            if use_face_enhancer:
 | 
				
			||||||
                if use_face_enhancer:
 | 
					                _, _, enhance_result = face_enhancer.enhance(
 | 
				
			||||||
                    _, _, enhance_result = face_enhancer.enhance(
 | 
					                    img,
 | 
				
			||||||
                        img,
 | 
					                    has_aligned=False,
 | 
				
			||||||
                        has_aligned=False,
 | 
					                    only_center_face=False,
 | 
				
			||||||
                        only_center_face=False,
 | 
					                    paste_back=True,
 | 
				
			||||||
                        paste_back=True,
 | 
					                )
 | 
				
			||||||
                    )
 | 
					            else:
 | 
				
			||||||
                else:
 | 
					                enhance_result, _ = upsampler.enhance(img)
 | 
				
			||||||
                    enhance_result, _ = upsampler.enhance(img)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
 | 
					            upscaled_imgs.append(PIL.Image.fromarray(enhance_result))
 | 
				
			||||||
                progress_bar.update(1)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return upscaled_imgs
 | 
					        return upscaled_imgs
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user