From 888dbe3dbc0534a1085a0f1c653a3278a74a393f Mon Sep 17 00:00:00 2001 From: hodanov <1031hoda@gmail.com> Date: Tue, 20 Jun 2023 22:24:19 +0900 Subject: [PATCH] Add a face enhancer. --- .env.example | 18 ++++++++++++++---- Dockerfile | 3 ++- Makefile | 3 +-- sd_cli.py | 53 +++++++++++++++++++++++++++++++++------------------- 4 files changed, 51 insertions(+), 26 deletions(-) diff --git a/.env.example b/.env.example index a2c9dfb..9438f85 100644 --- a/.env.example +++ b/.env.example @@ -7,16 +7,26 @@ USE_VAE="false" # Add LoRA if you want to use one. You can use a download link of civitai. # ex) -# - `LORA_NAMES="hogehoge.safetensors"` -# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"` +# - `LORA_NAMES="hogehoge.safetensors"` +# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"` # # If you have multiple LoRAs you want to use, separate by commas like the below: # ex) -# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"` -# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"` +# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"` +# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"` LORA_NAMES="" LORA_DOWNLOAD_URLS="" # Add Textual Inversion you wan to use. Usage is the same as `LORA_NAMES` and `LORA_DOWNLOAD_URLS`. TEXTUAL_INVERSION_NAMES="" TEXTUAL_INVERSION_DOWNLOAD_URLS="" + +# `UPSCALER` is a name of upscaler you want to use. +# Set `true` if you want to use a face enhancer too. +# You can use upscalers the below: +# - `RealESRGAN_x4plus` +# - `RealESRNet_x4plus` +# - `RealESRGAN_x4plus_anime_6B` +# - `RealESRGAN_x2plus` +UPSCALER="RealESRGAN_x4plus_anime_6B" +USE_FACE_ENHANCER="false" diff --git a/Dockerfile b/Dockerfile index 76d77a0..d69b8c1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,4 +7,5 @@ RUN apt update \ && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P /vol/cache/esrgan \ && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P /vol/cache/esrgan \ && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P /vol/cache/esrgan \ - && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan + && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan \ + && wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P /vol/cache/esrgan diff --git a/Makefile b/Makefile index febe53e..5268315 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,7 @@ run: modal run sd_cli.py \ --prompt "A woman with bob hair" \ --n-prompt "" \ - --upscaler "RealESRGAN_x4plus_anime_6B" \ --height 768 \ --width 512 \ --samples 5 \ - --steps 50 + --steps 30 diff --git a/sd_cli.py b/sd_cli.py index fe02b82..a33aff9 100644 --- a/sd_cli.py +++ b/sd_cli.py @@ -7,8 +7,6 @@ from urllib.request import Request, urlopen from modal import Image, Mount, Secret, Stub, method -import util - BASE_CACHE_PATH = "/vol/cache" BASE_CACHE_PATH_LORA = "/vol/cache/lora" BASE_CACHE_PATH_TEXTUAL_INVERSION = "/vol/cache/textual_inversion" @@ -49,14 +47,6 @@ def download_models(): ) vae.save_pretrained(cache_path, safe_serialization=True) - scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( - model_repo_id, - subfolder="scheduler", - use_auth_token=hugging_face_token, - cache_dir=cache_path, - ) - scheduler.save_pretrained(cache_path, safe_serialization=True) - pipe = diffusers.StableDiffusionPipeline.from_pretrained( model_repo_id, use_auth_token=hugging_face_token, @@ -107,6 +97,10 @@ class StableDiffusion: import diffusers import torch + use_vae = os.environ["USE_VAE"] == "true" + self.upscaler = os.environ["UPSCALER"] + self.use_face_enhancer = os.environ["USE_FACE_ENHANCER"] == "true" + cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"]) if os.path.exists(cache_path): print(f"The directory '{cache_path}' exists.") @@ -122,12 +116,14 @@ class StableDiffusion: torch_dtype=torch.float16, ) - self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( + # TODO: Add support for other schedulers. + # self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( + self.pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained( cache_path, subfolder="scheduler", ) - if os.environ["USE_VAE"] == "true": + if use_vae: self.pipe.vae = diffusers.AutoencoderKL.from_pretrained( cache_path, subfolder="vae", @@ -194,7 +190,7 @@ class StableDiffusion: generator = torch.Generator("cuda").manual_seed(inputs["seed"]) with torch.inference_mode(): with torch.autocast("cuda"): - base_images = self.pipe( + base_images = self.pipe.text2img( [inputs["prompt"]] * int(inputs["batch_size"]), negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]), height=inputs["height"], @@ -205,10 +201,9 @@ class StableDiffusion: generator=generator, ).images - if inputs["upscaler"] != "": + if self.upscaler != "": uplcaled_images = self.upscale( base_images=base_images, - model_name="RealESRGAN_x4plus", scale_factor=4, half_precision=False, tile=700, @@ -227,7 +222,6 @@ class StableDiffusion: def upscale( self, base_images: list[Image.Image], - model_name: str = "RealESRGAN_x4plus", scale_factor: float = 4, half_precision: bool = False, tile: int = 0, @@ -245,6 +239,7 @@ class StableDiffusion: from realesrgan import RealESRGANer from tqdm import tqdm + model_name = self.upscaler if model_name == "RealESRGAN_x4plus": upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) netscale = 4 @@ -272,14 +267,35 @@ class StableDiffusion: gpu_id=None, ) + from gfpgan import GFPGANer + + if self.use_face_enhancer: + face_enhancer = GFPGANer( + model_path=os.path.join(BASE_CACHE_PATH, "esrgan", "GFPGANv1.3.pth"), + upscale=netscale, + arch="clean", + channel_multiplier=2, + bg_upsampler=upsampler, + ) + torch.cuda.empty_cache() upscaled_imgs = [] with tqdm(total=len(base_images)) as progress_bar: for i, img in enumerate(base_images): img = numpy.array(img) - enhance_result = upsampler.enhance(img)[0] + if self.use_face_enhancer: + _, _, enhance_result = face_enhancer.enhance( + img, + has_aligned=False, + only_center_face=False, + paste_back=True, + ) + else: + enhance_result, _ = upsampler.enhance(img) + upscaled_imgs.append(Image.fromarray(enhance_result)) progress_bar.update(1) + torch.cuda.empty_cache() return upscaled_imgs @@ -289,7 +305,6 @@ class StableDiffusion: def entrypoint( prompt: str, n_prompt: str, - upscaler: str, height: int = 512, width: int = 512, samples: int = 5, @@ -302,6 +317,7 @@ def entrypoint( The function pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local. """ + import util inputs: dict[str, int | str] = { "prompt": prompt, @@ -311,7 +327,6 @@ def entrypoint( "samples": samples, "batch_size": batch_size, "steps": steps, - "upscaler": upscaler, "seed": seed, }