diff --git a/app/stable_diffusion_xl.py b/app/stable_diffusion_xl.py index 1eec066..c1fd09f 100644 --- a/app/stable_diffusion_xl.py +++ b/app/stable_diffusion_xl.py @@ -32,15 +32,14 @@ class SDXLTxt2Img: else: print(f"The directory '{self.cache_path}' does not exist.") - self.pipe = diffusers.DiffusionPipeline.from_pretrained( + self.pipe = diffusers.StableDiffusionXLPipeline.from_pretrained( self.cache_path, torch_dtype=torch.float16, use_safetensors=True, ) - self.upscaler_cache_path = self.cache_path - self.upscaler = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained( - self.upscaler_cache_path, + self.refiner = diffusers.StableDiffusionXLImg2ImgPipeline.from_pretrained( + self.cache_path, torch_dtype=torch.float16, use_safetensors=True, ) @@ -92,7 +91,7 @@ class SDXLTxt2Img: self.pipe.to("cuda") self.pipe.enable_vae_tiling() self.pipe.enable_xformers_memory_efficient_attention() - generated_images = self.pipe( + generated_image = self.pipe( prompt=prompt, negative_prompt=n_prompt, guidance_scale=7, @@ -100,25 +99,25 @@ class SDXLTxt2Img: width=width, generator=generator, num_inference_steps=steps, - ).images + ).images[0] + + generated_images = [generated_image] if use_upscaler: - base_images = generated_images - for image in base_images: - image = self._resize_image(image=image, scale_factor=2) - self.upscaler.to("cuda") - self.upscaler.enable_vae_tiling() - self.upscaler.enable_xformers_memory_efficient_attention() - upscaled_images = self.upscaler( - prompt=prompt, - negative_prompt=n_prompt, - num_inference_steps=steps, - strength=0.3, - guidance_scale=7, - generator=generator, - image=image, - ).images - generated_images.extend(upscaled_images) + self.refiner.to("cuda") + self.refiner.enable_vae_tiling() + self.refiner.enable_xformers_memory_efficient_attention() + base_image = self._double_image_size(generated_image) + image = self.refiner( + prompt=prompt, + negative_prompt=n_prompt, + num_inference_steps=50, + strength=0.3, + guidance_scale=7.5, + generator=generator, + image=base_image, + ).images[0] + generated_images.append(image) image_output = [] for image in generated_images: @@ -128,8 +127,7 @@ class SDXLTxt2Img: return image_output - def _resize_image(self, image: PIL.Image.Image, scale_factor: int) -> PIL.Image.Image: + def _double_image_size(self, image: PIL.Image.Image) -> PIL.Image.Image: image = image.convert("RGB") width, height = image.size - img = image.resize((width * scale_factor, height * scale_factor), resample=PIL.Image.LANCZOS) - return img + return image.resize((width * 2, height * 2), resample=PIL.Image.LANCZOS)