Modify setup.py to use a safetensors file.
This commit is contained in:
parent
41817006cf
commit
04d5255912
@ -1,14 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from setup import stub
|
from setup import stub
|
||||||
from txt2img import new_stable_diffusion
|
from txt2img import StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
@stub.function(gpu="A10G")
|
@stub.function(gpu="A10G")
|
||||||
def main():
|
def main():
|
||||||
sd = new_stable_diffusion()
|
StableDiffusion
|
||||||
print(f"Deploy '{sd.__class__.__name__}'.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main.local()
|
||||||
|
|||||||
@ -7,28 +7,28 @@
|
|||||||
##########
|
##########
|
||||||
# You can use a diffusers model and VAE on hugging face.
|
# You can use a diffusers model and VAE on hugging face.
|
||||||
model:
|
model:
|
||||||
name: stable-diffusion-2-1
|
name: stable-diffusion-1-5
|
||||||
repo_id: stabilityai/stable-diffusion-2-1
|
url: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors
|
||||||
vae:
|
vae:
|
||||||
name: sd-vae-ft-mse
|
name: sd-vae-ft-mse
|
||||||
repo_id: stabilityai/sd-vae-ft-mse
|
url: https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors
|
||||||
##########
|
##########
|
||||||
# Add LoRA if you want to use one. You can use a download url such as the below.
|
# Add LoRA if you want to use one. You can use a download url such as the below.
|
||||||
# ex)
|
# ex)
|
||||||
# loras:
|
# loras:
|
||||||
# - name: hogehoge.safetensors
|
# - name: hogehoge.safetensors
|
||||||
# download_url: https://hogehoge/xxxx
|
# url: https://hogehoge/xxxx
|
||||||
# - name: fugafuga.safetensors
|
# - name: fugafuga.safetensors
|
||||||
# download_url: https://fugafuga/xxxx
|
# url: https://fugafuga/xxxx
|
||||||
|
|
||||||
##########
|
##########
|
||||||
# You can use Textual Inversion and ControlNet also. Usage is the same as `loras`.
|
# You can use Textual Inversion and ControlNet also. Usage is the same as `loras`.
|
||||||
# ex)
|
# ex)
|
||||||
# textual_inversions:
|
# textual_inversions:
|
||||||
# - name: hogehoge
|
# - name: hogehoge
|
||||||
# download_url: https://hogehoge/xxxx
|
# url: https://hogehoge/xxxx
|
||||||
# - name: fugafuga
|
# - name: fugafuga
|
||||||
# download_url: https://fugafuga/xxxx
|
# url: https://fugafuga/xxxx
|
||||||
controlnets:
|
controlnets:
|
||||||
- name: control_v11f1e_sd15_tile
|
- name: control_v11f1e_sd15_tile
|
||||||
repo_id: lllyasviel/control_v11f1e_sd15_tile
|
repo_id: lllyasviel/control_v11f1e_sd15_tile
|
||||||
|
|||||||
@ -38,26 +38,26 @@ def download_controlnet(name: str, repo_id: str, token: str):
|
|||||||
controlnet.save_pretrained(cache_path, safe_serialization=True)
|
controlnet.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
def download_vae(name: str, repo_id: str, token: str):
|
def download_vae(name: str, model_url: str, token: str):
|
||||||
"""
|
"""
|
||||||
Download a vae.
|
Download a vae.
|
||||||
"""
|
"""
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
||||||
vae = diffusers.AutoencoderKL.from_pretrained(
|
vae = diffusers.AutoencoderKL.from_single_file(
|
||||||
repo_id,
|
pretrained_model_link_or_path=model_url,
|
||||||
use_auth_token=token,
|
use_auth_token=token,
|
||||||
cache_dir=cache_path,
|
cache_dir=cache_path,
|
||||||
)
|
)
|
||||||
vae.save_pretrained(cache_path, safe_serialization=True)
|
vae.save_pretrained(cache_path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
def download_model(name: str, repo_id: str, token: str):
|
def download_model(name: str, model_url: str, token: str):
|
||||||
"""
|
"""
|
||||||
Download a model.
|
Download a model.
|
||||||
"""
|
"""
|
||||||
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
cache_path = os.path.join(BASE_CACHE_PATH, name)
|
||||||
pipe = diffusers.StableDiffusionPipeline.from_pretrained(
|
pipe = diffusers.StableDiffusionPipeline.from_single_file(
|
||||||
repo_id,
|
pretrained_model_link_or_path=model_url,
|
||||||
use_auth_token=token,
|
use_auth_token=token,
|
||||||
cache_dir=cache_path,
|
cache_dir=cache_path,
|
||||||
)
|
)
|
||||||
@ -77,11 +77,11 @@ def build_image():
|
|||||||
|
|
||||||
model = config.get("model")
|
model = config.get("model")
|
||||||
if model is not None:
|
if model is not None:
|
||||||
download_model(name=model["name"], repo_id=model["repo_id"], token=token)
|
download_model(name=model["name"], model_url=model["url"], token=token)
|
||||||
|
|
||||||
vae = config.get("vae")
|
vae = config.get("vae")
|
||||||
if vae is not None:
|
if vae is not None:
|
||||||
download_vae(name=model["name"], repo_id=vae["repo_id"], token=token)
|
download_vae(name=model["name"], model_url=vae["url"], token=token)
|
||||||
|
|
||||||
controlnets = config.get("controlnets")
|
controlnets = config.get("controlnets")
|
||||||
if controlnets is not None:
|
if controlnets is not None:
|
||||||
@ -92,7 +92,7 @@ def build_image():
|
|||||||
if loras is not None:
|
if loras is not None:
|
||||||
for lora in loras:
|
for lora in loras:
|
||||||
download_file(
|
download_file(
|
||||||
url=lora["download_url"],
|
url=lora["url"],
|
||||||
file_name=lora["name"],
|
file_name=lora["name"],
|
||||||
file_path=BASE_CACHE_PATH_LORA,
|
file_path=BASE_CACHE_PATH_LORA,
|
||||||
)
|
)
|
||||||
@ -101,7 +101,7 @@ def build_image():
|
|||||||
if textual_inversions is not None:
|
if textual_inversions is not None:
|
||||||
for textual_inversion in textual_inversions:
|
for textual_inversion in textual_inversions:
|
||||||
download_file(
|
download_file(
|
||||||
url=textual_inversion["download_url"],
|
url=textual_inversion["url"],
|
||||||
file_name=textual_inversion["name"],
|
file_name=textual_inversion["name"],
|
||||||
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
file_path=BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -9,51 +8,20 @@ import PIL.Image
|
|||||||
import torch
|
import torch
|
||||||
from modal import Secret, method
|
from modal import Secret, method
|
||||||
|
|
||||||
from setup import (BASE_CACHE_PATH, BASE_CACHE_PATH_CONTROLNET,
|
from setup import (
|
||||||
BASE_CACHE_PATH_LORA, BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
BASE_CACHE_PATH,
|
||||||
stub)
|
BASE_CACHE_PATH_CONTROLNET,
|
||||||
|
BASE_CACHE_PATH_LORA,
|
||||||
|
BASE_CACHE_PATH_TEXTUAL_INVERSION,
|
||||||
def new_stable_diffusion() -> StableDiffusionInterface:
|
stub,
|
||||||
return StableDiffusion()
|
)
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionInterface(metaclass=abc.ABCMeta):
|
|
||||||
"""
|
|
||||||
A StableDiffusionInterface is an interface that will be used for StableDiffusion class creation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __subclasshook__(cls, subclass):
|
|
||||||
return hasattr(subclass, "run_inference") and callable(subclass.run_inference)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
@method()
|
|
||||||
def run_inference(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
n_prompt: str,
|
|
||||||
height: int = 512,
|
|
||||||
width: int = 512,
|
|
||||||
samples: int = 1,
|
|
||||||
batch_size: int = 1,
|
|
||||||
steps: int = 30,
|
|
||||||
seed: int = 1,
|
|
||||||
upscaler: str = "",
|
|
||||||
use_face_enhancer: bool = False,
|
|
||||||
fix_by_controlnet_tile: bool = False,
|
|
||||||
) -> list[bytes]:
|
|
||||||
"""
|
|
||||||
Run inference.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
@stub.cls(
|
@stub.cls(
|
||||||
gpu="A10G",
|
gpu="A10G",
|
||||||
secrets=[Secret.from_dotenv(__file__)],
|
secrets=[Secret.from_dotenv(__file__)],
|
||||||
)
|
)
|
||||||
class StableDiffusion(StableDiffusionInterface):
|
class StableDiffusion:
|
||||||
"""
|
"""
|
||||||
A class that wraps the Stable Diffusion pipeline and scheduler.
|
A class that wraps the Stable Diffusion pipeline and scheduler.
|
||||||
"""
|
"""
|
||||||
@ -70,12 +38,11 @@ class StableDiffusion(StableDiffusionInterface):
|
|||||||
else:
|
else:
|
||||||
print(f"The directory '{self.cache_path}' does not exist.")
|
print(f"The directory '{self.cache_path}' does not exist.")
|
||||||
|
|
||||||
# torch.cuda.memory._set_allocator_settings("max_split_size_mb:256")
|
|
||||||
|
|
||||||
self.pipe = diffusers.StableDiffusionPipeline.from_pretrained(
|
self.pipe = diffusers.StableDiffusionPipeline.from_pretrained(
|
||||||
self.cache_path,
|
self.cache_path,
|
||||||
custom_pipeline="lpw_stable_diffusion",
|
custom_pipeline="lpw_stable_diffusion",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
|
use_safetensors=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Add support for other schedulers.
|
# TODO: Add support for other schedulers.
|
||||||
@ -90,8 +57,8 @@ class StableDiffusion(StableDiffusionInterface):
|
|||||||
self.pipe.vae = diffusers.AutoencoderKL.from_pretrained(
|
self.pipe.vae = diffusers.AutoencoderKL.from_pretrained(
|
||||||
self.cache_path,
|
self.cache_path,
|
||||||
subfolder="vae",
|
subfolder="vae",
|
||||||
|
use_safetensors=True,
|
||||||
)
|
)
|
||||||
self.pipe.to("cuda")
|
|
||||||
|
|
||||||
loras = config.get("loras")
|
loras = config.get("loras")
|
||||||
if loras is not None:
|
if loras is not None:
|
||||||
@ -113,7 +80,7 @@ class StableDiffusion(StableDiffusionInterface):
|
|||||||
print(f"The directory '{path}' does not exist. Need to execute 'modal deploy' first.")
|
print(f"The directory '{path}' does not exist. Need to execute 'modal deploy' first.")
|
||||||
self.pipe.load_textual_inversion(path)
|
self.pipe.load_textual_inversion(path)
|
||||||
|
|
||||||
self.pipe.enable_xformers_memory_efficient_attention()
|
self.pipe = self.pipe.to("cuda")
|
||||||
|
|
||||||
# TODO: Repair the controlnet loading.
|
# TODO: Repair the controlnet loading.
|
||||||
controlnets = config.get("controlnets")
|
controlnets = config.get("controlnets")
|
||||||
@ -128,9 +95,9 @@ class StableDiffusion(StableDiffusionInterface):
|
|||||||
scheduler=self.pipe.scheduler,
|
scheduler=self.pipe.scheduler,
|
||||||
vae=self.pipe.vae,
|
vae=self.pipe.vae,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
|
use_safetensors=True,
|
||||||
)
|
)
|
||||||
self.controlnet_pipe.to("cuda")
|
self.controlnet_pipe = self.controlnet_pipe.to("cuda")
|
||||||
self.controlnet_pipe.enable_xformers_memory_efficient_attention()
|
|
||||||
|
|
||||||
def _count_token(self, p: str, n: str) -> int:
|
def _count_token(self, p: str, n: str) -> int:
|
||||||
"""
|
"""
|
||||||
@ -164,7 +131,6 @@ class StableDiffusion(StableDiffusionInterface):
|
|||||||
n_prompt: str,
|
n_prompt: str,
|
||||||
height: int = 512,
|
height: int = 512,
|
||||||
width: int = 512,
|
width: int = 512,
|
||||||
samples: int = 1,
|
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
steps: int = 30,
|
steps: int = 30,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
@ -175,21 +141,21 @@ class StableDiffusion(StableDiffusionInterface):
|
|||||||
"""
|
"""
|
||||||
Runs the Stable Diffusion pipeline on the given prompt and outputs images.
|
Runs the Stable Diffusion pipeline on the given prompt and outputs images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt)
|
max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt)
|
||||||
generator = torch.Generator("cuda").manual_seed(seed)
|
generator = torch.Generator("cuda").manual_seed(seed)
|
||||||
with torch.inference_mode():
|
self.pipe.enable_vae_tiling()
|
||||||
with torch.autocast("cuda"):
|
self.pipe.enable_xformers_memory_efficient_attention()
|
||||||
generated_images = self.pipe(
|
with torch.autocast("cuda"):
|
||||||
prompt * batch_size,
|
generated_images = self.pipe(
|
||||||
negative_prompt=n_prompt * batch_size,
|
prompt * batch_size,
|
||||||
height=height,
|
negative_prompt=n_prompt * batch_size,
|
||||||
width=width,
|
height=height,
|
||||||
num_inference_steps=steps,
|
width=width,
|
||||||
guidance_scale=7.5,
|
num_inference_steps=steps,
|
||||||
max_embeddings_multiples=max_embeddings_multiples,
|
guidance_scale=7.5,
|
||||||
generator=generator,
|
max_embeddings_multiples=max_embeddings_multiples,
|
||||||
).images
|
generator=generator,
|
||||||
|
).images
|
||||||
|
|
||||||
base_images = generated_images
|
base_images = generated_images
|
||||||
|
|
||||||
@ -198,20 +164,21 @@ class StableDiffusion(StableDiffusionInterface):
|
|||||||
https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile
|
https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile
|
||||||
"""
|
"""
|
||||||
if fix_by_controlnet_tile:
|
if fix_by_controlnet_tile:
|
||||||
|
self.controlnet_pipe.enable_vae_tiling()
|
||||||
|
self.controlnet_pipe.enable_xformers_memory_efficient_attention()
|
||||||
for image in base_images:
|
for image in base_images:
|
||||||
image = self._resize_image(image=image, scale_factor=2)
|
image = self._resize_image(image=image, scale_factor=2)
|
||||||
with torch.inference_mode():
|
with torch.autocast("cuda"):
|
||||||
with torch.autocast("cuda"):
|
fixed_by_controlnet = self.controlnet_pipe(
|
||||||
fixed_by_controlnet = self.controlnet_pipe(
|
prompt=prompt * batch_size,
|
||||||
prompt=prompt * batch_size,
|
negative_prompt=n_prompt * batch_size,
|
||||||
negative_prompt=n_prompt * batch_size,
|
num_inference_steps=steps,
|
||||||
num_inference_steps=steps,
|
strength=0.3,
|
||||||
strength=0.3,
|
guidance_scale=7.5,
|
||||||
guidance_scale=7.5,
|
max_embeddings_multiples=max_embeddings_multiples,
|
||||||
max_embeddings_multiples=max_embeddings_multiples,
|
generator=generator,
|
||||||
generator=generator,
|
image=image,
|
||||||
image=image,
|
).images
|
||||||
).images
|
|
||||||
generated_images.extend(fixed_by_controlnet)
|
generated_images.extend(fixed_by_controlnet)
|
||||||
base_images = fixed_by_controlnet
|
base_images = fixed_by_controlnet
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user