Add a method to count tokens using tokenizer.

This commit is contained in:
hodanov 2023-06-17 22:33:34 +09:00
parent 07d6f97cb1
commit 2585aa8ec0
2 changed files with 24 additions and 21 deletions

View File

@ -136,6 +136,29 @@ class StableDiffusion:
self.pipe.enable_xformers_memory_efficient_attention()
@method()
def count_token(self, p: str, n: str) -> int:
"""
Count the number of tokens in the prompt and negative prompt.
"""
from transformers import CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
token_size_p = len(tokenizer.tokenize(p))
token_size_n = len(tokenizer.tokenize(n))
token_size = token_size_p
if token_size_p <= token_size_n:
token_size = token_size_n
max_embeddings_multiples = 1
max_length = tokenizer.model_max_length - 2
if token_size > max_length:
max_embeddings_multiples = token_size // max_length + 1
print(f"token_size: {token_size}, max_embeddings_multiples: {max_embeddings_multiples}")
return max_embeddings_multiples
@method()
def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]:
"""
@ -267,10 +290,10 @@ def entrypoint(
"seed": seed,
}
inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt)
directory = util.make_directory()
sd = StableDiffusion()
inputs["max_embeddings_multiples"] = sd.count_token(p=prompt, n=n_prompt)
for i in range(samples):
if seed == -1:
inputs["seed"] = util.generate_seed()

20
util.py
View File

@ -43,26 +43,6 @@ def save_prompts(inputs: dict):
print(f"Save prompts: {prompts_filename}.txt")
def count_token(p: str, n: str) -> int:
"""
Count the number of tokens in the prompt and negative prompt.
"""
token_count_p = len(p.split())
token_count_n = len(n.split())
if token_count_p >= token_count_n:
token_count = token_count_p
else:
token_count = token_count_n
max_embeddings_multiples = 1
if token_count > 77:
max_embeddings_multiples = token_count // 77 + 1
print(f"token_count: {token_count}, max_embeddings_multiples: {max_embeddings_multiples}")
return max_embeddings_multiples
def save_images(directory: Path, images: list[bytes], seed: int, i: int):
"""
Save images to a file.