适配 google gemini 优化为从用户input中提取文件 (#1419)

适配 google gemini 优化为从用户input中提取文件
This commit is contained in:
XIao 2023-12-31 17:13:50 +08:00 committed by qingxu fu
parent a96f842b3a
commit a7c960dcb0
5 changed files with 472 additions and 95 deletions

View File

@ -89,12 +89,14 @@ DEFAULT_FN_GROUPS = ['对话', '编程', '学术', '智能体']
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓ LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview", AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
"api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
"gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4", "gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4",
"chatglm3", "moss", "claude-2"] "gemini-pro", "chatglm3", "moss", "claude-2"]
# P.S. 其他可用的模型还包括 ["zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random" # P.S. 其他可用的模型还包括 [
# "qwen-turbo", "qwen-plus", "qwen-max"
# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613",
# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama" # "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
# “qwen-turbo", "qwen-plus", "qwen-max"] # ]
# 定义界面上“询问多个GPT模型”插件应该使用哪些模型请从AVAIL_LLM_MODELS中选择并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4" # 定义界面上“询问多个GPT模型”插件应该使用哪些模型请从AVAIL_LLM_MODELS中选择并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
@ -204,6 +206,10 @@ ANTHROPIC_API_KEY = ""
CUSTOM_API_KEY_PATTERN = "" CUSTOM_API_KEY_PATTERN = ""
# Google Gemini API-Key
GEMINI_API_KEY = ''
# HUGGINGFACE的TOKEN下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens # HUGGINGFACE的TOKEN下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
HUGGINGFACE_ACCESS_TOKEN = "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" HUGGINGFACE_ACCESS_TOKEN = "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"
@ -292,6 +298,9 @@ NUM_CUSTOM_BASIC_BTN = 4
"qwen-turbo" 等通义千问大模型 "qwen-turbo" 等通义千问大模型
DASHSCOPE_API_KEY DASHSCOPE_API_KEY
"Gemini"
GEMINI_API_KEY
"newbing" Newbing接口不再稳定不推荐使用 "newbing" Newbing接口不再稳定不推荐使用
NEWBING_STYLE NEWBING_STYLE
NEWBING_COOKIES NEWBING_COOKIES

View File

@ -28,6 +28,9 @@ from .bridge_chatglm3 import predict as chatglm3_ui
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
from .bridge_qianfan import predict as qianfan_ui from .bridge_qianfan import predict as qianfan_ui
from .bridge_google_gemini import predict as genai_ui
from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044'] colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
class LazyloadTiktoken(object): class LazyloadTiktoken(object):
@ -246,6 +249,22 @@ model_info = {
"tokenizer": tokenizer_gpt35, "tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35, "token_cnt": get_token_num_gpt35,
}, },
"gemini-pro": {
"fn_with_ui": genai_ui,
"fn_without_ui": genai_noui,
"endpoint": None,
"max_token": 1024 * 32,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"gemini-pro-vision": {
"fn_with_ui": genai_ui,
"fn_without_ui": genai_noui,
"endpoint": None,
"max_token": 1024 * 32,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
} }
# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=- # -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-

View File

@ -0,0 +1,101 @@
# encoding: utf-8
# @Time : 2023/12/21
# @Author : Spike
# @Descr :
import json
import re
import time
from request_llms.com_google import GoogleChatInit
from toolbox import get_conf, update_ui, update_ui_lastest_msg
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
console_slience=False):
# 检查API_KEY
if get_conf("GEMINI_API_KEY") == "":
raise ValueError(f"请配置 GEMINI_API_KEY。")
genai = GoogleChatInit()
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
gpt_replying_buffer = ''
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
for response in stream_response:
results = response.decode()
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
if match:
try:
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
except:
raise ValueError(f"解析GEMINI消息出错。")
buffer = paraphrase['text']
gpt_replying_buffer += buffer
if len(observe_window) >= 1:
observe_window[0] = gpt_replying_buffer
if len(observe_window) >= 2:
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
if error_match:
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
return gpt_replying_buffer
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
# 检查API_KEY
if get_conf("GEMINI_API_KEY") == "":
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
return
chatbot.append((inputs, ""))
yield from update_ui(chatbot=chatbot, history=history)
genai = GoogleChatInit()
retry = 0
while True:
try:
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
break
except Exception as e:
retry += 1
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
if retry > MAX_RETRY: raise TimeoutError
gpt_replying_buffer = ""
gpt_security_policy = ""
history.extend([inputs, ''])
for response in stream_response:
results = response.decode("utf-8") # 被这个解码给耍了。。
gpt_security_policy += results
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
if match:
try:
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
except:
raise ValueError(f"解析GEMINI消息出错。")
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
chatbot[-1] = (inputs, gpt_replying_buffer)
history[-1] = gpt_replying_buffer
yield from update_ui(chatbot=chatbot, history=history)
if error_match:
history = history[-2] # 错误的不纳入对话
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误请查看message\n\n```\n{error_match.group(1)}\n```")
yield from update_ui(chatbot=chatbot, history=history)
raise RuntimeError('对话错误')
if not gpt_replying_buffer:
history = history[-2] # 错误的不纳入对话
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略没有回答\n\n```\n{gpt_security_policy}\n```")
yield from update_ui(chatbot=chatbot, history=history)
if __name__ == '__main__':
import sys
llm_kwargs = {'llm_model': 'gemini-pro'}
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
for i in result:
print(i)

198
request_llms/com_google.py Normal file
View File

@ -0,0 +1,198 @@
# encoding: utf-8
# @Time : 2023/12/25
# @Author : Spike
# @Descr :
import json
import os
import re
import requests
from typing import List, Dict, Tuple
from toolbox import get_conf, encode_image
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
"""
========================================================================
第五部分 一些文件处理方法
files_filter_handler 根据type过滤文件
input_encode_handler 提取input中的文件并解析
file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
link_mtime_to_md 文件增加本地时间参数避免下载到缓存文件
html_view_blank 超链接
html_local_file 本地文件取相对路径
to_markdown_tabs 文件list 转换为 md tab
"""
def files_filter_handler(file_list):
new_list = []
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
for file in file_list:
file = str(file).replace('file=', '')
if os.path.exists(file):
if str(os.path.basename(file)).split('.')[-1] in filter_:
new_list.append(file)
return new_list
def input_encode_handler(inputs):
md_encode = []
pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
matches_path = re.findall(pattern_md_file, inputs)
for md_path in matches_path:
pattern_file = r"\((file=.*)\)"
matches_path = re.findall(pattern_file, md_path)
encode_file = files_filter_handler(file_list=matches_path)
if encode_file:
md_encode.extend([{
"data": encode_image(i),
"type": os.path.splitext(i)[1].replace('.', '')
} for i in encode_file])
inputs = inputs.replace(md_path, '')
return inputs, md_encode
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
new_list = []
if not filter_:
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
for file in file_list:
if str(os.path.basename(file)).split('.')[-1] in filter_:
new_list.append(html_local_img(file, md=md_type))
elif os.path.exists(file):
new_list.append(link_mtime_to_md(file))
else:
new_list.append(file)
return new_list
def link_mtime_to_md(file):
link_local = html_local_file(file)
link_name = os.path.basename(file)
a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})"
return a
def html_local_file(file):
base_path = os.path.dirname(__file__) # 项目目录
if os.path.exists(str(file)):
file = f'file={file.replace(base_path, ".")}'
return file
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
style = ''
if max_width is not None:
style += f"max-width: {max_width};"
if max_height is not None:
style += f"max-height: {max_height};"
__file = html_local_file(__file)
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
if md:
a = f'![{__file}]({__file})'
return a
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
"""
Args:
head: 表头[]
tabs: 表值[[列1], [列2], [列3], [列4]]
alignment: :--- 左对齐 :---: 居中对齐 ---: 右对齐
column: True to keep data in columns, False to keep data in rows (default).
Returns:
A string representation of the markdown table.
"""
if column:
transposed_tabs = list(map(list, zip(*tabs)))
else:
transposed_tabs = tabs
# Find the maximum length among the columns
max_len = max(len(column) for column in transposed_tabs)
tab_format = "| %s "
tabs_list = "".join([tab_format % i for i in head]) + '|\n'
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
for i in range(max_len):
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
row_data = file_manifest_filter_html(row_data, filter_=None)
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
return tabs_list
class GoogleChatInit:
def __init__(self):
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
def __conversation_user(self, user_input):
what_i_have_asked = {"role": "user", "parts": []}
if 'vision' not in self.url_gemini:
input_ = user_input
encode_img = []
else:
input_, encode_img = input_encode_handler(user_input)
what_i_have_asked['parts'].append({'text': input_})
if encode_img:
for data in encode_img:
what_i_have_asked['parts'].append(
{'inline_data': {
"mime_type": f"image/{data['type']}",
"data": data['data']
}})
return what_i_have_asked
def __conversation_history(self, history):
messages = []
conversation_cnt = len(history) // 2
if conversation_cnt:
for index in range(0, 2 * conversation_cnt, 2):
what_i_have_asked = self.__conversation_user(history[index])
what_gpt_answer = {
"role": "model",
"parts": [{"text": history[index + 1]}]
}
messages.append(what_i_have_asked)
messages.append(what_gpt_answer)
return messages
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
return response.iter_lines()
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
messages = [
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
# {"role": "user", "parts": [{"text": ""}]},
# {"role": "model", "parts": [{"text": ""}]}
]
self.url_gemini = self.url_gemini.replace(
'%m', llm_kwargs['llm_model']).replace(
'%k', get_conf('GEMINI_API_KEY')
)
header = {'Content-Type': 'application/json'}
if 'vision' not in self.url_gemini: # 不是vision 才处理history
messages.extend(self.__conversation_history(history)) # 处理 history
messages.append(self.__conversation_user(inputs)) # 处理用户对话
payload = {
"contents": messages,
"generationConfig": {
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
"temperature": llm_kwargs.get('temperature', 1),
# "maxOutputTokens": 800,
"topP": llm_kwargs.get('top_p', 0.8),
"topK": 10
}
}
return header, payload
if __name__ == '__main__':
google = GoogleChatInit()
# print(gootle.generate_message_payload('你好呀', {},
# ['123123', '3123123'], ''))
# gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)')

View File

@ -11,8 +11,10 @@ import glob
import math import math
from latex2mathml.converter import convert as tex2mathml from latex2mathml.converter import convert as tex2mathml
from functools import wraps, lru_cache from functools import wraps, lru_cache
pj = os.path.join pj = os.path.join
default_user_name = 'default_user' default_user_name = 'default_user'
""" """
======================================================================== ========================================================================
第一部分 第一部分
@ -26,6 +28,7 @@ default_user_name = 'default_user'
======================================================================== ========================================================================
""" """
class ChatBotWithCookies(list): class ChatBotWithCookies(list):
def __init__(self, cookie): def __init__(self, cookie):
""" """
@ -67,18 +70,18 @@ def ArgsGeneralWrapper(f):
else: else:
user_name = default_user_name user_name = default_user_name
cookies.update({ cookies.update({
'top_p':top_p, 'top_p': top_p,
'api_key': cookies['api_key'], 'api_key': cookies['api_key'],
'llm_model': llm_model, 'llm_model': llm_model,
'temperature':temperature, 'temperature': temperature,
'user_name': user_name, 'user_name': user_name,
}) })
llm_kwargs = { llm_kwargs = {
'api_key': cookies['api_key'], 'api_key': cookies['api_key'],
'llm_model': llm_model, 'llm_model': llm_model,
'top_p':top_p, 'top_p': top_p,
'max_length': max_length, 'max_length': max_length,
'temperature':temperature, 'temperature': temperature,
'client_ip': request.client.host, 'client_ip': request.client.host,
'most_recent_uploaded': cookies.get('most_recent_uploaded') 'most_recent_uploaded': cookies.get('most_recent_uploaded')
} }
@ -103,8 +106,10 @@ def ArgsGeneralWrapper(f):
final_cookies = chatbot_with_cookie.get_cookies() final_cookies = chatbot_with_cookie.get_cookies()
# len(args) != 0 代表“提交”键对话通道,或者基础功能通道 # len(args) != 0 代表“提交”键对话通道,或者基础功能通道
if len(args) != 0 and 'files_to_promote' in final_cookies and len(final_cookies['files_to_promote']) > 0: if len(args) != 0 and 'files_to_promote' in final_cookies and len(final_cookies['files_to_promote']) > 0:
chatbot_with_cookie.append(["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"]) chatbot_with_cookie.append(
["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
yield from update_ui(chatbot_with_cookie, final_cookies['history'], msg="检测到被滞留的缓存文档") yield from update_ui(chatbot_with_cookie, final_cookies['history'], msg="检测到被滞留的缓存文档")
return decorated return decorated
@ -129,6 +134,7 @@ def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面
yield cookies, chatbot_gr, history, msg yield cookies, chatbot_gr, history, msg
def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面 def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
""" """
刷新用户界面 刷新用户界面
@ -147,6 +153,7 @@ def trimmed_format_exc():
replace_path = "." replace_path = "."
return str.replace(current_path, replace_path) return str.replace(current_path, replace_path)
def CatchException(f): def CatchException(f):
""" """
装饰器函数捕捉函数f中的异常并封装到一个生成器中返回并显示到聊天当中 装饰器函数捕捉函数f中的异常并封装到一个生成器中返回并显示到聊天当中
@ -164,9 +171,9 @@ def CatchException(f):
if len(chatbot_with_cookie) == 0: if len(chatbot_with_cookie) == 0:
chatbot_with_cookie.clear() chatbot_with_cookie.clear()
chatbot_with_cookie.append(["插件调度异常", "异常原因"]) chatbot_with_cookie.append(["插件调度异常", "异常原因"])
chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0], chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0], f"[Local Message] 插件调用出错: \n\n{tb_str} \n")
f"[Local Message] 插件调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}") yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
return decorated return decorated
@ -209,6 +216,7 @@ def HotReload(f):
======================================================================== ========================================================================
""" """
def get_reduce_token_percent(text): def get_reduce_token_percent(text):
""" """
* 此函数未来将被弃用 * 此函数未来将被弃用
@ -220,9 +228,9 @@ def get_reduce_token_percent(text):
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题 EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
max_limit = float(match[0]) - EXCEED_ALLO max_limit = float(match[0]) - EXCEED_ALLO
current_tokens = float(match[1]) current_tokens = float(match[1])
ratio = max_limit/current_tokens ratio = max_limit / current_tokens
assert ratio > 0 and ratio < 1 assert ratio > 0 and ratio < 1
return ratio, str(int(current_tokens-max_limit)) return ratio, str(int(current_tokens - max_limit))
except: except:
return 0.5, '不详' return 0.5, '不详'
@ -268,8 +276,6 @@ def regular_txt_to_markdown(text):
return text return text
def report_exception(chatbot, history, a, b): def report_exception(chatbot, history, a, b):
""" """
向chatbot中添加错误信息 向chatbot中添加错误信息
@ -352,7 +358,8 @@ def markdown_convertion(txt):
""" """
解决一个mdx_math的bug$包裹begin命令时多余<script> 解决一个mdx_math的bug$包裹begin命令时多余<script>
""" """
content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">', '<script type="math/tex; mode=display">') content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">',
'<script type="math/tex; mode=display">')
content = content.replace('</script>\n</script>', '</script>') content = content.replace('</script>\n</script>', '</script>')
return content return content
@ -363,16 +370,16 @@ def markdown_convertion(txt):
if '```' in txt and '```reference' not in txt: return False if '```' in txt and '```reference' not in txt: return False
if '$' not in txt and '\\[' not in txt: return False if '$' not in txt and '\\[' not in txt: return False
mathpatterns = { mathpatterns = {
r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False}, #  $...$ r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False}, #  $...$
r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True}, # $$...$$ r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True}, # $$...$$
r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False}, # \[...\] r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False}, # \[...\]
# r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\) # r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\)
# r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end # r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end
# r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$ # r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$
} }
matches = [] matches = []
for pattern, property in mathpatterns.items(): for pattern, property in mathpatterns.items():
flags = re.ASCII|re.DOTALL if property['allow_multi_lines'] else re.ASCII flags = re.ASCII | re.DOTALL if property['allow_multi_lines'] else re.ASCII
matches.extend(re.findall(pattern, txt, flags)) matches.extend(re.findall(pattern, txt, flags))
if len(matches) == 0: return False if len(matches) == 0: return False
contain_any_eq = False contain_any_eq = False
@ -389,7 +396,7 @@ def markdown_convertion(txt):
def fix_markdown_indent(txt): def fix_markdown_indent(txt):
# fix markdown indent # fix markdown indent
if (' - ' not in txt) or ('. ' not in txt): if (' - ' not in txt) or ('. ' not in txt):
return txt # do not need to fix, fast escape return txt # do not need to fix, fast escape
# walk through the lines and fix non-standard indentation # walk through the lines and fix non-standard indentation
lines = txt.split("\n") lines = txt.split("\n")
pattern = re.compile(r'^\s+-') pattern = re.compile(r'^\s+-')
@ -401,7 +408,7 @@ def markdown_convertion(txt):
stripped_string = line.lstrip() stripped_string = line.lstrip()
num_spaces = len(line) - len(stripped_string) num_spaces = len(line) - len(stripped_string)
if (num_spaces % 4) == 3: if (num_spaces % 4) == 3:
num_spaces_should_be = math.ceil(num_spaces/4) * 4 num_spaces_should_be = math.ceil(num_spaces / 4) * 4
lines[i] = ' ' * num_spaces_should_be + stripped_string lines[i] = ' ' * num_spaces_should_be + stripped_string
return '\n'.join(lines) return '\n'.join(lines)
@ -409,7 +416,8 @@ def markdown_convertion(txt):
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识 if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
# convert everything to html format # convert everything to html format
split = markdown.markdown(text='---') split = markdown.markdown(text='---')
convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'fenced_code'], extension_configs=markdown_extension_configs) convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'fenced_code'],
extension_configs=markdown_extension_configs)
convert_stage_1 = markdown_bug_hunt(convert_stage_1) convert_stage_1 = markdown_bug_hunt(convert_stage_1)
# 1. convert to easy-to-copy tex (do not render math) # 1. convert to easy-to-copy tex (do not render math)
convert_stage_2_1, n = re.subn(find_equation_pattern, replace_math_no_render, convert_stage_1, flags=re.DOTALL) convert_stage_2_1, n = re.subn(find_equation_pattern, replace_math_no_render, convert_stage_1, flags=re.DOTALL)
@ -441,8 +449,7 @@ def close_up_code_segment_during_stream(gpt_reply):
segments = gpt_reply.split('```') segments = gpt_reply.split('```')
n_mark = len(segments) - 1 n_mark = len(segments) - 1
if n_mark % 2 == 1: if n_mark % 2 == 1:
# print('输出代码片段中!') return gpt_reply + '\n```' # 输出代码片段中!
return gpt_reply+'\n```'
else: else:
return gpt_reply return gpt_reply
@ -559,6 +566,7 @@ def file_already_in_downloadzone(file, user_path):
except: except:
return False return False
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None): def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
# 将文件复制一份到下载区 # 将文件复制一份到下载区
import shutil import shutil
@ -581,8 +589,10 @@ def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
if not os.path.exists(new_path): shutil.copyfile(file, new_path) if not os.path.exists(new_path): shutil.copyfile(file, new_path)
# 将文件添加到chatbot cookie中 # 将文件添加到chatbot cookie中
if chatbot is not None: if chatbot is not None:
if 'files_to_promote' in chatbot._cookies: current = chatbot._cookies['files_to_promote'] if 'files_to_promote' in chatbot._cookies:
else: current = [] current = chatbot._cookies['files_to_promote']
else:
current = []
if new_path not in current: # 避免把同一个文件添加多次 if new_path not in current: # 避免把同一个文件添加多次
chatbot._cookies.update({'files_to_promote': [new_path] + current}) chatbot._cookies.update({'files_to_promote': [new_path] + current})
return new_path return new_path
@ -605,8 +615,10 @@ def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
for subdirectory in glob.glob(f'{user_upload_dir}/*'): for subdirectory in glob.glob(f'{user_upload_dir}/*'):
subdirectory_time = os.path.getmtime(subdirectory) subdirectory_time = os.path.getmtime(subdirectory)
if subdirectory_time < one_hour_ago: if subdirectory_time < one_hour_ago:
try: shutil.rmtree(subdirectory) try:
except: pass shutil.rmtree(subdirectory)
except:
pass
return return
@ -681,7 +693,7 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
os.makedirs(target_path_base, exist_ok=True) os.makedirs(target_path_base, exist_ok=True)
# 移除过时的旧文件从而节省空间&保护隐私 # 移除过时的旧文件从而节省空间&保护隐私
outdate_time_seconds = 3600 # 一小时 outdate_time_seconds = 3600 # 一小时
del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name)) del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name))
# 逐个文件转移到目标路径 # 逐个文件转移到目标路径
@ -690,12 +702,7 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
file_origin_name = os.path.basename(file.orig_name) file_origin_name = os.path.basename(file.orig_name)
this_file_path = pj(target_path_base, file_origin_name) this_file_path = pj(target_path_base, file_origin_name)
shutil.move(file.name, this_file_path) shutil.move(file.name, this_file_path)
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path+'.extract') upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path + '.extract')
if "浮动输入区" in checkboxes:
txt, txt2 = "", target_path_base
else:
txt, txt2 = target_path_base, ""
# 整理文件集合 输出消息 # 整理文件集合 输出消息
moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)] moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
@ -703,7 +710,11 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
chatbot.append(['我上传了文件,请查收', chatbot.append(['我上传了文件,请查收',
f'[Local Message] 收到以下文件: \n\n{moved_files_str}' + f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
f'\n\n调用路径参数已自动修正到: \n\n{txt}' + f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数'+upload_msg]) f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数' + upload_msg])
txt, txt2 = target_path_base, ""
if "浮动输入区" in checkboxes:
txt, txt2 = txt2, txt
# 记录近期文件 # 记录近期文件
cookies.update({ cookies.update({
@ -732,34 +743,40 @@ def on_report_generated(cookies, files, chatbot):
chatbot.append(['报告如何远程获取?', f'报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}']) chatbot.append(['报告如何远程获取?', f'报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}'])
return cookies, report_files, chatbot return cookies, report_files, chatbot
def load_chat_cookies(): def load_chat_cookies():
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY') API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN') AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
# deal with azure openai key # deal with azure openai key
if is_any_api_key(AZURE_API_KEY): if is_any_api_key(AZURE_API_KEY):
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY if is_any_api_key(API_KEY):
else: API_KEY = AZURE_API_KEY API_KEY = API_KEY + ',' + AZURE_API_KEY
else:
API_KEY = AZURE_API_KEY
if len(AZURE_CFG_ARRAY) > 0: if len(AZURE_CFG_ARRAY) > 0:
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items(): for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
if not azure_model_name.startswith('azure'): if not azure_model_name.startswith('azure'):
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头") raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"] AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
if is_any_api_key(AZURE_API_KEY_): if is_any_api_key(AZURE_API_KEY_):
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_ if is_any_api_key(API_KEY):
else: API_KEY = AZURE_API_KEY_ API_KEY = API_KEY + ',' + AZURE_API_KEY_
else:
API_KEY = AZURE_API_KEY_
customize_fn_overwrite_ = {} customize_fn_overwrite_ = {}
for k in range(NUM_CUSTOM_BASIC_BTN): for k in range(NUM_CUSTOM_BASIC_BTN):
customize_fn_overwrite_.update({ customize_fn_overwrite_.update({
"自定义按钮" + str(k+1):{ "自定义按钮" + str(k+1):{
"Title": r"", "Title": r"",
"Prefix": r"请在自定义菜单中定义提示词前缀.", "Prefix": r"请在自定义菜单中定义提示词前缀.",
"Suffix": r"请在自定义菜单中定义提示词后缀", "Suffix": r"请在自定义菜单中定义提示词后缀",
} }
}) })
return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_} return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_}
def is_openai_api_key(key): def is_openai_api_key(key):
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN') CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
if len(CUSTOM_API_KEY_PATTERN) != 0: if len(CUSTOM_API_KEY_PATTERN) != 0:
@ -768,14 +785,17 @@ def is_openai_api_key(key):
API_MATCH_ORIGINAL = re.match(r"sk-[a-zA-Z0-9]{48}$", key) API_MATCH_ORIGINAL = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
return bool(API_MATCH_ORIGINAL) return bool(API_MATCH_ORIGINAL)
def is_azure_api_key(key): def is_azure_api_key(key):
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key) API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_AZURE) return bool(API_MATCH_AZURE)
def is_api2d_key(key): def is_api2d_key(key):
API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key) API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_API2D) return bool(API_MATCH_API2D)
def is_any_api_key(key): def is_any_api_key(key):
if ',' in key: if ',' in key:
keys = key.split(',') keys = key.split(',')
@ -785,8 +805,9 @@ def is_any_api_key(key):
else: else:
return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key) return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key)
def what_keys(keys): def what_keys(keys):
avail_key_list = {'OpenAI Key':0, "Azure Key":0, "API2D Key":0} avail_key_list = {'OpenAI Key': 0, "Azure Key": 0, "API2D Key": 0}
key_list = keys.split(',') key_list = keys.split(',')
for k in key_list: for k in key_list:
@ -803,6 +824,7 @@ def what_keys(keys):
return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']}" return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']}"
def select_api_key(keys, llm_model): def select_api_key(keys, llm_model):
import random import random
avail_key_list = [] avail_key_list = []
@ -826,6 +848,7 @@ def select_api_key(keys, llm_model):
api_key = random.choice(avail_key_list) # 随机负载均衡 api_key = random.choice(avail_key_list) # 随机负载均衡
return api_key return api_key
def read_env_variable(arg, default_value): def read_env_variable(arg, default_value):
""" """
环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先)也可以直接是`CONFIG` 环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先)也可以直接是`CONFIG`
@ -856,7 +879,7 @@ def read_env_variable(arg, default_value):
env_arg = env_arg.strip() env_arg = env_arg.strip()
if env_arg == 'True': r = True if env_arg == 'True': r = True
elif env_arg == 'False': r = False elif env_arg == 'False': r = False
else: print('enter True or False, but have:', env_arg); r = default_value else: print('Enter True or False, but have:', env_arg); r = default_value
elif isinstance(default_value, int): elif isinstance(default_value, int):
r = int(env_arg) r = int(env_arg)
elif isinstance(default_value, float): elif isinstance(default_value, float):
@ -880,12 +903,13 @@ def read_env_variable(arg, default_value):
print亮绿(f"[ENV_VAR] 成功读取环境变量{arg}") print亮绿(f"[ENV_VAR] 成功读取环境变量{arg}")
return r return r
@lru_cache(maxsize=128) @lru_cache(maxsize=128)
def read_single_conf_with_lru_cache(arg): def read_single_conf_with_lru_cache(arg):
from colorful import print亮红, print亮绿, print亮蓝 from colorful import print亮红, print亮绿, print亮蓝
try: try:
# 优先级1. 获取环境变量作为配置 # 优先级1. 获取环境变量作为配置
default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考 default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考
r = read_env_variable(arg, default_ref) r = read_env_variable(arg, default_ref)
except: except:
try: try:
@ -899,7 +923,7 @@ def read_single_conf_with_lru_cache(arg):
if arg == 'API_URL_REDIRECT': if arg == 'API_URL_REDIRECT':
oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明` oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`
if oai_rd and not oai_rd.endswith('/completions'): if oai_rd and not oai_rd.endswith('/completions'):
print亮红( "\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。") print亮红("\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
time.sleep(5) time.sleep(5)
if arg == 'API_KEY': if arg == 'API_KEY':
print亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key如API_KEY=\"openai-key1,openai-key2,azure-key3\"") print亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key如API_KEY=\"openai-key1,openai-key2,azure-key3\"")
@ -907,9 +931,9 @@ def read_single_conf_with_lru_cache(arg):
if is_any_api_key(r): if is_any_api_key(r):
print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功") print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
else: else:
print亮红( "[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式请在config文件中修改API密钥之后再运行。") print亮红("[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式请在config文件中修改API密钥之后再运行。")
if arg == 'proxies': if arg == 'proxies':
if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY防止proxies单独起作用 if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY防止proxies单独起作用
if r is None: if r is None:
print亮红('[PROXY] 网络代理状态未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议检查USE_PROXY选项是否修改。') print亮红('[PROXY] 网络代理状态未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议检查USE_PROXY选项是否修改。')
else: else:
@ -953,17 +977,20 @@ class DummyWith():
在上下文执行开始的情况下__enter__()方法会在代码块被执行前被调用 在上下文执行开始的情况下__enter__()方法会在代码块被执行前被调用
而在上下文执行结束时__exit__()方法则会被调用 而在上下文执行结束时__exit__()方法则会被调用
""" """
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
return return
def run_gradio_in_subpath(demo, auth, port, custom_path): def run_gradio_in_subpath(demo, auth, port, custom_path):
""" """
把gradio的运行地址更改到指定的二次路径上 把gradio的运行地址更改到指定的二次路径上
""" """
def is_path_legal(path: str)->bool:
def is_path_legal(path: str) -> bool:
''' '''
check path for sub url check path for sub url
path: path to check path: path to check
@ -1039,14 +1066,15 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
while n_token > max_token_limit: while n_token > max_token_limit:
where = np.argmax(everything_token) where = np.argmax(everything_token)
encoded = tokenizer.encode(everything[where], disallowed_special=()) encoded = tokenizer.encode(everything[where], disallowed_special=())
clipped_encoded = encoded[:len(encoded)-delta] clipped_encoded = encoded[:len(encoded) - delta]
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
everything_token[where] = get_token_num(everything[where]) everything_token[where] = get_token_num(everything[where])
n_token = get_token_num('\n'.join(everything)) n_token = get_token_num('\n'.join(everything))
history = everything[1:] history = everything[1:]
return history return history
""" """
======================================================================== ========================================================================
第三部分 第三部分
@ -1058,6 +1086,7 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
======================================================================== ========================================================================
""" """
def zip_folder(source_folder, dest_folder, zip_name): def zip_folder(source_folder, dest_folder, zip_name):
import zipfile import zipfile
import os import os
@ -1089,15 +1118,18 @@ def zip_folder(source_folder, dest_folder, zip_name):
print(f"Zip file created at {zip_file}") print(f"Zip file created at {zip_file}")
def zip_result(folder): def zip_result(folder):
t = gen_time_str() t = gen_time_str()
zip_folder(folder, get_log_folder(), f'{t}-result.zip') zip_folder(folder, get_log_folder(), f'{t}-result.zip')
return pj(get_log_folder(), f'{t}-result.zip') return pj(get_log_folder(), f'{t}-result.zip')
def gen_time_str(): def gen_time_str():
import time import time
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
def get_log_folder(user=default_user_name, plugin_name='shared'): def get_log_folder(user=default_user_name, plugin_name='shared'):
if user is None: user = default_user_name if user is None: user = default_user_name
PATH_LOGGING = get_conf('PATH_LOGGING') PATH_LOGGING = get_conf('PATH_LOGGING')
@ -1108,29 +1140,36 @@ def get_log_folder(user=default_user_name, plugin_name='shared'):
if not os.path.exists(_dir): os.makedirs(_dir) if not os.path.exists(_dir): os.makedirs(_dir)
return _dir return _dir
def get_upload_folder(user=default_user_name, tag=None): def get_upload_folder(user=default_user_name, tag=None):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD') PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
if user is None: user = default_user_name if user is None: user = default_user_name
if tag is None or len(tag)==0: if tag is None or len(tag) == 0:
target_path_base = pj(PATH_PRIVATE_UPLOAD, user) target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
else: else:
target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag) target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag)
return target_path_base return target_path_base
def is_the_upload_folder(string): def is_the_upload_folder(string):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD') PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$' pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD) pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
if re.match(pattern, string): return True if re.match(pattern, string):
else: return False return True
else:
return False
def get_user(chatbotwithcookies): def get_user(chatbotwithcookies):
return chatbotwithcookies._cookies.get('user_name', default_user_name) return chatbotwithcookies._cookies.get('user_name', default_user_name)
class ProxyNetworkActivate(): class ProxyNetworkActivate():
""" """
这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理 这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
""" """
def __init__(self, task=None) -> None: def __init__(self, task=None) -> None:
self.task = task self.task = task
if not task: if not task:
@ -1158,12 +1197,14 @@ class ProxyNetworkActivate():
if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY') if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY')
return return
def objdump(obj, file='objdump.tmp'): def objdump(obj, file='objdump.tmp'):
import pickle import pickle
with open(file, 'wb+') as f: with open(file, 'wb+') as f:
pickle.dump(obj, f) pickle.dump(obj, f)
return return
def objload(file='objdump.tmp'): def objload(file='objdump.tmp'):
import pickle, os import pickle, os
if not os.path.exists(file): if not os.path.exists(file):
@ -1171,6 +1212,7 @@ def objload(file='objdump.tmp'):
with open(file, 'rb') as f: with open(file, 'rb') as f:
return pickle.load(f) return pickle.load(f)
def Singleton(cls): def Singleton(cls):
""" """
一个单实例装饰器 一个单实例装饰器
@ -1184,6 +1226,7 @@ def Singleton(cls):
return _singleton return _singleton
""" """
======================================================================== ========================================================================
第四部分 第四部分
@ -1197,6 +1240,7 @@ def Singleton(cls):
======================================================================== ========================================================================
""" """
def set_conf(key, value): def set_conf(key, value):
from toolbox import read_single_conf_with_lru_cache, get_conf from toolbox import read_single_conf_with_lru_cache, get_conf
read_single_conf_with_lru_cache.cache_clear() read_single_conf_with_lru_cache.cache_clear()
@ -1205,10 +1249,12 @@ def set_conf(key, value):
altered = get_conf(key) altered = get_conf(key)
return altered return altered
def set_multi_conf(dic): def set_multi_conf(dic):
for k, v in dic.items(): set_conf(k, v) for k, v in dic.items(): set_conf(k, v)
return return
def get_plugin_handle(plugin_name): def get_plugin_handle(plugin_name):
""" """
e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言' e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
@ -1220,12 +1266,14 @@ def get_plugin_handle(plugin_name):
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name) f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
return f_hot_reload return f_hot_reload
def get_chat_handle(): def get_chat_handle():
""" """
""" """
from request_llms.bridge_all import predict_no_ui_long_connection from request_llms.bridge_all import predict_no_ui_long_connection
return predict_no_ui_long_connection return predict_no_ui_long_connection
def get_plugin_default_kwargs(): def get_plugin_default_kwargs():
""" """
""" """
@ -1234,9 +1282,9 @@ def get_plugin_default_kwargs():
llm_kwargs = { llm_kwargs = {
'api_key': cookies['api_key'], 'api_key': cookies['api_key'],
'llm_model': cookies['llm_model'], 'llm_model': cookies['llm_model'],
'top_p':1.0, 'top_p': 1.0,
'max_length': None, 'max_length': None,
'temperature':1.0, 'temperature': 1.0,
} }
chatbot = ChatBotWithCookies(llm_kwargs) chatbot = ChatBotWithCookies(llm_kwargs)
@ -1252,6 +1300,7 @@ def get_plugin_default_kwargs():
} }
return DEFAULT_FN_GROUPS_kwargs return DEFAULT_FN_GROUPS_kwargs
def get_chat_default_kwargs(): def get_chat_default_kwargs():
""" """
""" """
@ -1259,9 +1308,9 @@ def get_chat_default_kwargs():
llm_kwargs = { llm_kwargs = {
'api_key': cookies['api_key'], 'api_key': cookies['api_key'],
'llm_model': cookies['llm_model'], 'llm_model': cookies['llm_model'],
'top_p':1.0, 'top_p': 1.0,
'max_length': None, 'max_length': None,
'temperature':1.0, 'temperature': 1.0,
} }
default_chat_kwargs = { default_chat_kwargs = {
"inputs": "Hello there, are you ready?", "inputs": "Hello there, are you ready?",
@ -1284,15 +1333,15 @@ def get_pictures_list(path):
def have_any_recent_upload_image_files(chatbot): def have_any_recent_upload_image_files(chatbot):
_5min = 5 * 60 _5min = 5 * 60
if chatbot is None: return False, None # chatbot is None if chatbot is None: return False, None # chatbot is None
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None) most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
if not most_recent_uploaded: return False, None # most_recent_uploaded is None if not most_recent_uploaded: return False, None # most_recent_uploaded is None
if time.time() - most_recent_uploaded["time"] < _5min: if time.time() - most_recent_uploaded["time"] < _5min:
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None) most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
path = most_recent_uploaded['path'] path = most_recent_uploaded['path']
file_manifest = get_pictures_list(path) file_manifest = get_pictures_list(path)
if len(file_manifest) == 0: return False, None if len(file_manifest) == 0: return False, None
return True, file_manifest # most_recent_uploaded is new return True, file_manifest # most_recent_uploaded is new
else: else:
return False, None # most_recent_uploaded is too old return False, None # most_recent_uploaded is too old
@ -1307,6 +1356,7 @@ def get_max_token(llm_kwargs):
from request_llms.bridge_all import model_info from request_llms.bridge_all import model_info
return model_info[llm_kwargs['llm_model']]['max_token'] return model_info[llm_kwargs['llm_model']]['max_token']
def check_packages(packages=[]): def check_packages(packages=[]):
import importlib.util import importlib.util
for p in packages: for p in packages: