From e4528a688457c31936e21cbd40a6cae1243a3b07 Mon Sep 17 00:00:00 2001 From: hodanov <1031hoda@gmail.com> Date: Mon, 27 Nov 2023 20:41:47 +0900 Subject: [PATCH] Refactor __main__.py and txt2img.py. --- setup_files/__main__.py | 2 -- setup_files/txt2img.py | 6 ++---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/setup_files/__main__.py b/setup_files/__main__.py index d1ff313..7fb70fe 100644 --- a/setup_files/__main__.py +++ b/setup_files/__main__.py @@ -1,10 +1,8 @@ from __future__ import annotations -from setup import stub from txt2img import StableDiffusion -@stub.function(gpu="A10G") def main(): StableDiffusion diff --git a/setup_files/txt2img.py b/setup_files/txt2img.py index 4a0e110..0ed9d65 100644 --- a/setup_files/txt2img.py +++ b/setup_files/txt2img.py @@ -7,7 +7,6 @@ import diffusers import PIL.Image import torch from modal import Secret, method - from setup import ( BASE_CACHE_PATH, BASE_CACHE_PATH_CONTROLNET, @@ -80,8 +79,6 @@ class StableDiffusion: print(f"The directory '{path}' does not exist. Need to execute 'modal deploy' first.") self.pipe.load_textual_inversion(path) - self.pipe = self.pipe.to("cuda") - # TODO: Repair the controlnet loading. controlnets = config.get("controlnets") if controlnets is not None: @@ -97,7 +94,6 @@ class StableDiffusion: torch_dtype=torch.float16, use_safetensors=True, ) - self.controlnet_pipe = self.controlnet_pipe.to("cuda") def _count_token(self, p: str, n: str) -> int: """ @@ -143,6 +139,7 @@ class StableDiffusion: """ max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt) generator = torch.Generator("cuda").manual_seed(seed) + self.pipe = self.pipe.to("cuda") self.pipe.enable_vae_tiling() self.pipe.enable_xformers_memory_efficient_attention() with torch.autocast("cuda"): @@ -164,6 +161,7 @@ class StableDiffusion: https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile """ if fix_by_controlnet_tile: + self.controlnet_pipe = self.controlnet_pipe.to("cuda") self.controlnet_pipe.enable_vae_tiling() self.controlnet_pipe.enable_xformers_memory_efficient_attention() for image in base_images: