From 0785ff2aedabbd1f4a4aee78b6b7c070ed30a069 Mon Sep 17 00:00:00 2001 From: binary-husky <505030475@qq.com> Date: Sun, 23 Apr 2023 17:45:56 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BE=AE=E8=B0=83=E5=AF=B9=E8=AF=9D=E8=A3=81?= =?UTF-8?q?=E5=89=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- request_llm/bridge_chatgpt.py | 2 +- toolbox.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/request_llm/bridge_chatgpt.py b/request_llm/bridge_chatgpt.py index a1614b7..5e32f45 100644 --- a/request_llm/bridge_chatgpt.py +++ b/request_llm/bridge_chatgpt.py @@ -200,7 +200,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp if "reduce the length" in error_msg: if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出 history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'], - max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])//2) # history至少释放二分之一 + max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一 chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)") # history = [] # 清除历史 elif "does not exist" in error_msg: diff --git a/toolbox.py b/toolbox.py index 4340814..c9dc207 100644 --- a/toolbox.py +++ b/toolbox.py @@ -555,23 +555,26 @@ def run_gradio_in_subpath(demo, auth, port, custom_path): def clip_history(inputs, history, tokenizer, max_token_limit): """ - reduce the length of input/history by clipping. + reduce the length of history by clipping. this function search for the longest entries to clip, little by little, - until the number of token of input/history is reduced under threshold. - 通过剪辑来缩短输入/历史记录的长度。 + until the number of token of history is reduced under threshold. + 通过裁剪来缩短历史记录的长度。 此函数逐渐地搜索最长的条目进行剪辑, - 直到输入/历史记录的标记数量降低到阈值以下。 + 直到历史记录的标记数量降低到阈值以下。 """ import numpy as np from request_llm.bridge_all import model_info def get_token_num(txt): return len(tokenizer.encode(txt, disallowed_special=())) input_token_num = get_token_num(inputs) - if input_token_num < max_token_limit * 3 / 4: - # 当输入部分的token占比小于限制的3/4时,在裁剪时把input的余量留出来 + if input_token_num < max_token_limit * 3 / 4: + # 当输入部分的token占比小于限制的3/4时,裁剪时 + # 1. 把input的余量留出来 max_token_limit = max_token_limit - input_token_num + # 2. 把输出用的余量留出来 + max_token_limit = max_token_limit - 128 + # 3. 如果余量太小了,直接清除历史 if max_token_limit < 128: - # 余量太小了,直接清除历史 history = [] return history else: