适配 google gemini 优化为从用户input中提取文件 (#1419)
适配 google gemini 优化为从用户input中提取文件
This commit is contained in:
parent
a96f842b3a
commit
a7c960dcb0
17
config.py
17
config.py
@ -89,12 +89,14 @@ DEFAULT_FN_GROUPS = ['对话', '编程', '学术', '智能体']
|
|||||||
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
|
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
|
||||||
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview",
|
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview",
|
||||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
|
"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
|
||||||
"api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
|
|
||||||
"gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4",
|
"gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4",
|
||||||
"chatglm3", "moss", "claude-2"]
|
"gemini-pro", "chatglm3", "moss", "claude-2"]
|
||||||
# P.S. 其他可用的模型还包括 ["zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random"
|
# P.S. 其他可用的模型还包括 [
|
||||||
|
# "qwen-turbo", "qwen-plus", "qwen-max"
|
||||||
|
# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613",
|
||||||
|
# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
|
||||||
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
|
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
|
||||||
# “qwen-turbo", "qwen-plus", "qwen-max"]
|
# ]
|
||||||
|
|
||||||
|
|
||||||
# 定义界面上“询问多个GPT模型”插件应该使用哪些模型,请从AVAIL_LLM_MODELS中选择,并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
|
# 定义界面上“询问多个GPT模型”插件应该使用哪些模型,请从AVAIL_LLM_MODELS中选择,并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
|
||||||
@ -204,6 +206,10 @@ ANTHROPIC_API_KEY = ""
|
|||||||
CUSTOM_API_KEY_PATTERN = ""
|
CUSTOM_API_KEY_PATTERN = ""
|
||||||
|
|
||||||
|
|
||||||
|
# Google Gemini API-Key
|
||||||
|
GEMINI_API_KEY = ''
|
||||||
|
|
||||||
|
|
||||||
# HUGGINGFACE的TOKEN,下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
|
# HUGGINGFACE的TOKEN,下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
|
||||||
HUGGINGFACE_ACCESS_TOKEN = "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"
|
HUGGINGFACE_ACCESS_TOKEN = "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"
|
||||||
|
|
||||||
@ -292,6 +298,9 @@ NUM_CUSTOM_BASIC_BTN = 4
|
|||||||
├── "qwen-turbo" 等通义千问大模型
|
├── "qwen-turbo" 等通义千问大模型
|
||||||
│ └── DASHSCOPE_API_KEY
|
│ └── DASHSCOPE_API_KEY
|
||||||
│
|
│
|
||||||
|
├── "Gemini"
|
||||||
|
│ └── GEMINI_API_KEY
|
||||||
|
│
|
||||||
└── "newbing" Newbing接口不再稳定,不推荐使用
|
└── "newbing" Newbing接口不再稳定,不推荐使用
|
||||||
├── NEWBING_STYLE
|
├── NEWBING_STYLE
|
||||||
└── NEWBING_COOKIES
|
└── NEWBING_COOKIES
|
||||||
|
@ -28,6 +28,9 @@ from .bridge_chatglm3 import predict as chatglm3_ui
|
|||||||
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
|
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
|
||||||
from .bridge_qianfan import predict as qianfan_ui
|
from .bridge_qianfan import predict as qianfan_ui
|
||||||
|
|
||||||
|
from .bridge_google_gemini import predict as genai_ui
|
||||||
|
from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui
|
||||||
|
|
||||||
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
|
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
|
||||||
|
|
||||||
class LazyloadTiktoken(object):
|
class LazyloadTiktoken(object):
|
||||||
@ -246,6 +249,22 @@ model_info = {
|
|||||||
"tokenizer": tokenizer_gpt35,
|
"tokenizer": tokenizer_gpt35,
|
||||||
"token_cnt": get_token_num_gpt35,
|
"token_cnt": get_token_num_gpt35,
|
||||||
},
|
},
|
||||||
|
"gemini-pro": {
|
||||||
|
"fn_with_ui": genai_ui,
|
||||||
|
"fn_without_ui": genai_noui,
|
||||||
|
"endpoint": None,
|
||||||
|
"max_token": 1024 * 32,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
},
|
||||||
|
"gemini-pro-vision": {
|
||||||
|
"fn_with_ui": genai_ui,
|
||||||
|
"fn_without_ui": genai_noui,
|
||||||
|
"endpoint": None,
|
||||||
|
"max_token": 1024 * 32,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-
|
||||||
|
101
request_llms/bridge_google_gemini.py
Normal file
101
request_llms/bridge_google_gemini.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# encoding: utf-8
|
||||||
|
# @Time : 2023/12/21
|
||||||
|
# @Author : Spike
|
||||||
|
# @Descr :
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from request_llms.com_google import GoogleChatInit
|
||||||
|
from toolbox import get_conf, update_ui, update_ui_lastest_msg
|
||||||
|
|
||||||
|
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
||||||
|
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
||||||
|
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
||||||
|
|
||||||
|
|
||||||
|
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
|
||||||
|
console_slience=False):
|
||||||
|
# 检查API_KEY
|
||||||
|
if get_conf("GEMINI_API_KEY") == "":
|
||||||
|
raise ValueError(f"请配置 GEMINI_API_KEY。")
|
||||||
|
|
||||||
|
genai = GoogleChatInit()
|
||||||
|
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
||||||
|
gpt_replying_buffer = ''
|
||||||
|
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
|
||||||
|
for response in stream_response:
|
||||||
|
results = response.decode()
|
||||||
|
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
|
||||||
|
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
|
||||||
|
except:
|
||||||
|
raise ValueError(f"解析GEMINI消息出错。")
|
||||||
|
buffer = paraphrase['text']
|
||||||
|
gpt_replying_buffer += buffer
|
||||||
|
if len(observe_window) >= 1:
|
||||||
|
observe_window[0] = gpt_replying_buffer
|
||||||
|
if len(observe_window) >= 2:
|
||||||
|
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
|
||||||
|
if error_match:
|
||||||
|
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
|
||||||
|
return gpt_replying_buffer
|
||||||
|
|
||||||
|
|
||||||
|
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
|
||||||
|
# 检查API_KEY
|
||||||
|
if get_conf("GEMINI_API_KEY") == "":
|
||||||
|
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
|
||||||
|
return
|
||||||
|
|
||||||
|
chatbot.append((inputs, ""))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
genai = GoogleChatInit()
|
||||||
|
retry = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
retry += 1
|
||||||
|
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
||||||
|
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
|
||||||
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
|
gpt_replying_buffer = ""
|
||||||
|
gpt_security_policy = ""
|
||||||
|
history.extend([inputs, ''])
|
||||||
|
for response in stream_response:
|
||||||
|
results = response.decode("utf-8") # 被这个解码给耍了。。
|
||||||
|
gpt_security_policy += results
|
||||||
|
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
|
||||||
|
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
|
||||||
|
except:
|
||||||
|
raise ValueError(f"解析GEMINI消息出错。")
|
||||||
|
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
|
||||||
|
chatbot[-1] = (inputs, gpt_replying_buffer)
|
||||||
|
history[-1] = gpt_replying_buffer
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
if error_match:
|
||||||
|
history = history[-2] # 错误的不纳入对话
|
||||||
|
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
raise RuntimeError('对话错误')
|
||||||
|
if not gpt_replying_buffer:
|
||||||
|
history = history[-2] # 错误的不纳入对话
|
||||||
|
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
|
||||||
|
llm_kwargs = {'llm_model': 'gemini-pro'}
|
||||||
|
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
|
||||||
|
for i in result:
|
||||||
|
print(i)
|
198
request_llms/com_google.py
Normal file
198
request_llms/com_google.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
# encoding: utf-8
|
||||||
|
# @Time : 2023/12/25
|
||||||
|
# @Author : Spike
|
||||||
|
# @Descr :
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import requests
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
from toolbox import get_conf, encode_image
|
||||||
|
|
||||||
|
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
|
||||||
|
|
||||||
|
"""
|
||||||
|
========================================================================
|
||||||
|
第五部分 一些文件处理方法
|
||||||
|
files_filter_handler 根据type过滤文件
|
||||||
|
input_encode_handler 提取input中的文件,并解析
|
||||||
|
file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
|
||||||
|
link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
|
||||||
|
html_view_blank 超链接
|
||||||
|
html_local_file 本地文件取相对路径
|
||||||
|
to_markdown_tabs 文件list 转换为 md tab
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def files_filter_handler(file_list):
|
||||||
|
new_list = []
|
||||||
|
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
|
||||||
|
for file in file_list:
|
||||||
|
file = str(file).replace('file=', '')
|
||||||
|
if os.path.exists(file):
|
||||||
|
if str(os.path.basename(file)).split('.')[-1] in filter_:
|
||||||
|
new_list.append(file)
|
||||||
|
return new_list
|
||||||
|
|
||||||
|
|
||||||
|
def input_encode_handler(inputs):
|
||||||
|
md_encode = []
|
||||||
|
pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
|
||||||
|
matches_path = re.findall(pattern_md_file, inputs)
|
||||||
|
for md_path in matches_path:
|
||||||
|
pattern_file = r"\((file=.*)\)"
|
||||||
|
matches_path = re.findall(pattern_file, md_path)
|
||||||
|
encode_file = files_filter_handler(file_list=matches_path)
|
||||||
|
if encode_file:
|
||||||
|
md_encode.extend([{
|
||||||
|
"data": encode_image(i),
|
||||||
|
"type": os.path.splitext(i)[1].replace('.', '')
|
||||||
|
} for i in encode_file])
|
||||||
|
inputs = inputs.replace(md_path, '')
|
||||||
|
return inputs, md_encode
|
||||||
|
|
||||||
|
|
||||||
|
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
|
||||||
|
new_list = []
|
||||||
|
if not filter_:
|
||||||
|
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
|
||||||
|
for file in file_list:
|
||||||
|
if str(os.path.basename(file)).split('.')[-1] in filter_:
|
||||||
|
new_list.append(html_local_img(file, md=md_type))
|
||||||
|
elif os.path.exists(file):
|
||||||
|
new_list.append(link_mtime_to_md(file))
|
||||||
|
else:
|
||||||
|
new_list.append(file)
|
||||||
|
return new_list
|
||||||
|
|
||||||
|
|
||||||
|
def link_mtime_to_md(file):
|
||||||
|
link_local = html_local_file(file)
|
||||||
|
link_name = os.path.basename(file)
|
||||||
|
a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})"
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def html_local_file(file):
|
||||||
|
base_path = os.path.dirname(__file__) # 项目目录
|
||||||
|
if os.path.exists(str(file)):
|
||||||
|
file = f'file={file.replace(base_path, ".")}'
|
||||||
|
return file
|
||||||
|
|
||||||
|
|
||||||
|
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
|
||||||
|
style = ''
|
||||||
|
if max_width is not None:
|
||||||
|
style += f"max-width: {max_width};"
|
||||||
|
if max_height is not None:
|
||||||
|
style += f"max-height: {max_height};"
|
||||||
|
__file = html_local_file(__file)
|
||||||
|
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
|
||||||
|
if md:
|
||||||
|
a = f''
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
head: 表头:[]
|
||||||
|
tabs: 表值:[[列1], [列2], [列3], [列4]]
|
||||||
|
alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
|
||||||
|
column: True to keep data in columns, False to keep data in rows (default).
|
||||||
|
Returns:
|
||||||
|
A string representation of the markdown table.
|
||||||
|
"""
|
||||||
|
if column:
|
||||||
|
transposed_tabs = list(map(list, zip(*tabs)))
|
||||||
|
else:
|
||||||
|
transposed_tabs = tabs
|
||||||
|
# Find the maximum length among the columns
|
||||||
|
max_len = max(len(column) for column in transposed_tabs)
|
||||||
|
|
||||||
|
tab_format = "| %s "
|
||||||
|
tabs_list = "".join([tab_format % i for i in head]) + '|\n'
|
||||||
|
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
|
||||||
|
|
||||||
|
for i in range(max_len):
|
||||||
|
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
|
||||||
|
row_data = file_manifest_filter_html(row_data, filter_=None)
|
||||||
|
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
|
||||||
|
|
||||||
|
return tabs_list
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleChatInit:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
|
||||||
|
|
||||||
|
def __conversation_user(self, user_input):
|
||||||
|
what_i_have_asked = {"role": "user", "parts": []}
|
||||||
|
if 'vision' not in self.url_gemini:
|
||||||
|
input_ = user_input
|
||||||
|
encode_img = []
|
||||||
|
else:
|
||||||
|
input_, encode_img = input_encode_handler(user_input)
|
||||||
|
what_i_have_asked['parts'].append({'text': input_})
|
||||||
|
if encode_img:
|
||||||
|
for data in encode_img:
|
||||||
|
what_i_have_asked['parts'].append(
|
||||||
|
{'inline_data': {
|
||||||
|
"mime_type": f"image/{data['type']}",
|
||||||
|
"data": data['data']
|
||||||
|
}})
|
||||||
|
return what_i_have_asked
|
||||||
|
|
||||||
|
def __conversation_history(self, history):
|
||||||
|
messages = []
|
||||||
|
conversation_cnt = len(history) // 2
|
||||||
|
if conversation_cnt:
|
||||||
|
for index in range(0, 2 * conversation_cnt, 2):
|
||||||
|
what_i_have_asked = self.__conversation_user(history[index])
|
||||||
|
what_gpt_answer = {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{"text": history[index + 1]}]
|
||||||
|
}
|
||||||
|
messages.append(what_i_have_asked)
|
||||||
|
messages.append(what_gpt_answer)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
|
||||||
|
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
|
||||||
|
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
|
||||||
|
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
|
||||||
|
return response.iter_lines()
|
||||||
|
|
||||||
|
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
|
||||||
|
messages = [
|
||||||
|
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
|
||||||
|
# {"role": "user", "parts": [{"text": ""}]},
|
||||||
|
# {"role": "model", "parts": [{"text": ""}]}
|
||||||
|
]
|
||||||
|
self.url_gemini = self.url_gemini.replace(
|
||||||
|
'%m', llm_kwargs['llm_model']).replace(
|
||||||
|
'%k', get_conf('GEMINI_API_KEY')
|
||||||
|
)
|
||||||
|
header = {'Content-Type': 'application/json'}
|
||||||
|
if 'vision' not in self.url_gemini: # 不是vision 才处理history
|
||||||
|
messages.extend(self.__conversation_history(history)) # 处理 history
|
||||||
|
messages.append(self.__conversation_user(inputs)) # 处理用户对话
|
||||||
|
payload = {
|
||||||
|
"contents": messages,
|
||||||
|
"generationConfig": {
|
||||||
|
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
|
||||||
|
"temperature": llm_kwargs.get('temperature', 1),
|
||||||
|
# "maxOutputTokens": 800,
|
||||||
|
"topP": llm_kwargs.get('top_p', 0.8),
|
||||||
|
"topK": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return header, payload
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
google = GoogleChatInit()
|
||||||
|
# print(gootle.generate_message_payload('你好呀', {},
|
||||||
|
# ['123123', '3123123'], ''))
|
||||||
|
# gootle.input_encode_handle('123123[123123](./123123), ')
|
140
toolbox.py
140
toolbox.py
@ -11,8 +11,10 @@ import glob
|
|||||||
import math
|
import math
|
||||||
from latex2mathml.converter import convert as tex2mathml
|
from latex2mathml.converter import convert as tex2mathml
|
||||||
from functools import wraps, lru_cache
|
from functools import wraps, lru_cache
|
||||||
|
|
||||||
pj = os.path.join
|
pj = os.path.join
|
||||||
default_user_name = 'default_user'
|
default_user_name = 'default_user'
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
========================================================================
|
||||||
第一部分
|
第一部分
|
||||||
@ -26,6 +28,7 @@ default_user_name = 'default_user'
|
|||||||
========================================================================
|
========================================================================
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ChatBotWithCookies(list):
|
class ChatBotWithCookies(list):
|
||||||
def __init__(self, cookie):
|
def __init__(self, cookie):
|
||||||
"""
|
"""
|
||||||
@ -67,18 +70,18 @@ def ArgsGeneralWrapper(f):
|
|||||||
else:
|
else:
|
||||||
user_name = default_user_name
|
user_name = default_user_name
|
||||||
cookies.update({
|
cookies.update({
|
||||||
'top_p':top_p,
|
'top_p': top_p,
|
||||||
'api_key': cookies['api_key'],
|
'api_key': cookies['api_key'],
|
||||||
'llm_model': llm_model,
|
'llm_model': llm_model,
|
||||||
'temperature':temperature,
|
'temperature': temperature,
|
||||||
'user_name': user_name,
|
'user_name': user_name,
|
||||||
})
|
})
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
'api_key': cookies['api_key'],
|
'api_key': cookies['api_key'],
|
||||||
'llm_model': llm_model,
|
'llm_model': llm_model,
|
||||||
'top_p':top_p,
|
'top_p': top_p,
|
||||||
'max_length': max_length,
|
'max_length': max_length,
|
||||||
'temperature':temperature,
|
'temperature': temperature,
|
||||||
'client_ip': request.client.host,
|
'client_ip': request.client.host,
|
||||||
'most_recent_uploaded': cookies.get('most_recent_uploaded')
|
'most_recent_uploaded': cookies.get('most_recent_uploaded')
|
||||||
}
|
}
|
||||||
@ -103,8 +106,10 @@ def ArgsGeneralWrapper(f):
|
|||||||
final_cookies = chatbot_with_cookie.get_cookies()
|
final_cookies = chatbot_with_cookie.get_cookies()
|
||||||
# len(args) != 0 代表“提交”键对话通道,或者基础功能通道
|
# len(args) != 0 代表“提交”键对话通道,或者基础功能通道
|
||||||
if len(args) != 0 and 'files_to_promote' in final_cookies and len(final_cookies['files_to_promote']) > 0:
|
if len(args) != 0 and 'files_to_promote' in final_cookies and len(final_cookies['files_to_promote']) > 0:
|
||||||
chatbot_with_cookie.append(["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
|
chatbot_with_cookie.append(
|
||||||
|
["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
|
||||||
yield from update_ui(chatbot_with_cookie, final_cookies['history'], msg="检测到被滞留的缓存文档")
|
yield from update_ui(chatbot_with_cookie, final_cookies['history'], msg="检测到被滞留的缓存文档")
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
@ -129,6 +134,7 @@ def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面
|
|||||||
|
|
||||||
yield cookies, chatbot_gr, history, msg
|
yield cookies, chatbot_gr, history, msg
|
||||||
|
|
||||||
|
|
||||||
def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
|
def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
|
||||||
"""
|
"""
|
||||||
刷新用户界面
|
刷新用户界面
|
||||||
@ -147,6 +153,7 @@ def trimmed_format_exc():
|
|||||||
replace_path = "."
|
replace_path = "."
|
||||||
return str.replace(current_path, replace_path)
|
return str.replace(current_path, replace_path)
|
||||||
|
|
||||||
|
|
||||||
def CatchException(f):
|
def CatchException(f):
|
||||||
"""
|
"""
|
||||||
装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
|
装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
|
||||||
@ -164,9 +171,9 @@ def CatchException(f):
|
|||||||
if len(chatbot_with_cookie) == 0:
|
if len(chatbot_with_cookie) == 0:
|
||||||
chatbot_with_cookie.clear()
|
chatbot_with_cookie.clear()
|
||||||
chatbot_with_cookie.append(["插件调度异常", "异常原因"])
|
chatbot_with_cookie.append(["插件调度异常", "异常原因"])
|
||||||
chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0],
|
chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0], f"[Local Message] 插件调用出错: \n\n{tb_str} \n")
|
||||||
f"[Local Message] 插件调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
|
|
||||||
yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
|
yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
@ -209,6 +216,7 @@ def HotReload(f):
|
|||||||
========================================================================
|
========================================================================
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_reduce_token_percent(text):
|
def get_reduce_token_percent(text):
|
||||||
"""
|
"""
|
||||||
* 此函数未来将被弃用
|
* 此函数未来将被弃用
|
||||||
@ -220,9 +228,9 @@ def get_reduce_token_percent(text):
|
|||||||
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
|
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
|
||||||
max_limit = float(match[0]) - EXCEED_ALLO
|
max_limit = float(match[0]) - EXCEED_ALLO
|
||||||
current_tokens = float(match[1])
|
current_tokens = float(match[1])
|
||||||
ratio = max_limit/current_tokens
|
ratio = max_limit / current_tokens
|
||||||
assert ratio > 0 and ratio < 1
|
assert ratio > 0 and ratio < 1
|
||||||
return ratio, str(int(current_tokens-max_limit))
|
return ratio, str(int(current_tokens - max_limit))
|
||||||
except:
|
except:
|
||||||
return 0.5, '不详'
|
return 0.5, '不详'
|
||||||
|
|
||||||
@ -268,8 +276,6 @@ def regular_txt_to_markdown(text):
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def report_exception(chatbot, history, a, b):
|
def report_exception(chatbot, history, a, b):
|
||||||
"""
|
"""
|
||||||
向chatbot中添加错误信息
|
向chatbot中添加错误信息
|
||||||
@ -352,7 +358,8 @@ def markdown_convertion(txt):
|
|||||||
"""
|
"""
|
||||||
解决一个mdx_math的bug(单$包裹begin命令时多余<script>)
|
解决一个mdx_math的bug(单$包裹begin命令时多余<script>)
|
||||||
"""
|
"""
|
||||||
content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">', '<script type="math/tex; mode=display">')
|
content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">',
|
||||||
|
'<script type="math/tex; mode=display">')
|
||||||
content = content.replace('</script>\n</script>', '</script>')
|
content = content.replace('</script>\n</script>', '</script>')
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@ -372,7 +379,7 @@ def markdown_convertion(txt):
|
|||||||
}
|
}
|
||||||
matches = []
|
matches = []
|
||||||
for pattern, property in mathpatterns.items():
|
for pattern, property in mathpatterns.items():
|
||||||
flags = re.ASCII|re.DOTALL if property['allow_multi_lines'] else re.ASCII
|
flags = re.ASCII | re.DOTALL if property['allow_multi_lines'] else re.ASCII
|
||||||
matches.extend(re.findall(pattern, txt, flags))
|
matches.extend(re.findall(pattern, txt, flags))
|
||||||
if len(matches) == 0: return False
|
if len(matches) == 0: return False
|
||||||
contain_any_eq = False
|
contain_any_eq = False
|
||||||
@ -401,7 +408,7 @@ def markdown_convertion(txt):
|
|||||||
stripped_string = line.lstrip()
|
stripped_string = line.lstrip()
|
||||||
num_spaces = len(line) - len(stripped_string)
|
num_spaces = len(line) - len(stripped_string)
|
||||||
if (num_spaces % 4) == 3:
|
if (num_spaces % 4) == 3:
|
||||||
num_spaces_should_be = math.ceil(num_spaces/4) * 4
|
num_spaces_should_be = math.ceil(num_spaces / 4) * 4
|
||||||
lines[i] = ' ' * num_spaces_should_be + stripped_string
|
lines[i] = ' ' * num_spaces_should_be + stripped_string
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
@ -409,7 +416,8 @@ def markdown_convertion(txt):
|
|||||||
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
|
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
|
||||||
# convert everything to html format
|
# convert everything to html format
|
||||||
split = markdown.markdown(text='---')
|
split = markdown.markdown(text='---')
|
||||||
convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'fenced_code'], extension_configs=markdown_extension_configs)
|
convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'fenced_code'],
|
||||||
|
extension_configs=markdown_extension_configs)
|
||||||
convert_stage_1 = markdown_bug_hunt(convert_stage_1)
|
convert_stage_1 = markdown_bug_hunt(convert_stage_1)
|
||||||
# 1. convert to easy-to-copy tex (do not render math)
|
# 1. convert to easy-to-copy tex (do not render math)
|
||||||
convert_stage_2_1, n = re.subn(find_equation_pattern, replace_math_no_render, convert_stage_1, flags=re.DOTALL)
|
convert_stage_2_1, n = re.subn(find_equation_pattern, replace_math_no_render, convert_stage_1, flags=re.DOTALL)
|
||||||
@ -441,8 +449,7 @@ def close_up_code_segment_during_stream(gpt_reply):
|
|||||||
segments = gpt_reply.split('```')
|
segments = gpt_reply.split('```')
|
||||||
n_mark = len(segments) - 1
|
n_mark = len(segments) - 1
|
||||||
if n_mark % 2 == 1:
|
if n_mark % 2 == 1:
|
||||||
# print('输出代码片段中!')
|
return gpt_reply + '\n```' # 输出代码片段中!
|
||||||
return gpt_reply+'\n```'
|
|
||||||
else:
|
else:
|
||||||
return gpt_reply
|
return gpt_reply
|
||||||
|
|
||||||
@ -559,6 +566,7 @@ def file_already_in_downloadzone(file, user_path):
|
|||||||
except:
|
except:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
|
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
|
||||||
# 将文件复制一份到下载区
|
# 将文件复制一份到下载区
|
||||||
import shutil
|
import shutil
|
||||||
@ -581,8 +589,10 @@ def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
|
|||||||
if not os.path.exists(new_path): shutil.copyfile(file, new_path)
|
if not os.path.exists(new_path): shutil.copyfile(file, new_path)
|
||||||
# 将文件添加到chatbot cookie中
|
# 将文件添加到chatbot cookie中
|
||||||
if chatbot is not None:
|
if chatbot is not None:
|
||||||
if 'files_to_promote' in chatbot._cookies: current = chatbot._cookies['files_to_promote']
|
if 'files_to_promote' in chatbot._cookies:
|
||||||
else: current = []
|
current = chatbot._cookies['files_to_promote']
|
||||||
|
else:
|
||||||
|
current = []
|
||||||
if new_path not in current: # 避免把同一个文件添加多次
|
if new_path not in current: # 避免把同一个文件添加多次
|
||||||
chatbot._cookies.update({'files_to_promote': [new_path] + current})
|
chatbot._cookies.update({'files_to_promote': [new_path] + current})
|
||||||
return new_path
|
return new_path
|
||||||
@ -605,8 +615,10 @@ def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
|
|||||||
for subdirectory in glob.glob(f'{user_upload_dir}/*'):
|
for subdirectory in glob.glob(f'{user_upload_dir}/*'):
|
||||||
subdirectory_time = os.path.getmtime(subdirectory)
|
subdirectory_time = os.path.getmtime(subdirectory)
|
||||||
if subdirectory_time < one_hour_ago:
|
if subdirectory_time < one_hour_ago:
|
||||||
try: shutil.rmtree(subdirectory)
|
try:
|
||||||
except: pass
|
shutil.rmtree(subdirectory)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@ -690,12 +702,7 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
|
|||||||
file_origin_name = os.path.basename(file.orig_name)
|
file_origin_name = os.path.basename(file.orig_name)
|
||||||
this_file_path = pj(target_path_base, file_origin_name)
|
this_file_path = pj(target_path_base, file_origin_name)
|
||||||
shutil.move(file.name, this_file_path)
|
shutil.move(file.name, this_file_path)
|
||||||
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path+'.extract')
|
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path + '.extract')
|
||||||
|
|
||||||
if "浮动输入区" in checkboxes:
|
|
||||||
txt, txt2 = "", target_path_base
|
|
||||||
else:
|
|
||||||
txt, txt2 = target_path_base, ""
|
|
||||||
|
|
||||||
# 整理文件集合 输出消息
|
# 整理文件集合 输出消息
|
||||||
moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
|
moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
|
||||||
@ -703,7 +710,11 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
|
|||||||
chatbot.append(['我上传了文件,请查收',
|
chatbot.append(['我上传了文件,请查收',
|
||||||
f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
|
f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
|
||||||
f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
|
f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
|
||||||
f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数'+upload_msg])
|
f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数' + upload_msg])
|
||||||
|
|
||||||
|
txt, txt2 = target_path_base, ""
|
||||||
|
if "浮动输入区" in checkboxes:
|
||||||
|
txt, txt2 = txt2, txt
|
||||||
|
|
||||||
# 记录近期文件
|
# 记录近期文件
|
||||||
cookies.update({
|
cookies.update({
|
||||||
@ -732,22 +743,27 @@ def on_report_generated(cookies, files, chatbot):
|
|||||||
chatbot.append(['报告如何远程获取?', f'报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}'])
|
chatbot.append(['报告如何远程获取?', f'报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}'])
|
||||||
return cookies, report_files, chatbot
|
return cookies, report_files, chatbot
|
||||||
|
|
||||||
|
|
||||||
def load_chat_cookies():
|
def load_chat_cookies():
|
||||||
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
|
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
|
||||||
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
|
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
|
||||||
|
|
||||||
# deal with azure openai key
|
# deal with azure openai key
|
||||||
if is_any_api_key(AZURE_API_KEY):
|
if is_any_api_key(AZURE_API_KEY):
|
||||||
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY
|
if is_any_api_key(API_KEY):
|
||||||
else: API_KEY = AZURE_API_KEY
|
API_KEY = API_KEY + ',' + AZURE_API_KEY
|
||||||
|
else:
|
||||||
|
API_KEY = AZURE_API_KEY
|
||||||
if len(AZURE_CFG_ARRAY) > 0:
|
if len(AZURE_CFG_ARRAY) > 0:
|
||||||
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
|
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
|
||||||
if not azure_model_name.startswith('azure'):
|
if not azure_model_name.startswith('azure'):
|
||||||
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
|
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
|
||||||
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
|
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
|
||||||
if is_any_api_key(AZURE_API_KEY_):
|
if is_any_api_key(AZURE_API_KEY_):
|
||||||
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_
|
if is_any_api_key(API_KEY):
|
||||||
else: API_KEY = AZURE_API_KEY_
|
API_KEY = API_KEY + ',' + AZURE_API_KEY_
|
||||||
|
else:
|
||||||
|
API_KEY = AZURE_API_KEY_
|
||||||
|
|
||||||
customize_fn_overwrite_ = {}
|
customize_fn_overwrite_ = {}
|
||||||
for k in range(NUM_CUSTOM_BASIC_BTN):
|
for k in range(NUM_CUSTOM_BASIC_BTN):
|
||||||
@ -760,6 +776,7 @@ def load_chat_cookies():
|
|||||||
})
|
})
|
||||||
return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_}
|
return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_}
|
||||||
|
|
||||||
|
|
||||||
def is_openai_api_key(key):
|
def is_openai_api_key(key):
|
||||||
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
|
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
|
||||||
if len(CUSTOM_API_KEY_PATTERN) != 0:
|
if len(CUSTOM_API_KEY_PATTERN) != 0:
|
||||||
@ -768,14 +785,17 @@ def is_openai_api_key(key):
|
|||||||
API_MATCH_ORIGINAL = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
|
API_MATCH_ORIGINAL = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
|
||||||
return bool(API_MATCH_ORIGINAL)
|
return bool(API_MATCH_ORIGINAL)
|
||||||
|
|
||||||
|
|
||||||
def is_azure_api_key(key):
|
def is_azure_api_key(key):
|
||||||
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
|
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
|
||||||
return bool(API_MATCH_AZURE)
|
return bool(API_MATCH_AZURE)
|
||||||
|
|
||||||
|
|
||||||
def is_api2d_key(key):
|
def is_api2d_key(key):
|
||||||
API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
|
API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
|
||||||
return bool(API_MATCH_API2D)
|
return bool(API_MATCH_API2D)
|
||||||
|
|
||||||
|
|
||||||
def is_any_api_key(key):
|
def is_any_api_key(key):
|
||||||
if ',' in key:
|
if ',' in key:
|
||||||
keys = key.split(',')
|
keys = key.split(',')
|
||||||
@ -785,8 +805,9 @@ def is_any_api_key(key):
|
|||||||
else:
|
else:
|
||||||
return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key)
|
return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key)
|
||||||
|
|
||||||
|
|
||||||
def what_keys(keys):
|
def what_keys(keys):
|
||||||
avail_key_list = {'OpenAI Key':0, "Azure Key":0, "API2D Key":0}
|
avail_key_list = {'OpenAI Key': 0, "Azure Key": 0, "API2D Key": 0}
|
||||||
key_list = keys.split(',')
|
key_list = keys.split(',')
|
||||||
|
|
||||||
for k in key_list:
|
for k in key_list:
|
||||||
@ -803,6 +824,7 @@ def what_keys(keys):
|
|||||||
|
|
||||||
return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']} 个"
|
return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']} 个"
|
||||||
|
|
||||||
|
|
||||||
def select_api_key(keys, llm_model):
|
def select_api_key(keys, llm_model):
|
||||||
import random
|
import random
|
||||||
avail_key_list = []
|
avail_key_list = []
|
||||||
@ -826,6 +848,7 @@ def select_api_key(keys, llm_model):
|
|||||||
api_key = random.choice(avail_key_list) # 随机负载均衡
|
api_key = random.choice(avail_key_list) # 随机负载均衡
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
def read_env_variable(arg, default_value):
|
def read_env_variable(arg, default_value):
|
||||||
"""
|
"""
|
||||||
环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先),也可以直接是`CONFIG`
|
环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先),也可以直接是`CONFIG`
|
||||||
@ -856,7 +879,7 @@ def read_env_variable(arg, default_value):
|
|||||||
env_arg = env_arg.strip()
|
env_arg = env_arg.strip()
|
||||||
if env_arg == 'True': r = True
|
if env_arg == 'True': r = True
|
||||||
elif env_arg == 'False': r = False
|
elif env_arg == 'False': r = False
|
||||||
else: print('enter True or False, but have:', env_arg); r = default_value
|
else: print('Enter True or False, but have:', env_arg); r = default_value
|
||||||
elif isinstance(default_value, int):
|
elif isinstance(default_value, int):
|
||||||
r = int(env_arg)
|
r = int(env_arg)
|
||||||
elif isinstance(default_value, float):
|
elif isinstance(default_value, float):
|
||||||
@ -880,6 +903,7 @@ def read_env_variable(arg, default_value):
|
|||||||
print亮绿(f"[ENV_VAR] 成功读取环境变量{arg}")
|
print亮绿(f"[ENV_VAR] 成功读取环境变量{arg}")
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def read_single_conf_with_lru_cache(arg):
|
def read_single_conf_with_lru_cache(arg):
|
||||||
from colorful import print亮红, print亮绿, print亮蓝
|
from colorful import print亮红, print亮绿, print亮蓝
|
||||||
@ -899,7 +923,7 @@ def read_single_conf_with_lru_cache(arg):
|
|||||||
if arg == 'API_URL_REDIRECT':
|
if arg == 'API_URL_REDIRECT':
|
||||||
oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的,请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`
|
oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的,请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`
|
||||||
if oai_rd and not oai_rd.endswith('/completions'):
|
if oai_rd and not oai_rd.endswith('/completions'):
|
||||||
print亮红( "\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
|
print亮红("\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
if arg == 'API_KEY':
|
if arg == 'API_KEY':
|
||||||
print亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key,如API_KEY=\"openai-key1,openai-key2,azure-key3\"")
|
print亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key,如API_KEY=\"openai-key1,openai-key2,azure-key3\"")
|
||||||
@ -907,7 +931,7 @@ def read_single_conf_with_lru_cache(arg):
|
|||||||
if is_any_api_key(r):
|
if is_any_api_key(r):
|
||||||
print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
|
print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
|
||||||
else:
|
else:
|
||||||
print亮红( "[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行。")
|
print亮红("[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行。")
|
||||||
if arg == 'proxies':
|
if arg == 'proxies':
|
||||||
if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用
|
if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用
|
||||||
if r is None:
|
if r is None:
|
||||||
@ -953,17 +977,20 @@ class DummyWith():
|
|||||||
在上下文执行开始的情况下,__enter__()方法会在代码块被执行前被调用,
|
在上下文执行开始的情况下,__enter__()方法会在代码块被执行前被调用,
|
||||||
而在上下文执行结束时,__exit__()方法则会被调用。
|
而在上下文执行结束时,__exit__()方法则会被调用。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def run_gradio_in_subpath(demo, auth, port, custom_path):
|
def run_gradio_in_subpath(demo, auth, port, custom_path):
|
||||||
"""
|
"""
|
||||||
把gradio的运行地址更改到指定的二次路径上
|
把gradio的运行地址更改到指定的二次路径上
|
||||||
"""
|
"""
|
||||||
def is_path_legal(path: str)->bool:
|
|
||||||
|
def is_path_legal(path: str) -> bool:
|
||||||
'''
|
'''
|
||||||
check path for sub url
|
check path for sub url
|
||||||
path: path to check
|
path: path to check
|
||||||
@ -1039,7 +1066,7 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
|
|||||||
while n_token > max_token_limit:
|
while n_token > max_token_limit:
|
||||||
where = np.argmax(everything_token)
|
where = np.argmax(everything_token)
|
||||||
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
||||||
clipped_encoded = encoded[:len(encoded)-delta]
|
clipped_encoded = encoded[:len(encoded) - delta]
|
||||||
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
|
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
|
||||||
everything_token[where] = get_token_num(everything[where])
|
everything_token[where] = get_token_num(everything[where])
|
||||||
n_token = get_token_num('\n'.join(everything))
|
n_token = get_token_num('\n'.join(everything))
|
||||||
@ -1047,6 +1074,7 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
|
|||||||
history = everything[1:]
|
history = everything[1:]
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
========================================================================
|
||||||
第三部分
|
第三部分
|
||||||
@ -1058,6 +1086,7 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
|
|||||||
========================================================================
|
========================================================================
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def zip_folder(source_folder, dest_folder, zip_name):
|
def zip_folder(source_folder, dest_folder, zip_name):
|
||||||
import zipfile
|
import zipfile
|
||||||
import os
|
import os
|
||||||
@ -1089,15 +1118,18 @@ def zip_folder(source_folder, dest_folder, zip_name):
|
|||||||
|
|
||||||
print(f"Zip file created at {zip_file}")
|
print(f"Zip file created at {zip_file}")
|
||||||
|
|
||||||
|
|
||||||
def zip_result(folder):
|
def zip_result(folder):
|
||||||
t = gen_time_str()
|
t = gen_time_str()
|
||||||
zip_folder(folder, get_log_folder(), f'{t}-result.zip')
|
zip_folder(folder, get_log_folder(), f'{t}-result.zip')
|
||||||
return pj(get_log_folder(), f'{t}-result.zip')
|
return pj(get_log_folder(), f'{t}-result.zip')
|
||||||
|
|
||||||
|
|
||||||
def gen_time_str():
|
def gen_time_str():
|
||||||
import time
|
import time
|
||||||
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||||
|
|
||||||
|
|
||||||
def get_log_folder(user=default_user_name, plugin_name='shared'):
|
def get_log_folder(user=default_user_name, plugin_name='shared'):
|
||||||
if user is None: user = default_user_name
|
if user is None: user = default_user_name
|
||||||
PATH_LOGGING = get_conf('PATH_LOGGING')
|
PATH_LOGGING = get_conf('PATH_LOGGING')
|
||||||
@ -1108,29 +1140,36 @@ def get_log_folder(user=default_user_name, plugin_name='shared'):
|
|||||||
if not os.path.exists(_dir): os.makedirs(_dir)
|
if not os.path.exists(_dir): os.makedirs(_dir)
|
||||||
return _dir
|
return _dir
|
||||||
|
|
||||||
|
|
||||||
def get_upload_folder(user=default_user_name, tag=None):
|
def get_upload_folder(user=default_user_name, tag=None):
|
||||||
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
|
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
|
||||||
if user is None: user = default_user_name
|
if user is None: user = default_user_name
|
||||||
if tag is None or len(tag)==0:
|
if tag is None or len(tag) == 0:
|
||||||
target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
|
target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
|
||||||
else:
|
else:
|
||||||
target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag)
|
target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag)
|
||||||
return target_path_base
|
return target_path_base
|
||||||
|
|
||||||
|
|
||||||
def is_the_upload_folder(string):
|
def is_the_upload_folder(string):
|
||||||
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
|
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
|
||||||
pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
|
pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
|
||||||
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
|
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
|
||||||
if re.match(pattern, string): return True
|
if re.match(pattern, string):
|
||||||
else: return False
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_user(chatbotwithcookies):
|
def get_user(chatbotwithcookies):
|
||||||
return chatbotwithcookies._cookies.get('user_name', default_user_name)
|
return chatbotwithcookies._cookies.get('user_name', default_user_name)
|
||||||
|
|
||||||
|
|
||||||
class ProxyNetworkActivate():
|
class ProxyNetworkActivate():
|
||||||
"""
|
"""
|
||||||
这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
|
这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, task=None) -> None:
|
def __init__(self, task=None) -> None:
|
||||||
self.task = task
|
self.task = task
|
||||||
if not task:
|
if not task:
|
||||||
@ -1158,12 +1197,14 @@ class ProxyNetworkActivate():
|
|||||||
if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY')
|
if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def objdump(obj, file='objdump.tmp'):
|
def objdump(obj, file='objdump.tmp'):
|
||||||
import pickle
|
import pickle
|
||||||
with open(file, 'wb+') as f:
|
with open(file, 'wb+') as f:
|
||||||
pickle.dump(obj, f)
|
pickle.dump(obj, f)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def objload(file='objdump.tmp'):
|
def objload(file='objdump.tmp'):
|
||||||
import pickle, os
|
import pickle, os
|
||||||
if not os.path.exists(file):
|
if not os.path.exists(file):
|
||||||
@ -1171,6 +1212,7 @@ def objload(file='objdump.tmp'):
|
|||||||
with open(file, 'rb') as f:
|
with open(file, 'rb') as f:
|
||||||
return pickle.load(f)
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
def Singleton(cls):
|
def Singleton(cls):
|
||||||
"""
|
"""
|
||||||
一个单实例装饰器
|
一个单实例装饰器
|
||||||
@ -1184,6 +1226,7 @@ def Singleton(cls):
|
|||||||
|
|
||||||
return _singleton
|
return _singleton
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
========================================================================
|
||||||
第四部分
|
第四部分
|
||||||
@ -1197,6 +1240,7 @@ def Singleton(cls):
|
|||||||
========================================================================
|
========================================================================
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def set_conf(key, value):
|
def set_conf(key, value):
|
||||||
from toolbox import read_single_conf_with_lru_cache, get_conf
|
from toolbox import read_single_conf_with_lru_cache, get_conf
|
||||||
read_single_conf_with_lru_cache.cache_clear()
|
read_single_conf_with_lru_cache.cache_clear()
|
||||||
@ -1205,10 +1249,12 @@ def set_conf(key, value):
|
|||||||
altered = get_conf(key)
|
altered = get_conf(key)
|
||||||
return altered
|
return altered
|
||||||
|
|
||||||
|
|
||||||
def set_multi_conf(dic):
|
def set_multi_conf(dic):
|
||||||
for k, v in dic.items(): set_conf(k, v)
|
for k, v in dic.items(): set_conf(k, v)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_handle(plugin_name):
|
def get_plugin_handle(plugin_name):
|
||||||
"""
|
"""
|
||||||
e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
|
e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
|
||||||
@ -1220,12 +1266,14 @@ def get_plugin_handle(plugin_name):
|
|||||||
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
|
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
|
||||||
return f_hot_reload
|
return f_hot_reload
|
||||||
|
|
||||||
|
|
||||||
def get_chat_handle():
|
def get_chat_handle():
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
from request_llms.bridge_all import predict_no_ui_long_connection
|
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||||
return predict_no_ui_long_connection
|
return predict_no_ui_long_connection
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_default_kwargs():
|
def get_plugin_default_kwargs():
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
@ -1234,9 +1282,9 @@ def get_plugin_default_kwargs():
|
|||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
'api_key': cookies['api_key'],
|
'api_key': cookies['api_key'],
|
||||||
'llm_model': cookies['llm_model'],
|
'llm_model': cookies['llm_model'],
|
||||||
'top_p':1.0,
|
'top_p': 1.0,
|
||||||
'max_length': None,
|
'max_length': None,
|
||||||
'temperature':1.0,
|
'temperature': 1.0,
|
||||||
}
|
}
|
||||||
chatbot = ChatBotWithCookies(llm_kwargs)
|
chatbot = ChatBotWithCookies(llm_kwargs)
|
||||||
|
|
||||||
@ -1252,6 +1300,7 @@ def get_plugin_default_kwargs():
|
|||||||
}
|
}
|
||||||
return DEFAULT_FN_GROUPS_kwargs
|
return DEFAULT_FN_GROUPS_kwargs
|
||||||
|
|
||||||
|
|
||||||
def get_chat_default_kwargs():
|
def get_chat_default_kwargs():
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
@ -1259,9 +1308,9 @@ def get_chat_default_kwargs():
|
|||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
'api_key': cookies['api_key'],
|
'api_key': cookies['api_key'],
|
||||||
'llm_model': cookies['llm_model'],
|
'llm_model': cookies['llm_model'],
|
||||||
'top_p':1.0,
|
'top_p': 1.0,
|
||||||
'max_length': None,
|
'max_length': None,
|
||||||
'temperature':1.0,
|
'temperature': 1.0,
|
||||||
}
|
}
|
||||||
default_chat_kwargs = {
|
default_chat_kwargs = {
|
||||||
"inputs": "Hello there, are you ready?",
|
"inputs": "Hello there, are you ready?",
|
||||||
@ -1307,6 +1356,7 @@ def get_max_token(llm_kwargs):
|
|||||||
from request_llms.bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
return model_info[llm_kwargs['llm_model']]['max_token']
|
return model_info[llm_kwargs['llm_model']]['max_token']
|
||||||
|
|
||||||
|
|
||||||
def check_packages(packages=[]):
|
def check_packages(packages=[]):
|
||||||
import importlib.util
|
import importlib.util
|
||||||
for p in packages:
|
for p in packages:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user