Merge pull request #103 from hodanov/feature/sdxl

Modify stable_diffusion_xl application.
This commit is contained in:
hodanov 2024-05-06 12:54:17 +09:00 committed by GitHub
commit 608cf88991
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 29 additions and 83 deletions

View File

@ -29,7 +29,10 @@ img_by_sd15_img2img:
img_by_sdxl_txt2img: img_by_sdxl_txt2img:
cd ./cmd && modal run sdxl_txt2img.py \ cd ./cmd && modal run sdxl_txt2img.py \
--prompt "A dog is running on the grass" \ --prompt "A dog is running on the grass" \
--n-prompt "" \
--height 1024 \ --height 1024 \
--width 1024 \ --width 1024 \
--samples 1 \ --samples 1 \
--steps 30 \
--use-upscaler "True" \
--output-format "avif" --output-format "avif"

View File

@ -86,13 +86,6 @@ def download_model_sdxl(name: str, model_url: str, token: str):
) )
pipe.save_pretrained(cache_path, safe_serialization=True) pipe.save_pretrained(cache_path, safe_serialization=True)
refiner_cache_path = cache_path + "-refiner"
refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors",
cache_dir=refiner_cache_path,
)
refiner.save_pretrained(refiner_cache_path, safe_serialization=True)
def build_image(): def build_image():
""" """
@ -125,11 +118,7 @@ def build_image():
loras = config.get("loras") loras = config.get("loras")
if loras is not None: if loras is not None:
for lora in loras: for lora in loras:
download_file( download_file(url=lora["url"], file_name=lora["name"], file_path=BASE_CACHE_PATH_LORA)
url=lora["url"],
file_name=lora["name"],
file_path=BASE_CACHE_PATH_LORA,
)
textual_inversions = config.get("textual_inversions") textual_inversions = config.get("textual_inversions")
if textual_inversions is not None: if textual_inversions is not None:

View File

@ -32,36 +32,19 @@ class SDXLTxt2Img:
else: else:
print(f"The directory '{self.cache_path}' does not exist.") print(f"The directory '{self.cache_path}' does not exist.")
self.pipe = diffusers.AutoPipelineForText2Image.from_pretrained( self.pipe = diffusers.DiffusionPipeline.from_pretrained(
self.cache_path, self.cache_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_safetensors=True, use_safetensors=True,
variant="fp16",
) )
self.refiner_cache_path = self.cache_path + "-refiner" self.upscaler_cache_path = self.cache_path
self.refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained( self.upscaler = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained(
self.refiner_cache_path, self.upscaler_cache_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_safetensors=True, use_safetensors=True,
variant="fp16",
) )
# controlnets = config.get("controlnets")
# if controlnets is not None:
# for controlnet in controlnets:
# path = os.path.join(BASE_CACHE_PATH_CONTROLNET, controlnet["name"])
# controlnet = diffusers.ControlNetModel.from_pretrained(path, torch_dtype=torch.float16)
# self.controlnet_pipe = diffusers.StableDiffusionControlNetPipeline.from_pretrained(
# self.cache_path,
# controlnet=controlnet,
# custom_pipeline="lpw_stable_diffusion",
# scheduler=self.pipe.scheduler,
# vae=self.pipe.vae,
# torch_dtype=torch.float16,
# use_safetensors=True,
# )
def _count_token(self, p: str, n: str) -> int: def _count_token(self, p: str, n: str) -> int:
""" """
Count the number of tokens in the prompt and negative prompt. Count the number of tokens in the prompt and negative prompt.
@ -107,63 +90,35 @@ class SDXLTxt2Img:
generator = torch.Generator("cuda").manual_seed(seed) generator = torch.Generator("cuda").manual_seed(seed)
self.pipe.to("cuda") self.pipe.to("cuda")
self.pipe.enable_vae_tiling()
self.pipe.enable_xformers_memory_efficient_attention()
generated_images = self.pipe( generated_images = self.pipe(
prompt=prompt, prompt=prompt,
negative_prompt=n_prompt, negative_prompt=n_prompt,
guidance_scale=7,
height=height, height=height,
width=width, width=width,
generator=generator, generator=generator,
num_inference_steps=steps,
).images ).images
base_images = generated_images
for image in base_images: if use_upscaler:
image = self._resize_image(image=image, scale_factor=2) base_images = generated_images
self.refiner.to("cuda") for image in base_images:
refined_images = self.refiner( image = self._resize_image(image=image, scale_factor=2)
prompt=prompt, self.upscaler.to("cuda")
negative_prompt=n_prompt, self.upscaler.enable_vae_tiling()
num_inference_steps=steps, self.upscaler.enable_xformers_memory_efficient_attention()
strength=0.1, upscaled_images = self.upscaler(
# guidance_scale=7.5, prompt=prompt,
generator=generator, negative_prompt=n_prompt,
image=image, num_inference_steps=steps,
).images strength=0.3,
generated_images.extend(refined_images) guidance_scale=7,
base_images = refined_images generator=generator,
image=image,
""" ).images
Fix the generated images by the control_v11f1e_sd15_tile when `fix_by_controlnet_tile` is `True`. generated_images.extend(upscaled_images)
https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile
"""
# if fix_by_controlnet_tile:
# max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt)
# self.controlnet_pipe.to("cuda")
# self.controlnet_pipe.enable_vae_tiling()
# self.controlnet_pipe.enable_xformers_memory_efficient_attention()
# for image in base_images:
# image = self._resize_image(image=image, scale_factor=2)
# with torch.autocast("cuda"):
# fixed_by_controlnet = self.controlnet_pipe(
# prompt=prompt * batch_size,
# negative_prompt=n_prompt * batch_size,
# num_inference_steps=steps,
# strength=0.3,
# guidance_scale=7.5,
# max_embeddings_multiples=max_embeddings_multiples,
# generator=generator,
# image=image,
# ).images
# generated_images.extend(fixed_by_controlnet)
# base_images = fixed_by_controlnet
# if use_upscaler:
# upscaled = self._upscale(
# base_images=base_images,
# half_precision=False,
# tile=700,
# upscaler=upscaler,
# )
# generated_images.extend(upscaled)
image_output = [] image_output = []
for image in generated_images: for image in generated_images:

View File

@ -1,5 +1,4 @@
import time import time
import modal import modal
import util import util