Merge pull request #59 from hodanov/feature/enable_output_avif
Enable output avif format.
This commit is contained in:
		
						commit
						e1baa68b2d
					
				
							
								
								
									
										2
									
								
								.github/workflows/lint_python.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/lint_python.yml
									
									
									
									
										vendored
									
									
								
							@ -20,5 +20,5 @@ jobs:
 | 
				
			|||||||
          pip install pycodestyle pyflakes
 | 
					          pip install pycodestyle pyflakes
 | 
				
			||||||
      - name: Analysing the code with pycodestyle
 | 
					      - name: Analysing the code with pycodestyle
 | 
				
			||||||
        run: |
 | 
					        run: |
 | 
				
			||||||
          pycodestyle --first --ignore='E501' $(git ls-files '*.py')
 | 
					          pycodestyle --first --ignore='E501,E401' $(git ls-files '*.py')
 | 
				
			||||||
          pyflakes $(git ls-files '*.py')
 | 
					          pyflakes $(git ls-files '*.py')
 | 
				
			||||||
 | 
				
			|||||||
@ -43,13 +43,13 @@ def save_prompts(inputs: dict):
 | 
				
			|||||||
        print(f"Save prompts: {prompts_filename}.txt")
 | 
					        print(f"Save prompts: {prompts_filename}.txt")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def save_images(directory: Path, images: list[bytes], seed: int, i: int):
 | 
					def save_images(directory: Path, images: list[bytes], seed: int, i: int, output_format: str = "png"):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Save images to a file.
 | 
					    Save images to a file.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    for j, image_bytes in enumerate(images):
 | 
					    for j, image_bytes in enumerate(images):
 | 
				
			||||||
        formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
 | 
					        formatted_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
 | 
				
			||||||
        output_path = directory / f"{formatted_time}_{seed}_{i}_{j}.png"
 | 
					        output_path = directory / f"{formatted_time}_{seed}_{i}_{j}.{output_format}"
 | 
				
			||||||
        print(f"Saving it to {output_path}")
 | 
					        print(f"Saving it to {output_path}")
 | 
				
			||||||
        with open(output_path, "wb") as file:
 | 
					        with open(output_path, "wb") as file:
 | 
				
			||||||
            file.write(image_bytes)
 | 
					            file.write(image_bytes)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +1,10 @@
 | 
				
			|||||||
from __future__ import annotations
 | 
					from __future__ import annotations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from setup import stub
 | 
				
			||||||
from txt2img import StableDiffusion
 | 
					from txt2img import StableDiffusion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@stub.function(gpu="A10G")
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    StableDiffusion
 | 
					    StableDiffusion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -14,6 +14,7 @@ gfpgan>=1.3.8
 | 
				
			|||||||
scipy==1.11.4
 | 
					scipy==1.11.4
 | 
				
			||||||
opencv-python
 | 
					opencv-python
 | 
				
			||||||
Pillow
 | 
					Pillow
 | 
				
			||||||
 | 
					pillow-avif-plugin
 | 
				
			||||||
torchvision
 | 
					torchvision
 | 
				
			||||||
tqdm
 | 
					tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -3,9 +3,7 @@ from __future__ import annotations
 | 
				
			|||||||
import io
 | 
					import io
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import diffusers
 | 
					 | 
				
			||||||
import PIL.Image
 | 
					import PIL.Image
 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from modal import Secret, method
 | 
					from modal import Secret, method
 | 
				
			||||||
from setup import (
 | 
					from setup import (
 | 
				
			||||||
    BASE_CACHE_PATH,
 | 
					    BASE_CACHE_PATH,
 | 
				
			||||||
@ -26,6 +24,8 @@ class StableDiffusion:
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __enter__(self):
 | 
					    def __enter__(self):
 | 
				
			||||||
 | 
					        import diffusers
 | 
				
			||||||
 | 
					        import torch
 | 
				
			||||||
        import yaml
 | 
					        import yaml
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        config = {}
 | 
					        config = {}
 | 
				
			||||||
@ -133,13 +133,17 @@ class StableDiffusion:
 | 
				
			|||||||
        upscaler: str = "",
 | 
					        upscaler: str = "",
 | 
				
			||||||
        use_face_enhancer: bool = False,
 | 
					        use_face_enhancer: bool = False,
 | 
				
			||||||
        fix_by_controlnet_tile: bool = False,
 | 
					        fix_by_controlnet_tile: bool = False,
 | 
				
			||||||
 | 
					        output_format: str = "png",
 | 
				
			||||||
    ) -> list[bytes]:
 | 
					    ) -> list[bytes]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Runs the Stable Diffusion pipeline on the given prompt and outputs images.
 | 
					        Runs the Stable Diffusion pipeline on the given prompt and outputs images.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        import pillow_avif  # noqa: F401
 | 
				
			||||||
 | 
					        import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt)
 | 
					        max_embeddings_multiples = self._count_token(p=prompt, n=n_prompt)
 | 
				
			||||||
        generator = torch.Generator("cuda").manual_seed(seed)
 | 
					        generator = torch.Generator("cuda").manual_seed(seed)
 | 
				
			||||||
        self.pipe = self.pipe.to("cuda")
 | 
					        self.pipe.to("cuda")
 | 
				
			||||||
        self.pipe.enable_vae_tiling()
 | 
					        self.pipe.enable_vae_tiling()
 | 
				
			||||||
        self.pipe.enable_xformers_memory_efficient_attention()
 | 
					        self.pipe.enable_xformers_memory_efficient_attention()
 | 
				
			||||||
        with torch.autocast("cuda"):
 | 
					        with torch.autocast("cuda"):
 | 
				
			||||||
@ -161,7 +165,7 @@ class StableDiffusion:
 | 
				
			|||||||
        https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile
 | 
					        https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if fix_by_controlnet_tile:
 | 
					        if fix_by_controlnet_tile:
 | 
				
			||||||
            self.controlnet_pipe = self.controlnet_pipe.to("cuda")
 | 
					            self.controlnet_pipe.to("cuda")
 | 
				
			||||||
            self.controlnet_pipe.enable_vae_tiling()
 | 
					            self.controlnet_pipe.enable_vae_tiling()
 | 
				
			||||||
            self.controlnet_pipe.enable_xformers_memory_efficient_attention()
 | 
					            self.controlnet_pipe.enable_xformers_memory_efficient_attention()
 | 
				
			||||||
            for image in base_images:
 | 
					            for image in base_images:
 | 
				
			||||||
@ -193,7 +197,7 @@ class StableDiffusion:
 | 
				
			|||||||
        image_output = []
 | 
					        image_output = []
 | 
				
			||||||
        for image in generated_images:
 | 
					        for image in generated_images:
 | 
				
			||||||
            with io.BytesIO() as buf:
 | 
					            with io.BytesIO() as buf:
 | 
				
			||||||
                image.save(buf, format="PNG")
 | 
					                image.save(buf, format=output_format)
 | 
				
			||||||
                image_output.append(buf.getvalue())
 | 
					                image_output.append(buf.getvalue())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return image_output
 | 
					        return image_output
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user