Merge branch 'main' of github.com:hodanov/a-script-for-running-sd-on-modal

This commit is contained in:
hodanov 2023-06-12 11:29:55 +09:00
commit ce406d7def
7 changed files with 220 additions and 80 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
.DS_Store .DS_Store
.mypy_cache/
__pycache__/ __pycache__/
outputs/ outputs/
.env .env

View File

@ -1,5 +1,10 @@
FROM python:3.11.3-slim-bullseye FROM python:3.11.3-slim-bullseye
COPY requirements.txt / COPY requirements.txt /
RUN apt update \ RUN apt update \
&& apt install -y wget git \ && apt install -y wget git libgl1-mesa-glx libglib2.0-0 \
&& pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117 --pre xformers && pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117 \
&& mkdir -p /vol/cache/esrgan \
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P /vol/cache/esrgan \
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P /vol/cache/esrgan \
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P /vol/cache/esrgan \
&& wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P /vol/cache/esrgan

View File

@ -1,7 +1,9 @@
run: run:
modal run sd_cli.py \ modal run sd_cli.py \
--prompt "a woman with bob hair" \ --prompt "A woman with bob hair" \
--n-prompt "" \ --n-prompt "" \
--height 768 \ --height 768 \
--width 512 \ --width 512 \
--samples 5 --samples 5 \
--steps 50 \
--upscaler "RealESRGAN_x4plus_anime_6B"

View File

@ -6,7 +6,7 @@ This is the script to execute Stable Diffusion on [Modal](https://modal.com/).
The app requires the following to run: The app requires the following to run:
- python: v3.10 > - python: > 3.10
- modal-client - modal-client
- A token for Modal. - A token for Modal.

View File

@ -1,9 +1,17 @@
accelerate accelerate
scipy diffusers[torch]==0.16.1
diffusers[torch] onnxruntime==1.15.0
safetensors safetensors==0.3.1
torch==2.0.1+cu117 torch==2.0.1+cu117
transformers==4.29.2
xformers==0.0.20
realesrgan
basicsr>=1.4.2
facexlib>=0.2.5
gfpgan>=1.3.5
numpy
opencv-python
Pillow
torchvision torchvision
torchmetrics tqdm
omegaconf
transformers

190
sd_cli.py
View File

@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
import io import io
import os import os
import time import time
from datetime import date
from pathlib import Path
from modal import Image, Secret, Stub, method, Mount
stub = Stub("stable-diffusion-cli") from modal import Image, Mount, Secret, Stub, method
import util
BASE_CACHE_PATH = "/vol/cache" BASE_CACHE_PATH = "/vol/cache"
@ -18,10 +18,17 @@ def download_models():
""" """
import diffusers import diffusers
hugging_face_token = os.environ["HUGGINGFACE_TOKEN"] hugging_face_token = os.environ["HUGGING_FACE_TOKEN"]
model_repo_id = os.environ["MODEL_REPO_ID"] model_repo_id = os.environ["MODEL_REPO_ID"]
cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"]) cache_path = os.path.join(BASE_CACHE_PATH, os.environ["MODEL_NAME"])
vae = diffusers.AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
use_auth_token=hugging_face_token,
cache_dir=cache_path,
)
vae.save_pretrained(cache_path, safe_serialization=True)
scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
model_repo_id, model_repo_id,
subfolder="scheduler", subfolder="scheduler",
@ -45,6 +52,7 @@ stub_image = Image.from_dockerfile(
download_models, download_models,
secrets=[Secret.from_dotenv(__file__)], secrets=[Secret.from_dotenv(__file__)],
) )
stub = Stub("stable-diffusion-cli")
stub.image = stub_image stub.image = stub_image
@ -67,6 +75,11 @@ class StableDiffusion:
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
vae = diffusers.AutoencoderKL.from_pretrained(
cache_path,
subfolder="vae",
)
scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained( scheduler = diffusers.EulerAncestralDiscreteScheduler.from_pretrained(
cache_path, cache_path,
subfolder="scheduler", subfolder="scheduler",
@ -75,21 +88,14 @@ class StableDiffusion:
self.pipe = diffusers.StableDiffusionPipeline.from_pretrained( self.pipe = diffusers.StableDiffusionPipeline.from_pretrained(
cache_path, cache_path,
scheduler=scheduler, scheduler=scheduler,
vae=vae,
custom_pipeline="lpw_stable_diffusion", custom_pipeline="lpw_stable_diffusion",
torch_dtype=torch.float16,
).to("cuda") ).to("cuda")
self.pipe.enable_xformers_memory_efficient_attention() self.pipe.enable_xformers_memory_efficient_attention()
@method() @method()
def run_inference( def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]:
self,
prompt: str,
n_prompt: str,
steps: int = 30,
batch_size: int = 1,
height: int = 512,
width: int = 512,
max_embeddings_multiples: int = 1,
) -> list[bytes]:
""" """
Runs the Stable Diffusion pipeline on the given prompt and outputs images. Runs the Stable Diffusion pipeline on the given prompt and outputs images.
""" """
@ -97,82 +103,134 @@ class StableDiffusion:
with torch.inference_mode(): with torch.inference_mode():
with torch.autocast("cuda"): with torch.autocast("cuda"):
images = self.pipe( base_images = self.pipe(
[prompt] * batch_size, [inputs["prompt"]] * int(inputs["batch_size"]),
negative_prompt=[n_prompt] * batch_size, negative_prompt=[inputs["n_prompt"]] * int(inputs["batch_size"]),
height=height, height=inputs["height"],
width=width, width=inputs["width"],
num_inference_steps=steps, num_inference_steps=inputs["steps"],
guidance_scale=7.5, guidance_scale=7.5,
max_embeddings_multiples=max_embeddings_multiples, max_embeddings_multiples=inputs["max_embeddings_multiples"],
).images ).images
if inputs["upscaler"] != "":
uplcaled_images = self.upscale(
base_images=base_images,
model_name="RealESRGAN_x4plus",
scale_factor=4,
half_precision=False,
tile=700,
)
base_images.extend(uplcaled_images)
image_output = [] image_output = []
for image in images: for image in base_images:
with io.BytesIO() as buf: with io.BytesIO() as buf:
image.save(buf, format="PNG") image.save(buf, format="PNG")
image_output.append(buf.getvalue()) image_output.append(buf.getvalue())
return image_output return image_output
@method()
def upscale(
self,
base_images: list[Image.Image],
model_name: str = "RealESRGAN_x4plus",
scale_factor: float = 4,
half_precision: bool = False,
tile: int = 0,
tile_pad: int = 10,
pre_pad: int = 0,
) -> list[Image.Image]:
"""
Upscales the given images using the given model.
https://github.com/xinntao/Real-ESRGAN
"""
import numpy
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from realesrgan import RealESRGANer
from tqdm import tqdm
if model_name == "RealESRGAN_x4plus":
upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
elif model_name == "RealESRNet_x4plus":
upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
elif model_name == "RealESRGAN_x4plus_anime_6B":
upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
elif model_name == "RealESRGAN_x2plus":
upscale_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
else:
raise NotImplementedError("Model name not supported")
upsampler = RealESRGANer(
scale=netscale,
model_path=os.path.join(BASE_CACHE_PATH, "esrgan", f"{model_name}.pth"),
dni_weight=None,
model=upscale_model,
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
half=half_precision,
gpu_id=None,
)
torch.cuda.empty_cache()
upscaled_imgs = []
with tqdm(total=len(base_images)) as progress_bar:
for i, img in enumerate(base_images):
img = numpy.array(img)
enhance_result = upsampler.enhance(img)[0]
upscaled_imgs.append(Image.fromarray(enhance_result))
progress_bar.update(1)
torch.cuda.empty_cache()
return upscaled_imgs
@stub.local_entrypoint() @stub.local_entrypoint()
def entrypoint( def entrypoint(
prompt: str, prompt: str,
n_prompt: str, n_prompt: str,
samples: int = 5,
steps: int = 30,
batch_size: int = 1,
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
samples: int = 5,
batch_size: int = 1,
steps: int = 20,
upscaler: str = "",
): ):
""" """
This function is the entrypoint for the Runway CLI. This function is the entrypoint for the Runway CLI.
The function pass the given prompt to StableDiffusion on Modal, The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local. gets back a list of images and outputs images to local.
The function is called with the following arguments:
- prompt: the prompt to run inference on
- n_prompt: the negative prompt to run inference on
- samples: the number of samples to generate
- steps: the number of steps to run inference for
- batch_size: the batch size to use
- height: the height of the output image
- width: the width of the output image
""" """
print(f"steps => {steps}, sapmles => {samples}, batch_size => {batch_size}")
max_embeddings_multiples = 1 inputs: dict[str, int | str] = {
token_count = len(prompt.split()) "prompt": prompt,
if token_count > 77: "n_prompt": n_prompt,
max_embeddings_multiples = token_count // 77 + 1 "height": height,
"width": width,
"samples": samples,
"batch_size": batch_size,
"steps": steps,
"upscaler": upscaler, # sd_x2_latent_upscaler, sd_x4_upscaler
# seed=-1
}
print( inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt)
f"token_count => {token_count}, max_embeddings_multiples => {max_embeddings_multiples}" directory = util.make_directory()
)
directory = Path(f"./outputs/{date.today().strftime('%Y-%m-%d')}") sd = StableDiffusion()
if not directory.exists():
directory.mkdir(exist_ok=True, parents=True)
stable_diffusion = StableDiffusion()
for i in range(samples): for i in range(samples):
start_time = time.time() start_time = time.time()
images = stable_diffusion.run_inference.call( images = sd.run_inference.call(inputs)
prompt, util.save_images(directory, images, i)
n_prompt,
steps,
batch_size,
height,
width,
max_embeddings_multiples,
)
total_time = time.time() - start_time total_time = time.time() - start_time
print( print(f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image).")
f"Sample {i} took {total_time:.3f}s ({(total_time)/len(images):.3f}s / image)."
) util.save_prompts(inputs)
for j, image_bytes in enumerate(images):
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = directory / f"{formatted_time}_{i}_{j}.png"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as file:
file.write(image_bytes)

66
util.py Normal file
View File

@ -0,0 +1,66 @@
""" Utility functions for the script. """
import time
from datetime import date
from pathlib import Path
from PIL import Image
OUTPUT_DIRECTORY = "outputs"
DATE_TODAY = date.today().strftime("%Y-%m-%d")
def make_directory() -> Path:
"""
Make a directory for saving outputs.
"""
directory = Path(f"{OUTPUT_DIRECTORY}/{DATE_TODAY}")
if not directory.exists():
directory.mkdir(exist_ok=True, parents=True)
print(f"Make directory: {directory}")
return directory
def save_prompts(inputs: dict):
"""
Save prompts to a file.
"""
prompts_filename = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
with open(
file=f"{OUTPUT_DIRECTORY}/{DATE_TODAY}/prompts_{prompts_filename}.txt", mode="w", encoding="utf-8"
) as file:
for name, value in inputs.items():
file.write(f"{name} = {repr(value)}\n")
print(f"Save prompts: {prompts_filename}.txt")
def count_token(p: str, n: str) -> int:
"""
Count the number of tokens in the prompt and negative prompt.
"""
token_count_p = len(p.split())
token_count_n = len(n.split())
if token_count_p >= token_count_n:
token_count = token_count_p
else:
token_count = token_count_n
max_embeddings_multiples = 1
if token_count > 77:
max_embeddings_multiples = token_count // 77 + 1
print(f"token_count: {token_count}, max_embeddings_multiples: {max_embeddings_multiples}")
return max_embeddings_multiples
def save_images(directory: Path, images: list[bytes], i: int):
"""
Save images to a file.
"""
for j, image_bytes in enumerate(images):
formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
output_path = directory / f"{formatted_time}_{i}_{j}.png"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as file:
file.write(image_bytes)