Merge pull request #20 from hodanov/feature/fix_by_controlnet_tile

Add fix_by_controlnet_tile
This commit is contained in:
hodanov 2023-07-05 10:15:37 +09:00 committed by GitHub
commit 8df4050b28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 68 deletions

View File

@ -17,4 +17,4 @@ run:
--steps 50 \ --steps 50 \
--upscaler "" \ --upscaler "" \
--use-face-enhancer "False" \ --use-face-enhancer "False" \
--use-hires-fix "False" --fix-by-controlnet-tile "False"

View File

@ -18,7 +18,7 @@ def main(
seed: int = -1, seed: int = -1,
upscaler: str = "", upscaler: str = "",
use_face_enhancer: str = "False", use_face_enhancer: str = "False",
use_hires_fix: str = "False", fix_by_controlnet_tile: str = "False",
): ):
""" """
This function is the entrypoint for the Runway CLI. This function is the entrypoint for the Runway CLI.
@ -43,7 +43,7 @@ def main(
seed=seed_generated, seed=seed_generated,
upscaler=upscaler, upscaler=upscaler,
use_face_enhancer=use_face_enhancer == "True", use_face_enhancer=use_face_enhancer == "True",
use_hires_fix=use_hires_fix == "True", fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
) )
util.save_images(directory, images, seed_generated, i) util.save_images(directory, images, seed_generated, i)
total_time = time.time() - start_time total_time = time.time() - start_time

View File

@ -29,10 +29,6 @@ vae:
# download_url: https://hogehoge/xxxx # download_url: https://hogehoge/xxxx
# - name: fugafuga # - name: fugafuga
# download_url: https://fugafuga/xxxx # download_url: https://fugafuga/xxxx
# cotrolnets: # controlnets:
# - name: control_v11f1e_sd15_tile # - name: control_v11f1e_sd15_tile
# repo_id: lllyasviel/control_v11f1e_sd15_tile # repo_id: lllyasviel/control_v11f1e_sd15_tile
# upscaler:
# name: RealESRGAN_x2plus
# use_face_enhancer: false
# use_hires_fix: false

View File

@ -196,28 +196,22 @@ class StableDiffusion(ClsMixin):
self.pipe.enable_xformers_memory_efficient_attention() self.pipe.enable_xformers_memory_efficient_attention()
# TODO: Add support for controlnets. # TODO: Repair the controlnet loading.
# controlnet = diffusers.ControlNetModel.from_pretrained( controlnets = config.get("controlnets")
# "lllyasviel/control_v11f1e_sd15_tile", if controlnets is not None:
# # "lllyasviel/sd-controlnet-canny", for controlnet in controlnets:
# # self.cache_path, path = os.path.join(BASE_CACHE_PATH_CONTROLNET, controlnet["name"])
# # subfolder="controlnet", controlnet = diffusers.ControlNetModel.from_pretrained(path, torch_dtype=torch.float16)
# torch_dtype=torch.float16, self.controlnet_pipe = diffusers.StableDiffusionControlNetPipeline.from_pretrained(
# ) self.cache_path,
controlnet=controlnet,
# self.controlnet_pipe = diffusers.StableDiffusionControlNetPipeline.from_pretrained( custom_pipeline="lpw_stable_diffusion",
# self.cache_path, scheduler=self.pipe.scheduler,
# controlnet=controlnet, vae=self.pipe.vae,
# custom_pipeline="lpw_stable_diffusion", torch_dtype=torch.float16,
# # custom_pipeline="stable_diffusion_controlnet_img2img", )
# scheduler=self.pipe.scheduler, self.controlnet_pipe.to("cuda")
# vae=self.pipe.vae, self.controlnet_pipe.enable_xformers_memory_efficient_attention()
# torch_dtype=torch.float16,
# )
# self.controlnet_pipe.to("cuda")
# self.controlnet_pipe.enable_xformers_memory_efficient_attention()
@method() @method()
def count_token(self, p: str, n: str) -> int: def count_token(self, p: str, n: str) -> int:
@ -258,7 +252,7 @@ class StableDiffusion(ClsMixin):
seed: int = 1, seed: int = 1,
upscaler: str = "", upscaler: str = "",
use_face_enhancer: bool = False, use_face_enhancer: bool = False,
use_hires_fix: bool = False, fix_by_controlnet_tile: bool = False,
) -> list[bytes]: ) -> list[bytes]:
""" """
Runs the Stable Diffusion pipeline on the given prompt and outputs images. Runs the Stable Diffusion pipeline on the given prompt and outputs images.
@ -269,7 +263,7 @@ class StableDiffusion(ClsMixin):
generator = torch.Generator("cuda").manual_seed(seed) generator = torch.Generator("cuda").manual_seed(seed)
with torch.inference_mode(): with torch.inference_mode():
with torch.autocast("cuda"): with torch.autocast("cuda"):
base_images = self.pipe.text2img( generated_images = self.pipe.text2img(
prompt * batch_size, prompt * batch_size,
negative_prompt=n_prompt * batch_size, negative_prompt=n_prompt * batch_size,
height=height, height=height,
@ -280,21 +274,29 @@ class StableDiffusion(ClsMixin):
generator=generator, generator=generator,
).images ).images
# for image in base_images: base_images = generated_images
# image = self.resize_image(image=image, scale_factor=2)
# with torch.inference_mode(): """
# with torch.autocast("cuda"): Fix the generated images by the control_v11f1e_sd15_tile when `fix_by_controlnet_tile` is `True`.
# generatedWithControlnet = self.controlnet_pipe( https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile
# prompt=prompt * batch_size, """
# negative_prompt=n_prompt * batch_size, if fix_by_controlnet_tile:
# num_inference_steps=steps, for image in base_images:
# strength=0.3, image = self.resize_image(image=image, scale_factor=2)
# guidance_scale=7.5, with torch.inference_mode():
# max_embeddings_multiples=max_embeddings_multiples, with torch.autocast("cuda"):
# generator=generator, fixed_by_controlnet = self.controlnet_pipe(
# image=image, prompt=prompt * batch_size,
# ).images negative_prompt=n_prompt * batch_size,
# base_images.extend(generatedWithControlnet) num_inference_steps=steps,
strength=0.3,
guidance_scale=7.5,
max_embeddings_multiples=max_embeddings_multiples,
generator=generator,
image=image,
).images
generated_images.extend(fixed_by_controlnet)
base_images = fixed_by_controlnet
if upscaler != "": if upscaler != "":
upscaled = self.upscale( upscaled = self.upscale(
@ -303,28 +305,11 @@ class StableDiffusion(ClsMixin):
tile=700, tile=700,
upscaler=upscaler, upscaler=upscaler,
use_face_enhancer=use_face_enhancer, use_face_enhancer=use_face_enhancer,
use_hires_fix=use_hires_fix,
) )
base_images.extend(upscaled) generated_images.extend(upscaled)
if use_hires_fix:
for img in upscaled:
with torch.inference_mode():
with torch.autocast("cuda"):
hires_fixed = self.pipe.img2img(
prompt=prompt * batch_size,
negative_prompt=n_prompt * batch_size,
num_inference_steps=steps,
strength=0.3,
guidance_scale=7.5,
max_embeddings_multiples=max_embeddings_multiples,
generator=generator,
image=img,
).images
base_images.extend(hires_fixed)
image_output = [] image_output = []
for image in base_images: for image in generated_images:
with io.BytesIO() as buf: with io.BytesIO() as buf:
image.save(buf, format="PNG") image.save(buf, format="PNG")
image_output.append(buf.getvalue()) image_output.append(buf.getvalue())
@ -350,10 +335,14 @@ class StableDiffusion(ClsMixin):
pre_pad: int = 0, pre_pad: int = 0,
upscaler: str = "", upscaler: str = "",
use_face_enhancer: bool = False, use_face_enhancer: bool = False,
use_hires_fix: bool = False,
) -> list[Image.Image]: ) -> list[Image.Image]:
""" """
Upscales the given images using a upscaler. Upscale the generated images by the upscaler when `upscaler` is selected.
The upscaler can be selected from the following list:
- `RealESRGAN_x4plus`
- `RealESRNet_x4plus`
- `RealESRGAN_x4plus_anime_6B`
- `RealESRGAN_x2plus`
https://github.com/xinntao/Real-ESRGAN https://github.com/xinntao/Real-ESRGAN
""" """
import numpy import numpy