diff --git a/config.py b/config.py index 861bbed..c202ca0 100644 --- a/config.py +++ b/config.py @@ -89,12 +89,14 @@ DEFAULT_FN_GROUPS = ['对话', '编程', '学术', '智能体'] LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓ AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", - "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k', "gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4", - "chatglm3", "moss", "claude-2"] -# P.S. 其他可用的模型还包括 ["zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random" + "gemini-pro", "chatglm3", "moss", "claude-2"] +# P.S. 其他可用的模型还包括 [ +# "qwen-turbo", "qwen-plus", "qwen-max" +# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", +# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k', # "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama" -# “qwen-turbo", "qwen-plus", "qwen-max"] +# ] # 定义界面上“询问多个GPT模型”插件应该使用哪些模型,请从AVAIL_LLM_MODELS中选择,并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4" @@ -204,6 +206,10 @@ ANTHROPIC_API_KEY = "" CUSTOM_API_KEY_PATTERN = "" +# Google Gemini API-Key +GEMINI_API_KEY = '' + + # HUGGINGFACE的TOKEN,下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens HUGGINGFACE_ACCESS_TOKEN = "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" @@ -292,6 +298,9 @@ NUM_CUSTOM_BASIC_BTN = 4 ├── "qwen-turbo" 等通义千问大模型 │ └── DASHSCOPE_API_KEY │ +├── "Gemini" +│ └── GEMINI_API_KEY +│ └── "newbing" Newbing接口不再稳定,不推荐使用 ├── NEWBING_STYLE └── NEWBING_COOKIES diff --git a/request_llms/bridge_all.py b/request_llms/bridge_all.py index 689b1f9..61e58a0 100644 --- a/request_llms/bridge_all.py +++ b/request_llms/bridge_all.py @@ -28,6 +28,9 @@ from .bridge_chatglm3 import predict as chatglm3_ui from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui from .bridge_qianfan import predict as qianfan_ui +from .bridge_google_gemini import predict as genai_ui +from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui + colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044'] class LazyloadTiktoken(object): @@ -246,6 +249,22 @@ model_info = { "tokenizer": tokenizer_gpt35, "token_cnt": get_token_num_gpt35, }, + "gemini-pro": { + "fn_with_ui": genai_ui, + "fn_without_ui": genai_noui, + "endpoint": None, + "max_token": 1024 * 32, + "tokenizer": tokenizer_gpt35, + "token_cnt": get_token_num_gpt35, + }, + "gemini-pro-vision": { + "fn_with_ui": genai_ui, + "fn_without_ui": genai_noui, + "endpoint": None, + "max_token": 1024 * 32, + "tokenizer": tokenizer_gpt35, + "token_cnt": get_token_num_gpt35, + }, } # -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=- diff --git a/request_llms/bridge_google_gemini.py b/request_llms/bridge_google_gemini.py new file mode 100644 index 0000000..2438e09 --- /dev/null +++ b/request_llms/bridge_google_gemini.py @@ -0,0 +1,101 @@ +# encoding: utf-8 +# @Time : 2023/12/21 +# @Author : Spike +# @Descr : +import json +import re +import time +from request_llms.com_google import GoogleChatInit +from toolbox import get_conf, update_ui, update_ui_lastest_msg + +proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY') +timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \ + '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。' + + +def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, + console_slience=False): + # 检查API_KEY + if get_conf("GEMINI_API_KEY") == "": + raise ValueError(f"请配置 GEMINI_API_KEY。") + + genai = GoogleChatInit() + watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可 + gpt_replying_buffer = '' + stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt) + for response in stream_response: + results = response.decode() + match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL) + error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL) + if match: + try: + paraphrase = json.loads('{"text": "%s"}' % match.group(1)) + except: + raise ValueError(f"解析GEMINI消息出错。") + buffer = paraphrase['text'] + gpt_replying_buffer += buffer + if len(observe_window) >= 1: + observe_window[0] = gpt_replying_buffer + if len(observe_window) >= 2: + if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。") + if error_match: + raise RuntimeError(f'{gpt_replying_buffer} 对话错误') + return gpt_replying_buffer + + +def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None): + # 检查API_KEY + if get_conf("GEMINI_API_KEY") == "": + yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0) + return + + chatbot.append((inputs, "")) + yield from update_ui(chatbot=chatbot, history=history) + genai = GoogleChatInit() + retry = 0 + while True: + try: + stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt) + break + except Exception as e: + retry += 1 + chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg)) + retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else "" + yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面 + if retry > MAX_RETRY: raise TimeoutError + gpt_replying_buffer = "" + gpt_security_policy = "" + history.extend([inputs, '']) + for response in stream_response: + results = response.decode("utf-8") # 被这个解码给耍了。。 + gpt_security_policy += results + match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL) + error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL) + if match: + try: + paraphrase = json.loads('{"text": "%s"}' % match.group(1)) + except: + raise ValueError(f"解析GEMINI消息出错。") + gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理 + chatbot[-1] = (inputs, gpt_replying_buffer) + history[-1] = gpt_replying_buffer + yield from update_ui(chatbot=chatbot, history=history) + if error_match: + history = history[-2] # 错误的不纳入对话 + chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```") + yield from update_ui(chatbot=chatbot, history=history) + raise RuntimeError('对话错误') + if not gpt_replying_buffer: + history = history[-2] # 错误的不纳入对话 + chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```") + yield from update_ui(chatbot=chatbot, history=history) + + + +if __name__ == '__main__': + import sys + + llm_kwargs = {'llm_model': 'gemini-pro'} + result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, []) + for i in result: + print(i) diff --git a/request_llms/com_google.py b/request_llms/com_google.py new file mode 100644 index 0000000..7981908 --- /dev/null +++ b/request_llms/com_google.py @@ -0,0 +1,198 @@ +# encoding: utf-8 +# @Time : 2023/12/25 +# @Author : Spike +# @Descr : +import json +import os +import re +import requests +from typing import List, Dict, Tuple +from toolbox import get_conf, encode_image + +proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS') + +""" +======================================================================== +第五部分 一些文件处理方法 +files_filter_handler 根据type过滤文件 +input_encode_handler 提取input中的文件,并解析 +file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本 +link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件 +html_view_blank 超链接 +html_local_file 本地文件取相对路径 +to_markdown_tabs 文件list 转换为 md tab +""" + + +def files_filter_handler(file_list): + new_list = [] + filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps'] + for file in file_list: + file = str(file).replace('file=', '') + if os.path.exists(file): + if str(os.path.basename(file)).split('.')[-1] in filter_: + new_list.append(file) + return new_list + + +def input_encode_handler(inputs): + md_encode = [] + pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))" + matches_path = re.findall(pattern_md_file, inputs) + for md_path in matches_path: + pattern_file = r"\((file=.*)\)" + matches_path = re.findall(pattern_file, md_path) + encode_file = files_filter_handler(file_list=matches_path) + if encode_file: + md_encode.extend([{ + "data": encode_image(i), + "type": os.path.splitext(i)[1].replace('.', '') + } for i in encode_file]) + inputs = inputs.replace(md_path, '') + return inputs, md_encode + + +def file_manifest_filter_html(file_list, filter_: list = None, md_type=False): + new_list = [] + if not filter_: + filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps'] + for file in file_list: + if str(os.path.basename(file)).split('.')[-1] in filter_: + new_list.append(html_local_img(file, md=md_type)) + elif os.path.exists(file): + new_list.append(link_mtime_to_md(file)) + else: + new_list.append(file) + return new_list + + +def link_mtime_to_md(file): + link_local = html_local_file(file) + link_name = os.path.basename(file) + a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})" + return a + + +def html_local_file(file): + base_path = os.path.dirname(__file__) # 项目目录 + if os.path.exists(str(file)): + file = f'file={file.replace(base_path, ".")}' + return file + + +def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True): + style = '' + if max_width is not None: + style += f"max-width: {max_width};" + if max_height is not None: + style += f"max-height: {max_height};" + __file = html_local_file(__file) + a = f'
' + if md: + a = f'![{__file}]({__file})' + return a + + +def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False): + """ + Args: + head: 表头:[] + tabs: 表值:[[列1], [列2], [列3], [列4]] + alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐 + column: True to keep data in columns, False to keep data in rows (default). + Returns: + A string representation of the markdown table. + """ + if column: + transposed_tabs = list(map(list, zip(*tabs))) + else: + transposed_tabs = tabs + # Find the maximum length among the columns + max_len = max(len(column) for column in transposed_tabs) + + tab_format = "| %s " + tabs_list = "".join([tab_format % i for i in head]) + '|\n' + tabs_list += "".join([tab_format % alignment for i in head]) + '|\n' + + for i in range(max_len): + row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs] + row_data = file_manifest_filter_html(row_data, filter_=None) + tabs_list += "".join([tab_format % i for i in row_data]) + '|\n' + + return tabs_list + + +class GoogleChatInit: + + def __init__(self): + self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k' + + def __conversation_user(self, user_input): + what_i_have_asked = {"role": "user", "parts": []} + if 'vision' not in self.url_gemini: + input_ = user_input + encode_img = [] + else: + input_, encode_img = input_encode_handler(user_input) + what_i_have_asked['parts'].append({'text': input_}) + if encode_img: + for data in encode_img: + what_i_have_asked['parts'].append( + {'inline_data': { + "mime_type": f"image/{data['type']}", + "data": data['data'] + }}) + return what_i_have_asked + + def __conversation_history(self, history): + messages = [] + conversation_cnt = len(history) // 2 + if conversation_cnt: + for index in range(0, 2 * conversation_cnt, 2): + what_i_have_asked = self.__conversation_user(history[index]) + what_gpt_answer = { + "role": "model", + "parts": [{"text": history[index + 1]}] + } + messages.append(what_i_have_asked) + messages.append(what_gpt_answer) + return messages + + def generate_chat(self, inputs, llm_kwargs, history, system_prompt): + headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt) + response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload), + stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS) + return response.iter_lines() + + def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]: + messages = [ + # {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。 + # {"role": "user", "parts": [{"text": ""}]}, + # {"role": "model", "parts": [{"text": ""}]} + ] + self.url_gemini = self.url_gemini.replace( + '%m', llm_kwargs['llm_model']).replace( + '%k', get_conf('GEMINI_API_KEY') + ) + header = {'Content-Type': 'application/json'} + if 'vision' not in self.url_gemini: # 不是vision 才处理history + messages.extend(self.__conversation_history(history)) # 处理 history + messages.append(self.__conversation_user(inputs)) # 处理用户对话 + payload = { + "contents": messages, + "generationConfig": { + "stopSequences": str(llm_kwargs.get('stop', '')).split(' '), + "temperature": llm_kwargs.get('temperature', 1), + # "maxOutputTokens": 800, + "topP": llm_kwargs.get('top_p', 0.8), + "topK": 10 + } + } + return header, payload + + +if __name__ == '__main__': + google = GoogleChatInit() + # print(gootle.generate_message_payload('你好呀', {}, + # ['123123', '3123123'], '')) + # gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)') \ No newline at end of file diff --git a/toolbox.py b/toolbox.py index 154b54c..632279b 100644 --- a/toolbox.py +++ b/toolbox.py @@ -11,8 +11,10 @@ import glob import math from latex2mathml.converter import convert as tex2mathml from functools import wraps, lru_cache + pj = os.path.join default_user_name = 'default_user' + """ ======================================================================== 第一部分 @@ -26,6 +28,7 @@ default_user_name = 'default_user' ======================================================================== """ + class ChatBotWithCookies(list): def __init__(self, cookie): """ @@ -67,18 +70,18 @@ def ArgsGeneralWrapper(f): else: user_name = default_user_name cookies.update({ - 'top_p':top_p, + 'top_p': top_p, 'api_key': cookies['api_key'], 'llm_model': llm_model, - 'temperature':temperature, + 'temperature': temperature, 'user_name': user_name, }) llm_kwargs = { 'api_key': cookies['api_key'], 'llm_model': llm_model, - 'top_p':top_p, + 'top_p': top_p, 'max_length': max_length, - 'temperature':temperature, + 'temperature': temperature, 'client_ip': request.client.host, 'most_recent_uploaded': cookies.get('most_recent_uploaded') } @@ -87,7 +90,7 @@ def ArgsGeneralWrapper(f): } chatbot_with_cookie = ChatBotWithCookies(cookies) chatbot_with_cookie.write_list(chatbot) - + if cookies.get('lock_plugin', None) is None: # 正常状态 if len(args) == 0: # 插件通道 @@ -103,8 +106,10 @@ def ArgsGeneralWrapper(f): final_cookies = chatbot_with_cookie.get_cookies() # len(args) != 0 代表“提交”键对话通道,或者基础功能通道 if len(args) != 0 and 'files_to_promote' in final_cookies and len(final_cookies['files_to_promote']) > 0: - chatbot_with_cookie.append(["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"]) + chatbot_with_cookie.append( + ["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"]) yield from update_ui(chatbot_with_cookie, final_cookies['history'], msg="检测到被滞留的缓存文档") + return decorated @@ -129,6 +134,7 @@ def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面 yield cookies, chatbot_gr, history, msg + def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面 """ 刷新用户界面 @@ -147,6 +153,7 @@ def trimmed_format_exc(): replace_path = "." return str.replace(current_path, replace_path) + def CatchException(f): """ 装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。 @@ -164,9 +171,9 @@ def CatchException(f): if len(chatbot_with_cookie) == 0: chatbot_with_cookie.clear() chatbot_with_cookie.append(["插件调度异常", "异常原因"]) - chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0], - f"[Local Message] 插件调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}") - yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面 + chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0], f"[Local Message] 插件调用出错: \n\n{tb_str} \n") + yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面 + return decorated @@ -209,6 +216,7 @@ def HotReload(f): ======================================================================== """ + def get_reduce_token_percent(text): """ * 此函数未来将被弃用 @@ -220,9 +228,9 @@ def get_reduce_token_percent(text): EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题 max_limit = float(match[0]) - EXCEED_ALLO current_tokens = float(match[1]) - ratio = max_limit/current_tokens + ratio = max_limit / current_tokens assert ratio > 0 and ratio < 1 - return ratio, str(int(current_tokens-max_limit)) + return ratio, str(int(current_tokens - max_limit)) except: return 0.5, '不详' @@ -242,7 +250,7 @@ def write_history_to_file(history, file_basename=None, file_fullname=None, auto_ with open(file_fullname, 'w', encoding='utf8') as f: f.write('# GPT-Academic Report\n') for i, content in enumerate(history): - try: + try: if type(content) != str: content = str(content) except: continue @@ -268,8 +276,6 @@ def regular_txt_to_markdown(text): return text - - def report_exception(chatbot, history, a, b): """ 向chatbot中添加错误信息 @@ -286,7 +292,7 @@ def text_divide_paragraph(text): suf = '' if text.startswith(pre) and text.endswith(suf): return text - + if '```' in text: # careful input return text @@ -312,7 +318,7 @@ def markdown_convertion(txt): if txt.startswith(pre) and txt.endswith(suf): # print('警告,输入了已经经过转化的字符串,二次转化可能出问题') return txt # 已经被转化过,不需要再次转化 - + markdown_extension_configs = { 'mdx_math': { 'enable_dollar_delimiter': True, @@ -352,7 +358,8 @@ def markdown_convertion(txt): """ 解决一个mdx_math的bug(单$包裹begin命令时多余\n', '') return content @@ -363,16 +370,16 @@ def markdown_convertion(txt): if '```' in txt and '```reference' not in txt: return False if '$' not in txt and '\\[' not in txt: return False mathpatterns = { - r'(? 0: for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items(): - if not azure_model_name.startswith('azure'): + if not azure_model_name.startswith('azure'): raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头") AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"] if is_any_api_key(AZURE_API_KEY_): - if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_ - else: API_KEY = AZURE_API_KEY_ + if is_any_api_key(API_KEY): + API_KEY = API_KEY + ',' + AZURE_API_KEY_ + else: + API_KEY = AZURE_API_KEY_ customize_fn_overwrite_ = {} for k in range(NUM_CUSTOM_BASIC_BTN): - customize_fn_overwrite_.update({ + customize_fn_overwrite_.update({ "自定义按钮" + str(k+1):{ - "Title": r"", - "Prefix": r"请在自定义菜单中定义提示词前缀.", - "Suffix": r"请在自定义菜单中定义提示词后缀", + "Title": r"", + "Prefix": r"请在自定义菜单中定义提示词前缀.", + "Suffix": r"请在自定义菜单中定义提示词后缀", } }) return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_} + def is_openai_api_key(key): CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN') if len(CUSTOM_API_KEY_PATTERN) != 0: @@ -768,14 +785,17 @@ def is_openai_api_key(key): API_MATCH_ORIGINAL = re.match(r"sk-[a-zA-Z0-9]{48}$", key) return bool(API_MATCH_ORIGINAL) + def is_azure_api_key(key): API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key) return bool(API_MATCH_AZURE) + def is_api2d_key(key): API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key) return bool(API_MATCH_API2D) + def is_any_api_key(key): if ',' in key: keys = key.split(',') @@ -785,24 +805,26 @@ def is_any_api_key(key): else: return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key) + def what_keys(keys): - avail_key_list = {'OpenAI Key':0, "Azure Key":0, "API2D Key":0} + avail_key_list = {'OpenAI Key': 0, "Azure Key": 0, "API2D Key": 0} key_list = keys.split(',') for k in key_list: - if is_openai_api_key(k): + if is_openai_api_key(k): avail_key_list['OpenAI Key'] += 1 for k in key_list: - if is_api2d_key(k): + if is_api2d_key(k): avail_key_list['API2D Key'] += 1 for k in key_list: - if is_azure_api_key(k): + if is_azure_api_key(k): avail_key_list['Azure Key'] += 1 return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']} 个" + def select_api_key(keys, llm_model): import random avail_key_list = [] @@ -826,6 +848,7 @@ def select_api_key(keys, llm_model): api_key = random.choice(avail_key_list) # 随机负载均衡 return api_key + def read_env_variable(arg, default_value): """ 环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先),也可以直接是`CONFIG` @@ -843,10 +866,10 @@ def read_env_variable(arg, default_value): set GPT_ACADEMIC_AUTHENTICATION=[("username", "password"), ("username2", "password2")] """ from colorful import print亮红, print亮绿 - arg_with_prefix = "GPT_ACADEMIC_" + arg - if arg_with_prefix in os.environ: + arg_with_prefix = "GPT_ACADEMIC_" + arg + if arg_with_prefix in os.environ: env_arg = os.environ[arg_with_prefix] - elif arg in os.environ: + elif arg in os.environ: env_arg = os.environ[arg] else: raise KeyError @@ -856,7 +879,7 @@ def read_env_variable(arg, default_value): env_arg = env_arg.strip() if env_arg == 'True': r = True elif env_arg == 'False': r = False - else: print('enter True or False, but have:', env_arg); r = default_value + else: print('Enter True or False, but have:', env_arg); r = default_value elif isinstance(default_value, int): r = int(env_arg) elif isinstance(default_value, float): @@ -880,13 +903,14 @@ def read_env_variable(arg, default_value): print亮绿(f"[ENV_VAR] 成功读取环境变量{arg}") return r + @lru_cache(maxsize=128) def read_single_conf_with_lru_cache(arg): from colorful import print亮红, print亮绿, print亮蓝 try: # 优先级1. 获取环境变量作为配置 - default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考 - r = read_env_variable(arg, default_ref) + default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考 + r = read_env_variable(arg, default_ref) except: try: # 优先级2. 获取config_private中的配置 @@ -899,7 +923,7 @@ def read_single_conf_with_lru_cache(arg): if arg == 'API_URL_REDIRECT': oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的,请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明` if oai_rd and not oai_rd.endswith('/completions'): - print亮红( "\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。") + print亮红("\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。") time.sleep(5) if arg == 'API_KEY': print亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key,如API_KEY=\"openai-key1,openai-key2,azure-key3\"") @@ -907,9 +931,9 @@ def read_single_conf_with_lru_cache(arg): if is_any_api_key(r): print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功") else: - print亮红( "[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行。") + print亮红("[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行。") if arg == 'proxies': - if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用 + if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用 if r is None: print亮红('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议:检查USE_PROXY选项是否修改。') else: @@ -953,17 +977,20 @@ class DummyWith(): 在上下文执行开始的情况下,__enter__()方法会在代码块被执行前被调用, 而在上下文执行结束时,__exit__()方法则会被调用。 """ + def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): return + def run_gradio_in_subpath(demo, auth, port, custom_path): """ 把gradio的运行地址更改到指定的二次路径上 """ - def is_path_legal(path: str)->bool: + + def is_path_legal(path: str) -> bool: ''' check path for sub url path: path to check @@ -988,7 +1015,7 @@ def run_gradio_in_subpath(demo, auth, port, custom_path): app = FastAPI() if custom_path != "/": @app.get("/") - def read_main(): + def read_main(): return {"message": f"Gradio is running at: {custom_path}"} app = gr.mount_gradio_app(app, demo, path=custom_path) uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth @@ -999,13 +1026,13 @@ def clip_history(inputs, history, tokenizer, max_token_limit): reduce the length of history by clipping. this function search for the longest entries to clip, little by little, until the number of token of history is reduced under threshold. - 通过裁剪来缩短历史记录的长度。 + 通过裁剪来缩短历史记录的长度。 此函数逐渐地搜索最长的条目进行剪辑, 直到历史记录的标记数量降低到阈值以下。 """ import numpy as np from request_llms.bridge_all import model_info - def get_token_num(txt): + def get_token_num(txt): return len(tokenizer.encode(txt, disallowed_special=())) input_token_num = get_token_num(inputs) @@ -1039,14 +1066,15 @@ def clip_history(inputs, history, tokenizer, max_token_limit): while n_token > max_token_limit: where = np.argmax(everything_token) encoded = tokenizer.encode(everything[where], disallowed_special=()) - clipped_encoded = encoded[:len(encoded)-delta] - everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char + clipped_encoded = encoded[:len(encoded) - delta] + everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char everything_token[where] = get_token_num(everything[where]) n_token = get_token_num('\n'.join(everything)) history = everything[1:] return history + """ ======================================================================== 第三部分 @@ -1058,6 +1086,7 @@ def clip_history(inputs, history, tokenizer, max_token_limit): ======================================================================== """ + def zip_folder(source_folder, dest_folder, zip_name): import zipfile import os @@ -1089,15 +1118,18 @@ def zip_folder(source_folder, dest_folder, zip_name): print(f"Zip file created at {zip_file}") + def zip_result(folder): t = gen_time_str() zip_folder(folder, get_log_folder(), f'{t}-result.zip') return pj(get_log_folder(), f'{t}-result.zip') + def gen_time_str(): import time return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + def get_log_folder(user=default_user_name, plugin_name='shared'): if user is None: user = default_user_name PATH_LOGGING = get_conf('PATH_LOGGING') @@ -1108,29 +1140,36 @@ def get_log_folder(user=default_user_name, plugin_name='shared'): if not os.path.exists(_dir): os.makedirs(_dir) return _dir + def get_upload_folder(user=default_user_name, tag=None): PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD') if user is None: user = default_user_name - if tag is None or len(tag)==0: + if tag is None or len(tag) == 0: target_path_base = pj(PATH_PRIVATE_UPLOAD, user) else: target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag) return target_path_base + def is_the_upload_folder(string): PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD') pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$' pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD) - if re.match(pattern, string): return True - else: return False + if re.match(pattern, string): + return True + else: + return False + def get_user(chatbotwithcookies): return chatbotwithcookies._cookies.get('user_name', default_user_name) + class ProxyNetworkActivate(): """ 这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理 """ + def __init__(self, task=None) -> None: self.task = task if not task: @@ -1158,32 +1197,36 @@ class ProxyNetworkActivate(): if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY') return + def objdump(obj, file='objdump.tmp'): import pickle with open(file, 'wb+') as f: pickle.dump(obj, f) return + def objload(file='objdump.tmp'): import pickle, os - if not os.path.exists(file): + if not os.path.exists(file): return with open(file, 'rb') as f: return pickle.load(f) - + + def Singleton(cls): """ 一个单实例装饰器 """ _instance = {} - + def _singleton(*args, **kargs): if cls not in _instance: _instance[cls] = cls(*args, **kargs) return _instance[cls] - + return _singleton + """ ======================================================================== 第四部分 @@ -1197,6 +1240,7 @@ def Singleton(cls): ======================================================================== """ + def set_conf(key, value): from toolbox import read_single_conf_with_lru_cache, get_conf read_single_conf_with_lru_cache.cache_clear() @@ -1205,10 +1249,12 @@ def set_conf(key, value): altered = get_conf(key) return altered + def set_multi_conf(dic): for k, v in dic.items(): set_conf(k, v) return + def get_plugin_handle(plugin_name): """ e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言' @@ -1220,12 +1266,14 @@ def get_plugin_handle(plugin_name): f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name) return f_hot_reload + def get_chat_handle(): """ """ from request_llms.bridge_all import predict_no_ui_long_connection return predict_no_ui_long_connection + def get_plugin_default_kwargs(): """ """ @@ -1234,9 +1282,9 @@ def get_plugin_default_kwargs(): llm_kwargs = { 'api_key': cookies['api_key'], 'llm_model': cookies['llm_model'], - 'top_p':1.0, + 'top_p': 1.0, 'max_length': None, - 'temperature':1.0, + 'temperature': 1.0, } chatbot = ChatBotWithCookies(llm_kwargs) @@ -1247,11 +1295,12 @@ def get_plugin_default_kwargs(): "plugin_kwargs": {}, "chatbot_with_cookie": chatbot, "history": [], - "system_prompt": "You are a good AI.", + "system_prompt": "You are a good AI.", "web_port": None } return DEFAULT_FN_GROUPS_kwargs + def get_chat_default_kwargs(): """ """ @@ -1259,9 +1308,9 @@ def get_chat_default_kwargs(): llm_kwargs = { 'api_key': cookies['api_key'], 'llm_model': cookies['llm_model'], - 'top_p':1.0, + 'top_p': 1.0, 'max_length': None, - 'temperature':1.0, + 'temperature': 1.0, } default_chat_kwargs = { "inputs": "Hello there, are you ready?", @@ -1284,15 +1333,15 @@ def get_pictures_list(path): def have_any_recent_upload_image_files(chatbot): _5min = 5 * 60 - if chatbot is None: return False, None # chatbot is None + if chatbot is None: return False, None # chatbot is None most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None) - if not most_recent_uploaded: return False, None # most_recent_uploaded is None + if not most_recent_uploaded: return False, None # most_recent_uploaded is None if time.time() - most_recent_uploaded["time"] < _5min: most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None) path = most_recent_uploaded['path'] file_manifest = get_pictures_list(path) if len(file_manifest) == 0: return False, None - return True, file_manifest # most_recent_uploaded is new + return True, file_manifest # most_recent_uploaded is new else: return False, None # most_recent_uploaded is too old @@ -1307,6 +1356,7 @@ def get_max_token(llm_kwargs): from request_llms.bridge_all import model_info return model_info[llm_kwargs['llm_model']]['max_token'] + def check_packages(packages=[]): import importlib.util for p in packages: