Merge pull request #8 from hodanov/feature/add_a_method_to_count_tokens
Add a method to count tokens using tokenizer.
This commit is contained in:
commit
d61e212d49
25
sd_cli.py
25
sd_cli.py
@ -136,6 +136,29 @@ class StableDiffusion:
|
|||||||
|
|
||||||
self.pipe.enable_xformers_memory_efficient_attention()
|
self.pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
@method()
|
||||||
|
def count_token(self, p: str, n: str) -> int:
|
||||||
|
"""
|
||||||
|
Count the number of tokens in the prompt and negative prompt.
|
||||||
|
"""
|
||||||
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
token_size_p = len(tokenizer.tokenize(p))
|
||||||
|
token_size_n = len(tokenizer.tokenize(n))
|
||||||
|
token_size = token_size_p
|
||||||
|
if token_size_p <= token_size_n:
|
||||||
|
token_size = token_size_n
|
||||||
|
|
||||||
|
max_embeddings_multiples = 1
|
||||||
|
max_length = tokenizer.model_max_length - 2
|
||||||
|
if token_size > max_length:
|
||||||
|
max_embeddings_multiples = token_size // max_length + 1
|
||||||
|
|
||||||
|
print(f"token_size: {token_size}, max_embeddings_multiples: {max_embeddings_multiples}")
|
||||||
|
|
||||||
|
return max_embeddings_multiples
|
||||||
|
|
||||||
@method()
|
@method()
|
||||||
def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]:
|
def run_inference(self, inputs: dict[str, int | str]) -> list[bytes]:
|
||||||
"""
|
"""
|
||||||
@ -267,10 +290,10 @@ def entrypoint(
|
|||||||
"seed": seed,
|
"seed": seed,
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs["max_embeddings_multiples"] = util.count_token(p=prompt, n=n_prompt)
|
|
||||||
directory = util.make_directory()
|
directory = util.make_directory()
|
||||||
|
|
||||||
sd = StableDiffusion()
|
sd = StableDiffusion()
|
||||||
|
inputs["max_embeddings_multiples"] = sd.count_token(p=prompt, n=n_prompt)
|
||||||
for i in range(samples):
|
for i in range(samples):
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
inputs["seed"] = util.generate_seed()
|
inputs["seed"] = util.generate_seed()
|
||||||
|
|||||||
20
util.py
20
util.py
@ -43,26 +43,6 @@ def save_prompts(inputs: dict):
|
|||||||
print(f"Save prompts: {prompts_filename}.txt")
|
print(f"Save prompts: {prompts_filename}.txt")
|
||||||
|
|
||||||
|
|
||||||
def count_token(p: str, n: str) -> int:
|
|
||||||
"""
|
|
||||||
Count the number of tokens in the prompt and negative prompt.
|
|
||||||
"""
|
|
||||||
token_count_p = len(p.split())
|
|
||||||
token_count_n = len(n.split())
|
|
||||||
if token_count_p >= token_count_n:
|
|
||||||
token_count = token_count_p
|
|
||||||
else:
|
|
||||||
token_count = token_count_n
|
|
||||||
|
|
||||||
max_embeddings_multiples = 1
|
|
||||||
if token_count > 77:
|
|
||||||
max_embeddings_multiples = token_count // 77 + 1
|
|
||||||
|
|
||||||
print(f"token_count: {token_count}, max_embeddings_multiples: {max_embeddings_multiples}")
|
|
||||||
|
|
||||||
return max_embeddings_multiples
|
|
||||||
|
|
||||||
|
|
||||||
def save_images(directory: Path, images: list[bytes], seed: int, i: int):
|
def save_images(directory: Path, images: list[bytes], seed: int, i: int):
|
||||||
"""
|
"""
|
||||||
Save images to a file.
|
Save images to a file.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user