Add a face enhancer.
This commit is contained in:
		
							parent
							
								
									82acd39a6e
								
							
						
					
					
						commit
						888dbe3dbc
					
				
							
								
								
									
										18
									
								
								.env.example
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								.env.example
									
									
									
									
									
								
							@ -7,16 +7,26 @@ USE_VAE="false"
 | 
			
		||||
 | 
			
		||||
# Add LoRA if you want to use one. You can use a download link of civitai.
 | 
			
		||||
# ex)
 | 
			
		||||
#  - `LORA_NAMES="hogehoge.safetensors"`
 | 
			
		||||
#  - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
 | 
			
		||||
#   - `LORA_NAMES="hogehoge.safetensors"`
 | 
			
		||||
#   - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
 | 
			
		||||
#
 | 
			
		||||
# If you have multiple LoRAs you want to use, separate by commas like the below:
 | 
			
		||||
# ex)
 | 
			
		||||
#  - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
 | 
			
		||||
#  - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
 | 
			
		||||
#   - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
 | 
			
		||||
#   - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
 | 
			
		||||
LORA_NAMES=""
 | 
			
		||||
LORA_DOWNLOAD_URLS=""
 | 
			
		||||
 | 
			
		||||
# Add Textual Inversion you wan to use. Usage is the same as `LORA_NAMES` and `LORA_DOWNLOAD_URLS`.
 | 
			
		||||
TEXTUAL_INVERSION_NAMES=""
 | 
			
		||||
TEXTUAL_INVERSION_DOWNLOAD_URLS=""
 | 
			
		||||
 | 
			
		||||
# `UPSCALER` is a name of upscaler you want to use.
 | 
			
		||||
# Set `true` if you want to use a face enhancer too.
 | 
			
		||||
# You can use upscalers the below:
 | 
			
		||||
#   - `RealESRGAN_x4plus`
 | 
			
		||||
#   - `RealESRNet_x4plus`
 | 
			
		||||
#   - `RealESRGAN_x4plus_anime_6B`
 | 
			
		||||
#   - `RealESRGAN_x2plus`
 | 
			
		||||
UPSCALER="RealESRGAN_x4plus_anime_6B"
 | 
			
		||||
USE_FACE_ENHANCER="false"
 | 
			
		||||
 | 
			
		||||
@ -7,4 +7,5 @@ RUN apt update \
 | 
			
		||||
    && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P /vol/cache/esrgan \
 | 
			
		||||
    && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P /vol/cache/esrgan \
 | 
			
		||||
    && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P /vol/cache/esrgan \
 | 
			
		||||
    && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan
 | 
			
		||||
    && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan \
 | 
			
		||||
    && wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P /vol/cache/esrgan
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								Makefile
									
									
									
									
									
								
							@ -2,8 +2,7 @@ run:
 | 
			
		||||
	modal run sd_cli.py \
 | 
			
		||||
	--prompt "A woman with bob hair" \
 | 
			
		||||
	--n-prompt "" \
 | 
			
		||||
	--upscaler "RealESRGAN_x4plus_anime_6B" \
 | 
			
		||||
	--height 768 \
 | 
			
		||||
	--width 512 \
 | 
			
		||||
	--samples 5 \
 | 
			
		||||
	--steps 50
 | 
			
		||||
	--steps 30
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										53
									
								
								sd_cli.py
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								sd_cli.py
									
									
									
									
									
								
							@ -7,8 +7,6 @@ from urllib.request import Request, urlopen
 | 
			
		||||
 | 
			
		||||
from modal import Image, Mount, Secret, Stub, method
 | 
			
		||||
 | 
			
		||||
import util
 | 
			
		||||
 | 
			
		||||
BASE_CACHE_PATH = "/vol/cache"
 | 
			
		||||
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
 | 
			
		||||
BASE_CACHE_PATH_TEXTUAL_INVERSION = "/vol/cache/textual_inversion"
 | 
			
		||||
@ -49,14 +47,6 @@ def download_models():
 | 
			
		||||
    )
 | 
			
		||||
    vae.save_pretrained(cache_path, safe_serialization=True)
 | 
			
		||||
 | 
			
		||||
    scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
 | 
			
		||||
        model_repo_id,
 | 
			
		||||
        subfolder="scheduler",
 | 
			
		||||
        use_auth_token=hugging_face_token,
 | 
			
		||||
        cache_dir=cache_path,
 | 
			
		||||
    )
 | 
			
		||||
    scheduler.save_pretrained(cache_path, safe_serialization=True)
 | 
			
		||||
 | 
			
		||||
    pipe = diffusers.StableDiffusionPipeline.from_pretrained(
 | 
			
		||||
        model_repo_id,
 | 
			
		||||
        use_auth_token=hugging_face_token,
 | 
			
		||||
@ -107,6 +97,10 @@ class StableDiffusion:
 | 
			
		||||
        import diffusers
 | 
			
		||||
        import torch
 | 
			
		||||
 | 
			
		||||
        use_vae = os.environ["USE_VAE"] == "true"
 | 
			
		||||
        self.upscaler = os.environ["UPSCALER"]
 | 
			
		||||
        self.use_face_enhancer = os.environ["USE_FACE_ENHANCER"] == "true"
 | 
			
		||||
 | 
			
		||||
        cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"])
 | 
			
		||||
        if os.path.exists(cache_path):
 | 
			
		||||
            print(f"The directory '{cache_path}' exists.")
 | 
			
		||||
@ -122,12 +116,14 @@ class StableDiffusion:
 | 
			
		||||
            torch_dtype=torch.float16,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
 | 
			
		||||
        # TODO: Add support for other schedulers.
 | 
			
		||||
        # self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
 | 
			
		||||
        self.pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
 | 
			
		||||
            cache_path,
 | 
			
		||||
            subfolder="scheduler",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if os.environ["USE_VAE"] == "true":
 | 
			
		||||
        if use_vae:
 | 
			
		||||
            self.pipe.vae = diffusers.AutoencoderKL.from_pretrained(
 | 
			
		||||
                cache_path,
 | 
			
		||||
                subfolder="vae",
 | 
			
		||||
@ -194,7 +190,7 @@ class StableDiffusion:
 | 
			
		||||
        generator = torch.Generator("cuda").manual_seed(inputs["seed"])
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            with torch.autocast("cuda"):
 | 
			
		||||
                base_images = self.pipe(
 | 
			
		||||
                base_images = self.pipe.text2img(
 | 
			
		||||
                    [inputs["prompt"]] * int(inputs["batch_size"]),
 | 
			
		||||
                    negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]),
 | 
			
		||||
                    height=inputs["height"],
 | 
			
		||||
@ -205,10 +201,9 @@ class StableDiffusion:
 | 
			
		||||
                    generator=generator,
 | 
			
		||||
                ).images
 | 
			
		||||
 | 
			
		||||
        if inputs["upscaler"] != "":
 | 
			
		||||
        if self.upscaler != "":
 | 
			
		||||
            uplcaled_images = self.upscale(
 | 
			
		||||
                base_images=base_images,
 | 
			
		||||
                model_name="RealESRGAN_x4plus",
 | 
			
		||||
                scale_factor=4,
 | 
			
		||||
                half_precision=False,
 | 
			
		||||
                tile=700,
 | 
			
		||||
@ -227,7 +222,6 @@ class StableDiffusion:
 | 
			
		||||
    def upscale(
 | 
			
		||||
        self,
 | 
			
		||||
        base_images: list[Image.Image],
 | 
			
		||||
        model_name: str = "RealESRGAN_x4plus",
 | 
			
		||||
        scale_factor: float = 4,
 | 
			
		||||
        half_precision: bool = False,
 | 
			
		||||
        tile: int = 0,
 | 
			
		||||
@ -245,6 +239,7 @@ class StableDiffusion:
 | 
			
		||||
        from realesrgan import RealESRGANer
 | 
			
		||||
        from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
        model_name = self.upscaler
 | 
			
		||||
        if model_name == "RealESRGAN_x4plus":
 | 
			
		||||
            upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
 | 
			
		||||
            netscale = 4
 | 
			
		||||
@ -272,14 +267,35 @@ class StableDiffusion:
 | 
			
		||||
            gpu_id=None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        from gfpgan import GFPGANer
 | 
			
		||||
 | 
			
		||||
        if self.use_face_enhancer:
 | 
			
		||||
            face_enhancer = GFPGANer(
 | 
			
		||||
                model_path=os.path.join(BASE_CACHE_PATH, "esrgan", "GFPGANv1.3.pth"),
 | 
			
		||||
                upscale=netscale,
 | 
			
		||||
                arch="clean",
 | 
			
		||||
                channel_multiplier=2,
 | 
			
		||||
                bg_upsampler=upsampler,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        upscaled_imgs = []
 | 
			
		||||
        with tqdm(total=len(base_images)) as progress_bar:
 | 
			
		||||
            for i, img in enumerate(base_images):
 | 
			
		||||
                img = numpy.array(img)
 | 
			
		||||
                enhance_result = upsampler.enhance(img)[0]
 | 
			
		||||
                if self.use_face_enhancer:
 | 
			
		||||
                    _, _, enhance_result = face_enhancer.enhance(
 | 
			
		||||
                        img,
 | 
			
		||||
                        has_aligned=False,
 | 
			
		||||
                        only_center_face=False,
 | 
			
		||||
                        paste_back=True,
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    enhance_result, _ = upsampler.enhance(img)
 | 
			
		||||
 | 
			
		||||
                upscaled_imgs.append(Image.fromarray(enhance_result))
 | 
			
		||||
                progress_bar.update(1)
 | 
			
		||||
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
        return upscaled_imgs
 | 
			
		||||
@ -289,7 +305,6 @@ class StableDiffusion:
 | 
			
		||||
def entrypoint(
 | 
			
		||||
    prompt: str,
 | 
			
		||||
    n_prompt: str,
 | 
			
		||||
    upscaler: str,
 | 
			
		||||
    height: int = 512,
 | 
			
		||||
    width: int = 512,
 | 
			
		||||
    samples: int = 5,
 | 
			
		||||
@ -302,6 +317,7 @@ def entrypoint(
 | 
			
		||||
    The function pass the given prompt to StableDiffusion on Modal,
 | 
			
		||||
    gets back a list of images and outputs images to local.
 | 
			
		||||
    """
 | 
			
		||||
    import util
 | 
			
		||||
 | 
			
		||||
    inputs: dict[str, int | str] = {
 | 
			
		||||
        "prompt": prompt,
 | 
			
		||||
@ -311,7 +327,6 @@ def entrypoint(
 | 
			
		||||
        "samples": samples,
 | 
			
		||||
        "batch_size": batch_size,
 | 
			
		||||
        "steps": steps,
 | 
			
		||||
        "upscaler": upscaler,
 | 
			
		||||
        "seed": seed,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user