Add a face enhancer.

This commit is contained in:
hodanov 2023-06-20 22:24:19 +09:00
parent 82acd39a6e
commit 888dbe3dbc
4 changed files with 51 additions and 26 deletions

View File

@ -7,16 +7,26 @@ USE_VAE="false"
# Add LoRA if you want to use one. You can use a download link of civitai.
# ex)
# - `LORA_NAMES="hogehoge.safetensors"`
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
# - `LORA_NAMES="hogehoge.safetensors"`
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
#
# If you have multiple LoRAs you want to use, separate by commas like the below:
# ex)
# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
LORA_NAMES=""
LORA_DOWNLOAD_URLS=""
# Add Textual Inversion you wan to use. Usage is the same as `LORA_NAMES` and `LORA_DOWNLOAD_URLS`.
TEXTUAL_INVERSION_NAMES=""
TEXTUAL_INVERSION_DOWNLOAD_URLS=""
# `UPSCALER` is a name of upscaler you want to use.
# Set `true` if you want to use a face enhancer too.
# You can use upscalers the below:
# - `RealESRGAN_x4plus`
# - `RealESRNet_x4plus`
# - `RealESRGAN_x4plus_anime_6B`
# - `RealESRGAN_x2plus`
UPSCALER="RealESRGAN_x4plus_anime_6B"
USE_FACE_ENHANCER="false"

View File

@ -7,4 +7,5 @@ RUN apt update \
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P /vol/cache/esrgan \
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P /vol/cache/esrgan \
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P /vol/cache/esrgan \
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan \
&& wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P /vol/cache/esrgan

View File

@ -2,8 +2,7 @@ run:
modal run sd_cli.py \
--prompt "A woman with bob hair" \
--n-prompt "" \
--upscaler "RealESRGAN_x4plus_anime_6B" \
--height 768 \
--width 512 \
--samples 5 \
--steps 50
--steps 30

View File

@ -7,8 +7,6 @@ from urllib.request import Request, urlopen
from modal import Image, Mount, Secret, Stub, method
import util
BASE_CACHE_PATH = "/vol/cache"
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
BASE_CACHE_PATH_TEXTUAL_INVERSION = "/vol/cache/textual_inversion"
@ -49,14 +47,6 @@ def download_models():
)
vae.save_pretrained(cache_path, safe_serialization=True)
scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
model_repo_id,
subfolder="scheduler",
use_auth_token=hugging_face_token,
cache_dir=cache_path,
)
scheduler.save_pretrained(cache_path, safe_serialization=True)
pipe = diffusers.StableDiffusionPipeline.from_pretrained(
model_repo_id,
use_auth_token=hugging_face_token,
@ -107,6 +97,10 @@ class StableDiffusion:
import diffusers
import torch
use_vae = os.environ["USE_VAE"] == "true"
self.upscaler = os.environ["UPSCALER"]
self.use_face_enhancer = os.environ["USE_FACE_ENHANCER"] == "true"
cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"])
if os.path.exists(cache_path):
print(f"The directory '{cache_path}' exists.")
@ -122,12 +116,14 @@ class StableDiffusion:
torch_dtype=torch.float16,
)
self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
# TODO: Add support for other schedulers.
# self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
self.pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
cache_path,
subfolder="scheduler",
)
if os.environ["USE_VAE"] == "true":
if use_vae:
self.pipe.vae = diffusers.AutoencoderKL.from_pretrained(
cache_path,
subfolder="vae",
@ -194,7 +190,7 @@ class StableDiffusion:
generator = torch.Generator("cuda").manual_seed(inputs["seed"])
with torch.inference_mode():
with torch.autocast("cuda"):
base_images = self.pipe(
base_images = self.pipe.text2img(
[inputs["prompt"]] * int(inputs["batch_size"]),
negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]),
height=inputs["height"],
@ -205,10 +201,9 @@ class StableDiffusion:
generator=generator,
).images
if inputs["upscaler"] != "":
if self.upscaler != "":
uplcaled_images = self.upscale(
base_images=base_images,
model_name="RealESRGAN_x4plus",
scale_factor=4,
half_precision=False,
tile=700,
@ -227,7 +222,6 @@ class StableDiffusion:
def upscale(
self,
base_images: list[Image.Image],
model_name: str = "RealESRGAN_x4plus",
scale_factor: float = 4,
half_precision: bool = False,
tile: int = 0,
@ -245,6 +239,7 @@ class StableDiffusion:
from realesrgan import RealESRGANer
from tqdm import tqdm
model_name = self.upscaler
if model_name == "RealESRGAN_x4plus":
upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
@ -272,14 +267,35 @@ class StableDiffusion:
gpu_id=None,
)
from gfpgan import GFPGANer
if self.use_face_enhancer:
face_enhancer = GFPGANer(
model_path=os.path.join(BASE_CACHE_PATH, "esrgan", "GFPGANv1.3.pth"),
upscale=netscale,
arch="clean",
channel_multiplier=2,
bg_upsampler=upsampler,
)
torch.cuda.empty_cache()
upscaled_imgs = []
with tqdm(total=len(base_images)) as progress_bar:
for i, img in enumerate(base_images):
img = numpy.array(img)
enhance_result = upsampler.enhance(img)[0]
if self.use_face_enhancer:
_, _, enhance_result = face_enhancer.enhance(
img,
has_aligned=False,
only_center_face=False,
paste_back=True,
)
else:
enhance_result, _ = upsampler.enhance(img)
upscaled_imgs.append(Image.fromarray(enhance_result))
progress_bar.update(1)
torch.cuda.empty_cache()
return upscaled_imgs
@ -289,7 +305,6 @@ class StableDiffusion:
def entrypoint(
prompt: str,
n_prompt: str,
upscaler: str,
height: int = 512,
width: int = 512,
samples: int = 5,
@ -302,6 +317,7 @@ def entrypoint(
The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local.
"""
import util
inputs: dict[str, int | str] = {
"prompt": prompt,
@ -311,7 +327,6 @@ def entrypoint(
"samples": samples,
"batch_size": batch_size,
"steps": steps,
"upscaler": upscaler,
"seed": seed,
}