修改变量命名,整理配置清单
This commit is contained in:
		
							parent
							
								
									a208782049
								
							
						
					
					
						commit
						89de49f31e
					
				
							
								
								
									
										20
									
								
								config.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								config.py
									
									
									
									
									
								
							@ -67,8 +67,10 @@ WEB_PORT = -1
 | 
				
			|||||||
# 如果OpenAI不响应(网络卡顿、代理失败、KEY失效),重试的次数限制
 | 
					# 如果OpenAI不响应(网络卡顿、代理失败、KEY失效),重试的次数限制
 | 
				
			||||||
MAX_RETRY = 2
 | 
					MAX_RETRY = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 插件分类默认选项
 | 
					# 插件分类默认选项
 | 
				
			||||||
default_plugin = ['学术优化', '多功能插件', '代码解析']
 | 
					DEFAULT_FN_GROUPS = ['学术优化', '多功能插件', '代码解析']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
 | 
					# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
 | 
				
			||||||
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
 | 
					LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
 | 
				
			||||||
@ -85,7 +87,7 @@ BAIDU_CLOUD_QIANFAN_MODEL = 'ERNIE-Bot'    # 可选 "ERNIE-Bot"(文心一言), "
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 如果使用ChatGLM2微调模型,请把 LLM_MODEL="chatglmft",并在此处指定模型路径
 | 
					# 如果使用ChatGLM2微调模型,请把 LLM_MODEL="chatglmft",并在此处指定模型路径
 | 
				
			||||||
ChatGLM_PTUNING_CHECKPOINT = "" # 例如"/home/hmp/ChatGLM2-6B/ptuning/output/6b-pt-128-1e-2/checkpoint-100"
 | 
					CHATGLM_PTUNING_CHECKPOINT = "" # 例如"/home/hmp/ChatGLM2-6B/ptuning/output/6b-pt-128-1e-2/checkpoint-100"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 本地LLM模型如ChatGLM的执行方式 CPU/GPU
 | 
					# 本地LLM模型如ChatGLM的执行方式 CPU/GPU
 | 
				
			||||||
@ -101,7 +103,7 @@ CONCURRENT_COUNT = 100
 | 
				
			|||||||
AUTO_CLEAR_TXT = False
 | 
					AUTO_CLEAR_TXT = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 色彩主体,可选 ["Default", "Chuanhu-Small-and-Beautiful"]
 | 
					# 色彩主题,可选 ["Default", "Chuanhu-Small-and-Beautiful"]
 | 
				
			||||||
THEME = "Chuanhu-Small-and-Beautiful"
 | 
					THEME = "Chuanhu-Small-and-Beautiful"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -216,6 +218,18 @@ ALLOW_RESET_CONFIG = False
 | 
				
			|||||||
    └── NEWBING_COOKIES
 | 
					    └── NEWBING_COOKIES
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					用户图形界面布局依赖关系示意图
 | 
				
			||||||
 | 
					│
 | 
				
			||||||
 | 
					├── CHATBOT_HEIGHT 对话窗的高度
 | 
				
			||||||
 | 
					├── CODE_HIGHLIGHT 代码高亮
 | 
				
			||||||
 | 
					├── LAYOUT 窗口布局
 | 
				
			||||||
 | 
					├── DARK_MODE 暗色模式 / 亮色模式
 | 
				
			||||||
 | 
					├── DEFAULT_FN_GROUPS 插件分类默认选项
 | 
				
			||||||
 | 
					├── THEME 色彩主题
 | 
				
			||||||
 | 
					├── AUTO_CLEAR_TXT 是否在提交时自动清空输入框
 | 
				
			||||||
 | 
					├── ADD_WAIFU 加一个live2d装饰
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
插件在线服务配置依赖关系示意图
 | 
					插件在线服务配置依赖关系示意图
 | 
				
			||||||
│
 | 
					│
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										8
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								main.py
									
									
									
									
									
								
							@ -34,7 +34,7 @@ def main():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # 高级函数插件
 | 
					    # 高级函数插件
 | 
				
			||||||
    from crazy_functional import get_crazy_functions
 | 
					    from crazy_functional import get_crazy_functions
 | 
				
			||||||
    default_plugin, = get_conf('default_plugin')
 | 
					    DEFAULT_FN_GROUPS, = get_conf('DEFAULT_FN_GROUPS')
 | 
				
			||||||
    crazy_fns_role = get_crazy_functions()
 | 
					    crazy_fns_role = get_crazy_functions()
 | 
				
			||||||
    crazy_classification = [i for i in crazy_fns_role]
 | 
					    crazy_classification = [i for i in crazy_fns_role]
 | 
				
			||||||
    crazy_fns = {}
 | 
					    crazy_fns = {}
 | 
				
			||||||
@ -93,7 +93,7 @@ def main():
 | 
				
			|||||||
                    with gr.Row():
 | 
					                    with gr.Row():
 | 
				
			||||||
                        gr.Markdown("插件可读取“输入区”文本/路径作为参数(上传文件自动修正路径)")
 | 
					                        gr.Markdown("插件可读取“输入区”文本/路径作为参数(上传文件自动修正路径)")
 | 
				
			||||||
                    plugin_dropdown = gr.Dropdown(choices=crazy_classification, label='选择插件分类',
 | 
					                    plugin_dropdown = gr.Dropdown(choices=crazy_classification, label='选择插件分类',
 | 
				
			||||||
                                                       value=default_plugin,
 | 
					                                                       value=DEFAULT_FN_GROUPS,
 | 
				
			||||||
                                                       multiselect=True, interactive=True,
 | 
					                                                       multiselect=True, interactive=True,
 | 
				
			||||||
                                                       elem_classes='normal_mut_select'
 | 
					                                                       elem_classes='normal_mut_select'
 | 
				
			||||||
                                                       ).style(container=False)
 | 
					                                                       ).style(container=False)
 | 
				
			||||||
@ -101,7 +101,7 @@ def main():
 | 
				
			|||||||
                        for role in crazy_fns_role:
 | 
					                        for role in crazy_fns_role:
 | 
				
			||||||
                            for k in crazy_fns_role[role]:
 | 
					                            for k in crazy_fns_role[role]:
 | 
				
			||||||
                                if not crazy_fns_role[role][k].get("AsButton", True): continue
 | 
					                                if not crazy_fns_role[role][k].get("AsButton", True): continue
 | 
				
			||||||
                                if role not in default_plugin:
 | 
					                                if role not in DEFAULT_FN_GROUPS:
 | 
				
			||||||
                                    variant = crazy_fns_role[role][k]["Color"] if "Color" in crazy_fns_role[role][
 | 
					                                    variant = crazy_fns_role[role][k]["Color"] if "Color" in crazy_fns_role[role][
 | 
				
			||||||
                                        k] else "secondary"
 | 
					                                        k] else "secondary"
 | 
				
			||||||
                                    crazy_fns_role[role][k]['Button'] = gr.Button(k, variant=variant,
 | 
					                                    crazy_fns_role[role][k]['Button'] = gr.Button(k, variant=variant,
 | 
				
			||||||
@ -115,7 +115,7 @@ def main():
 | 
				
			|||||||
                        with gr.Accordion("更多函数插件", open=True):
 | 
					                        with gr.Accordion("更多函数插件", open=True):
 | 
				
			||||||
                            dropdown_fn_list = []
 | 
					                            dropdown_fn_list = []
 | 
				
			||||||
                            for role in crazy_fns_role:
 | 
					                            for role in crazy_fns_role:
 | 
				
			||||||
                                if role in default_plugin:
 | 
					                                if role in DEFAULT_FN_GROUPS:
 | 
				
			||||||
                                    for k in crazy_fns_role[role]:
 | 
					                                    for k in crazy_fns_role[role]:
 | 
				
			||||||
                                        if not crazy_fns_role[role][k].get("AsButton", True):
 | 
					                                        if not crazy_fns_role[role][k].get("AsButton", True):
 | 
				
			||||||
                                            dropdown_fn_list.append(k)
 | 
					                                            dropdown_fn_list.append(k)
 | 
				
			||||||
 | 
				
			|||||||
@ -63,9 +63,9 @@ class GetGLMFTHandle(Process):
 | 
				
			|||||||
                    # if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息')
 | 
					                    # if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息')
 | 
				
			||||||
                    # with open(conf, 'r', encoding='utf8') as f:
 | 
					                    # with open(conf, 'r', encoding='utf8') as f:
 | 
				
			||||||
                    #     model_args = json.loads(f.read())
 | 
					                    #     model_args = json.loads(f.read())
 | 
				
			||||||
                    ChatGLM_PTUNING_CHECKPOINT, = get_conf('ChatGLM_PTUNING_CHECKPOINT')
 | 
					                    CHATGLM_PTUNING_CHECKPOINT, = get_conf('CHATGLM_PTUNING_CHECKPOINT')
 | 
				
			||||||
                    assert os.path.exists(ChatGLM_PTUNING_CHECKPOINT), "找不到微调模型检查点"
 | 
					                    assert os.path.exists(CHATGLM_PTUNING_CHECKPOINT), "找不到微调模型检查点"
 | 
				
			||||||
                    conf = os.path.join(ChatGLM_PTUNING_CHECKPOINT, "config.json")
 | 
					                    conf = os.path.join(CHATGLM_PTUNING_CHECKPOINT, "config.json")
 | 
				
			||||||
                    with open(conf, 'r', encoding='utf8') as f:
 | 
					                    with open(conf, 'r', encoding='utf8') as f:
 | 
				
			||||||
                        model_args = json.loads(f.read())
 | 
					                        model_args = json.loads(f.read())
 | 
				
			||||||
                    if 'model_name_or_path' not in model_args:
 | 
					                    if 'model_name_or_path' not in model_args:
 | 
				
			||||||
@ -78,9 +78,9 @@ class GetGLMFTHandle(Process):
 | 
				
			|||||||
                    config.pre_seq_len = model_args['pre_seq_len']
 | 
					                    config.pre_seq_len = model_args['pre_seq_len']
 | 
				
			||||||
                    config.prefix_projection = model_args['prefix_projection']
 | 
					                    config.prefix_projection = model_args['prefix_projection']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    print(f"Loading prefix_encoder weight from {ChatGLM_PTUNING_CHECKPOINT}")
 | 
					                    print(f"Loading prefix_encoder weight from {CHATGLM_PTUNING_CHECKPOINT}")
 | 
				
			||||||
                    model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
 | 
					                    model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
 | 
				
			||||||
                    prefix_state_dict = torch.load(os.path.join(ChatGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
 | 
					                    prefix_state_dict = torch.load(os.path.join(CHATGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
 | 
				
			||||||
                    new_prefix_state_dict = {}
 | 
					                    new_prefix_state_dict = {}
 | 
				
			||||||
                    for k, v in prefix_state_dict.items():
 | 
					                    for k, v in prefix_state_dict.items():
 | 
				
			||||||
                        if k.startswith("transformer.prefix_encoder."):
 | 
					                        if k.startswith("transformer.prefix_encoder."):
 | 
				
			||||||
 | 
				
			|||||||
@ -1001,7 +1001,7 @@ def get_plugin_default_kwargs():
 | 
				
			|||||||
    chatbot = ChatBotWithCookies(llm_kwargs)
 | 
					    chatbot = ChatBotWithCookies(llm_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port
 | 
					    # txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port
 | 
				
			||||||
    default_plugin_kwargs = {
 | 
					    DEFAULT_FN_GROUPS_kwargs = {
 | 
				
			||||||
        "main_input": "./README.md",
 | 
					        "main_input": "./README.md",
 | 
				
			||||||
        "llm_kwargs": llm_kwargs,
 | 
					        "llm_kwargs": llm_kwargs,
 | 
				
			||||||
        "plugin_kwargs": {},
 | 
					        "plugin_kwargs": {},
 | 
				
			||||||
@ -1010,7 +1010,7 @@ def get_plugin_default_kwargs():
 | 
				
			|||||||
        "system_prompt": "You are a good AI.", 
 | 
					        "system_prompt": "You are a good AI.", 
 | 
				
			||||||
        "web_port": WEB_PORT
 | 
					        "web_port": WEB_PORT
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    return default_plugin_kwargs
 | 
					    return DEFAULT_FN_GROUPS_kwargs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_chat_default_kwargs():
 | 
					def get_chat_default_kwargs():
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user