Split a class 'StableDiffusion' from main.py
This commit is contained in:
parent
24733b3bf7
commit
d28959cf3c
@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import modal
|
import modal
|
||||||
|
import util
|
||||||
|
|
||||||
stub = modal.Stub("run-stable-diffusion-cli")
|
stub = modal.Stub("run-stable-diffusion-cli")
|
||||||
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference")
|
stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference")
|
||||||
@ -25,8 +26,6 @@ def main(
|
|||||||
The function pass the given prompt to StableDiffusion on Modal,
|
The function pass the given prompt to StableDiffusion on Modal,
|
||||||
gets back a list of images and outputs images to local.
|
gets back a list of images and outputs images to local.
|
||||||
"""
|
"""
|
||||||
import util
|
|
||||||
|
|
||||||
directory = util.make_directory()
|
directory = util.make_directory()
|
||||||
seed_generated = seed
|
seed_generated = seed
|
||||||
for i in range(samples):
|
for i in range(samples):
|
||||||
|
|||||||
18
setup_files/__main__.py
Normal file
18
setup_files/__main__.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from setup import stub
|
||||||
|
from txt2img import StableDiffusion, StableDiffusionInterface
|
||||||
|
|
||||||
|
|
||||||
|
def new_stable_diffusion() -> StableDiffusion:
|
||||||
|
return StableDiffusion()
|
||||||
|
|
||||||
|
|
||||||
|
@stub.function(gpu="A10G")
|
||||||
|
def main():
|
||||||
|
sd = new_stable_diffusion()
|
||||||
|
print(isinstance(sd, StableDiffusion))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -7,18 +8,48 @@ import diffusers
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
from modal import Secret, method
|
from modal import Secret, method
|
||||||
from modal.cls import ClsMixin
|
|
||||||
|
|
||||||
from setup import (BASE_CACHE_PATH, BASE_CACHE_PATH_CONTROLNET,
|
from setup import (BASE_CACHE_PATH, BASE_CACHE_PATH_CONTROLNET,
|
||||||
BASE_CACHE_PATH_LORA, BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
BASE_CACHE_PATH_LORA, BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
||||||
stub)
|
stub)
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionInterface(metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
A StableDiffusionInterface is an interface that will be used for StableDiffusion class creation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __subclasshook__(cls, subclass):
|
||||||
|
return hasattr(subclass, "run_inference") and callable(subclass.run_inference)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
@method()
|
||||||
|
def run_inference(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
n_prompt: str,
|
||||||
|
height: int = 512,
|
||||||
|
width: int = 512,
|
||||||
|
samples: int = 1,
|
||||||
|
batch_size: int = 1,
|
||||||
|
steps: int = 30,
|
||||||
|
seed: int = 1,
|
||||||
|
upscaler: str = "",
|
||||||
|
use_face_enhancer: bool = False,
|
||||||
|
fix_by_controlnet_tile: bool = False,
|
||||||
|
) -> list[bytes]:
|
||||||
|
"""
|
||||||
|
Run inference.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@stub.cls(
|
@stub.cls(
|
||||||
gpu="A10G",
|
gpu="A10G",
|
||||||
secrets=[Secret.from_dotenv(__file__)],
|
secrets=[Secret.from_dotenv(__file__)],
|
||||||
)
|
)
|
||||||
class StableDiffusion(ClsMixin):
|
class StableDiffusion(StableDiffusionInterface):
|
||||||
"""
|
"""
|
||||||
A class that wraps the Stable Diffusion pipeline and scheduler.
|
A class that wraps the Stable Diffusion pipeline and scheduler.
|
||||||
"""
|
"""
|
||||||
@ -97,8 +128,7 @@ class StableDiffusion(ClsMixin):
|
|||||||
self.controlnet_pipe.to("cuda")
|
self.controlnet_pipe.to("cuda")
|
||||||
self.controlnet_pipe.enable_xformers_memory_efficient_attention()
|
self.controlnet_pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
@method()
|
def _count_token(self, p: str, n: str) -> int:
|
||||||
def count_token(self, p: str, n: str) -> int:
|
|
||||||
"""
|
"""
|
||||||
Count the number of tokens in the prompt and negative prompt.
|
Count the number of tokens in the prompt and negative prompt.
|
||||||
"""
|
"""
|
||||||
@ -142,7 +172,7 @@ class StableDiffusion(ClsMixin):
|
|||||||
Runs the Stable Diffusion pipeline on the given prompt and outputs images.
|
Runs the Stable Diffusion pipeline on the given prompt and outputs images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_embeddings_multiples = self.count_token(p=prompt, n=n_prompt)
|
max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt)
|
||||||
generator = torch.Generator("cuda").manual_seed(seed)
|
generator = torch.Generator("cuda").manual_seed(seed)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
@ -165,7 +195,7 @@ class StableDiffusion(ClsMixin):
|
|||||||
"""
|
"""
|
||||||
if fix_by_controlnet_tile:
|
if fix_by_controlnet_tile:
|
||||||
for image in base_images:
|
for image in base_images:
|
||||||
image = self.resize_image(image=image, scale_factor=2)
|
image = self._resize_image(image=image, scale_factor=2)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
fixed_by_controlnet = self.controlnet_pipe(
|
fixed_by_controlnet = self.controlnet_pipe(
|
||||||
@ -182,7 +212,7 @@ class StableDiffusion(ClsMixin):
|
|||||||
base_images = fixed_by_controlnet
|
base_images = fixed_by_controlnet
|
||||||
|
|
||||||
if upscaler != "":
|
if upscaler != "":
|
||||||
upscaled = self.upscale(
|
upscaled = self._upscale(
|
||||||
base_images=base_images,
|
base_images=base_images,
|
||||||
half_precision=False,
|
half_precision=False,
|
||||||
tile=700,
|
tile=700,
|
||||||
@ -199,15 +229,13 @@ class StableDiffusion(ClsMixin):
|
|||||||
|
|
||||||
return image_output
|
return image_output
|
||||||
|
|
||||||
@method()
|
def _resize_image(self, image: PIL.Image.Image, scale_factor: int) -> PIL.Image.Image:
|
||||||
def resize_image(self, image: PIL.Image.Image, scale_factor: int) -> PIL.Image.Image:
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
img = image.resize((width * scale_factor, height * scale_factor), resample=PIL.Image.LANCZOS)
|
img = image.resize((width * scale_factor, height * scale_factor), resample=PIL.Image.LANCZOS)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@method()
|
def _upscale(
|
||||||
def upscale(
|
|
||||||
self,
|
self,
|
||||||
base_images: list[PIL.Image],
|
base_images: list[PIL.Image],
|
||||||
half_precision: bool = False,
|
half_precision: bool = False,
|
||||||
Loading…
x
Reference in New Issue
Block a user