diff --git a/.gitignore b/.gitignore index 4ab8942..421016c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .DS_Store +.mypy_cache/ __pycache__/ outputs/ .env diff --git a/Dockerfile b/Dockerfile index 464b462..76d77a0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,10 @@ FROM python:3.11.3-slim-bullseye COPY requirements.txt / RUN apt update \ - && apt install -y wget git \ - && pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117 --pre xformers + && apt install -y wget git libgl1-mesa-glx libglib2.0-0 \ + && pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117 \ + && mkdir -p /vol/cache/esrgan \ + && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P /vol/cache/esrgan \ + && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P /vol/cache/esrgan \ + && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P /vol/cache/esrgan \ + && wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan diff --git a/Makefile b/Makefile index 584f953..16c1533 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,9 @@ run: modal run sd_cli.py \ - --prompt "a woman with bob hair" \ - --n-prompt "" \ - --height 768 \ - --width 512 \ - --samples 5 + --prompt "A woman with bob hair" \ + --n-prompt "" \ + --height 768 \ + --width 512 \ + --samples 5 \ + --steps 50 \ + --upscaler "RealESRGAN_x4plus_anime_6B" diff --git a/README.md b/README.md index 7fc3861..c5badc1 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ This is the script to execute Stable Diffusion on [Modal](https://modal.com/). The app requires the following to run: -- python: v3.10 > +- python: > 3.10 - modal-client - A token for Modal. diff --git a/requirements.txt b/requirements.txt index 98c8d5e..01919f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,17 @@ accelerate -scipy -diffusers[torch] -safetensors +diffusers[torch]==0.16.1 +onnxruntime==1.15.0 +safetensors==0.3.1 torch==2.0.1+cu117 +transformers==4.29.2 +xformers==0.0.20 + +realesrgan +basicsr>=1.4.2 +facexlib>=0.2.5 +gfpgan>=1.3.5 +numpy +opencv-python +Pillow torchvision -torchmetrics -omegaconf -transformers +tqdm diff --git a/sd_cli.py b/sd_cli.py index 6cc570e..0795a74 100644 --- a/sd_cli.py +++ b/sd_cli.py @@ -1,12 +1,12 @@ from __future__ import annotations + import io import os import time -from datetime import date -from pathlib import Path -from modal import Image, Secret, Stub, method, Mount -stub = Stub("stable-diffusion-cli") +from modal import Image, Mount, Secret, Stub, method + +import util BASE_CACHE_PATH = "/vol/cache" @@ -18,10 +18,17 @@ def download_models(): """ import diffusers - hugging_face_token = os.environ["HUGGINGFACE_TOKEN"] + hugging_face_token = os.environ["HUGGING_FACE_TOKEN"] model_repo_id = os.environ["MODEL_REPO_ID"] cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"]) + vae = diffusers.AutoencoderKL.from_pretrained( + "stabilityai/sd-vae-ft-mse", + use_auth_token=hugging_face_token, + cache_dir=cache_path, + ) + vae.save_pretrained(cache_path, safe_serialization=True) + scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( model_repo_id, subfolder="scheduler", @@ -45,6 +52,7 @@ stub_image = Image.from_dockerfile( download_models, secrets=[Secret.from_dotenv(__file__)], ) +stub = Stub("stable-diffusion-cli") stub.image = stub_image @@ -67,6 +75,11 @@ class StableDiffusion: torch.backends.cuda.matmul.allow_tf32 = True + vae = diffusers.AutoencoderKL.from_pretrained( + cache_path, + subfolder="vae", + ) + scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( cache_path, subfolder="scheduler", @@ -75,21 +88,14 @@ class StableDiffusion: self.pipe = diffusers.StableDiffusionPipeline.from_pretrained( cache_path, scheduler=scheduler, + vae=vae, custom_pipeline="lpw_stable_diffusion", + torch_dtype=torch.float16, ).to("cuda") self.pipe.enable_xformers_memory_efficient_attention() @method() - def run_inference( - self, - prompt: str, - n_prompt: str, - steps: int = 30, - batch_size: int = 1, - height: int = 512, - width: int = 512, - max_embeddings_multiples: int = 1, - ) -> list[bytes]: + def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]: """ Runs the Stable Diffusion pipeline on the given prompt and outputs images. """ @@ -97,82 +103,134 @@ class StableDiffusion: with torch.inference_mode(): with torch.autocast("cuda"): - images = self.pipe( - [prompt] * batch_size, - negative_prompt=[n_prompt] * batch_size, - height=height, - width=width, - num_inference_steps=steps, + base_images = self.pipe( + [inputs["prompt"]] * int(inputs["batch_size"]), + negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]), + height=inputs["height"], + width=inputs["width"], + num_inference_steps=inputs["steps"], guidance_scale=7.5, - max_embeddings_multiples=max_embeddings_multiples, + max_embeddings_multiples=inputs["max_embeddings_multiples"], ).images + if inputs["upscaler"] != "": + uplcaled_images = self.upscale( + base_images=base_images, + model_name="RealESRGAN_x4plus", + scale_factor=4, + half_precision=False, + tile=700, + ) + base_images.extend(uplcaled_images) + image_output = [] - for image in images: + for image in base_images: with io.BytesIO() as buf: image.save(buf, format="PNG") image_output.append(buf.getvalue()) + return image_output + @method() + def upscale( + self, + base_images: list[Image.Image], + model_name: str = "RealESRGAN_x4plus", + scale_factor: float = 4, + half_precision: bool = False, + tile: int = 0, + tile_pad: int = 10, + pre_pad: int = 0, + ) -> list[Image.Image]: + """ + Upscales the given images using the given model. + https://github.com/xinntao/Real-ESRGAN + """ + import numpy + import torch + from basicsr.archs.rrdbnet_arch import RRDBNet + from PIL import Image + from realesrgan import RealESRGANer + from tqdm import tqdm + + if model_name == "RealESRGAN_x4plus": + upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + elif model_name == "RealESRNet_x4plus": + upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + elif model_name == "RealESRGAN_x4plus_anime_6B": + upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + netscale = 4 + elif model_name == "RealESRGAN_x2plus": + upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + netscale = 2 + else: + raise NotImplementedError("Model name not supported") + + upsampler = RealESRGANer( + scale=netscale, + model_path=os.path.join(BASE_CACHE_PATH, "esrgan", f"{model_name}.pth"), + dni_weight=None, + model=upscale_model, + tile=tile, + tile_pad=tile_pad, + pre_pad=pre_pad, + half=half_precision, + gpu_id=None, + ) + + torch.cuda.empty_cache() + upscaled_imgs = [] + with tqdm(total=len(base_images)) as progress_bar: + for i, img in enumerate(base_images): + img = numpy.array(img) + enhance_result = upsampler.enhance(img)[0] + upscaled_imgs.append(Image.fromarray(enhance_result)) + progress_bar.update(1) + torch.cuda.empty_cache() + + return upscaled_imgs + @stub.local_entrypoint() def entrypoint( prompt: str, n_prompt: str, - samples: int = 5, - steps: int = 30, - batch_size: int = 1, height: int = 512, width: int = 512, + samples: int = 5, + batch_size: int = 1, + steps: int = 20, + upscaler: str = "", ): """ This function is the entrypoint for the Runway CLI. The function pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local. - - The function is called with the following arguments: - - prompt: the prompt to run inference on - - n_prompt: the negative prompt to run inference on - - samples: the number of samples to generate - - steps: the number of steps to run inference for - - batch_size: the batch size to use - - height: the height of the output image - - width: the width of the output image """ - print(f"steps => {steps}, sapmles => {samples}, batch_size => {batch_size}") - max_embeddings_multiples = 1 - token_count = len(prompt.split()) - if token_count > 77: - max_embeddings_multiples = token_count // 77 + 1 + inputs: dict[str, int | str] = { + "prompt": prompt, + "n_prompt": n_prompt, + "height": height, + "width": width, + "samples": samples, + "batch_size": batch_size, + "steps": steps, + "upscaler": upscaler, # sd_x2_latent_upscaler, sd_x4_upscaler + # seed=-1 + } - print( - f"token_count => {token_count}, max_embeddings_multiples => {max_embeddings_multiples}" - ) + inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt) + directory = util.make_directory() - directory = Path(f"./outputs/{date.today().strftime('%Y-%m-%d')}") - if not directory.exists(): - directory.mkdir(exist_ok=True, parents=True) - - stable_diffusion = StableDiffusion() + sd = StableDiffusion() for i in range(samples): start_time = time.time() - images = stable_diffusion.run_inference.call( - prompt, - n_prompt, - steps, - batch_size, - height, - width, - max_embeddings_multiples, - ) + images = sd.run_inference.call(inputs) + util.save_images(directory, images, i) total_time = time.time() - start_time - print( - f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)." - ) - for j, image_bytes in enumerate(images): - formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) - output_path = directory / f"{formatted_time}_{i}_{j}.png" - print(f"Saving it to {output_path}") - with open(output_path, "wb") as file: - file.write(image_bytes) + print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).") + + util.save_prompts(inputs) diff --git a/util.py b/util.py new file mode 100644 index 0000000..affb2a0 --- /dev/null +++ b/util.py @@ -0,0 +1,66 @@ +""" Utility functions for the script. """ +import time +from datetime import date +from pathlib import Path + +from PIL import Image + +OUTPUT_DIRECTORY = "outputs" +DATE_TODAY = date.today().strftime("%Y-%m-%d") + + +def make_directory() -> Path: + """ + Make a directory for saving outputs. + """ + directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}") + if not directory.exists(): + directory.mkdir(exist_ok=True, parents=True) + print(f"Make directory: {directory}") + + return directory + + +def save_prompts(inputs: dict): + """ + Save prompts to a file. + """ + prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) + with open( + file=f"{OUTPUT_DIRECTORY}/{DATE_TODAY}/prompts_{prompts_filename}.txt", mode="w", encoding="utf-8" + ) as file: + for name, value in inputs.items(): + file.write(f"{name} = {repr(value)}\n") + print(f"Save prompts: {prompts_filename}.txt") + + +def count_token(p: str, n: str) -> int: + """ + Count the number of tokens in the prompt and negative prompt. + """ + token_count_p = len(p.split()) + token_count_n = len(n.split()) + if token_count_p >= token_count_n: + token_count = token_count_p + else: + token_count = token_count_n + + max_embeddings_multiples = 1 + if token_count > 77: + max_embeddings_multiples = token_count // 77 + 1 + + print(f"token_count: {token_count}, max_embeddings_multiples: {max_embeddings_multiples}") + + return max_embeddings_multiples + + +def save_images(directory: Path, images: list[bytes], i: int): + """ + Save images to a file. + """ + for j, image_bytes in enumerate(images): + formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) + output_path = directory / f"{formatted_time}_{i}_{j}.png" + print(f"Saving it to {output_path}") + with open(output_path, "wb") as file: + file.write(image_bytes)