diff --git a/sdcli/__init__.py b/sdcli/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/sdcli/txt2img.py b/sdcli/txt2img.py index b731758..f97379a 100644 --- a/sdcli/txt2img.py +++ b/sdcli/txt2img.py @@ -1,6 +1,7 @@ import time import modal +import util stub = modal.Stub("run-stable-diffusion-cli") stub.run_inference = modal.Function.from_name("stable-diffusion-cli", "StableDiffusion.run_inference") @@ -25,8 +26,6 @@ def main( The function pass the given prompt to StableDiffusion on Modal, gets back a list of images and outputs images to local. """ - import util - directory = util.make_directory() seed_generated = seed for i in range(samples): diff --git a/setup_files/__main__.py b/setup_files/__main__.py new file mode 100644 index 0000000..0075491 --- /dev/null +++ b/setup_files/__main__.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from setup import stub +from txt2img import StableDiffusion + + +def new_stable_diffusion() -> StableDiffusion: + return StableDiffusion() + + +@stub.function(gpu="A10G") +def main(): + sd = new_stable_diffusion() + print(isinstance(sd, StableDiffusion)) + + +if __name__ == "__main__": + main() diff --git a/setup_files/main.py b/setup_files/txt2img.py similarity index 89% rename from setup_files/main.py rename to setup_files/txt2img.py index 1374d0f..bc12a6d 100644 --- a/setup_files/main.py +++ b/setup_files/txt2img.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import io import os @@ -7,18 +8,48 @@ import diffusers import PIL.Image import torch from modal import Secret, method -from modal.cls import ClsMixin from setup import (BASE_CACHE_PATH, BASE_CACHE_PATH_CONTROLNET, BASE_CACHE_PATH_LORA, BASE_CACHE_PATH_TEXTUAL_INVERSION, stub) +class StableDiffusionInterface(metaclass=abc.ABCMeta): + """ + A StableDiffusionInterface is an interface that will be used for StableDiffusion class creation. + """ + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, "run_inference") and callable(subclass.run_inference) + + @abc.abstractmethod + @method() + def run_inference( + self, + prompt: str, + n_prompt: str, + height: int = 512, + width: int = 512, + samples: int = 1, + batch_size: int = 1, + steps: int = 30, + seed: int = 1, + upscaler: str = "", + use_face_enhancer: bool = False, + fix_by_controlnet_tile: bool = False, + ) -> list[bytes]: + """ + Run inference. + """ + raise NotImplementedError + + @stub.cls( gpu="A10G", secrets=[Secret.from_dotenv(__file__)], ) -class StableDiffusion(ClsMixin): +class StableDiffusion(StableDiffusionInterface): """ A class that wraps the Stable Diffusion pipeline and scheduler. """ @@ -97,8 +128,7 @@ class StableDiffusion(ClsMixin): self.controlnet_pipe.to("cuda") self.controlnet_pipe.enable_xformers_memory_efficient_attention() - @method() - def count_token(self, p: str, n: str) -> int: + def _count_token(self, p: str, n: str) -> int: """ Count the number of tokens in the prompt and negative prompt. """ @@ -142,7 +172,7 @@ class StableDiffusion(ClsMixin): Runs the Stable Diffusion pipeline on the given prompt and outputs images. """ - max_embeddings_multiples = self.count_token(p=prompt, n=n_prompt) + max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt) generator = torch.Generator("cuda").manual_seed(seed) with torch.inference_mode(): with torch.autocast("cuda"): @@ -165,7 +195,7 @@ class StableDiffusion(ClsMixin): """ if fix_by_controlnet_tile: for image in base_images: - image = self.resize_image(image=image, scale_factor=2) + image = self._resize_image(image=image, scale_factor=2) with torch.inference_mode(): with torch.autocast("cuda"): fixed_by_controlnet = self.controlnet_pipe( @@ -182,7 +212,7 @@ class StableDiffusion(ClsMixin): base_images = fixed_by_controlnet if upscaler != "": - upscaled = self.upscale( + upscaled = self._upscale( base_images=base_images, half_precision=False, tile=700, @@ -199,15 +229,13 @@ class StableDiffusion(ClsMixin): return image_output - @method() - def resize_image(self, image: PIL.Image.Image, scale_factor: int) -> PIL.Image.Image: + def _resize_image(self, image: PIL.Image.Image, scale_factor: int) -> PIL.Image.Image: image = image.convert("RGB") width, height = image.size img = image.resize((width * scale_factor, height * scale_factor), resample=PIL.Image.LANCZOS) return img - @method() - def upscale( + def _upscale( self, base_images: list[PIL.Image], half_precision: bool = False,