微调对话裁剪
This commit is contained in:
		
							parent
							
								
									676fe40d39
								
							
						
					
					
						commit
						0785ff2aed
					
				@ -200,7 +200,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
 | 
				
			|||||||
                    if "reduce the length" in error_msg:
 | 
					                    if "reduce the length" in error_msg:
 | 
				
			||||||
                        if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
 | 
					                        if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
 | 
				
			||||||
                        history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'], 
 | 
					                        history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'], 
 | 
				
			||||||
                                               max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])//2) # history至少释放二分之一
 | 
					                                               max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一
 | 
				
			||||||
                        chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
 | 
					                        chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
 | 
				
			||||||
                        # history = []    # 清除历史
 | 
					                        # history = []    # 清除历史
 | 
				
			||||||
                    elif "does not exist" in error_msg:
 | 
					                    elif "does not exist" in error_msg:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										15
									
								
								toolbox.py
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								toolbox.py
									
									
									
									
									
								
							@ -555,12 +555,12 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def clip_history(inputs, history, tokenizer, max_token_limit):
 | 
					def clip_history(inputs, history, tokenizer, max_token_limit):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    reduce the length of input/history by clipping.
 | 
					    reduce the length of history by clipping.
 | 
				
			||||||
    this function search for the longest entries to clip, little by little,
 | 
					    this function search for the longest entries to clip, little by little,
 | 
				
			||||||
    until the number of token of input/history is reduced under threshold.
 | 
					    until the number of token of history is reduced under threshold.
 | 
				
			||||||
    通过剪辑来缩短输入/历史记录的长度。 
 | 
					    通过裁剪来缩短历史记录的长度。 
 | 
				
			||||||
    此函数逐渐地搜索最长的条目进行剪辑,
 | 
					    此函数逐渐地搜索最长的条目进行剪辑,
 | 
				
			||||||
    直到输入/历史记录的标记数量降低到阈值以下。
 | 
					    直到历史记录的标记数量降低到阈值以下。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    import numpy as np
 | 
					    import numpy as np
 | 
				
			||||||
    from request_llm.bridge_all import model_info
 | 
					    from request_llm.bridge_all import model_info
 | 
				
			||||||
@ -568,10 +568,13 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
 | 
				
			|||||||
        return len(tokenizer.encode(txt, disallowed_special=()))
 | 
					        return len(tokenizer.encode(txt, disallowed_special=()))
 | 
				
			||||||
    input_token_num = get_token_num(inputs)
 | 
					    input_token_num = get_token_num(inputs)
 | 
				
			||||||
    if input_token_num < max_token_limit * 3 / 4:
 | 
					    if input_token_num < max_token_limit * 3 / 4:
 | 
				
			||||||
        # 当输入部分的token占比小于限制的3/4时,在裁剪时把input的余量留出来
 | 
					        # 当输入部分的token占比小于限制的3/4时,裁剪时
 | 
				
			||||||
 | 
					        # 1. 把input的余量留出来
 | 
				
			||||||
        max_token_limit = max_token_limit - input_token_num
 | 
					        max_token_limit = max_token_limit - input_token_num
 | 
				
			||||||
 | 
					        # 2. 把输出用的余量留出来
 | 
				
			||||||
 | 
					        max_token_limit = max_token_limit - 128
 | 
				
			||||||
 | 
					        # 3. 如果余量太小了,直接清除历史
 | 
				
			||||||
        if max_token_limit < 128:
 | 
					        if max_token_limit < 128:
 | 
				
			||||||
            # 余量太小了,直接清除历史
 | 
					 | 
				
			||||||
            history = []
 | 
					            history = []
 | 
				
			||||||
            return history
 | 
					            return history
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user