From 676fe40d39883b6cae70234652ec84c3db1b31e3 Mon Sep 17 00:00:00 2001 From: binary-husky <505030475@qq.com> Date: Sun, 23 Apr 2023 17:32:44 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96chatgpt=E5=AF=B9=E8=AF=9D?= =?UTF-8?q?=E7=9A=84=E6=88=AA=E6=96=AD=E7=AD=96=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crazy_functions/谷歌检索小助手.py | 3 +- request_llm/bridge_chatgpt.py | 17 +++++++----- toolbox.py | 46 +++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 8 deletions(-) diff --git a/crazy_functions/谷歌检索小助手.py b/crazy_functions/谷歌检索小助手.py index 94ef256..786b266 100644 --- a/crazy_functions/谷歌检索小助手.py +++ b/crazy_functions/谷歌检索小助手.py @@ -98,7 +98,8 @@ def 谷歌检索小助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst history.extend([ "第一批", gpt_say ]) meta_paper_info_list = meta_paper_info_list[10:] - chatbot.append(["状态?", "已经全部完成"]) + chatbot.append(["状态?", + "已经全部完成,您可以试试让AI写一个Related Works,例如您可以继续输入Write a \"Related Works\" section about \"你搜索的研究领域\" for me."]) msg = '正常' yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面 res = write_results_to_file(history) diff --git a/request_llm/bridge_chatgpt.py b/request_llm/bridge_chatgpt.py index c1a900b..a1614b7 100644 --- a/request_llm/bridge_chatgpt.py +++ b/request_llm/bridge_chatgpt.py @@ -21,7 +21,7 @@ import importlib # config_private.py放自己的秘密如API和代理网址 # 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件 -from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys +from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history proxies, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \ get_conf('proxies', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY') @@ -145,7 +145,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面 return - history.append(inputs); history.append(" ") + history.append(inputs); history.append("") retry = 0 while True: @@ -198,14 +198,17 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp chunk_decoded = chunk.decode() error_msg = chunk_decoded if "reduce the length" in error_msg: - chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长,或历史数据过长. 历史缓存数据现已释放,您可以请再次尝试.") - history = [] # 清除历史 + if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出 + history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'], + max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])//2) # history至少释放二分之一 + chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)") + # history = [] # 清除历史 elif "does not exist" in error_msg: - chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在,或者您没有获得体验资格.") + chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在, 或者您没有获得体验资格.") elif "Incorrect API key" in error_msg: - chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由,拒绝服务.") + chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由, 拒绝服务.") elif "exceeded your current quota" in error_msg: - chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由,拒绝服务.") + chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由, 拒绝服务.") elif "bad forward key" in error_msg: chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.") elif "Not enough point" in error_msg: diff --git a/toolbox.py b/toolbox.py index 6625c43..4340814 100644 --- a/toolbox.py +++ b/toolbox.py @@ -551,3 +551,49 @@ def run_gradio_in_subpath(demo, auth, port, custom_path): return {"message": f"Gradio is running at: {custom_path}"} app = gr.mount_gradio_app(app, demo, path=custom_path) uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth + + +def clip_history(inputs, history, tokenizer, max_token_limit): + """ + reduce the length of input/history by clipping. + this function search for the longest entries to clip, little by little, + until the number of token of input/history is reduced under threshold. + 通过剪辑来缩短输入/历史记录的长度。 + 此函数逐渐地搜索最长的条目进行剪辑, + 直到输入/历史记录的标记数量降低到阈值以下。 + """ + import numpy as np + from request_llm.bridge_all import model_info + def get_token_num(txt): + return len(tokenizer.encode(txt, disallowed_special=())) + input_token_num = get_token_num(inputs) + if input_token_num < max_token_limit * 3 / 4: + # 当输入部分的token占比小于限制的3/4时,在裁剪时把input的余量留出来 + max_token_limit = max_token_limit - input_token_num + if max_token_limit < 128: + # 余量太小了,直接清除历史 + history = [] + return history + else: + # 当输入部分的token占比 > 限制的3/4时,直接清除历史 + history = [] + return history + + everything = [''] + everything.extend(history) + n_token = get_token_num('\n'.join(everything)) + everything_token = [get_token_num(e) for e in everything] + + # 截断时的颗粒度 + delta = max(everything_token) // 16 + + while n_token > max_token_limit: + where = np.argmax(everything_token) + encoded = tokenizer.encode(everything[where], disallowed_special=()) + clipped_encoded = encoded[:len(encoded)-delta] + everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char + everything_token[where] = get_token_num(everything[where]) + n_token = get_token_num('\n'.join(everything)) + + history = everything[1:] + return history