Add a method to count tokens using tokenizer.
This commit is contained in:
parent
07d6f97cb1
commit
2585aa8ec0
25
sd_cli.py
25
sd_cli.py
@ -136,6 +136,29 @@ class StableDiffusion:
|
||||
|
||||
self.pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
@method()
|
||||
def count_token(self, p: str, n: str) -> int:
|
||||
"""
|
||||
Count the number of tokens in the prompt and negative prompt.
|
||||
"""
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
token_size_p = len(tokenizer.tokenize(p))
|
||||
token_size_n = len(tokenizer.tokenize(n))
|
||||
token_size = token_size_p
|
||||
if token_size_p <= token_size_n:
|
||||
token_size = token_size_n
|
||||
|
||||
max_embeddings_multiples = 1
|
||||
max_length = tokenizer.model_max_length - 2
|
||||
if token_size > max_length:
|
||||
max_embeddings_multiples = token_size // max_length + 1
|
||||
|
||||
print(f"token_size: {token_size}, max_embeddings_multiples: {max_embeddings_multiples}")
|
||||
|
||||
return max_embeddings_multiples
|
||||
|
||||
@method()
|
||||
def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]:
|
||||
"""
|
||||
@ -267,10 +290,10 @@ def entrypoint(
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt)
|
||||
directory = util.make_directory()
|
||||
|
||||
sd = StableDiffusion()
|
||||
inputs["max_embeddings_multiples"] = sd.count_token(p=prompt, n=n_prompt)
|
||||
for i in range(samples):
|
||||
if seed == -1:
|
||||
inputs["seed"] = util.generate_seed()
|
||||
|
||||
20
util.py
20
util.py
@ -43,26 +43,6 @@ def save_prompts(inputs: dict):
|
||||
print(f"Save prompts: {prompts_filename}.txt")
|
||||
|
||||
|
||||
def count_token(p: str, n: str) -> int:
|
||||
"""
|
||||
Count the number of tokens in the prompt and negative prompt.
|
||||
"""
|
||||
token_count_p = len(p.split())
|
||||
token_count_n = len(n.split())
|
||||
if token_count_p >= token_count_n:
|
||||
token_count = token_count_p
|
||||
else:
|
||||
token_count = token_count_n
|
||||
|
||||
max_embeddings_multiples = 1
|
||||
if token_count > 77:
|
||||
max_embeddings_multiples = token_count // 77 + 1
|
||||
|
||||
print(f"token_count: {token_count}, max_embeddings_multiples: {max_embeddings_multiples}")
|
||||
|
||||
return max_embeddings_multiples
|
||||
|
||||
|
||||
def save_images(directory: Path, images: list[bytes], seed: int, i: int):
|
||||
"""
|
||||
Save images to a file.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user