Merge pull request #20 from hodanov/feature/fix_by_controlnet_tile
Add fix_by_controlnet_tile
This commit is contained in:
commit
8df4050b28
2
Makefile
2
Makefile
@ -17,4 +17,4 @@ run:
|
||||
--steps 50 \
|
||||
--upscaler "" \
|
||||
--use-face-enhancer "False" \
|
||||
--use-hires-fix "False"
|
||||
--fix-by-controlnet-tile "False"
|
||||
|
||||
@ -18,7 +18,7 @@ def main(
|
||||
seed: int = -1,
|
||||
upscaler: str = "",
|
||||
use_face_enhancer: str = "False",
|
||||
use_hires_fix: str = "False",
|
||||
fix_by_controlnet_tile: str = "False",
|
||||
):
|
||||
"""
|
||||
This function is the entrypoint for the Runway CLI.
|
||||
@ -43,7 +43,7 @@ def main(
|
||||
seed=seed_generated,
|
||||
upscaler=upscaler,
|
||||
use_face_enhancer=use_face_enhancer == "True",
|
||||
use_hires_fix=use_hires_fix == "True",
|
||||
fix_by_controlnet_tile=fix_by_controlnet_tile == "True",
|
||||
)
|
||||
util.save_images(directory, images, seed_generated, i)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
@ -29,10 +29,6 @@ vae:
|
||||
# download_url: https://hogehoge/xxxx
|
||||
# - name: fugafuga
|
||||
# download_url: https://fugafuga/xxxx
|
||||
# cotrolnets:
|
||||
# controlnets:
|
||||
# - name: control_v11f1e_sd15_tile
|
||||
# repo_id: lllyasviel/control_v11f1e_sd15_tile
|
||||
# upscaler:
|
||||
# name: RealESRGAN_x2plus
|
||||
# use_face_enhancer: false
|
||||
# use_hires_fix: false
|
||||
|
||||
@ -196,28 +196,22 @@ class StableDiffusion(ClsMixin):
|
||||
|
||||
self.pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# TODO: Add support for controlnets.
|
||||
# controlnet = diffusers.ControlNetModel.from_pretrained(
|
||||
# "lllyasviel/control_v11f1e_sd15_tile",
|
||||
# # "lllyasviel/sd-controlnet-canny",
|
||||
# # self.cache_path,
|
||||
# # subfolder="controlnet",
|
||||
# torch_dtype=torch.float16,
|
||||
# )
|
||||
|
||||
# self.controlnet_pipe = diffusers.StableDiffusionControlNetPipeline.from_pretrained(
|
||||
# self.cache_path,
|
||||
# controlnet=controlnet,
|
||||
# custom_pipeline="lpw_stable_diffusion",
|
||||
# # custom_pipeline="stable_diffusion_controlnet_img2img",
|
||||
# scheduler=self.pipe.scheduler,
|
||||
# vae=self.pipe.vae,
|
||||
# torch_dtype=torch.float16,
|
||||
# )
|
||||
|
||||
# self.controlnet_pipe.to("cuda")
|
||||
|
||||
# self.controlnet_pipe.enable_xformers_memory_efficient_attention()
|
||||
# TODO: Repair the controlnet loading.
|
||||
controlnets = config.get("controlnets")
|
||||
if controlnets is not None:
|
||||
for controlnet in controlnets:
|
||||
path = os.path.join(BASE_CACHE_PATH_CONTROLNET, controlnet["name"])
|
||||
controlnet = diffusers.ControlNetModel.from_pretrained(path, torch_dtype=torch.float16)
|
||||
self.controlnet_pipe = diffusers.StableDiffusionControlNetPipeline.from_pretrained(
|
||||
self.cache_path,
|
||||
controlnet=controlnet,
|
||||
custom_pipeline="lpw_stable_diffusion",
|
||||
scheduler=self.pipe.scheduler,
|
||||
vae=self.pipe.vae,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
self.controlnet_pipe.to("cuda")
|
||||
self.controlnet_pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@method()
|
||||
def count_token(self, p: str, n: str) -> int:
|
||||
@ -258,7 +252,7 @@ class StableDiffusion(ClsMixin):
|
||||
seed: int = 1,
|
||||
upscaler: str = "",
|
||||
use_face_enhancer: bool = False,
|
||||
use_hires_fix: bool = False,
|
||||
fix_by_controlnet_tile: bool = False,
|
||||
) -> list[bytes]:
|
||||
"""
|
||||
Runs the Stable Diffusion pipeline on the given prompt and outputs images.
|
||||
@ -269,7 +263,7 @@ class StableDiffusion(ClsMixin):
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
with torch.inference_mode():
|
||||
with torch.autocast("cuda"):
|
||||
base_images = self.pipe.text2img(
|
||||
generated_images = self.pipe.text2img(
|
||||
prompt * batch_size,
|
||||
negative_prompt=n_prompt * batch_size,
|
||||
height=height,
|
||||
@ -280,21 +274,29 @@ class StableDiffusion(ClsMixin):
|
||||
generator=generator,
|
||||
).images
|
||||
|
||||
# for image in base_images:
|
||||
# image = self.resize_image(image=image, scale_factor=2)
|
||||
# with torch.inference_mode():
|
||||
# with torch.autocast("cuda"):
|
||||
# generatedWithControlnet = self.controlnet_pipe(
|
||||
# prompt=prompt * batch_size,
|
||||
# negative_prompt=n_prompt * batch_size,
|
||||
# num_inference_steps=steps,
|
||||
# strength=0.3,
|
||||
# guidance_scale=7.5,
|
||||
# max_embeddings_multiples=max_embeddings_multiples,
|
||||
# generator=generator,
|
||||
# image=image,
|
||||
# ).images
|
||||
# base_images.extend(generatedWithControlnet)
|
||||
base_images = generated_images
|
||||
|
||||
"""
|
||||
Fix the generated images by the control_v11f1e_sd15_tile when `fix_by_controlnet_tile` is `True`.
|
||||
https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile
|
||||
"""
|
||||
if fix_by_controlnet_tile:
|
||||
for image in base_images:
|
||||
image = self.resize_image(image=image, scale_factor=2)
|
||||
with torch.inference_mode():
|
||||
with torch.autocast("cuda"):
|
||||
fixed_by_controlnet = self.controlnet_pipe(
|
||||
prompt=prompt * batch_size,
|
||||
negative_prompt=n_prompt * batch_size,
|
||||
num_inference_steps=steps,
|
||||
strength=0.3,
|
||||
guidance_scale=7.5,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
generator=generator,
|
||||
image=image,
|
||||
).images
|
||||
generated_images.extend(fixed_by_controlnet)
|
||||
base_images = fixed_by_controlnet
|
||||
|
||||
if upscaler != "":
|
||||
upscaled = self.upscale(
|
||||
@ -303,28 +305,11 @@ class StableDiffusion(ClsMixin):
|
||||
tile=700,
|
||||
upscaler=upscaler,
|
||||
use_face_enhancer=use_face_enhancer,
|
||||
use_hires_fix=use_hires_fix,
|
||||
)
|
||||
base_images.extend(upscaled)
|
||||
|
||||
if use_hires_fix:
|
||||
for img in upscaled:
|
||||
with torch.inference_mode():
|
||||
with torch.autocast("cuda"):
|
||||
hires_fixed = self.pipe.img2img(
|
||||
prompt=prompt * batch_size,
|
||||
negative_prompt=n_prompt * batch_size,
|
||||
num_inference_steps=steps,
|
||||
strength=0.3,
|
||||
guidance_scale=7.5,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
generator=generator,
|
||||
image=img,
|
||||
).images
|
||||
base_images.extend(hires_fixed)
|
||||
generated_images.extend(upscaled)
|
||||
|
||||
image_output = []
|
||||
for image in base_images:
|
||||
for image in generated_images:
|
||||
with io.BytesIO() as buf:
|
||||
image.save(buf, format="PNG")
|
||||
image_output.append(buf.getvalue())
|
||||
@ -350,10 +335,14 @@ class StableDiffusion(ClsMixin):
|
||||
pre_pad: int = 0,
|
||||
upscaler: str = "",
|
||||
use_face_enhancer: bool = False,
|
||||
use_hires_fix: bool = False,
|
||||
) -> list[Image.Image]:
|
||||
"""
|
||||
Upscales the given images using a upscaler.
|
||||
Upscale the generated images by the upscaler when `upscaler` is selected.
|
||||
The upscaler can be selected from the following list:
|
||||
- `RealESRGAN_x4plus`
|
||||
- `RealESRNet_x4plus`
|
||||
- `RealESRGAN_x4plus_anime_6B`
|
||||
- `RealESRGAN_x2plus`
|
||||
https://github.com/xinntao/Real-ESRGAN
|
||||
"""
|
||||
import numpy
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user