Add a face enhancer.
This commit is contained in:
parent
82acd39a6e
commit
888dbe3dbc
18
.env.example
18
.env.example
@ -7,16 +7,26 @@ USE_VAE="false"
|
||||
|
||||
# Add LoRA if you want to use one. You can use a download link of civitai.
|
||||
# ex)
|
||||
# - `LORA_NAMES="hogehoge.safetensors"`
|
||||
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
|
||||
# - `LORA_NAMES="hogehoge.safetensors"`
|
||||
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
|
||||
#
|
||||
# If you have multiple LoRAs you want to use, separate by commas like the below:
|
||||
# ex)
|
||||
# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
|
||||
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
|
||||
# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
|
||||
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
|
||||
LORA_NAMES=""
|
||||
LORA_DOWNLOAD_URLS=""
|
||||
|
||||
# Add Textual Inversion you wan to use. Usage is the same as `LORA_NAMES` and `LORA_DOWNLOAD_URLS`.
|
||||
TEXTUAL_INVERSION_NAMES=""
|
||||
TEXTUAL_INVERSION_DOWNLOAD_URLS=""
|
||||
|
||||
# `UPSCALER` is a name of upscaler you want to use.
|
||||
# Set `true` if you want to use a face enhancer too.
|
||||
# You can use upscalers the below:
|
||||
# - `RealESRGAN_x4plus`
|
||||
# - `RealESRNet_x4plus`
|
||||
# - `RealESRGAN_x4plus_anime_6B`
|
||||
# - `RealESRGAN_x2plus`
|
||||
UPSCALER="RealESRGAN_x4plus_anime_6B"
|
||||
USE_FACE_ENHANCER="false"
|
||||
|
||||
@ -7,4 +7,5 @@ RUN apt update \
|
||||
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P /vol/cache/esrgan \
|
||||
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P /vol/cache/esrgan \
|
||||
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P /vol/cache/esrgan \
|
||||
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan
|
||||
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan \
|
||||
&& wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P /vol/cache/esrgan
|
||||
|
||||
3
Makefile
3
Makefile
@ -2,8 +2,7 @@ run:
|
||||
modal run sd_cli.py \
|
||||
--prompt "A woman with bob hair" \
|
||||
--n-prompt "" \
|
||||
--upscaler "RealESRGAN_x4plus_anime_6B" \
|
||||
--height 768 \
|
||||
--width 512 \
|
||||
--samples 5 \
|
||||
--steps 50
|
||||
--steps 30
|
||||
|
||||
53
sd_cli.py
53
sd_cli.py
@ -7,8 +7,6 @@ from urllib.request import Request, urlopen
|
||||
|
||||
from modal import Image, Mount, Secret, Stub, method
|
||||
|
||||
import util
|
||||
|
||||
BASE_CACHE_PATH = "/vol/cache"
|
||||
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
|
||||
BASE_CACHE_PATH_TEXTUAL_INVERSION = "/vol/cache/textual_inversion"
|
||||
@ -49,14 +47,6 @@ def download_models():
|
||||
)
|
||||
vae.save_pretrained(cache_path, safe_serialization=True)
|
||||
|
||||
scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_repo_id,
|
||||
subfolder="scheduler",
|
||||
use_auth_token=hugging_face_token,
|
||||
cache_dir=cache_path,
|
||||
)
|
||||
scheduler.save_pretrained(cache_path, safe_serialization=True)
|
||||
|
||||
pipe = diffusers.StableDiffusionPipeline.from_pretrained(
|
||||
model_repo_id,
|
||||
use_auth_token=hugging_face_token,
|
||||
@ -107,6 +97,10 @@ class StableDiffusion:
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
use_vae = os.environ["USE_VAE"] == "true"
|
||||
self.upscaler = os.environ["UPSCALER"]
|
||||
self.use_face_enhancer = os.environ["USE_FACE_ENHANCER"] == "true"
|
||||
|
||||
cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"])
|
||||
if os.path.exists(cache_path):
|
||||
print(f"The directory '{cache_path}' exists.")
|
||||
@ -122,12 +116,14 @@ class StableDiffusion:
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
# TODO: Add support for other schedulers.
|
||||
# self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
self.pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
|
||||
cache_path,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
if os.environ["USE_VAE"] == "true":
|
||||
if use_vae:
|
||||
self.pipe.vae = diffusers.AutoencoderKL.from_pretrained(
|
||||
cache_path,
|
||||
subfolder="vae",
|
||||
@ -194,7 +190,7 @@ class StableDiffusion:
|
||||
generator = torch.Generator("cuda").manual_seed(inputs["seed"])
|
||||
with torch.inference_mode():
|
||||
with torch.autocast("cuda"):
|
||||
base_images = self.pipe(
|
||||
base_images = self.pipe.text2img(
|
||||
[inputs["prompt"]] * int(inputs["batch_size"]),
|
||||
negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]),
|
||||
height=inputs["height"],
|
||||
@ -205,10 +201,9 @@ class StableDiffusion:
|
||||
generator=generator,
|
||||
).images
|
||||
|
||||
if inputs["upscaler"] != "":
|
||||
if self.upscaler != "":
|
||||
uplcaled_images = self.upscale(
|
||||
base_images=base_images,
|
||||
model_name="RealESRGAN_x4plus",
|
||||
scale_factor=4,
|
||||
half_precision=False,
|
||||
tile=700,
|
||||
@ -227,7 +222,6 @@ class StableDiffusion:
|
||||
def upscale(
|
||||
self,
|
||||
base_images: list[Image.Image],
|
||||
model_name: str = "RealESRGAN_x4plus",
|
||||
scale_factor: float = 4,
|
||||
half_precision: bool = False,
|
||||
tile: int = 0,
|
||||
@ -245,6 +239,7 @@ class StableDiffusion:
|
||||
from realesrgan import RealESRGANer
|
||||
from tqdm import tqdm
|
||||
|
||||
model_name = self.upscaler
|
||||
if model_name == "RealESRGAN_x4plus":
|
||||
upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
netscale = 4
|
||||
@ -272,14 +267,35 @@ class StableDiffusion:
|
||||
gpu_id=None,
|
||||
)
|
||||
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
if self.use_face_enhancer:
|
||||
face_enhancer = GFPGANer(
|
||||
model_path=os.path.join(BASE_CACHE_PATH, "esrgan", "GFPGANv1.3.pth"),
|
||||
upscale=netscale,
|
||||
arch="clean",
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=upsampler,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
upscaled_imgs = []
|
||||
with tqdm(total=len(base_images)) as progress_bar:
|
||||
for i, img in enumerate(base_images):
|
||||
img = numpy.array(img)
|
||||
enhance_result = upsampler.enhance(img)[0]
|
||||
if self.use_face_enhancer:
|
||||
_, _, enhance_result = face_enhancer.enhance(
|
||||
img,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
)
|
||||
else:
|
||||
enhance_result, _ = upsampler.enhance(img)
|
||||
|
||||
upscaled_imgs.append(Image.fromarray(enhance_result))
|
||||
progress_bar.update(1)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return upscaled_imgs
|
||||
@ -289,7 +305,6 @@ class StableDiffusion:
|
||||
def entrypoint(
|
||||
prompt: str,
|
||||
n_prompt: str,
|
||||
upscaler: str,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
samples: int = 5,
|
||||
@ -302,6 +317,7 @@ def entrypoint(
|
||||
The function pass the given prompt to StableDiffusion on Modal,
|
||||
gets back a list of images and outputs images to local.
|
||||
"""
|
||||
import util
|
||||
|
||||
inputs: dict[str, int | str] = {
|
||||
"prompt": prompt,
|
||||
@ -311,7 +327,6 @@ def entrypoint(
|
||||
"samples": samples,
|
||||
"batch_size": batch_size,
|
||||
"steps": steps,
|
||||
"upscaler": upscaler,
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user