Add a face enhancer.
This commit is contained in:
parent
82acd39a6e
commit
888dbe3dbc
18
.env.example
18
.env.example
@ -7,16 +7,26 @@ USE_VAE="false"
|
|||||||
|
|
||||||
# Add LoRA if you want to use one. You can use a download link of civitai.
|
# Add LoRA if you want to use one. You can use a download link of civitai.
|
||||||
# ex)
|
# ex)
|
||||||
# - `LORA_NAMES="hogehoge.safetensors"`
|
# - `LORA_NAMES="hogehoge.safetensors"`
|
||||||
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
|
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx"`
|
||||||
#
|
#
|
||||||
# If you have multiple LoRAs you want to use, separate by commas like the below:
|
# If you have multiple LoRAs you want to use, separate by commas like the below:
|
||||||
# ex)
|
# ex)
|
||||||
# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
|
# - `LORA_NAMES="hogehoge.safetensors,mogumogu.safetensors"`
|
||||||
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
|
# - `LORA_DOWNLOAD_URLS="https://civitai.com/api/download/models/xxxxxx,https://civitai.com/api/download/models/xxxxxx"`
|
||||||
LORA_NAMES=""
|
LORA_NAMES=""
|
||||||
LORA_DOWNLOAD_URLS=""
|
LORA_DOWNLOAD_URLS=""
|
||||||
|
|
||||||
# Add Textual Inversion you wan to use. Usage is the same as `LORA_NAMES` and `LORA_DOWNLOAD_URLS`.
|
# Add Textual Inversion you wan to use. Usage is the same as `LORA_NAMES` and `LORA_DOWNLOAD_URLS`.
|
||||||
TEXTUAL_INVERSION_NAMES=""
|
TEXTUAL_INVERSION_NAMES=""
|
||||||
TEXTUAL_INVERSION_DOWNLOAD_URLS=""
|
TEXTUAL_INVERSION_DOWNLOAD_URLS=""
|
||||||
|
|
||||||
|
# `UPSCALER` is a name of upscaler you want to use.
|
||||||
|
# Set `true` if you want to use a face enhancer too.
|
||||||
|
# You can use upscalers the below:
|
||||||
|
# - `RealESRGAN_x4plus`
|
||||||
|
# - `RealESRNet_x4plus`
|
||||||
|
# - `RealESRGAN_x4plus_anime_6B`
|
||||||
|
# - `RealESRGAN_x2plus`
|
||||||
|
UPSCALER="RealESRGAN_x4plus_anime_6B"
|
||||||
|
USE_FACE_ENHANCER="false"
|
||||||
|
|||||||
@ -7,4 +7,5 @@ RUN apt update \
|
|||||||
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P /vol/cache/esrgan \
|
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P /vol/cache/esrgan \
|
||||||
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P /vol/cache/esrgan \
|
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P /vol/cache/esrgan \
|
||||||
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P /vol/cache/esrgan \
|
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P /vol/cache/esrgan \
|
||||||
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan
|
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan \
|
||||||
|
&& wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P /vol/cache/esrgan
|
||||||
|
|||||||
3
Makefile
3
Makefile
@ -2,8 +2,7 @@ run:
|
|||||||
modal run sd_cli.py \
|
modal run sd_cli.py \
|
||||||
--prompt "A woman with bob hair" \
|
--prompt "A woman with bob hair" \
|
||||||
--n-prompt "" \
|
--n-prompt "" \
|
||||||
--upscaler "RealESRGAN_x4plus_anime_6B" \
|
|
||||||
--height 768 \
|
--height 768 \
|
||||||
--width 512 \
|
--width 512 \
|
||||||
--samples 5 \
|
--samples 5 \
|
||||||
--steps 50
|
--steps 30
|
||||||
|
|||||||
53
sd_cli.py
53
sd_cli.py
@ -7,8 +7,6 @@ from urllib.request import Request, urlopen
|
|||||||
|
|
||||||
from modal import Image, Mount, Secret, Stub, method
|
from modal import Image, Mount, Secret, Stub, method
|
||||||
|
|
||||||
import util
|
|
||||||
|
|
||||||
BASE_CACHE_PATH = "/vol/cache"
|
BASE_CACHE_PATH = "/vol/cache"
|
||||||
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
|
BASE_CACHE_PATH_LORA = "/vol/cache/lora"
|
||||||
BASE_CACHE_PATH_TEXTUAL_INVERSION = "/vol/cache/textual_inversion"
|
BASE_CACHE_PATH_TEXTUAL_INVERSION = "/vol/cache/textual_inversion"
|
||||||
@ -49,14 +47,6 @@ def download_models():
|
|||||||
)
|
)
|
||||||
vae.save_pretrained(cache_path, safe_serialization=True)
|
vae.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
|
||||||
scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
|
|
||||||
model_repo_id,
|
|
||||||
subfolder="scheduler",
|
|
||||||
use_auth_token=hugging_face_token,
|
|
||||||
cache_dir=cache_path,
|
|
||||||
)
|
|
||||||
scheduler.save_pretrained(cache_path, safe_serialization=True)
|
|
||||||
|
|
||||||
pipe = diffusers.StableDiffusionPipeline.from_pretrained(
|
pipe = diffusers.StableDiffusionPipeline.from_pretrained(
|
||||||
model_repo_id,
|
model_repo_id,
|
||||||
use_auth_token=hugging_face_token,
|
use_auth_token=hugging_face_token,
|
||||||
@ -107,6 +97,10 @@ class StableDiffusion:
|
|||||||
import diffusers
|
import diffusers
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
use_vae = os.environ["USE_VAE"] == "true"
|
||||||
|
self.upscaler = os.environ["UPSCALER"]
|
||||||
|
self.use_face_enhancer = os.environ["USE_FACE_ENHANCER"] == "true"
|
||||||
|
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"])
|
cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"])
|
||||||
if os.path.exists(cache_path):
|
if os.path.exists(cache_path):
|
||||||
print(f"The directory '{cache_path}' exists.")
|
print(f"The directory '{cache_path}' exists.")
|
||||||
@ -122,12 +116,14 @@ class StableDiffusion:
|
|||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
|
# TODO: Add support for other schedulers.
|
||||||
|
# self.pipe.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
|
||||||
|
self.pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(
|
||||||
cache_path,
|
cache_path,
|
||||||
subfolder="scheduler",
|
subfolder="scheduler",
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.environ["USE_VAE"] == "true":
|
if use_vae:
|
||||||
self.pipe.vae = diffusers.AutoencoderKL.from_pretrained(
|
self.pipe.vae = diffusers.AutoencoderKL.from_pretrained(
|
||||||
cache_path,
|
cache_path,
|
||||||
subfolder="vae",
|
subfolder="vae",
|
||||||
@ -194,7 +190,7 @@ class StableDiffusion:
|
|||||||
generator = torch.Generator("cuda").manual_seed(inputs["seed"])
|
generator = torch.Generator("cuda").manual_seed(inputs["seed"])
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
base_images = self.pipe(
|
base_images = self.pipe.text2img(
|
||||||
[inputs["prompt"]] * int(inputs["batch_size"]),
|
[inputs["prompt"]] * int(inputs["batch_size"]),
|
||||||
negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]),
|
negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]),
|
||||||
height=inputs["height"],
|
height=inputs["height"],
|
||||||
@ -205,10 +201,9 @@ class StableDiffusion:
|
|||||||
generator=generator,
|
generator=generator,
|
||||||
).images
|
).images
|
||||||
|
|
||||||
if inputs["upscaler"] != "":
|
if self.upscaler != "":
|
||||||
uplcaled_images = self.upscale(
|
uplcaled_images = self.upscale(
|
||||||
base_images=base_images,
|
base_images=base_images,
|
||||||
model_name="RealESRGAN_x4plus",
|
|
||||||
scale_factor=4,
|
scale_factor=4,
|
||||||
half_precision=False,
|
half_precision=False,
|
||||||
tile=700,
|
tile=700,
|
||||||
@ -227,7 +222,6 @@ class StableDiffusion:
|
|||||||
def upscale(
|
def upscale(
|
||||||
self,
|
self,
|
||||||
base_images: list[Image.Image],
|
base_images: list[Image.Image],
|
||||||
model_name: str = "RealESRGAN_x4plus",
|
|
||||||
scale_factor: float = 4,
|
scale_factor: float = 4,
|
||||||
half_precision: bool = False,
|
half_precision: bool = False,
|
||||||
tile: int = 0,
|
tile: int = 0,
|
||||||
@ -245,6 +239,7 @@ class StableDiffusion:
|
|||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
model_name = self.upscaler
|
||||||
if model_name == "RealESRGAN_x4plus":
|
if model_name == "RealESRGAN_x4plus":
|
||||||
upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||||
netscale = 4
|
netscale = 4
|
||||||
@ -272,14 +267,35 @@ class StableDiffusion:
|
|||||||
gpu_id=None,
|
gpu_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from gfpgan import GFPGANer
|
||||||
|
|
||||||
|
if self.use_face_enhancer:
|
||||||
|
face_enhancer = GFPGANer(
|
||||||
|
model_path=os.path.join(BASE_CACHE_PATH, "esrgan", "GFPGANv1.3.pth"),
|
||||||
|
upscale=netscale,
|
||||||
|
arch="clean",
|
||||||
|
channel_multiplier=2,
|
||||||
|
bg_upsampler=upsampler,
|
||||||
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
upscaled_imgs = []
|
upscaled_imgs = []
|
||||||
with tqdm(total=len(base_images)) as progress_bar:
|
with tqdm(total=len(base_images)) as progress_bar:
|
||||||
for i, img in enumerate(base_images):
|
for i, img in enumerate(base_images):
|
||||||
img = numpy.array(img)
|
img = numpy.array(img)
|
||||||
enhance_result = upsampler.enhance(img)[0]
|
if self.use_face_enhancer:
|
||||||
|
_, _, enhance_result = face_enhancer.enhance(
|
||||||
|
img,
|
||||||
|
has_aligned=False,
|
||||||
|
only_center_face=False,
|
||||||
|
paste_back=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
enhance_result, _ = upsampler.enhance(img)
|
||||||
|
|
||||||
upscaled_imgs.append(Image.fromarray(enhance_result))
|
upscaled_imgs.append(Image.fromarray(enhance_result))
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return upscaled_imgs
|
return upscaled_imgs
|
||||||
@ -289,7 +305,6 @@ class StableDiffusion:
|
|||||||
def entrypoint(
|
def entrypoint(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
n_prompt: str,
|
n_prompt: str,
|
||||||
upscaler: str,
|
|
||||||
height: int = 512,
|
height: int = 512,
|
||||||
width: int = 512,
|
width: int = 512,
|
||||||
samples: int = 5,
|
samples: int = 5,
|
||||||
@ -302,6 +317,7 @@ def entrypoint(
|
|||||||
The function pass the given prompt to StableDiffusion on Modal,
|
The function pass the given prompt to StableDiffusion on Modal,
|
||||||
gets back a list of images and outputs images to local.
|
gets back a list of images and outputs images to local.
|
||||||
"""
|
"""
|
||||||
|
import util
|
||||||
|
|
||||||
inputs: dict[str, int | str] = {
|
inputs: dict[str, int | str] = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
@ -311,7 +327,6 @@ def entrypoint(
|
|||||||
"samples": samples,
|
"samples": samples,
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
"upscaler": upscaler,
|
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user