up
This commit is contained in:
		
							parent
							
								
									d79dfe2fc7
								
							
						
					
					
						commit
						bfa6661367
					
				@ -34,7 +34,7 @@ WEB_PORT = -1
 | 
				
			|||||||
MAX_RETRY = 2
 | 
					MAX_RETRY = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# OpenAI模型选择是(gpt4现在只对申请成功的人开放)
 | 
					# OpenAI模型选择是(gpt4现在只对申请成功的人开放)
 | 
				
			||||||
LLM_MODEL = "gpt-3.5-turbo"
 | 
					LLM_MODEL = "pygmalion-1.3b@localhost@7860" # "gpt-3.5-turbo"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# OpenAI的API_URL
 | 
					# OpenAI的API_URL
 | 
				
			||||||
API_URL = "https://api.openai.com/v1/chat/completions"
 | 
					API_URL = "https://api.openai.com/v1/chat/completions"
 | 
				
			||||||
 | 
				
			|||||||
@ -112,8 +112,7 @@ def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_pr
 | 
				
			|||||||
    return result
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', 
 | 
					def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', stream = True, additional_fn=None):
 | 
				
			||||||
            stream = True, additional_fn=None):
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        发送至chatGPT,流式获取输出。
 | 
					        发送至chatGPT,流式获取输出。
 | 
				
			||||||
        用于基础的对话功能。
 | 
					        用于基础的对话功能。
 | 
				
			||||||
@ -244,3 +243,8 @@ def generate_payload(inputs, top_p, temperature, history, system_prompt, stream)
 | 
				
			|||||||
    return headers,payload
 | 
					    return headers,payload
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if not LLM_MODEL.startswith('gpt'):
 | 
				
			||||||
 | 
					    from request_llm.bridge_tgui import predict_tgui
 | 
				
			||||||
 | 
					    predict = predict_tgui
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
							
								
								
									
										137
									
								
								request_llm/bridge_tgui.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								request_llm/bridge_tgui.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,137 @@
 | 
				
			|||||||
 | 
					'''
 | 
				
			||||||
 | 
					Contributed by SagsMug. Modified by binary-husky
 | 
				
			||||||
 | 
					https://github.com/oobabooga/text-generation-webui/pull/175
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					import string
 | 
				
			||||||
 | 
					import websockets
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
 | 
					from toolbox import get_conf
 | 
				
			||||||
 | 
					LLM_MODEL, = get_conf('LLM_MODEL')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					model_name, addr, port = LLM_MODEL.split('@')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def random_hash():
 | 
				
			||||||
 | 
					    letters = string.ascii_lowercase + string.digits
 | 
				
			||||||
 | 
					    return ''.join(random.choice(letters) for i in range(9))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def run(context):
 | 
				
			||||||
 | 
					    params = {
 | 
				
			||||||
 | 
					        'max_new_tokens': 200,
 | 
				
			||||||
 | 
					        'do_sample': True,
 | 
				
			||||||
 | 
					        'temperature': 0.5,
 | 
				
			||||||
 | 
					        'top_p': 0.9,
 | 
				
			||||||
 | 
					        'typical_p': 1,
 | 
				
			||||||
 | 
					        'repetition_penalty': 1.05,
 | 
				
			||||||
 | 
					        'encoder_repetition_penalty': 1.0,
 | 
				
			||||||
 | 
					        'top_k': 0,
 | 
				
			||||||
 | 
					        'min_length': 0,
 | 
				
			||||||
 | 
					        'no_repeat_ngram_size': 0,
 | 
				
			||||||
 | 
					        'num_beams': 1,
 | 
				
			||||||
 | 
					        'penalty_alpha': 0,
 | 
				
			||||||
 | 
					        'length_penalty': 1,
 | 
				
			||||||
 | 
					        'early_stopping': False,
 | 
				
			||||||
 | 
					        'seed': -1,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    session = random_hash()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async with websockets.connect(f"ws://{addr}:{port}/queue/join") as websocket:
 | 
				
			||||||
 | 
					        while content := json.loads(await websocket.recv()):
 | 
				
			||||||
 | 
					            #Python3.10 syntax, replace with if elif on older
 | 
				
			||||||
 | 
					            if content["msg"] ==  "send_hash":
 | 
				
			||||||
 | 
					                await websocket.send(json.dumps({
 | 
				
			||||||
 | 
					                    "session_hash": session,
 | 
				
			||||||
 | 
					                    "fn_index": 12
 | 
				
			||||||
 | 
					                }))
 | 
				
			||||||
 | 
					            elif content["msg"] ==  "estimation":
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
 | 
					            elif content["msg"] ==  "send_data":
 | 
				
			||||||
 | 
					                await websocket.send(json.dumps({
 | 
				
			||||||
 | 
					                    "session_hash": session,
 | 
				
			||||||
 | 
					                    "fn_index": 12,
 | 
				
			||||||
 | 
					                    "data": [
 | 
				
			||||||
 | 
					                        context,
 | 
				
			||||||
 | 
					                        params['max_new_tokens'],
 | 
				
			||||||
 | 
					                        params['do_sample'],
 | 
				
			||||||
 | 
					                        params['temperature'],
 | 
				
			||||||
 | 
					                        params['top_p'],
 | 
				
			||||||
 | 
					                        params['typical_p'],
 | 
				
			||||||
 | 
					                        params['repetition_penalty'],
 | 
				
			||||||
 | 
					                        params['encoder_repetition_penalty'],
 | 
				
			||||||
 | 
					                        params['top_k'],
 | 
				
			||||||
 | 
					                        params['min_length'],
 | 
				
			||||||
 | 
					                        params['no_repeat_ngram_size'],
 | 
				
			||||||
 | 
					                        params['num_beams'],
 | 
				
			||||||
 | 
					                        params['penalty_alpha'],
 | 
				
			||||||
 | 
					                        params['length_penalty'],
 | 
				
			||||||
 | 
					                        params['early_stopping'],
 | 
				
			||||||
 | 
					                        params['seed'],
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					                }))
 | 
				
			||||||
 | 
					            elif content["msg"] ==  "process_starts":
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
 | 
					            elif content["msg"] in ["process_generating", "process_completed"]:
 | 
				
			||||||
 | 
					                yield content["output"]["data"][0]
 | 
				
			||||||
 | 
					                # You can search for your desired end indicator and 
 | 
				
			||||||
 | 
					                #  stop generation by closing the websocket here
 | 
				
			||||||
 | 
					                if (content["msg"] == "process_completed"):
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def predict_tgui(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', stream = True, additional_fn=None):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					        发送至chatGPT,流式获取输出。
 | 
				
			||||||
 | 
					        用于基础的对话功能。
 | 
				
			||||||
 | 
					        inputs 是本次问询的输入
 | 
				
			||||||
 | 
					        top_p, temperature是chatGPT的内部调优参数
 | 
				
			||||||
 | 
					        history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
 | 
				
			||||||
 | 
					        chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
 | 
				
			||||||
 | 
					        additional_fn代表点击的哪个按钮,按钮见functional.py
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if additional_fn is not None:
 | 
				
			||||||
 | 
					        import functional
 | 
				
			||||||
 | 
					        importlib.reload(functional)    # 热更新prompt
 | 
				
			||||||
 | 
					        functional = functional.get_functionals()
 | 
				
			||||||
 | 
					        if "PreProcess" in functional[additional_fn]: inputs = functional[additional_fn]["PreProcess"](inputs)  # 获取预处理函数(如果有的话)
 | 
				
			||||||
 | 
					        inputs = functional[additional_fn]["Prefix"] + inputs + functional[additional_fn]["Suffix"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    raw_input = inputs
 | 
				
			||||||
 | 
					    logging.info(f'[raw_input] {raw_input}')
 | 
				
			||||||
 | 
					    chatbot.append((inputs, ""))
 | 
				
			||||||
 | 
					    yield chatbot, history, "等待响应"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    prompt = inputs
 | 
				
			||||||
 | 
					    tgui_say = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mutable = [""]
 | 
				
			||||||
 | 
					    def run_coorotine(mutable):
 | 
				
			||||||
 | 
					        async def get_result():
 | 
				
			||||||
 | 
					            async for response in run(prompt):
 | 
				
			||||||
 | 
					                # Print intermediate steps
 | 
				
			||||||
 | 
					                mutable += response
 | 
				
			||||||
 | 
					        asyncio.run(get_result())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    thread_listen = threading.Thread(target=run_coorotine, args=(mutable,))
 | 
				
			||||||
 | 
					    thread_listen.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    while thread_listen.is_alive():
 | 
				
			||||||
 | 
					        time.sleep(1)
 | 
				
			||||||
 | 
					        # Print intermediate steps
 | 
				
			||||||
 | 
					        if tgui_say != mutable[0]:
 | 
				
			||||||
 | 
					            tgui_say = mutable[0]
 | 
				
			||||||
 | 
					            history[-1] = tgui_say
 | 
				
			||||||
 | 
					            chatbot[-1] = (history[-2], history[-1])
 | 
				
			||||||
 | 
					            yield chatbot, history, status_text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info(f'[response] {tgui_say}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user