From a7c960dcb07e46735743cb566443d835e44f15f9 Mon Sep 17 00:00:00 2001
From: XIao <46100050+Kilig947@users.noreply.github.com>
Date: Sun, 31 Dec 2023 17:13:50 +0800
Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8D=20google=20gemini=20?=
=?UTF-8?q?=E4=BC=98=E5=8C=96=E4=B8=BA=E4=BB=8E=E7=94=A8=E6=88=B7input?=
=?UTF-8?q?=E4=B8=AD=E6=8F=90=E5=8F=96=E6=96=87=E4=BB=B6=20(#1419)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
适配 google gemini 优化为从用户input中提取文件
---
config.py | 17 +-
request_llms/bridge_all.py | 19 +++
request_llms/bridge_google_gemini.py | 101 ++++++++++++
request_llms/com_google.py | 198 +++++++++++++++++++++++
toolbox.py | 232 ++++++++++++++++-----------
5 files changed, 472 insertions(+), 95 deletions(-)
create mode 100644 request_llms/bridge_google_gemini.py
create mode 100644 request_llms/com_google.py
diff --git a/config.py b/config.py
index 861bbed..c202ca0 100644
--- a/config.py
+++ b/config.py
@@ -89,12 +89,14 @@ DEFAULT_FN_GROUPS = ['对话', '编程', '学术', '智能体']
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
- "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
"gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4",
- "chatglm3", "moss", "claude-2"]
-# P.S. 其他可用的模型还包括 ["zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random"
+ "gemini-pro", "chatglm3", "moss", "claude-2"]
+# P.S. 其他可用的模型还包括 [
+# "qwen-turbo", "qwen-plus", "qwen-max"
+# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613",
+# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
-# “qwen-turbo", "qwen-plus", "qwen-max"]
+# ]
# 定义界面上“询问多个GPT模型”插件应该使用哪些模型,请从AVAIL_LLM_MODELS中选择,并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
@@ -204,6 +206,10 @@ ANTHROPIC_API_KEY = ""
CUSTOM_API_KEY_PATTERN = ""
+# Google Gemini API-Key
+GEMINI_API_KEY = ''
+
+
# HUGGINGFACE的TOKEN,下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
HUGGINGFACE_ACCESS_TOKEN = "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"
@@ -292,6 +298,9 @@ NUM_CUSTOM_BASIC_BTN = 4
├── "qwen-turbo" 等通义千问大模型
│ └── DASHSCOPE_API_KEY
│
+├── "Gemini"
+│ └── GEMINI_API_KEY
+│
└── "newbing" Newbing接口不再稳定,不推荐使用
├── NEWBING_STYLE
└── NEWBING_COOKIES
diff --git a/request_llms/bridge_all.py b/request_llms/bridge_all.py
index 689b1f9..61e58a0 100644
--- a/request_llms/bridge_all.py
+++ b/request_llms/bridge_all.py
@@ -28,6 +28,9 @@ from .bridge_chatglm3 import predict as chatglm3_ui
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
from .bridge_qianfan import predict as qianfan_ui
+from .bridge_google_gemini import predict as genai_ui
+from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui
+
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
class LazyloadTiktoken(object):
@@ -246,6 +249,22 @@ model_info = {
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
+ "gemini-pro": {
+ "fn_with_ui": genai_ui,
+ "fn_without_ui": genai_noui,
+ "endpoint": None,
+ "max_token": 1024 * 32,
+ "tokenizer": tokenizer_gpt35,
+ "token_cnt": get_token_num_gpt35,
+ },
+ "gemini-pro-vision": {
+ "fn_with_ui": genai_ui,
+ "fn_without_ui": genai_noui,
+ "endpoint": None,
+ "max_token": 1024 * 32,
+ "tokenizer": tokenizer_gpt35,
+ "token_cnt": get_token_num_gpt35,
+ },
}
# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-
diff --git a/request_llms/bridge_google_gemini.py b/request_llms/bridge_google_gemini.py
new file mode 100644
index 0000000..2438e09
--- /dev/null
+++ b/request_llms/bridge_google_gemini.py
@@ -0,0 +1,101 @@
+# encoding: utf-8
+# @Time : 2023/12/21
+# @Author : Spike
+# @Descr :
+import json
+import re
+import time
+from request_llms.com_google import GoogleChatInit
+from toolbox import get_conf, update_ui, update_ui_lastest_msg
+
+proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
+timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
+ '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
+
+
+def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
+ console_slience=False):
+ # 检查API_KEY
+ if get_conf("GEMINI_API_KEY") == "":
+ raise ValueError(f"请配置 GEMINI_API_KEY。")
+
+ genai = GoogleChatInit()
+ watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
+ gpt_replying_buffer = ''
+ stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
+ for response in stream_response:
+ results = response.decode()
+ match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
+ error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
+ if match:
+ try:
+ paraphrase = json.loads('{"text": "%s"}' % match.group(1))
+ except:
+ raise ValueError(f"解析GEMINI消息出错。")
+ buffer = paraphrase['text']
+ gpt_replying_buffer += buffer
+ if len(observe_window) >= 1:
+ observe_window[0] = gpt_replying_buffer
+ if len(observe_window) >= 2:
+ if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
+ if error_match:
+ raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
+ return gpt_replying_buffer
+
+
+def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
+ # 检查API_KEY
+ if get_conf("GEMINI_API_KEY") == "":
+ yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
+ return
+
+ chatbot.append((inputs, ""))
+ yield from update_ui(chatbot=chatbot, history=history)
+ genai = GoogleChatInit()
+ retry = 0
+ while True:
+ try:
+ stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
+ break
+ except Exception as e:
+ retry += 1
+ chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
+ retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
+ yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
+ if retry > MAX_RETRY: raise TimeoutError
+ gpt_replying_buffer = ""
+ gpt_security_policy = ""
+ history.extend([inputs, ''])
+ for response in stream_response:
+ results = response.decode("utf-8") # 被这个解码给耍了。。
+ gpt_security_policy += results
+ match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
+ error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
+ if match:
+ try:
+ paraphrase = json.loads('{"text": "%s"}' % match.group(1))
+ except:
+ raise ValueError(f"解析GEMINI消息出错。")
+ gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
+ chatbot[-1] = (inputs, gpt_replying_buffer)
+ history[-1] = gpt_replying_buffer
+ yield from update_ui(chatbot=chatbot, history=history)
+ if error_match:
+ history = history[-2] # 错误的不纳入对话
+ chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```")
+ yield from update_ui(chatbot=chatbot, history=history)
+ raise RuntimeError('对话错误')
+ if not gpt_replying_buffer:
+ history = history[-2] # 错误的不纳入对话
+ chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```")
+ yield from update_ui(chatbot=chatbot, history=history)
+
+
+
+if __name__ == '__main__':
+ import sys
+
+ llm_kwargs = {'llm_model': 'gemini-pro'}
+ result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
+ for i in result:
+ print(i)
diff --git a/request_llms/com_google.py b/request_llms/com_google.py
new file mode 100644
index 0000000..7981908
--- /dev/null
+++ b/request_llms/com_google.py
@@ -0,0 +1,198 @@
+# encoding: utf-8
+# @Time : 2023/12/25
+# @Author : Spike
+# @Descr :
+import json
+import os
+import re
+import requests
+from typing import List, Dict, Tuple
+from toolbox import get_conf, encode_image
+
+proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
+
+"""
+========================================================================
+第五部分 一些文件处理方法
+files_filter_handler 根据type过滤文件
+input_encode_handler 提取input中的文件,并解析
+file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
+link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
+html_view_blank 超链接
+html_local_file 本地文件取相对路径
+to_markdown_tabs 文件list 转换为 md tab
+"""
+
+
+def files_filter_handler(file_list):
+ new_list = []
+ filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
+ for file in file_list:
+ file = str(file).replace('file=', '')
+ if os.path.exists(file):
+ if str(os.path.basename(file)).split('.')[-1] in filter_:
+ new_list.append(file)
+ return new_list
+
+
+def input_encode_handler(inputs):
+ md_encode = []
+ pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
+ matches_path = re.findall(pattern_md_file, inputs)
+ for md_path in matches_path:
+ pattern_file = r"\((file=.*)\)"
+ matches_path = re.findall(pattern_file, md_path)
+ encode_file = files_filter_handler(file_list=matches_path)
+ if encode_file:
+ md_encode.extend([{
+ "data": encode_image(i),
+ "type": os.path.splitext(i)[1].replace('.', '')
+ } for i in encode_file])
+ inputs = inputs.replace(md_path, '')
+ return inputs, md_encode
+
+
+def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
+ new_list = []
+ if not filter_:
+ filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
+ for file in file_list:
+ if str(os.path.basename(file)).split('.')[-1] in filter_:
+ new_list.append(html_local_img(file, md=md_type))
+ elif os.path.exists(file):
+ new_list.append(link_mtime_to_md(file))
+ else:
+ new_list.append(file)
+ return new_list
+
+
+def link_mtime_to_md(file):
+ link_local = html_local_file(file)
+ link_name = os.path.basename(file)
+ a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})"
+ return a
+
+
+def html_local_file(file):
+ base_path = os.path.dirname(__file__) # 项目目录
+ if os.path.exists(str(file)):
+ file = f'file={file.replace(base_path, ".")}'
+ return file
+
+
+def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
+ style = ''
+ if max_width is not None:
+ style += f"max-width: {max_width};"
+ if max_height is not None:
+ style += f"max-height: {max_height};"
+ __file = html_local_file(__file)
+ a = f'

'
+ if md:
+ a = f''
+ return a
+
+
+def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
+ """
+ Args:
+ head: 表头:[]
+ tabs: 表值:[[列1], [列2], [列3], [列4]]
+ alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
+ column: True to keep data in columns, False to keep data in rows (default).
+ Returns:
+ A string representation of the markdown table.
+ """
+ if column:
+ transposed_tabs = list(map(list, zip(*tabs)))
+ else:
+ transposed_tabs = tabs
+ # Find the maximum length among the columns
+ max_len = max(len(column) for column in transposed_tabs)
+
+ tab_format = "| %s "
+ tabs_list = "".join([tab_format % i for i in head]) + '|\n'
+ tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
+
+ for i in range(max_len):
+ row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
+ row_data = file_manifest_filter_html(row_data, filter_=None)
+ tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
+
+ return tabs_list
+
+
+class GoogleChatInit:
+
+ def __init__(self):
+ self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
+
+ def __conversation_user(self, user_input):
+ what_i_have_asked = {"role": "user", "parts": []}
+ if 'vision' not in self.url_gemini:
+ input_ = user_input
+ encode_img = []
+ else:
+ input_, encode_img = input_encode_handler(user_input)
+ what_i_have_asked['parts'].append({'text': input_})
+ if encode_img:
+ for data in encode_img:
+ what_i_have_asked['parts'].append(
+ {'inline_data': {
+ "mime_type": f"image/{data['type']}",
+ "data": data['data']
+ }})
+ return what_i_have_asked
+
+ def __conversation_history(self, history):
+ messages = []
+ conversation_cnt = len(history) // 2
+ if conversation_cnt:
+ for index in range(0, 2 * conversation_cnt, 2):
+ what_i_have_asked = self.__conversation_user(history[index])
+ what_gpt_answer = {
+ "role": "model",
+ "parts": [{"text": history[index + 1]}]
+ }
+ messages.append(what_i_have_asked)
+ messages.append(what_gpt_answer)
+ return messages
+
+ def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
+ headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
+ response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
+ stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
+ return response.iter_lines()
+
+ def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
+ messages = [
+ # {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
+ # {"role": "user", "parts": [{"text": ""}]},
+ # {"role": "model", "parts": [{"text": ""}]}
+ ]
+ self.url_gemini = self.url_gemini.replace(
+ '%m', llm_kwargs['llm_model']).replace(
+ '%k', get_conf('GEMINI_API_KEY')
+ )
+ header = {'Content-Type': 'application/json'}
+ if 'vision' not in self.url_gemini: # 不是vision 才处理history
+ messages.extend(self.__conversation_history(history)) # 处理 history
+ messages.append(self.__conversation_user(inputs)) # 处理用户对话
+ payload = {
+ "contents": messages,
+ "generationConfig": {
+ "stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
+ "temperature": llm_kwargs.get('temperature', 1),
+ # "maxOutputTokens": 800,
+ "topP": llm_kwargs.get('top_p', 0.8),
+ "topK": 10
+ }
+ }
+ return header, payload
+
+
+if __name__ == '__main__':
+ google = GoogleChatInit()
+ # print(gootle.generate_message_payload('你好呀', {},
+ # ['123123', '3123123'], ''))
+ # gootle.input_encode_handle('123123[123123](./123123), ')
\ No newline at end of file
diff --git a/toolbox.py b/toolbox.py
index 154b54c..632279b 100644
--- a/toolbox.py
+++ b/toolbox.py
@@ -11,8 +11,10 @@ import glob
import math
from latex2mathml.converter import convert as tex2mathml
from functools import wraps, lru_cache
+
pj = os.path.join
default_user_name = 'default_user'
+
"""
========================================================================
第一部分
@@ -26,6 +28,7 @@ default_user_name = 'default_user'
========================================================================
"""
+
class ChatBotWithCookies(list):
def __init__(self, cookie):
"""
@@ -67,18 +70,18 @@ def ArgsGeneralWrapper(f):
else:
user_name = default_user_name
cookies.update({
- 'top_p':top_p,
+ 'top_p': top_p,
'api_key': cookies['api_key'],
'llm_model': llm_model,
- 'temperature':temperature,
+ 'temperature': temperature,
'user_name': user_name,
})
llm_kwargs = {
'api_key': cookies['api_key'],
'llm_model': llm_model,
- 'top_p':top_p,
+ 'top_p': top_p,
'max_length': max_length,
- 'temperature':temperature,
+ 'temperature': temperature,
'client_ip': request.client.host,
'most_recent_uploaded': cookies.get('most_recent_uploaded')
}
@@ -87,7 +90,7 @@ def ArgsGeneralWrapper(f):
}
chatbot_with_cookie = ChatBotWithCookies(cookies)
chatbot_with_cookie.write_list(chatbot)
-
+
if cookies.get('lock_plugin', None) is None:
# 正常状态
if len(args) == 0: # 插件通道
@@ -103,8 +106,10 @@ def ArgsGeneralWrapper(f):
final_cookies = chatbot_with_cookie.get_cookies()
# len(args) != 0 代表“提交”键对话通道,或者基础功能通道
if len(args) != 0 and 'files_to_promote' in final_cookies and len(final_cookies['files_to_promote']) > 0:
- chatbot_with_cookie.append(["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
+ chatbot_with_cookie.append(
+ ["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
yield from update_ui(chatbot_with_cookie, final_cookies['history'], msg="检测到被滞留的缓存文档")
+
return decorated
@@ -129,6 +134,7 @@ def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面
yield cookies, chatbot_gr, history, msg
+
def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
"""
刷新用户界面
@@ -147,6 +153,7 @@ def trimmed_format_exc():
replace_path = "."
return str.replace(current_path, replace_path)
+
def CatchException(f):
"""
装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
@@ -164,9 +171,9 @@ def CatchException(f):
if len(chatbot_with_cookie) == 0:
chatbot_with_cookie.clear()
chatbot_with_cookie.append(["插件调度异常", "异常原因"])
- chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0],
- f"[Local Message] 插件调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
- yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
+ chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0], f"[Local Message] 插件调用出错: \n\n{tb_str} \n")
+ yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
+
return decorated
@@ -209,6 +216,7 @@ def HotReload(f):
========================================================================
"""
+
def get_reduce_token_percent(text):
"""
* 此函数未来将被弃用
@@ -220,9 +228,9 @@ def get_reduce_token_percent(text):
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
max_limit = float(match[0]) - EXCEED_ALLO
current_tokens = float(match[1])
- ratio = max_limit/current_tokens
+ ratio = max_limit / current_tokens
assert ratio > 0 and ratio < 1
- return ratio, str(int(current_tokens-max_limit))
+ return ratio, str(int(current_tokens - max_limit))
except:
return 0.5, '不详'
@@ -242,7 +250,7 @@ def write_history_to_file(history, file_basename=None, file_fullname=None, auto_
with open(file_fullname, 'w', encoding='utf8') as f:
f.write('# GPT-Academic Report\n')
for i, content in enumerate(history):
- try:
+ try:
if type(content) != str: content = str(content)
except:
continue
@@ -268,8 +276,6 @@ def regular_txt_to_markdown(text):
return text
-
-
def report_exception(chatbot, history, a, b):
"""
向chatbot中添加错误信息
@@ -286,7 +292,7 @@ def text_divide_paragraph(text):
suf = ''
if text.startswith(pre) and text.endswith(suf):
return text
-
+
if '```' in text:
# careful input
return text
@@ -312,7 +318,7 @@ def markdown_convertion(txt):
if txt.startswith(pre) and txt.endswith(suf):
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
return txt # 已经被转化过,不需要再次转化
-
+
markdown_extension_configs = {
'mdx_math': {
'enable_dollar_delimiter': True,
@@ -352,7 +358,8 @@ def markdown_convertion(txt):
"""
解决一个mdx_math的bug(单$包裹begin命令时多余\n', '')
return content
@@ -363,16 +370,16 @@ def markdown_convertion(txt):
if '```' in txt and '```reference' not in txt: return False
if '$' not in txt and '\\[' not in txt: return False
mathpatterns = {
- r'(? 0:
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
- if not azure_model_name.startswith('azure'):
+ if not azure_model_name.startswith('azure'):
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
if is_any_api_key(AZURE_API_KEY_):
- if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_
- else: API_KEY = AZURE_API_KEY_
+ if is_any_api_key(API_KEY):
+ API_KEY = API_KEY + ',' + AZURE_API_KEY_
+ else:
+ API_KEY = AZURE_API_KEY_
customize_fn_overwrite_ = {}
for k in range(NUM_CUSTOM_BASIC_BTN):
- customize_fn_overwrite_.update({
+ customize_fn_overwrite_.update({
"自定义按钮" + str(k+1):{
- "Title": r"",
- "Prefix": r"请在自定义菜单中定义提示词前缀.",
- "Suffix": r"请在自定义菜单中定义提示词后缀",
+ "Title": r"",
+ "Prefix": r"请在自定义菜单中定义提示词前缀.",
+ "Suffix": r"请在自定义菜单中定义提示词后缀",
}
})
return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_}
+
def is_openai_api_key(key):
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
if len(CUSTOM_API_KEY_PATTERN) != 0:
@@ -768,14 +785,17 @@ def is_openai_api_key(key):
API_MATCH_ORIGINAL = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
return bool(API_MATCH_ORIGINAL)
+
def is_azure_api_key(key):
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_AZURE)
+
def is_api2d_key(key):
API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_API2D)
+
def is_any_api_key(key):
if ',' in key:
keys = key.split(',')
@@ -785,24 +805,26 @@ def is_any_api_key(key):
else:
return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key)
+
def what_keys(keys):
- avail_key_list = {'OpenAI Key':0, "Azure Key":0, "API2D Key":0}
+ avail_key_list = {'OpenAI Key': 0, "Azure Key": 0, "API2D Key": 0}
key_list = keys.split(',')
for k in key_list:
- if is_openai_api_key(k):
+ if is_openai_api_key(k):
avail_key_list['OpenAI Key'] += 1
for k in key_list:
- if is_api2d_key(k):
+ if is_api2d_key(k):
avail_key_list['API2D Key'] += 1
for k in key_list:
- if is_azure_api_key(k):
+ if is_azure_api_key(k):
avail_key_list['Azure Key'] += 1
return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']} 个"
+
def select_api_key(keys, llm_model):
import random
avail_key_list = []
@@ -826,6 +848,7 @@ def select_api_key(keys, llm_model):
api_key = random.choice(avail_key_list) # 随机负载均衡
return api_key
+
def read_env_variable(arg, default_value):
"""
环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先),也可以直接是`CONFIG`
@@ -843,10 +866,10 @@ def read_env_variable(arg, default_value):
set GPT_ACADEMIC_AUTHENTICATION=[("username", "password"), ("username2", "password2")]
"""
from colorful import print亮红, print亮绿
- arg_with_prefix = "GPT_ACADEMIC_" + arg
- if arg_with_prefix in os.environ:
+ arg_with_prefix = "GPT_ACADEMIC_" + arg
+ if arg_with_prefix in os.environ:
env_arg = os.environ[arg_with_prefix]
- elif arg in os.environ:
+ elif arg in os.environ:
env_arg = os.environ[arg]
else:
raise KeyError
@@ -856,7 +879,7 @@ def read_env_variable(arg, default_value):
env_arg = env_arg.strip()
if env_arg == 'True': r = True
elif env_arg == 'False': r = False
- else: print('enter True or False, but have:', env_arg); r = default_value
+ else: print('Enter True or False, but have:', env_arg); r = default_value
elif isinstance(default_value, int):
r = int(env_arg)
elif isinstance(default_value, float):
@@ -880,13 +903,14 @@ def read_env_variable(arg, default_value):
print亮绿(f"[ENV_VAR] 成功读取环境变量{arg}")
return r
+
@lru_cache(maxsize=128)
def read_single_conf_with_lru_cache(arg):
from colorful import print亮红, print亮绿, print亮蓝
try:
# 优先级1. 获取环境变量作为配置
- default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考
- r = read_env_variable(arg, default_ref)
+ default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考
+ r = read_env_variable(arg, default_ref)
except:
try:
# 优先级2. 获取config_private中的配置
@@ -899,7 +923,7 @@ def read_single_conf_with_lru_cache(arg):
if arg == 'API_URL_REDIRECT':
oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的,请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`
if oai_rd and not oai_rd.endswith('/completions'):
- print亮红( "\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
+ print亮红("\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
time.sleep(5)
if arg == 'API_KEY':
print亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key,如API_KEY=\"openai-key1,openai-key2,azure-key3\"")
@@ -907,9 +931,9 @@ def read_single_conf_with_lru_cache(arg):
if is_any_api_key(r):
print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
else:
- print亮红( "[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行。")
+ print亮红("[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行。")
if arg == 'proxies':
- if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用
+ if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用
if r is None:
print亮红('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议:检查USE_PROXY选项是否修改。')
else:
@@ -953,17 +977,20 @@ class DummyWith():
在上下文执行开始的情况下,__enter__()方法会在代码块被执行前被调用,
而在上下文执行结束时,__exit__()方法则会被调用。
"""
+
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return
+
def run_gradio_in_subpath(demo, auth, port, custom_path):
"""
把gradio的运行地址更改到指定的二次路径上
"""
- def is_path_legal(path: str)->bool:
+
+ def is_path_legal(path: str) -> bool:
'''
check path for sub url
path: path to check
@@ -988,7 +1015,7 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
app = FastAPI()
if custom_path != "/":
@app.get("/")
- def read_main():
+ def read_main():
return {"message": f"Gradio is running at: {custom_path}"}
app = gr.mount_gradio_app(app, demo, path=custom_path)
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
@@ -999,13 +1026,13 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
reduce the length of history by clipping.
this function search for the longest entries to clip, little by little,
until the number of token of history is reduced under threshold.
- 通过裁剪来缩短历史记录的长度。
+ 通过裁剪来缩短历史记录的长度。
此函数逐渐地搜索最长的条目进行剪辑,
直到历史记录的标记数量降低到阈值以下。
"""
import numpy as np
from request_llms.bridge_all import model_info
- def get_token_num(txt):
+ def get_token_num(txt):
return len(tokenizer.encode(txt, disallowed_special=()))
input_token_num = get_token_num(inputs)
@@ -1039,14 +1066,15 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
while n_token > max_token_limit:
where = np.argmax(everything_token)
encoded = tokenizer.encode(everything[where], disallowed_special=())
- clipped_encoded = encoded[:len(encoded)-delta]
- everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
+ clipped_encoded = encoded[:len(encoded) - delta]
+ everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
everything_token[where] = get_token_num(everything[where])
n_token = get_token_num('\n'.join(everything))
history = everything[1:]
return history
+
"""
========================================================================
第三部分
@@ -1058,6 +1086,7 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
========================================================================
"""
+
def zip_folder(source_folder, dest_folder, zip_name):
import zipfile
import os
@@ -1089,15 +1118,18 @@ def zip_folder(source_folder, dest_folder, zip_name):
print(f"Zip file created at {zip_file}")
+
def zip_result(folder):
t = gen_time_str()
zip_folder(folder, get_log_folder(), f'{t}-result.zip')
return pj(get_log_folder(), f'{t}-result.zip')
+
def gen_time_str():
import time
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
+
def get_log_folder(user=default_user_name, plugin_name='shared'):
if user is None: user = default_user_name
PATH_LOGGING = get_conf('PATH_LOGGING')
@@ -1108,29 +1140,36 @@ def get_log_folder(user=default_user_name, plugin_name='shared'):
if not os.path.exists(_dir): os.makedirs(_dir)
return _dir
+
def get_upload_folder(user=default_user_name, tag=None):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
if user is None: user = default_user_name
- if tag is None or len(tag)==0:
+ if tag is None or len(tag) == 0:
target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
else:
target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag)
return target_path_base
+
def is_the_upload_folder(string):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
- if re.match(pattern, string): return True
- else: return False
+ if re.match(pattern, string):
+ return True
+ else:
+ return False
+
def get_user(chatbotwithcookies):
return chatbotwithcookies._cookies.get('user_name', default_user_name)
+
class ProxyNetworkActivate():
"""
这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
"""
+
def __init__(self, task=None) -> None:
self.task = task
if not task:
@@ -1158,32 +1197,36 @@ class ProxyNetworkActivate():
if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY')
return
+
def objdump(obj, file='objdump.tmp'):
import pickle
with open(file, 'wb+') as f:
pickle.dump(obj, f)
return
+
def objload(file='objdump.tmp'):
import pickle, os
- if not os.path.exists(file):
+ if not os.path.exists(file):
return
with open(file, 'rb') as f:
return pickle.load(f)
-
+
+
def Singleton(cls):
"""
一个单实例装饰器
"""
_instance = {}
-
+
def _singleton(*args, **kargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kargs)
return _instance[cls]
-
+
return _singleton
+
"""
========================================================================
第四部分
@@ -1197,6 +1240,7 @@ def Singleton(cls):
========================================================================
"""
+
def set_conf(key, value):
from toolbox import read_single_conf_with_lru_cache, get_conf
read_single_conf_with_lru_cache.cache_clear()
@@ -1205,10 +1249,12 @@ def set_conf(key, value):
altered = get_conf(key)
return altered
+
def set_multi_conf(dic):
for k, v in dic.items(): set_conf(k, v)
return
+
def get_plugin_handle(plugin_name):
"""
e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
@@ -1220,12 +1266,14 @@ def get_plugin_handle(plugin_name):
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
return f_hot_reload
+
def get_chat_handle():
"""
"""
from request_llms.bridge_all import predict_no_ui_long_connection
return predict_no_ui_long_connection
+
def get_plugin_default_kwargs():
"""
"""
@@ -1234,9 +1282,9 @@ def get_plugin_default_kwargs():
llm_kwargs = {
'api_key': cookies['api_key'],
'llm_model': cookies['llm_model'],
- 'top_p':1.0,
+ 'top_p': 1.0,
'max_length': None,
- 'temperature':1.0,
+ 'temperature': 1.0,
}
chatbot = ChatBotWithCookies(llm_kwargs)
@@ -1247,11 +1295,12 @@ def get_plugin_default_kwargs():
"plugin_kwargs": {},
"chatbot_with_cookie": chatbot,
"history": [],
- "system_prompt": "You are a good AI.",
+ "system_prompt": "You are a good AI.",
"web_port": None
}
return DEFAULT_FN_GROUPS_kwargs
+
def get_chat_default_kwargs():
"""
"""
@@ -1259,9 +1308,9 @@ def get_chat_default_kwargs():
llm_kwargs = {
'api_key': cookies['api_key'],
'llm_model': cookies['llm_model'],
- 'top_p':1.0,
+ 'top_p': 1.0,
'max_length': None,
- 'temperature':1.0,
+ 'temperature': 1.0,
}
default_chat_kwargs = {
"inputs": "Hello there, are you ready?",
@@ -1284,15 +1333,15 @@ def get_pictures_list(path):
def have_any_recent_upload_image_files(chatbot):
_5min = 5 * 60
- if chatbot is None: return False, None # chatbot is None
+ if chatbot is None: return False, None # chatbot is None
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
- if not most_recent_uploaded: return False, None # most_recent_uploaded is None
+ if not most_recent_uploaded: return False, None # most_recent_uploaded is None
if time.time() - most_recent_uploaded["time"] < _5min:
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
path = most_recent_uploaded['path']
file_manifest = get_pictures_list(path)
if len(file_manifest) == 0: return False, None
- return True, file_manifest # most_recent_uploaded is new
+ return True, file_manifest # most_recent_uploaded is new
else:
return False, None # most_recent_uploaded is too old
@@ -1307,6 +1356,7 @@ def get_max_token(llm_kwargs):
from request_llms.bridge_all import model_info
return model_info[llm_kwargs['llm_model']]['max_token']
+
def check_packages(packages=[]):
import importlib.util
for p in packages: