Split a class 'StableDiffusion' from main.py

This commit is contained in:
hodanov 2023-07-09 23:22:30 +09:00
parent 24733b3bf7
commit d28959cf3c
4 changed files with 58 additions and 13 deletions

View File

View File

@ -1,6 +1,7 @@
import time import time
import modal import modal
import util
stub = modal.Stub("run-stable-diffusion-cli") stub = modal.Stub("run-stable-diffusion-cli")
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference") stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference")
@ -25,8 +26,6 @@ def main(
The function pass the given prompt to StableDiffusion on Modal, The function pass the given prompt to StableDiffusion on Modal,
gets back a list of images and outputs images to local. gets back a list of images and outputs images to local.
""" """
import util
directory = util.make_directory() directory = util.make_directory()
seed_generated = seed seed_generated = seed
for i in range(samples): for i in range(samples):

18
setup_files/__main__.py Normal file
View File

@ -0,0 +1,18 @@
from __future__ import annotations
from setup import stub
from txt2img import StableDiffusion, StableDiffusionInterface
def new_stable_diffusion() -> StableDiffusion:
return StableDiffusion()
@stub.function(gpu="A10G")
def main():
sd = new_stable_diffusion()
print(isinstance(sd, StableDiffusion))
if __name__ == "__main__":
main()

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import abc
import io import io
import os import os
@ -7,18 +8,48 @@ import diffusers
import PIL.Image import PIL.Image
import torch import torch
from modal import Secret, method from modal import Secret, method
from modal.cls import ClsMixin
from setup import (BASE_CACHE_PATH, BASE_CACHE_PATH_CONTROLNET, from setup import (BASE_CACHE_PATH, BASE_CACHE_PATH_CONTROLNET,
BASE_CACHE_PATH_LORA, BASE_CACHE_PATH_TEXTUAL_INVERSION, BASE_CACHE_PATH_LORA, BASE_CACHE_PATH_TEXTUAL_INVERSION,
stub) stub)
class StableDiffusionInterface(metaclass=abc.ABCMeta):
"""
A StableDiffusionInterface is an interface that will be used for StableDiffusion class creation.
"""
@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "run_inference") and callable(subclass.run_inference)
@abc.abstractmethod
@method()
def run_inference(
self,
prompt: str,
n_prompt: str,
height: int = 512,
width: int = 512,
samples: int = 1,
batch_size: int = 1,
steps: int = 30,
seed: int = 1,
upscaler: str = "",
use_face_enhancer: bool = False,
fix_by_controlnet_tile: bool = False,
) -> list[bytes]:
"""
Run inference.
"""
raise NotImplementedError
@stub.cls( @stub.cls(
gpu="A10G", gpu="A10G",
secrets=[Secret.from_dotenv(__file__)], secrets=[Secret.from_dotenv(__file__)],
) )
class StableDiffusion(ClsMixin): class StableDiffusion(StableDiffusionInterface):
""" """
A class that wraps the Stable Diffusion pipeline and scheduler. A class that wraps the Stable Diffusion pipeline and scheduler.
""" """
@ -97,8 +128,7 @@ class StableDiffusion(ClsMixin):
self.controlnet_pipe.to("cuda") self.controlnet_pipe.to("cuda")
self.controlnet_pipe.enable_xformers_memory_efficient_attention() self.controlnet_pipe.enable_xformers_memory_efficient_attention()
@method() def _count_token(self, p: str, n: str) -> int:
def count_token(self, p: str, n: str) -> int:
""" """
Count the number of tokens in the prompt and negative prompt. Count the number of tokens in the prompt and negative prompt.
""" """
@ -142,7 +172,7 @@ class StableDiffusion(ClsMixin):
Runs the Stable Diffusion pipeline on the given prompt and outputs images. Runs the Stable Diffusion pipeline on the given prompt and outputs images.
""" """
max_embeddings_multiples = self.count_token(p=prompt, n=n_prompt) max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt)
generator = torch.Generator("cuda").manual_seed(seed) generator = torch.Generator("cuda").manual_seed(seed)
with torch.inference_mode(): with torch.inference_mode():
with torch.autocast("cuda"): with torch.autocast("cuda"):
@ -165,7 +195,7 @@ class StableDiffusion(ClsMixin):
""" """
if fix_by_controlnet_tile: if fix_by_controlnet_tile:
for image in base_images: for image in base_images:
image = self.resize_image(image=image, scale_factor=2) image = self._resize_image(image=image, scale_factor=2)
with torch.inference_mode(): with torch.inference_mode():
with torch.autocast("cuda"): with torch.autocast("cuda"):
fixed_by_controlnet = self.controlnet_pipe( fixed_by_controlnet = self.controlnet_pipe(
@ -182,7 +212,7 @@ class StableDiffusion(ClsMixin):
base_images = fixed_by_controlnet base_images = fixed_by_controlnet
if upscaler != "": if upscaler != "":
upscaled = self.upscale( upscaled = self._upscale(
base_images=base_images, base_images=base_images,
half_precision=False, half_precision=False,
tile=700, tile=700,
@ -199,15 +229,13 @@ class StableDiffusion(ClsMixin):
return image_output return image_output
@method() def _resize_image(self, image: PIL.Image.Image, scale_factor: int) -> PIL.Image.Image:
def resize_image(self, image: PIL.Image.Image, scale_factor: int) -> PIL.Image.Image:
image = image.convert("RGB") image = image.convert("RGB")
width, height = image.size width, height = image.size
img = image.resize((width * scale_factor, height * scale_factor), resample=PIL.Image.LANCZOS) img = image.resize((width * scale_factor, height * scale_factor), resample=PIL.Image.LANCZOS)
return img return img
@method() def _upscale(
def upscale(
self, self,
base_images: list[PIL.Image], base_images: list[PIL.Image],
half_precision: bool = False, half_precision: bool = False,