merge success
This commit is contained in:
		
							parent
							
								
									babb775cfb
								
							
						
					
					
						commit
						fd549fb986
					
				@ -70,7 +70,7 @@ MAX_RETRY = 2
 | 
			
		||||
 | 
			
		||||
# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
 | 
			
		||||
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
 | 
			
		||||
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"]
 | 
			
		||||
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "internlm", "newbing", "stack-claude"]
 | 
			
		||||
# P.S. 其他可用的模型还包括 ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "claude-1-100k", "claude-2", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -248,7 +248,6 @@ if "moss" in AVAIL_LLM_MODELS:
 | 
			
		||||
if "stack-claude" in AVAIL_LLM_MODELS:
 | 
			
		||||
    from .bridge_stackclaude import predict_no_ui_long_connection as claude_noui
 | 
			
		||||
    from .bridge_stackclaude import predict as claude_ui
 | 
			
		||||
    # claude
 | 
			
		||||
    model_info.update({
 | 
			
		||||
        "stack-claude": {
 | 
			
		||||
            "fn_with_ui": claude_ui,
 | 
			
		||||
@ -263,7 +262,6 @@ if "newbing-free" in AVAIL_LLM_MODELS:
 | 
			
		||||
    try:
 | 
			
		||||
        from .bridge_newbingfree import predict_no_ui_long_connection as newbingfree_noui
 | 
			
		||||
        from .bridge_newbingfree import predict as newbingfree_ui
 | 
			
		||||
        # claude
 | 
			
		||||
        model_info.update({
 | 
			
		||||
            "newbing-free": {
 | 
			
		||||
                "fn_with_ui": newbingfree_ui,
 | 
			
		||||
@ -280,7 +278,6 @@ if "newbing" in AVAIL_LLM_MODELS:   # same with newbing-free
 | 
			
		||||
    try:
 | 
			
		||||
        from .bridge_newbingfree import predict_no_ui_long_connection as newbingfree_noui
 | 
			
		||||
        from .bridge_newbingfree import predict as newbingfree_ui
 | 
			
		||||
        # claude
 | 
			
		||||
        model_info.update({
 | 
			
		||||
            "newbing": {
 | 
			
		||||
                "fn_with_ui": newbingfree_ui,
 | 
			
		||||
@ -297,7 +294,6 @@ if "chatglmft" in AVAIL_LLM_MODELS:   # same with newbing-free
 | 
			
		||||
    try:
 | 
			
		||||
        from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui
 | 
			
		||||
        from .bridge_chatglmft import predict as chatglmft_ui
 | 
			
		||||
        # claude
 | 
			
		||||
        model_info.update({
 | 
			
		||||
            "chatglmft": {
 | 
			
		||||
                "fn_with_ui": chatglmft_ui,
 | 
			
		||||
@ -310,7 +306,22 @@ if "chatglmft" in AVAIL_LLM_MODELS:   # same with newbing-free
 | 
			
		||||
        })
 | 
			
		||||
    except:
 | 
			
		||||
        print(trimmed_format_exc())
 | 
			
		||||
 | 
			
		||||
if "internlm" in AVAIL_LLM_MODELS:
 | 
			
		||||
    try:
 | 
			
		||||
        from .bridge_internlm import predict_no_ui_long_connection as internlm_noui
 | 
			
		||||
        from .bridge_internlm import predict as internlm_ui
 | 
			
		||||
        model_info.update({
 | 
			
		||||
            "internlm": {
 | 
			
		||||
                "fn_with_ui": internlm_ui,
 | 
			
		||||
                "fn_without_ui": internlm_noui,
 | 
			
		||||
                "endpoint": None,
 | 
			
		||||
                "max_token": 4096,
 | 
			
		||||
                "tokenizer": tokenizer_gpt35,
 | 
			
		||||
                "token_cnt": get_token_num_gpt35,
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
    except:
 | 
			
		||||
        print(trimmed_format_exc())
 | 
			
		||||
 | 
			
		||||
def LLM_CATCH_EXCEPTION(f):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -12,6 +12,22 @@ load_message = f"{model_name}尚未加载,加载需要一段时间。注意,
 | 
			
		||||
def try_to_import_special_deps():
 | 
			
		||||
    import sentencepiece
 | 
			
		||||
 | 
			
		||||
user_prompt = "<|User|>:{user}<eoh>\n"
 | 
			
		||||
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
 | 
			
		||||
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def combine_history(prompt, hist):
 | 
			
		||||
    messages = hist
 | 
			
		||||
    total_prompt = ""
 | 
			
		||||
    for message in messages:
 | 
			
		||||
        cur_content = message
 | 
			
		||||
        cur_prompt = user_prompt.replace("{user}", cur_content[0])
 | 
			
		||||
        total_prompt += cur_prompt
 | 
			
		||||
        cur_prompt = robot_prompt.replace("{robot}", cur_content[1])
 | 
			
		||||
        total_prompt += cur_prompt
 | 
			
		||||
    total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
 | 
			
		||||
    return total_prompt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@Singleton
 | 
			
		||||
@ -44,10 +60,10 @@ class GetInternlmHandle(Process):
 | 
			
		||||
            else:
 | 
			
		||||
                model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()
 | 
			
		||||
 | 
			
		||||
            self._model = self._model.eval()
 | 
			
		||||
            model = model.eval()
 | 
			
		||||
        return model, tokenizer
 | 
			
		||||
 | 
			
		||||
    def llm_stream_generator(self, kwargs):
 | 
			
		||||
    def llm_stream_generator(self, **kwargs):
 | 
			
		||||
        import torch
 | 
			
		||||
        import logging
 | 
			
		||||
        import copy
 | 
			
		||||
@ -63,12 +79,16 @@ class GetInternlmHandle(Process):
 | 
			
		||||
            max_length = kwargs['max_length']
 | 
			
		||||
            top_p = kwargs['top_p']
 | 
			
		||||
            temperature = kwargs['temperature']
 | 
			
		||||
            return model, tokenizer, prompt, max_length, top_p, temperature
 | 
			
		||||
            history = kwargs['history']
 | 
			
		||||
            real_prompt = combine_history(prompt, history)
 | 
			
		||||
            return model, tokenizer, real_prompt, max_length, top_p, temperature
 | 
			
		||||
        
 | 
			
		||||
        model, tokenizer, prompt, max_length, top_p, temperature = adaptor()
 | 
			
		||||
        prefix_allowed_tokens_fn = None
 | 
			
		||||
        logger = logging.get_logger(__name__)
 | 
			
		||||
        logits_processor = None
 | 
			
		||||
        stopping_criteria = None
 | 
			
		||||
        additional_eos_token_id = 103028
 | 
			
		||||
        generation_config = None
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ 子进程执行
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ https://github.com/InternLM/InternLM/blob/efbf5335709a8c8faeac6eaf07193973ff1d56a1/web_demo.py#L25
 | 
			
		||||
 | 
			
		||||
@ -98,7 +118,7 @@ class GetInternlmHandle(Process):
 | 
			
		||||
        elif generation_config.max_new_tokens is not None:
 | 
			
		||||
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
 | 
			
		||||
            if not has_default_max_length:
 | 
			
		||||
                logger.warn(
 | 
			
		||||
                logging.warn(
 | 
			
		||||
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
 | 
			
		||||
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
 | 
			
		||||
                    "Please refer to the documentation for more information. "
 | 
			
		||||
@ -108,7 +128,7 @@ class GetInternlmHandle(Process):
 | 
			
		||||
 | 
			
		||||
        if input_ids_seq_length >= generation_config.max_length:
 | 
			
		||||
            input_ids_string = "input_ids"
 | 
			
		||||
            logger.warning(
 | 
			
		||||
            logging.warning(
 | 
			
		||||
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
 | 
			
		||||
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
 | 
			
		||||
                " increasing `max_new_tokens`."
 | 
			
		||||
 | 
			
		||||
@ -23,45 +23,8 @@ if __name__ == "__main__":
 | 
			
		||||
        'temperature': 1,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    result = predict_no_ui_long_connection(inputs="你好", 
 | 
			
		||||
                                        llm_kwargs=llm_kwargs,
 | 
			
		||||
                                        history=[],
 | 
			
		||||
                                        sys_prompt="")
 | 
			
		||||
    result = predict_no_ui_long_connection( inputs="请问什么是质子?", 
 | 
			
		||||
                                            llm_kwargs=llm_kwargs,
 | 
			
		||||
                                            history=["你好", "我好!"],
 | 
			
		||||
                                            sys_prompt="")
 | 
			
		||||
    print('final result:', result)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # # print(result)
 | 
			
		||||
    # from multiprocessing import Process, Pipe
 | 
			
		||||
    # class GetGLMHandle(Process):
 | 
			
		||||
    #     def __init__(self):
 | 
			
		||||
    #         super().__init__(daemon=True)
 | 
			
		||||
    #         pass
 | 
			
		||||
    #     def run(self):
 | 
			
		||||
    #         # 子进程执行
 | 
			
		||||
    #         # 第一次运行,加载参数
 | 
			
		||||
    #         def validate_path():
 | 
			
		||||
    #             import os, sys
 | 
			
		||||
    #             dir_name = os.path.dirname(__file__)
 | 
			
		||||
    #             root_dir_assume = os.path.abspath(os.path.dirname(__file__) +  '/..')
 | 
			
		||||
    #             os.chdir(root_dir_assume + '/request_llm/jittorllms')
 | 
			
		||||
    #             sys.path.append(root_dir_assume + '/request_llm/jittorllms')
 | 
			
		||||
    #         validate_path() # validate path so you can run from base directory
 | 
			
		||||
    #         jittorllms_model = None
 | 
			
		||||
    #         import types
 | 
			
		||||
    #         try:
 | 
			
		||||
    #             if jittorllms_model is None:
 | 
			
		||||
    #                 from models import get_model
 | 
			
		||||
    #                 # availabel_models = ["chatglm", "pangualpha", "llama", "chatrwkv"]
 | 
			
		||||
    #                 args_dict = {'model': 'chatrwkv'}
 | 
			
		||||
    #                 print('self.jittorllms_model = get_model(types.SimpleNamespace(**args_dict))')
 | 
			
		||||
    #                 jittorllms_model = get_model(types.SimpleNamespace(**args_dict))
 | 
			
		||||
    #                 print('done get model')
 | 
			
		||||
    #         except:
 | 
			
		||||
    #             # self.child.send('[Local Message] Call jittorllms fail 不能正常加载jittorllms的参数。')
 | 
			
		||||
    #             raise RuntimeError("不能正常加载jittorllms的参数!")
 | 
			
		||||
    # x = GetGLMHandle()
 | 
			
		||||
    # x.start()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # input()
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user