Interface with LLaMa2 from huggingface
This commit is contained in:
		
							parent
							
								
									8b3b883fce
								
							
						
					
					
						commit
						9720bec5e5
					
				@ -27,7 +27,7 @@ To translate this project to arbitary language with GPT, read and run [`multi_la
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
功能(⭐= 近期新增功能) | 描述
 | 
					功能(⭐= 近期新增功能) | 描述
 | 
				
			||||||
--- | ---
 | 
					--- | ---
 | 
				
			||||||
⭐[接入新模型](https://github.com/binary-husky/gpt_academic/wiki/%E5%A6%82%E4%BD%95%E5%88%87%E6%8D%A2%E6%A8%A1%E5%9E%8B)! | ⭐阿里达摩院[通义千问](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),上海AI-Lab[书生](https://github.com/InternLM/InternLM),讯飞[星火](https://xinghuo.xfyun.cn/)
 | 
					⭐[接入新模型](https://github.com/binary-husky/gpt_academic/wiki/%E5%A6%82%E4%BD%95%E5%88%87%E6%8D%A2%E6%A8%A1%E5%9E%8B)! | ⭐阿里达摩院[通义千问](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),上海AI-Lab[书生](https://github.com/InternLM/InternLM),讯飞[星火](https://xinghuo.xfyun.cn/),[LLaMa2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
 | 
				
			||||||
一键润色 | 支持一键润色、一键查找论文语法错误
 | 
					一键润色 | 支持一键润色、一键查找论文语法错误
 | 
				
			||||||
一键中英互译 | 一键中英互译
 | 
					一键中英互译 | 一键中英互译
 | 
				
			||||||
一键代码解释 | 显示代码、解释代码、生成代码、给代码加注释
 | 
					一键代码解释 | 显示代码、解释代码、生成代码、给代码加注释
 | 
				
			||||||
 | 
				
			|||||||
@ -149,4 +149,8 @@ ANTHROPIC_API_KEY = ""
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 自定义API KEY格式
 | 
					# 自定义API KEY格式
 | 
				
			||||||
CUSTOM_API_KEY_PATTERN = ""
 | 
					CUSTOM_API_KEY_PATTERN = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# HUGGINGFACE的TOKEN 下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
 | 
				
			||||||
 | 
					HUGGINGFACE_ACCESS_TOKEN = "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"
 | 
				
			||||||
@ -385,6 +385,22 @@ if "spark" in AVAIL_LLM_MODELS:   # 讯飞星火认知大模型
 | 
				
			|||||||
        })
 | 
					        })
 | 
				
			||||||
    except:
 | 
					    except:
 | 
				
			||||||
        print(trimmed_format_exc())
 | 
					        print(trimmed_format_exc())
 | 
				
			||||||
 | 
					if "llama2" in AVAIL_LLM_MODELS:   # 讯飞星火认知大模型
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        from .bridge_llama2 import predict_no_ui_long_connection as llama2_noui
 | 
				
			||||||
 | 
					        from .bridge_llama2 import predict as llama2_ui
 | 
				
			||||||
 | 
					        model_info.update({
 | 
				
			||||||
 | 
					            "llama2": {
 | 
				
			||||||
 | 
					                "fn_with_ui": llama2_ui,
 | 
				
			||||||
 | 
					                "fn_without_ui": llama2_noui,
 | 
				
			||||||
 | 
					                "endpoint": None,
 | 
				
			||||||
 | 
					                "max_token": 4096,
 | 
				
			||||||
 | 
					                "tokenizer": tokenizer_gpt35,
 | 
				
			||||||
 | 
					                "token_cnt": get_token_num_gpt35,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        print(trimmed_format_exc())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										91
									
								
								request_llm/bridge_llama2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								request_llm/bridge_llama2.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
				
			|||||||
 | 
					model_name = "LLaMA"
 | 
				
			||||||
 | 
					cmd_to_install = "`pip install -r request_llm/requirements_chatglm.txt`"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 | 
				
			||||||
 | 
					from toolbox import update_ui, get_conf, ProxyNetworkActivate
 | 
				
			||||||
 | 
					from multiprocessing import Process, Pipe
 | 
				
			||||||
 | 
					from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns, SingletonLocalLLM
 | 
				
			||||||
 | 
					from threading import Thread
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ------------------------------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					# 🔌💻 Local Model
 | 
				
			||||||
 | 
					# ------------------------------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					@SingletonLocalLLM
 | 
				
			||||||
 | 
					class GetONNXGLMHandle(LocalLLMHandle):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_model_info(self):
 | 
				
			||||||
 | 
					        # 🏃♂️🏃♂️🏃♂️ 子进程执行
 | 
				
			||||||
 | 
					        self.model_name = model_name
 | 
				
			||||||
 | 
					        self.cmd_to_install = cmd_to_install
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_model_and_tokenizer(self):
 | 
				
			||||||
 | 
					        # 🏃♂️🏃♂️🏃♂️ 子进程执行
 | 
				
			||||||
 | 
					        import os, glob
 | 
				
			||||||
 | 
					        import os
 | 
				
			||||||
 | 
					        import platform
 | 
				
			||||||
 | 
					        huggingface_token, device = get_conf('HUGGINGFACE_ACCESS_TOKEN', 'LOCAL_MODEL_DEVICE')
 | 
				
			||||||
 | 
					        assert len(huggingface_token) != 0, "没有填写 HUGGINGFACE_ACCESS_TOKEN"
 | 
				
			||||||
 | 
					        with open(os.path.expanduser('~/.cache/huggingface/token'), 'w') as f:
 | 
				
			||||||
 | 
					            f.write(huggingface_token)
 | 
				
			||||||
 | 
					        model_id = 'meta-llama/Llama-2-7b-chat-hf'
 | 
				
			||||||
 | 
					        with ProxyNetworkActivate():
 | 
				
			||||||
 | 
					            self._tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=huggingface_token)
 | 
				
			||||||
 | 
					            # use fp16
 | 
				
			||||||
 | 
					            model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=huggingface_token).eval()
 | 
				
			||||||
 | 
					            if device.startswith('cuda'): model = model.half().to(device)
 | 
				
			||||||
 | 
					            self._model = model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return self._model, self._tokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def llm_stream_generator(self, **kwargs):
 | 
				
			||||||
 | 
					        # 🏃♂️🏃♂️🏃♂️ 子进程执行
 | 
				
			||||||
 | 
					        def adaptor(kwargs):
 | 
				
			||||||
 | 
					            query = kwargs['query']
 | 
				
			||||||
 | 
					            max_length = kwargs['max_length']
 | 
				
			||||||
 | 
					            top_p = kwargs['top_p']
 | 
				
			||||||
 | 
					            temperature = kwargs['temperature']
 | 
				
			||||||
 | 
					            history = kwargs['history']
 | 
				
			||||||
 | 
					            console_slience = kwargs.get('console_slience', True)
 | 
				
			||||||
 | 
					            return query, max_length, top_p, temperature, history, console_slience
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        def convert_messages_to_prompt(query, history):
 | 
				
			||||||
 | 
					            prompt = ""
 | 
				
			||||||
 | 
					            for a, b in history:
 | 
				
			||||||
 | 
					                prompt += f"\n[INST]{a}[/INST]"
 | 
				
			||||||
 | 
					                prompt += "\n{b}" + b
 | 
				
			||||||
 | 
					            prompt += f"\n[INST]{query}[/INST]"
 | 
				
			||||||
 | 
					            return prompt
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        query, max_length, top_p, temperature, history, console_slience = adaptor(kwargs)
 | 
				
			||||||
 | 
					        prompt = convert_messages_to_prompt(query, history)
 | 
				
			||||||
 | 
					        # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-
 | 
				
			||||||
 | 
					        # code from transformers.llama
 | 
				
			||||||
 | 
					        streamer = TextIteratorStreamer(self._tokenizer)
 | 
				
			||||||
 | 
					        # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
 | 
				
			||||||
 | 
					        inputs = self._tokenizer([prompt], return_tensors="pt")
 | 
				
			||||||
 | 
					        prompt_tk_back = self._tokenizer.batch_decode(inputs['input_ids'])[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        generation_kwargs = dict(inputs.to(self._model.device), streamer=streamer, max_new_tokens=max_length)
 | 
				
			||||||
 | 
					        thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
 | 
				
			||||||
 | 
					        thread.start()
 | 
				
			||||||
 | 
					        generated_text = ""
 | 
				
			||||||
 | 
					        for new_text in streamer: 
 | 
				
			||||||
 | 
					            generated_text += new_text
 | 
				
			||||||
 | 
					            if not console_slience: print(new_text, end='')
 | 
				
			||||||
 | 
					            yield generated_text.lstrip(prompt_tk_back).rstrip("</s>")
 | 
				
			||||||
 | 
					        if not console_slience: print()
 | 
				
			||||||
 | 
					        # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    def try_to_import_special_deps(self, **kwargs):
 | 
				
			||||||
 | 
					        # import something that will raise error if the user does not install requirement_*.txt
 | 
				
			||||||
 | 
					        # 🏃♂️🏃♂️🏃♂️ 主进程执行
 | 
				
			||||||
 | 
					        import importlib
 | 
				
			||||||
 | 
					        importlib.import_module('transformers')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ------------------------------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					# 🔌💻 GPT-Academic Interface
 | 
				
			||||||
 | 
					# ------------------------------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					predict_no_ui_long_connection, predict = get_local_llm_predict_fns(GetONNXGLMHandle, model_name)
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user