diff --git a/sd_cli.py b/sd_cli.py index 72ade14..3240a3d 100644 --- a/sd_cli.py +++ b/sd_cli.py @@ -136,6 +136,29 @@ class StableDiffusion: self.pipe.enable_xformers_memory_efficient_attention() + @method() + def count_token(self, p: str, n: str) -> int: + """ + Count the number of tokens in the prompt and negative prompt. + """ + from transformers import CLIPTokenizer + + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + token_size_p = len(tokenizer.tokenize(p)) + token_size_n = len(tokenizer.tokenize(n)) + token_size = token_size_p + if token_size_p <= token_size_n: + token_size = token_size_n + + max_embeddings_multiples = 1 + max_length = tokenizer.model_max_length - 2 + if token_size > max_length: + max_embeddings_multiples = token_size // max_length + 1 + + print(f"token_size: {token_size}, max_embeddings_multiples: {max_embeddings_multiples}") + + return max_embeddings_multiples + @method() def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]: """ @@ -267,10 +290,10 @@ def entrypoint( "seed": seed, } - inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt) directory = util.make_directory() sd = StableDiffusion() + inputs["max_embeddings_multiples"] = sd.count_token(p=prompt, n=n_prompt) for i in range(samples): if seed == -1: inputs["seed"] = util.generate_seed() diff --git a/util.py b/util.py index c7c1d75..32708c3 100644 --- a/util.py +++ b/util.py @@ -43,26 +43,6 @@ def save_prompts(inputs: dict): print(f"Save prompts: {prompts_filename}.txt") -def count_token(p: str, n: str) -> int: - """ - Count the number of tokens in the prompt and negative prompt. - """ - token_count_p = len(p.split()) - token_count_n = len(n.split()) - if token_count_p >= token_count_n: - token_count = token_count_p - else: - token_count = token_count_n - - max_embeddings_multiples = 1 - if token_count > 77: - max_embeddings_multiples = token_count // 77 + 1 - - print(f"token_count: {token_count}, max_embeddings_multiples: {max_embeddings_multiples}") - - return max_embeddings_multiples - - def save_images(directory: Path, images: list[bytes], seed: int, i: int): """ Save images to a file.