微调对话裁剪

This commit is contained in:
binary-husky 2023-04-23 17:45:56 +08:00
parent 676fe40d39
commit 0785ff2aed
2 changed files with 11 additions and 8 deletions

View File

@ -200,7 +200,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
if "reduce the length" in error_msg: if "reduce the length" in error_msg:
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入history[-2] 是本次输入, history[-1] 是本次输出 if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入history[-2] 是本次输入, history[-1] 是本次输出
history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'], history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])//2) # history至少释放二分之一 max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一
chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)") chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
# history = [] # 清除历史 # history = [] # 清除历史
elif "does not exist" in error_msg: elif "does not exist" in error_msg:

View File

@ -555,12 +555,12 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
def clip_history(inputs, history, tokenizer, max_token_limit): def clip_history(inputs, history, tokenizer, max_token_limit):
""" """
reduce the length of input/history by clipping. reduce the length of history by clipping.
this function search for the longest entries to clip, little by little, this function search for the longest entries to clip, little by little,
until the number of token of input/history is reduced under threshold. until the number of token of history is reduced under threshold.
通过来缩短输入/历史记录的长度 通过剪来缩短历史记录的长度
此函数逐渐地搜索最长的条目进行剪辑 此函数逐渐地搜索最长的条目进行剪辑
直到输入/历史记录的标记数量降低到阈值以下 直到历史记录的标记数量降低到阈值以下
""" """
import numpy as np import numpy as np
from request_llm.bridge_all import model_info from request_llm.bridge_all import model_info
@ -568,10 +568,13 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
return len(tokenizer.encode(txt, disallowed_special=())) return len(tokenizer.encode(txt, disallowed_special=()))
input_token_num = get_token_num(inputs) input_token_num = get_token_num(inputs)
if input_token_num < max_token_limit * 3 / 4: if input_token_num < max_token_limit * 3 / 4:
# 当输入部分的token占比小于限制的3/4时在裁剪时把input的余量留出来 # 当输入部分的token占比小于限制的3/4时裁剪时
# 1. 把input的余量留出来
max_token_limit = max_token_limit - input_token_num max_token_limit = max_token_limit - input_token_num
# 2. 把输出用的余量留出来
max_token_limit = max_token_limit - 128
# 3. 如果余量太小了,直接清除历史
if max_token_limit < 128: if max_token_limit < 128:
# 余量太小了,直接清除历史
history = [] history = []
return history return history
else: else: