适配 google gemini 优化为从用户input中提取文件 (#1419)

适配 google gemini 优化为从用户input中提取文件
This commit is contained in:
XIao 2023-12-31 17:13:50 +08:00 committed by qingxu fu
parent a96f842b3a
commit a7c960dcb0
5 changed files with 472 additions and 95 deletions

View File

@ -89,12 +89,14 @@ DEFAULT_FN_GROUPS = ['对话', '编程', '学术', '智能体']
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
"api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
"gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4",
"chatglm3", "moss", "claude-2"]
# P.S. 其他可用的模型还包括 ["zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random"
"gemini-pro", "chatglm3", "moss", "claude-2"]
# P.S. 其他可用的模型还包括 [
# "qwen-turbo", "qwen-plus", "qwen-max"
# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613",
# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
# “qwen-turbo", "qwen-plus", "qwen-max"]
# ]
# 定义界面上“询问多个GPT模型”插件应该使用哪些模型请从AVAIL_LLM_MODELS中选择并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
@ -204,6 +206,10 @@ ANTHROPIC_API_KEY = ""
CUSTOM_API_KEY_PATTERN = ""
# Google Gemini API-Key
GEMINI_API_KEY = ''
# HUGGINGFACE的TOKEN下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
HUGGINGFACE_ACCESS_TOKEN = "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"
@ -292,6 +298,9 @@ NUM_CUSTOM_BASIC_BTN = 4
"qwen-turbo" 等通义千问大模型
DASHSCOPE_API_KEY
"Gemini"
GEMINI_API_KEY
"newbing" Newbing接口不再稳定不推荐使用
NEWBING_STYLE
NEWBING_COOKIES

View File

@ -28,6 +28,9 @@ from .bridge_chatglm3 import predict as chatglm3_ui
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
from .bridge_qianfan import predict as qianfan_ui
from .bridge_google_gemini import predict as genai_ui
from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
class LazyloadTiktoken(object):
@ -246,6 +249,22 @@ model_info = {
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"gemini-pro": {
"fn_with_ui": genai_ui,
"fn_without_ui": genai_noui,
"endpoint": None,
"max_token": 1024 * 32,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"gemini-pro-vision": {
"fn_with_ui": genai_ui,
"fn_without_ui": genai_noui,
"endpoint": None,
"max_token": 1024 * 32,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
}
# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-

View File

@ -0,0 +1,101 @@
# encoding: utf-8
# @Time : 2023/12/21
# @Author : Spike
# @Descr :
import json
import re
import time
from request_llms.com_google import GoogleChatInit
from toolbox import get_conf, update_ui, update_ui_lastest_msg
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
console_slience=False):
# 检查API_KEY
if get_conf("GEMINI_API_KEY") == "":
raise ValueError(f"请配置 GEMINI_API_KEY。")
genai = GoogleChatInit()
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
gpt_replying_buffer = ''
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
for response in stream_response:
results = response.decode()
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
if match:
try:
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
except:
raise ValueError(f"解析GEMINI消息出错。")
buffer = paraphrase['text']
gpt_replying_buffer += buffer
if len(observe_window) >= 1:
observe_window[0] = gpt_replying_buffer
if len(observe_window) >= 2:
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
if error_match:
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
return gpt_replying_buffer
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
# 检查API_KEY
if get_conf("GEMINI_API_KEY") == "":
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
return
chatbot.append((inputs, ""))
yield from update_ui(chatbot=chatbot, history=history)
genai = GoogleChatInit()
retry = 0
while True:
try:
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
break
except Exception as e:
retry += 1
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
if retry > MAX_RETRY: raise TimeoutError
gpt_replying_buffer = ""
gpt_security_policy = ""
history.extend([inputs, ''])
for response in stream_response:
results = response.decode("utf-8") # 被这个解码给耍了。。
gpt_security_policy += results
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
if match:
try:
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
except:
raise ValueError(f"解析GEMINI消息出错。")
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
chatbot[-1] = (inputs, gpt_replying_buffer)
history[-1] = gpt_replying_buffer
yield from update_ui(chatbot=chatbot, history=history)
if error_match:
history = history[-2] # 错误的不纳入对话
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误请查看message\n\n```\n{error_match.group(1)}\n```")
yield from update_ui(chatbot=chatbot, history=history)
raise RuntimeError('对话错误')
if not gpt_replying_buffer:
history = history[-2] # 错误的不纳入对话
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略没有回答\n\n```\n{gpt_security_policy}\n```")
yield from update_ui(chatbot=chatbot, history=history)
if __name__ == '__main__':
import sys
llm_kwargs = {'llm_model': 'gemini-pro'}
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
for i in result:
print(i)

198
request_llms/com_google.py Normal file
View File

@ -0,0 +1,198 @@
# encoding: utf-8
# @Time : 2023/12/25
# @Author : Spike
# @Descr :
import json
import os
import re
import requests
from typing import List, Dict, Tuple
from toolbox import get_conf, encode_image
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
"""
========================================================================
第五部分 一些文件处理方法
files_filter_handler 根据type过滤文件
input_encode_handler 提取input中的文件并解析
file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
link_mtime_to_md 文件增加本地时间参数避免下载到缓存文件
html_view_blank 超链接
html_local_file 本地文件取相对路径
to_markdown_tabs 文件list 转换为 md tab
"""
def files_filter_handler(file_list):
new_list = []
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
for file in file_list:
file = str(file).replace('file=', '')
if os.path.exists(file):
if str(os.path.basename(file)).split('.')[-1] in filter_:
new_list.append(file)
return new_list
def input_encode_handler(inputs):
md_encode = []
pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
matches_path = re.findall(pattern_md_file, inputs)
for md_path in matches_path:
pattern_file = r"\((file=.*)\)"
matches_path = re.findall(pattern_file, md_path)
encode_file = files_filter_handler(file_list=matches_path)
if encode_file:
md_encode.extend([{
"data": encode_image(i),
"type": os.path.splitext(i)[1].replace('.', '')
} for i in encode_file])
inputs = inputs.replace(md_path, '')
return inputs, md_encode
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
new_list = []
if not filter_:
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
for file in file_list:
if str(os.path.basename(file)).split('.')[-1] in filter_:
new_list.append(html_local_img(file, md=md_type))
elif os.path.exists(file):
new_list.append(link_mtime_to_md(file))
else:
new_list.append(file)
return new_list
def link_mtime_to_md(file):
link_local = html_local_file(file)
link_name = os.path.basename(file)
a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})"
return a
def html_local_file(file):
base_path = os.path.dirname(__file__) # 项目目录
if os.path.exists(str(file)):
file = f'file={file.replace(base_path, ".")}'
return file
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
style = ''
if max_width is not None:
style += f"max-width: {max_width};"
if max_height is not None:
style += f"max-height: {max_height};"
__file = html_local_file(__file)
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
if md:
a = f'![{__file}]({__file})'
return a
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
"""
Args:
head: 表头[]
tabs: 表值[[列1], [列2], [列3], [列4]]
alignment: :--- 左对齐 :---: 居中对齐 ---: 右对齐
column: True to keep data in columns, False to keep data in rows (default).
Returns:
A string representation of the markdown table.
"""
if column:
transposed_tabs = list(map(list, zip(*tabs)))
else:
transposed_tabs = tabs
# Find the maximum length among the columns
max_len = max(len(column) for column in transposed_tabs)
tab_format = "| %s "
tabs_list = "".join([tab_format % i for i in head]) + '|\n'
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
for i in range(max_len):
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
row_data = file_manifest_filter_html(row_data, filter_=None)
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
return tabs_list
class GoogleChatInit:
def __init__(self):
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
def __conversation_user(self, user_input):
what_i_have_asked = {"role": "user", "parts": []}
if 'vision' not in self.url_gemini:
input_ = user_input
encode_img = []
else:
input_, encode_img = input_encode_handler(user_input)
what_i_have_asked['parts'].append({'text': input_})
if encode_img:
for data in encode_img:
what_i_have_asked['parts'].append(
{'inline_data': {
"mime_type": f"image/{data['type']}",
"data": data['data']
}})
return what_i_have_asked
def __conversation_history(self, history):
messages = []
conversation_cnt = len(history) // 2
if conversation_cnt:
for index in range(0, 2 * conversation_cnt, 2):
what_i_have_asked = self.__conversation_user(history[index])
what_gpt_answer = {
"role": "model",
"parts": [{"text": history[index + 1]}]
}
messages.append(what_i_have_asked)
messages.append(what_gpt_answer)
return messages
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
return response.iter_lines()
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
messages = [
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
# {"role": "user", "parts": [{"text": ""}]},
# {"role": "model", "parts": [{"text": ""}]}
]
self.url_gemini = self.url_gemini.replace(
'%m', llm_kwargs['llm_model']).replace(
'%k', get_conf('GEMINI_API_KEY')
)
header = {'Content-Type': 'application/json'}
if 'vision' not in self.url_gemini: # 不是vision 才处理history
messages.extend(self.__conversation_history(history)) # 处理 history
messages.append(self.__conversation_user(inputs)) # 处理用户对话
payload = {
"contents": messages,
"generationConfig": {
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
"temperature": llm_kwargs.get('temperature', 1),
# "maxOutputTokens": 800,
"topP": llm_kwargs.get('top_p', 0.8),
"topK": 10
}
}
return header, payload
if __name__ == '__main__':
google = GoogleChatInit()
# print(gootle.generate_message_payload('你好呀', {},
# ['123123', '3123123'], ''))
# gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)')

View File

@ -11,8 +11,10 @@ import glob
import math
from latex2mathml.converter import convert as tex2mathml
from functools import wraps, lru_cache
pj = os.path.join
default_user_name = 'default_user'
"""
========================================================================
第一部分
@ -26,6 +28,7 @@ default_user_name = 'default_user'
========================================================================
"""
class ChatBotWithCookies(list):
def __init__(self, cookie):
"""
@ -67,18 +70,18 @@ def ArgsGeneralWrapper(f):
else:
user_name = default_user_name
cookies.update({
'top_p':top_p,
'top_p': top_p,
'api_key': cookies['api_key'],
'llm_model': llm_model,
'temperature':temperature,
'temperature': temperature,
'user_name': user_name,
})
llm_kwargs = {
'api_key': cookies['api_key'],
'llm_model': llm_model,
'top_p':top_p,
'top_p': top_p,
'max_length': max_length,
'temperature':temperature,
'temperature': temperature,
'client_ip': request.client.host,
'most_recent_uploaded': cookies.get('most_recent_uploaded')
}
@ -87,7 +90,7 @@ def ArgsGeneralWrapper(f):
}
chatbot_with_cookie = ChatBotWithCookies(cookies)
chatbot_with_cookie.write_list(chatbot)
if cookies.get('lock_plugin', None) is None:
# 正常状态
if len(args) == 0: # 插件通道
@ -103,8 +106,10 @@ def ArgsGeneralWrapper(f):
final_cookies = chatbot_with_cookie.get_cookies()
# len(args) != 0 代表“提交”键对话通道,或者基础功能通道
if len(args) != 0 and 'files_to_promote' in final_cookies and len(final_cookies['files_to_promote']) > 0:
chatbot_with_cookie.append(["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
chatbot_with_cookie.append(
["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
yield from update_ui(chatbot_with_cookie, final_cookies['history'], msg="检测到被滞留的缓存文档")
return decorated
@ -129,6 +134,7 @@ def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面
yield cookies, chatbot_gr, history, msg
def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
"""
刷新用户界面
@ -147,6 +153,7 @@ def trimmed_format_exc():
replace_path = "."
return str.replace(current_path, replace_path)
def CatchException(f):
"""
装饰器函数捕捉函数f中的异常并封装到一个生成器中返回并显示到聊天当中
@ -164,9 +171,9 @@ def CatchException(f):
if len(chatbot_with_cookie) == 0:
chatbot_with_cookie.clear()
chatbot_with_cookie.append(["插件调度异常", "异常原因"])
chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0],
f"[Local Message] 插件调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0], f"[Local Message] 插件调用出错: \n\n{tb_str} \n")
yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
return decorated
@ -209,6 +216,7 @@ def HotReload(f):
========================================================================
"""
def get_reduce_token_percent(text):
"""
* 此函数未来将被弃用
@ -220,9 +228,9 @@ def get_reduce_token_percent(text):
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
max_limit = float(match[0]) - EXCEED_ALLO
current_tokens = float(match[1])
ratio = max_limit/current_tokens
ratio = max_limit / current_tokens
assert ratio > 0 and ratio < 1
return ratio, str(int(current_tokens-max_limit))
return ratio, str(int(current_tokens - max_limit))
except:
return 0.5, '不详'
@ -242,7 +250,7 @@ def write_history_to_file(history, file_basename=None, file_fullname=None, auto_
with open(file_fullname, 'w', encoding='utf8') as f:
f.write('# GPT-Academic Report\n')
for i, content in enumerate(history):
try:
try:
if type(content) != str: content = str(content)
except:
continue
@ -268,8 +276,6 @@ def regular_txt_to_markdown(text):
return text
def report_exception(chatbot, history, a, b):
"""
向chatbot中添加错误信息
@ -286,7 +292,7 @@ def text_divide_paragraph(text):
suf = '</div>'
if text.startswith(pre) and text.endswith(suf):
return text
if '```' in text:
# careful input
return text
@ -312,7 +318,7 @@ def markdown_convertion(txt):
if txt.startswith(pre) and txt.endswith(suf):
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
return txt # 已经被转化过,不需要再次转化
markdown_extension_configs = {
'mdx_math': {
'enable_dollar_delimiter': True,
@ -352,7 +358,8 @@ def markdown_convertion(txt):
"""
解决一个mdx_math的bug$包裹begin命令时多余<script>
"""
content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">', '<script type="math/tex; mode=display">')
content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">',
'<script type="math/tex; mode=display">')
content = content.replace('</script>\n</script>', '</script>')
return content
@ -363,16 +370,16 @@ def markdown_convertion(txt):
if '```' in txt and '```reference' not in txt: return False
if '$' not in txt and '\\[' not in txt: return False
mathpatterns = {
r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False}, #  $...$
r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True}, # $$...$$
r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False}, # \[...\]
# r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\)
# r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end
# r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$
r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False}, #  $...$
r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True}, # $$...$$
r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False}, # \[...\]
# r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\)
# r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end
# r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$
}
matches = []
for pattern, property in mathpatterns.items():
flags = re.ASCII|re.DOTALL if property['allow_multi_lines'] else re.ASCII
flags = re.ASCII | re.DOTALL if property['allow_multi_lines'] else re.ASCII
matches.extend(re.findall(pattern, txt, flags))
if len(matches) == 0: return False
contain_any_eq = False
@ -380,16 +387,16 @@ def markdown_convertion(txt):
for match in matches:
if len(match) != 3: return False
eq_canidate = match[1]
if illegal_pattern.search(eq_canidate):
if illegal_pattern.search(eq_canidate):
return False
else:
else:
contain_any_eq = True
return contain_any_eq
def fix_markdown_indent(txt):
# fix markdown indent
if (' - ' not in txt) or ('. ' not in txt):
return txt # do not need to fix, fast escape
if (' - ' not in txt) or ('. ' not in txt):
return txt # do not need to fix, fast escape
# walk through the lines and fix non-standard indentation
lines = txt.split("\n")
pattern = re.compile(r'^\s+-')
@ -401,7 +408,7 @@ def markdown_convertion(txt):
stripped_string = line.lstrip()
num_spaces = len(line) - len(stripped_string)
if (num_spaces % 4) == 3:
num_spaces_should_be = math.ceil(num_spaces/4) * 4
num_spaces_should_be = math.ceil(num_spaces / 4) * 4
lines[i] = ' ' * num_spaces_should_be + stripped_string
return '\n'.join(lines)
@ -409,7 +416,8 @@ def markdown_convertion(txt):
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
# convert everything to html format
split = markdown.markdown(text='---')
convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'fenced_code'], extension_configs=markdown_extension_configs)
convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'fenced_code'],
extension_configs=markdown_extension_configs)
convert_stage_1 = markdown_bug_hunt(convert_stage_1)
# 1. convert to easy-to-copy tex (do not render math)
convert_stage_2_1, n = re.subn(find_equation_pattern, replace_math_no_render, convert_stage_1, flags=re.DOTALL)
@ -441,8 +449,7 @@ def close_up_code_segment_during_stream(gpt_reply):
segments = gpt_reply.split('```')
n_mark = len(segments) - 1
if n_mark % 2 == 1:
# print('输出代码片段中!')
return gpt_reply+'\n```'
return gpt_reply + '\n```' # 输出代码片段中!
else:
return gpt_reply
@ -533,7 +540,7 @@ def find_recent_files(directory):
current_time = time.time()
one_minute_ago = current_time - 60
recent_files = []
if not os.path.exists(directory):
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
for filename in os.listdir(directory):
file_path = pj(directory, filename)
@ -559,6 +566,7 @@ def file_already_in_downloadzone(file, user_path):
except:
return False
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
# 将文件复制一份到下载区
import shutil
@ -581,8 +589,10 @@ def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
if not os.path.exists(new_path): shutil.copyfile(file, new_path)
# 将文件添加到chatbot cookie中
if chatbot is not None:
if 'files_to_promote' in chatbot._cookies: current = chatbot._cookies['files_to_promote']
else: current = []
if 'files_to_promote' in chatbot._cookies:
current = chatbot._cookies['files_to_promote']
else:
current = []
if new_path not in current: # 避免把同一个文件添加多次
chatbot._cookies.update({'files_to_promote': [new_path] + current})
return new_path
@ -605,8 +615,10 @@ def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
for subdirectory in glob.glob(f'{user_upload_dir}/*'):
subdirectory_time = os.path.getmtime(subdirectory)
if subdirectory_time < one_hour_ago:
try: shutil.rmtree(subdirectory)
except: pass
try:
shutil.rmtree(subdirectory)
except:
pass
return
@ -679,9 +691,9 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
time_tag = gen_time_str()
target_path_base = get_upload_folder(user_name, tag=time_tag)
os.makedirs(target_path_base, exist_ok=True)
# 移除过时的旧文件从而节省空间&保护隐私
outdate_time_seconds = 3600 # 一小时
outdate_time_seconds = 3600 # 一小时
del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name))
# 逐个文件转移到目标路径
@ -690,21 +702,20 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
file_origin_name = os.path.basename(file.orig_name)
this_file_path = pj(target_path_base, file_origin_name)
shutil.move(file.name, this_file_path)
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path+'.extract')
if "浮动输入区" in checkboxes:
txt, txt2 = "", target_path_base
else:
txt, txt2 = target_path_base, ""
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path + '.extract')
# 整理文件集合 输出消息
moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
moved_files_str = to_markdown_tabs(head=['文件'], tabs=[moved_files])
chatbot.append(['我上传了文件,请查收',
chatbot.append(['我上传了文件,请查收',
f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数'+upload_msg])
f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数' + upload_msg])
txt, txt2 = target_path_base, ""
if "浮动输入区" in checkboxes:
txt, txt2 = txt2, txt
# 记录近期文件
cookies.update({
'most_recent_uploaded': {
@ -732,34 +743,40 @@ def on_report_generated(cookies, files, chatbot):
chatbot.append(['报告如何远程获取?', f'报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}'])
return cookies, report_files, chatbot
def load_chat_cookies():
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
# deal with azure openai key
if is_any_api_key(AZURE_API_KEY):
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY
else: API_KEY = AZURE_API_KEY
if is_any_api_key(API_KEY):
API_KEY = API_KEY + ',' + AZURE_API_KEY
else:
API_KEY = AZURE_API_KEY
if len(AZURE_CFG_ARRAY) > 0:
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
if not azure_model_name.startswith('azure'):
if not azure_model_name.startswith('azure'):
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
if is_any_api_key(AZURE_API_KEY_):
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_
else: API_KEY = AZURE_API_KEY_
if is_any_api_key(API_KEY):
API_KEY = API_KEY + ',' + AZURE_API_KEY_
else:
API_KEY = AZURE_API_KEY_
customize_fn_overwrite_ = {}
for k in range(NUM_CUSTOM_BASIC_BTN):
customize_fn_overwrite_.update({
customize_fn_overwrite_.update({
"自定义按钮" + str(k+1):{
"Title": r"",
"Prefix": r"请在自定义菜单中定义提示词前缀.",
"Suffix": r"请在自定义菜单中定义提示词后缀",
"Title": r"",
"Prefix": r"请在自定义菜单中定义提示词前缀.",
"Suffix": r"请在自定义菜单中定义提示词后缀",
}
})
return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_}
def is_openai_api_key(key):
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
if len(CUSTOM_API_KEY_PATTERN) != 0:
@ -768,14 +785,17 @@ def is_openai_api_key(key):
API_MATCH_ORIGINAL = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
return bool(API_MATCH_ORIGINAL)
def is_azure_api_key(key):
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_AZURE)
def is_api2d_key(key):
API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_API2D)
def is_any_api_key(key):
if ',' in key:
keys = key.split(',')
@ -785,24 +805,26 @@ def is_any_api_key(key):
else:
return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key)
def what_keys(keys):
avail_key_list = {'OpenAI Key':0, "Azure Key":0, "API2D Key":0}
avail_key_list = {'OpenAI Key': 0, "Azure Key": 0, "API2D Key": 0}
key_list = keys.split(',')
for k in key_list:
if is_openai_api_key(k):
if is_openai_api_key(k):
avail_key_list['OpenAI Key'] += 1
for k in key_list:
if is_api2d_key(k):
if is_api2d_key(k):
avail_key_list['API2D Key'] += 1
for k in key_list:
if is_azure_api_key(k):
if is_azure_api_key(k):
avail_key_list['Azure Key'] += 1
return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']}"
def select_api_key(keys, llm_model):
import random
avail_key_list = []
@ -826,6 +848,7 @@ def select_api_key(keys, llm_model):
api_key = random.choice(avail_key_list) # 随机负载均衡
return api_key
def read_env_variable(arg, default_value):
"""
环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先)也可以直接是`CONFIG`
@ -843,10 +866,10 @@ def read_env_variable(arg, default_value):
set GPT_ACADEMIC_AUTHENTICATION=[("username", "password"), ("username2", "password2")]
"""
from colorful import print亮红, print亮绿
arg_with_prefix = "GPT_ACADEMIC_" + arg
if arg_with_prefix in os.environ:
arg_with_prefix = "GPT_ACADEMIC_" + arg
if arg_with_prefix in os.environ:
env_arg = os.environ[arg_with_prefix]
elif arg in os.environ:
elif arg in os.environ:
env_arg = os.environ[arg]
else:
raise KeyError
@ -856,7 +879,7 @@ def read_env_variable(arg, default_value):
env_arg = env_arg.strip()
if env_arg == 'True': r = True
elif env_arg == 'False': r = False
else: print('enter True or False, but have:', env_arg); r = default_value
else: print('Enter True or False, but have:', env_arg); r = default_value
elif isinstance(default_value, int):
r = int(env_arg)
elif isinstance(default_value, float):
@ -880,13 +903,14 @@ def read_env_variable(arg, default_value):
print亮绿(f"[ENV_VAR] 成功读取环境变量{arg}")
return r
@lru_cache(maxsize=128)
def read_single_conf_with_lru_cache(arg):
from colorful import print亮红, print亮绿, print亮蓝
try:
# 优先级1. 获取环境变量作为配置
default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考
r = read_env_variable(arg, default_ref)
default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考
r = read_env_variable(arg, default_ref)
except:
try:
# 优先级2. 获取config_private中的配置
@ -899,7 +923,7 @@ def read_single_conf_with_lru_cache(arg):
if arg == 'API_URL_REDIRECT':
oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`
if oai_rd and not oai_rd.endswith('/completions'):
print亮红( "\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
print亮红("\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
time.sleep(5)
if arg == 'API_KEY':
print亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key如API_KEY=\"openai-key1,openai-key2,azure-key3\"")
@ -907,9 +931,9 @@ def read_single_conf_with_lru_cache(arg):
if is_any_api_key(r):
print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
else:
print亮红( "[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式请在config文件中修改API密钥之后再运行。")
print亮红("[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式请在config文件中修改API密钥之后再运行。")
if arg == 'proxies':
if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY防止proxies单独起作用
if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY防止proxies单独起作用
if r is None:
print亮红('[PROXY] 网络代理状态未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议检查USE_PROXY选项是否修改。')
else:
@ -953,17 +977,20 @@ class DummyWith():
在上下文执行开始的情况下__enter__()方法会在代码块被执行前被调用
而在上下文执行结束时__exit__()方法则会被调用
"""
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return
def run_gradio_in_subpath(demo, auth, port, custom_path):
"""
把gradio的运行地址更改到指定的二次路径上
"""
def is_path_legal(path: str)->bool:
def is_path_legal(path: str) -> bool:
'''
check path for sub url
path: path to check
@ -988,7 +1015,7 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
app = FastAPI()
if custom_path != "/":
@app.get("/")
def read_main():
def read_main():
return {"message": f"Gradio is running at: {custom_path}"}
app = gr.mount_gradio_app(app, demo, path=custom_path)
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
@ -999,13 +1026,13 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
reduce the length of history by clipping.
this function search for the longest entries to clip, little by little,
until the number of token of history is reduced under threshold.
通过裁剪来缩短历史记录的长度
通过裁剪来缩短历史记录的长度
此函数逐渐地搜索最长的条目进行剪辑
直到历史记录的标记数量降低到阈值以下
"""
import numpy as np
from request_llms.bridge_all import model_info
def get_token_num(txt):
def get_token_num(txt):
return len(tokenizer.encode(txt, disallowed_special=()))
input_token_num = get_token_num(inputs)
@ -1039,14 +1066,15 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
while n_token > max_token_limit:
where = np.argmax(everything_token)
encoded = tokenizer.encode(everything[where], disallowed_special=())
clipped_encoded = encoded[:len(encoded)-delta]
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
clipped_encoded = encoded[:len(encoded) - delta]
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
everything_token[where] = get_token_num(everything[where])
n_token = get_token_num('\n'.join(everything))
history = everything[1:]
return history
"""
========================================================================
第三部分
@ -1058,6 +1086,7 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
========================================================================
"""
def zip_folder(source_folder, dest_folder, zip_name):
import zipfile
import os
@ -1089,15 +1118,18 @@ def zip_folder(source_folder, dest_folder, zip_name):
print(f"Zip file created at {zip_file}")
def zip_result(folder):
t = gen_time_str()
zip_folder(folder, get_log_folder(), f'{t}-result.zip')
return pj(get_log_folder(), f'{t}-result.zip')
def gen_time_str():
import time
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
def get_log_folder(user=default_user_name, plugin_name='shared'):
if user is None: user = default_user_name
PATH_LOGGING = get_conf('PATH_LOGGING')
@ -1108,29 +1140,36 @@ def get_log_folder(user=default_user_name, plugin_name='shared'):
if not os.path.exists(_dir): os.makedirs(_dir)
return _dir
def get_upload_folder(user=default_user_name, tag=None):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
if user is None: user = default_user_name
if tag is None or len(tag)==0:
if tag is None or len(tag) == 0:
target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
else:
target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag)
return target_path_base
def is_the_upload_folder(string):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
if re.match(pattern, string): return True
else: return False
if re.match(pattern, string):
return True
else:
return False
def get_user(chatbotwithcookies):
return chatbotwithcookies._cookies.get('user_name', default_user_name)
class ProxyNetworkActivate():
"""
这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
"""
def __init__(self, task=None) -> None:
self.task = task
if not task:
@ -1158,32 +1197,36 @@ class ProxyNetworkActivate():
if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY')
return
def objdump(obj, file='objdump.tmp'):
import pickle
with open(file, 'wb+') as f:
pickle.dump(obj, f)
return
def objload(file='objdump.tmp'):
import pickle, os
if not os.path.exists(file):
if not os.path.exists(file):
return
with open(file, 'rb') as f:
return pickle.load(f)
def Singleton(cls):
"""
一个单实例装饰器
"""
_instance = {}
def _singleton(*args, **kargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kargs)
return _instance[cls]
return _singleton
"""
========================================================================
第四部分
@ -1197,6 +1240,7 @@ def Singleton(cls):
========================================================================
"""
def set_conf(key, value):
from toolbox import read_single_conf_with_lru_cache, get_conf
read_single_conf_with_lru_cache.cache_clear()
@ -1205,10 +1249,12 @@ def set_conf(key, value):
altered = get_conf(key)
return altered
def set_multi_conf(dic):
for k, v in dic.items(): set_conf(k, v)
return
def get_plugin_handle(plugin_name):
"""
e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
@ -1220,12 +1266,14 @@ def get_plugin_handle(plugin_name):
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
return f_hot_reload
def get_chat_handle():
"""
"""
from request_llms.bridge_all import predict_no_ui_long_connection
return predict_no_ui_long_connection
def get_plugin_default_kwargs():
"""
"""
@ -1234,9 +1282,9 @@ def get_plugin_default_kwargs():
llm_kwargs = {
'api_key': cookies['api_key'],
'llm_model': cookies['llm_model'],
'top_p':1.0,
'top_p': 1.0,
'max_length': None,
'temperature':1.0,
'temperature': 1.0,
}
chatbot = ChatBotWithCookies(llm_kwargs)
@ -1247,11 +1295,12 @@ def get_plugin_default_kwargs():
"plugin_kwargs": {},
"chatbot_with_cookie": chatbot,
"history": [],
"system_prompt": "You are a good AI.",
"system_prompt": "You are a good AI.",
"web_port": None
}
return DEFAULT_FN_GROUPS_kwargs
def get_chat_default_kwargs():
"""
"""
@ -1259,9 +1308,9 @@ def get_chat_default_kwargs():
llm_kwargs = {
'api_key': cookies['api_key'],
'llm_model': cookies['llm_model'],
'top_p':1.0,
'top_p': 1.0,
'max_length': None,
'temperature':1.0,
'temperature': 1.0,
}
default_chat_kwargs = {
"inputs": "Hello there, are you ready?",
@ -1284,15 +1333,15 @@ def get_pictures_list(path):
def have_any_recent_upload_image_files(chatbot):
_5min = 5 * 60
if chatbot is None: return False, None # chatbot is None
if chatbot is None: return False, None # chatbot is None
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
if not most_recent_uploaded: return False, None # most_recent_uploaded is None
if not most_recent_uploaded: return False, None # most_recent_uploaded is None
if time.time() - most_recent_uploaded["time"] < _5min:
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
path = most_recent_uploaded['path']
file_manifest = get_pictures_list(path)
if len(file_manifest) == 0: return False, None
return True, file_manifest # most_recent_uploaded is new
return True, file_manifest # most_recent_uploaded is new
else:
return False, None # most_recent_uploaded is too old
@ -1307,6 +1356,7 @@ def get_max_token(llm_kwargs):
from request_llms.bridge_all import model_info
return model_info[llm_kwargs['llm_model']]['max_token']
def check_packages(packages=[]):
import importlib.util
for p in packages: