Refactor __main__.py and txt2img.py.
This commit is contained in:
parent
9f5d93f213
commit
e4528a6884
@ -1,10 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from setup import stub
|
|
||||||
from txt2img import StableDiffusion
|
from txt2img import StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
@stub.function(gpu="A10G")
|
|
||||||
def main():
|
def main():
|
||||||
StableDiffusion
|
StableDiffusion
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,6 @@ import diffusers
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
from modal import Secret, method
|
from modal import Secret, method
|
||||||
|
|
||||||
from setup import (
|
from setup import (
|
||||||
BASE_CACHE_PATH,
|
BASE_CACHE_PATH,
|
||||||
BASE_CACHE_PATH_CONTROLNET,
|
BASE_CACHE_PATH_CONTROLNET,
|
||||||
@ -80,8 +79,6 @@ class StableDiffusion:
|
|||||||
print(f"The directory '{path}' does not exist. Need to execute 'modal deploy' first.")
|
print(f"The directory '{path}' does not exist. Need to execute 'modal deploy' first.")
|
||||||
self.pipe.load_textual_inversion(path)
|
self.pipe.load_textual_inversion(path)
|
||||||
|
|
||||||
self.pipe = self.pipe.to("cuda")
|
|
||||||
|
|
||||||
# TODO: Repair the controlnet loading.
|
# TODO: Repair the controlnet loading.
|
||||||
controlnets = config.get("controlnets")
|
controlnets = config.get("controlnets")
|
||||||
if controlnets is not None:
|
if controlnets is not None:
|
||||||
@ -97,7 +94,6 @@ class StableDiffusion:
|
|||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
use_safetensors=True,
|
use_safetensors=True,
|
||||||
)
|
)
|
||||||
self.controlnet_pipe = self.controlnet_pipe.to("cuda")
|
|
||||||
|
|
||||||
def _count_token(self, p: str, n: str) -> int:
|
def _count_token(self, p: str, n: str) -> int:
|
||||||
"""
|
"""
|
||||||
@ -143,6 +139,7 @@ class StableDiffusion:
|
|||||||
"""
|
"""
|
||||||
max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt)
|
max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt)
|
||||||
generator = torch.Generator("cuda").manual_seed(seed)
|
generator = torch.Generator("cuda").manual_seed(seed)
|
||||||
|
self.pipe = self.pipe.to("cuda")
|
||||||
self.pipe.enable_vae_tiling()
|
self.pipe.enable_vae_tiling()
|
||||||
self.pipe.enable_xformers_memory_efficient_attention()
|
self.pipe.enable_xformers_memory_efficient_attention()
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
@ -164,6 +161,7 @@ class StableDiffusion:
|
|||||||
https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile
|
https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile
|
||||||
"""
|
"""
|
||||||
if fix_by_controlnet_tile:
|
if fix_by_controlnet_tile:
|
||||||
|
self.controlnet_pipe = self.controlnet_pipe.to("cuda")
|
||||||
self.controlnet_pipe.enable_vae_tiling()
|
self.controlnet_pipe.enable_vae_tiling()
|
||||||
self.controlnet_pipe.enable_xformers_memory_efficient_attention()
|
self.controlnet_pipe.enable_xformers_memory_efficient_attention()
|
||||||
for image in base_images:
|
for image in base_images:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user