diff --git a/Dockerfile b/Dockerfile index 464b462..76d77a0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,10 @@ FROM python:3.11.3-slim-bullseye COPY requirements.txt / RUN apt update \ - && apt install -y wget git \ - && pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117 --pre xformers + && apt install -y wget git libgl1-mesa-glx libglib2.0-0 \ + && pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117 \ + && mkdir -p /vol/cache/esrgan \ + && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P /vol/cache/esrgan \ + && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P /vol/cache/esrgan \ + && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P /vol/cache/esrgan \ + && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan diff --git a/Makefile b/Makefile index c340b3e..16c1533 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ run: modal run sd_cli.py \ - --prompt "a woman with bob hair" \ + --prompt "A woman with bob hair" \ --n-prompt "" \ --height 768 \ --width 512 \ --samples 5 \ - --steps 20 \ - --upscaler "sd_x2_latent_upscaler" + --steps 50 \ + --upscaler "RealESRGAN_x4plus_anime_6B" diff --git a/requirements.txt b/requirements.txt index 98c8d5e..01919f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,17 @@ accelerate -scipy -diffusers[torch] -safetensors +diffusers[torch]==0.16.1 +onnxruntime==1.15.0 +safetensors==0.3.1 torch==2.0.1+cu117 +transformers==4.29.2 +xformers==0.0.20 + +realesrgan +basicsr>=1.4.2 +facexlib>=0.2.5 +gfpgan>=1.3.5 +numpy +opencv-python +Pillow torchvision -torchmetrics -omegaconf -transformers +tqdm diff --git a/sd_cli.py b/sd_cli.py index 1275609..0795a74 100644 --- a/sd_cli.py +++ b/sd_cli.py @@ -22,6 +22,13 @@ def download_models(): model_repo_id = os.environ["MODEL_REPO_ID"] cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"]) + vae = diffusers.AutoencoderKL.from_pretrained( + "stabilityai/sd-vae-ft-mse", + use_auth_token=hugging_face_token, + cache_dir=cache_path, + ) + vae.save_pretrained(cache_path, safe_serialization=True) + scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( model_repo_id, subfolder="scheduler", @@ -68,6 +75,11 @@ class StableDiffusion: torch.backends.cuda.matmul.allow_tf32 = True + vae = diffusers.AutoencoderKL.from_pretrained( + cache_path, + subfolder="vae", + ) + scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( cache_path, subfolder="scheduler", @@ -76,21 +88,12 @@ class StableDiffusion: self.pipe = diffusers.StableDiffusionPipeline.from_pretrained( cache_path, scheduler=scheduler, + vae=vae, custom_pipeline="lpw_stable_diffusion", + torch_dtype=torch.float16, ).to("cuda") self.pipe.enable_xformers_memory_efficient_attention() - self.upscaler = diffusers.StableDiffusionLatentUpscalePipeline.from_pretrained( - "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16 - ).to("cuda") - self.upscaler.enable_xformers_memory_efficient_attention() - - # model_id = "stabilityai/stable-diffusion-x4-upscaler" - # self.upscaler = diffusers.StableDiffusionUpscalePipeline.from_pretrained( - # , revision="fp16", torch_dtype=torch.float16 - # ).to("cuda") - # self.upscaler.enable_xformers_memory_efficient_attention() - @method() def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]: """ @@ -100,7 +103,7 @@ class StableDiffusion: with torch.inference_mode(): with torch.autocast("cuda"): - images = self.pipe( + base_images = self.pipe( [inputs["prompt"]] * int(inputs["batch_size"]), negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]), height=inputs["height"], @@ -110,26 +113,85 @@ class StableDiffusion: max_embeddings_multiples=inputs["max_embeddings_multiples"], ).images + if inputs["upscaler"] != "": + uplcaled_images = self.upscale( + base_images=base_images, + model_name="RealESRGAN_x4plus", + scale_factor=4, + half_precision=False, + tile=700, + ) + base_images.extend(uplcaled_images) + image_output = [] - for image in images: + for image in base_images: with io.BytesIO() as buf: image.save(buf, format="PNG") image_output.append(buf.getvalue()) - if inputs["upscaler"] != "": - upscaled_images = self.upscaler( - prompt=inputs["prompt"], - image=images, - num_inference_steps=inputs["steps"], - guidance_scale=0, - ).images - for image in upscaled_images: - with io.BytesIO() as buf: - image.save(buf, format="PNG") - image_output.append(buf.getvalue()) - return image_output + @method() + def upscale( + self, + base_images: list[Image.Image], + model_name: str = "RealESRGAN_x4plus", + scale_factor: float = 4, + half_precision: bool = False, + tile: int = 0, + tile_pad: int = 10, + pre_pad: int = 0, + ) -> list[Image.Image]: + """ + Upscales the given images using the given model. + https://github.com/xinntao/Real-ESRGAN + """ + import numpy + import torch + from basicsr.archs.rrdbnet_arch import RRDBNet + from PIL import Image + from realesrgan import RealESRGANer + from tqdm import tqdm + + if model_name == "RealESRGAN_x4plus": + upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + elif model_name == "RealESRNet_x4plus": + upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + elif model_name == "RealESRGAN_x4plus_anime_6B": + upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + netscale = 4 + elif model_name == "RealESRGAN_x2plus": + upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + netscale = 2 + else: + raise NotImplementedError("Model name not supported") + + upsampler = RealESRGANer( + scale=netscale, + model_path=os.path.join(BASE_CACHE_PATH, "esrgan", f"{model_name}.pth"), + dni_weight=None, + model=upscale_model, + tile=tile, + tile_pad=tile_pad, + pre_pad=pre_pad, + half=half_precision, + gpu_id=None, + ) + + torch.cuda.empty_cache() + upscaled_imgs = [] + with tqdm(total=len(base_images)) as progress_bar: + for i, img in enumerate(base_images): + img = numpy.array(img) + enhance_result = upsampler.enhance(img)[0] + upscaled_imgs.append(Image.fromarray(enhance_result)) + progress_bar.update(1) + torch.cuda.empty_cache() + + return upscaled_imgs + @stub.local_entrypoint() def entrypoint(