Merge branch 'image_understanding_spark' of https://github.com/Kilig947/gpt_academic into Kilig947-image_understanding_spark
This commit is contained in:
		
						commit
						6ed88fe848
					
				@ -16,28 +16,13 @@ import base64
 | 
			
		||||
import os
 | 
			
		||||
import glob
 | 
			
		||||
 | 
			
		||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, is_the_upload_folder, update_ui_lastest_msg, get_max_token
 | 
			
		||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, is_the_upload_folder, update_ui_lastest_msg, get_max_token, encode_image, have_any_recent_upload_image_files
 | 
			
		||||
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \
 | 
			
		||||
    get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG', 'AZURE_CFG_ARRAY')
 | 
			
		||||
 | 
			
		||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
 | 
			
		||||
                  '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
 | 
			
		||||
 | 
			
		||||
def have_any_recent_upload_image_files(chatbot):
 | 
			
		||||
    _5min = 5 * 60
 | 
			
		||||
    if chatbot is None: return False, None    # chatbot is None
 | 
			
		||||
    most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
 | 
			
		||||
    if not most_recent_uploaded: return False, None   # most_recent_uploaded is None
 | 
			
		||||
    if time.time() - most_recent_uploaded["time"] < _5min: 
 | 
			
		||||
        most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
 | 
			
		||||
        path = most_recent_uploaded['path']
 | 
			
		||||
        file_manifest = [f for f in glob.glob(f'{path}/**/*.jpg', recursive=True)]
 | 
			
		||||
        file_manifest += [f for f in glob.glob(f'{path}/**/*.jpeg', recursive=True)]
 | 
			
		||||
        file_manifest += [f for f in glob.glob(f'{path}/**/*.png', recursive=True)]
 | 
			
		||||
        if len(file_manifest) == 0: return False, None
 | 
			
		||||
        return True, file_manifest # most_recent_uploaded is new
 | 
			
		||||
    else: 
 | 
			
		||||
        return False, None  # most_recent_uploaded is too old
 | 
			
		||||
 | 
			
		||||
def report_invalid_key(key):
 | 
			
		||||
    if get_conf("BLOCK_INVALID_APIKEY"): 
 | 
			
		||||
@ -258,10 +243,6 @@ def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg,
 | 
			
		||||
        chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk_decoded)}")
 | 
			
		||||
    return chatbot, history
 | 
			
		||||
 | 
			
		||||
# Function to encode the image
 | 
			
		||||
def encode_image(image_path):
 | 
			
		||||
    with open(image_path, "rb") as image_file:
 | 
			
		||||
        return base64.b64encode(image_file.read()).decode('utf-8')
 | 
			
		||||
 | 
			
		||||
def generate_payload(inputs, llm_kwargs, history, system_prompt, image_paths):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
from toolbox import get_conf
 | 
			
		||||
from toolbox import get_conf, get_pictures_list, encode_image
 | 
			
		||||
import base64
 | 
			
		||||
import datetime
 | 
			
		||||
import hashlib
 | 
			
		||||
@ -65,6 +65,7 @@ class SparkRequestInstance():
 | 
			
		||||
        self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"
 | 
			
		||||
        self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat"
 | 
			
		||||
        self.gpt_url_v3 = "ws://spark-api.xf-yun.com/v3.1/chat"
 | 
			
		||||
        self.gpt_url_img = "wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image"
 | 
			
		||||
 | 
			
		||||
        self.time_to_yield_event = threading.Event()
 | 
			
		||||
        self.time_to_exit_event = threading.Event()
 | 
			
		||||
@ -92,7 +93,11 @@ class SparkRequestInstance():
 | 
			
		||||
            gpt_url = self.gpt_url_v3
 | 
			
		||||
        else:
 | 
			
		||||
            gpt_url = self.gpt_url
 | 
			
		||||
 | 
			
		||||
        file_manifest = []
 | 
			
		||||
        if llm_kwargs.get('most_recent_uploaded'):
 | 
			
		||||
            if llm_kwargs['most_recent_uploaded'].get('path'):
 | 
			
		||||
                file_manifest = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
 | 
			
		||||
                gpt_url = self.gpt_url_img
 | 
			
		||||
        wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url)
 | 
			
		||||
        websocket.enableTrace(False)
 | 
			
		||||
        wsUrl = wsParam.create_url()
 | 
			
		||||
@ -101,9 +106,8 @@ class SparkRequestInstance():
 | 
			
		||||
        def on_open(ws):
 | 
			
		||||
            import _thread as thread
 | 
			
		||||
            thread.start_new_thread(run, (ws,))
 | 
			
		||||
 | 
			
		||||
        def run(ws, *args):
 | 
			
		||||
            data = json.dumps(gen_params(ws.appid, *ws.all_args))
 | 
			
		||||
            data = json.dumps(gen_params(ws.appid, *ws.all_args, file_manifest))
 | 
			
		||||
            ws.send(data)
 | 
			
		||||
 | 
			
		||||
        # 收到websocket消息的处理
 | 
			
		||||
@ -142,9 +146,18 @@ class SparkRequestInstance():
 | 
			
		||||
        ws.all_args = (inputs, llm_kwargs, history, system_prompt)
 | 
			
		||||
        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
 | 
			
		||||
 | 
			
		||||
def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
 | 
			
		||||
def generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest):
 | 
			
		||||
    conversation_cnt = len(history) // 2
 | 
			
		||||
    messages = [{"role": "system", "content": system_prompt}]
 | 
			
		||||
    messages = []
 | 
			
		||||
    if file_manifest:
 | 
			
		||||
        base64_images = []
 | 
			
		||||
        for image_path in file_manifest:
 | 
			
		||||
            base64_images.append(encode_image(image_path))
 | 
			
		||||
        for img_s in base64_images:
 | 
			
		||||
            if img_s not in str(messages):
 | 
			
		||||
                messages.append({"role": "user", "content": img_s, "content_type": "image"})
 | 
			
		||||
    else:
 | 
			
		||||
        messages = [{"role": "system", "content": system_prompt}]
 | 
			
		||||
    if conversation_cnt:
 | 
			
		||||
        for index in range(0, 2*conversation_cnt, 2):
 | 
			
		||||
            what_i_have_asked = {}
 | 
			
		||||
@ -167,7 +180,7 @@ def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
 | 
			
		||||
    return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
 | 
			
		||||
def gen_params(appid, inputs, llm_kwargs, history, system_prompt, file_manifest):
 | 
			
		||||
    """
 | 
			
		||||
    通过appid和用户的提问来生成请参数
 | 
			
		||||
    """
 | 
			
		||||
@ -176,6 +189,8 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
 | 
			
		||||
        "sparkv2": "generalv2",
 | 
			
		||||
        "sparkv3": "generalv3",
 | 
			
		||||
    }
 | 
			
		||||
    domains_select = domains[llm_kwargs['llm_model']]
 | 
			
		||||
    if file_manifest: domains_select = 'image'
 | 
			
		||||
    data = {
 | 
			
		||||
        "header": {
 | 
			
		||||
            "app_id": appid,
 | 
			
		||||
@ -183,7 +198,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
 | 
			
		||||
        },
 | 
			
		||||
        "parameter": {
 | 
			
		||||
            "chat": {
 | 
			
		||||
                "domain": domains[llm_kwargs['llm_model']],
 | 
			
		||||
                "domain": domains_select,
 | 
			
		||||
                "temperature": llm_kwargs["temperature"],
 | 
			
		||||
                "random_threshold": 0.5,
 | 
			
		||||
                "max_tokens": 4096,
 | 
			
		||||
@ -192,7 +207,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
 | 
			
		||||
        },
 | 
			
		||||
        "payload": {
 | 
			
		||||
            "message": {
 | 
			
		||||
                "text": generate_message_payload(inputs, llm_kwargs, history, system_prompt)
 | 
			
		||||
                "text": generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest)
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										98
									
								
								toolbox.py
									
									
									
									
									
								
							
							
						
						
									
										98
									
								
								toolbox.py
									
									
									
									
									
								
							@ -74,6 +74,7 @@ def ArgsGeneralWrapper(f):
 | 
			
		||||
            'max_length': max_length,
 | 
			
		||||
            'temperature':temperature,
 | 
			
		||||
            'client_ip': request.client.host,
 | 
			
		||||
            'most_recent_uploaded': cookies.get('most_recent_uploaded')
 | 
			
		||||
        }
 | 
			
		||||
        plugin_kwargs = {
 | 
			
		||||
            "advanced_arg": plugin_advanced_arg,
 | 
			
		||||
@ -577,6 +578,64 @@ def del_outdated_uploads(outdate_time_seconds):
 | 
			
		||||
            except: pass
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def html_local_file(file):
 | 
			
		||||
    base_path = os.path.dirname(__file__)  # 项目目录
 | 
			
		||||
    if os.path.exists(str(file)):
 | 
			
		||||
        file = f'file={file.replace(base_path, ".")}'
 | 
			
		||||
    return file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
 | 
			
		||||
    style = ''
 | 
			
		||||
    if max_width is not None:
 | 
			
		||||
        style += f"max-width: {max_width};"
 | 
			
		||||
    if max_height is not None:
 | 
			
		||||
        style += f"max-height: {max_height};"
 | 
			
		||||
    __file = html_local_file(__file)
 | 
			
		||||
    a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
 | 
			
		||||
    if md:
 | 
			
		||||
        a = f''
 | 
			
		||||
    return a
 | 
			
		||||
 | 
			
		||||
def file_manifest_filter_type(file_list, filter_: list = None):
 | 
			
		||||
    new_list = []
 | 
			
		||||
    if not filter_: filter_ = ['png', 'jpg', 'jpeg']
 | 
			
		||||
    for file in file_list:
 | 
			
		||||
        if str(os.path.basename(file)).split('.')[-1] in filter_:
 | 
			
		||||
            new_list.append(html_local_img(file, md=False))
 | 
			
		||||
        else:
 | 
			
		||||
            new_list.append(file)
 | 
			
		||||
    return new_list
 | 
			
		||||
 | 
			
		||||
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
 | 
			
		||||
    """
 | 
			
		||||
    Args:
 | 
			
		||||
        head: 表头:[]
 | 
			
		||||
        tabs: 表值:[[列1], [列2], [列3], [列4]]
 | 
			
		||||
        alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
 | 
			
		||||
        column: True to keep data in columns, False to keep data in rows (default).
 | 
			
		||||
    Returns:
 | 
			
		||||
        A string representation of the markdown table.
 | 
			
		||||
    """
 | 
			
		||||
    if column:
 | 
			
		||||
        transposed_tabs = list(map(list, zip(*tabs)))
 | 
			
		||||
    else:
 | 
			
		||||
        transposed_tabs = tabs
 | 
			
		||||
    # Find the maximum length among the columns
 | 
			
		||||
    max_len = max(len(column) for column in transposed_tabs)
 | 
			
		||||
 | 
			
		||||
    tab_format = "| %s "
 | 
			
		||||
    tabs_list = "".join([tab_format % i for i in head]) + '|\n'
 | 
			
		||||
    tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
 | 
			
		||||
 | 
			
		||||
    for i in range(max_len):
 | 
			
		||||
        row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
 | 
			
		||||
        row_data = file_manifest_filter_type(row_data, filter_=None)
 | 
			
		||||
        tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
 | 
			
		||||
 | 
			
		||||
    return tabs_list
 | 
			
		||||
 | 
			
		||||
def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkboxes, cookies):
 | 
			
		||||
    """
 | 
			
		||||
    当文件被上传时的回调函数
 | 
			
		||||
@ -602,16 +661,15 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
 | 
			
		||||
        this_file_path = pj(target_path_base, file_origin_name)
 | 
			
		||||
        shutil.move(file.name, this_file_path)
 | 
			
		||||
        upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path+'.extract')
 | 
			
		||||
    
 | 
			
		||||
    # 整理文件集合
 | 
			
		||||
    moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
 | 
			
		||||
 | 
			
		||||
    if "浮动输入区" in checkboxes: 
 | 
			
		||||
        txt, txt2 = "", target_path_base
 | 
			
		||||
    else:
 | 
			
		||||
        txt, txt2 = target_path_base, ""
 | 
			
		||||
 | 
			
		||||
    # 输出消息
 | 
			
		||||
    moved_files_str = '\t\n\n'.join(moved_files)
 | 
			
		||||
    # 整理文件集合 输出消息
 | 
			
		||||
    moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
 | 
			
		||||
    moved_files_str = to_markdown_tabs(head=['文件'], tabs=[moved_files])
 | 
			
		||||
    chatbot.append(['我上传了文件,请查收', 
 | 
			
		||||
                    f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
 | 
			
		||||
                    f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
 | 
			
		||||
@ -1151,6 +1209,36 @@ def get_chat_default_kwargs():
 | 
			
		||||
 | 
			
		||||
    return default_chat_kwargs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pictures_list(path):
 | 
			
		||||
    file_manifest = [f for f in glob.glob(f'{path}/**/*.jpg', recursive=True)]
 | 
			
		||||
    file_manifest += [f for f in glob.glob(f'{path}/**/*.jpeg', recursive=True)]
 | 
			
		||||
    file_manifest += [f for f in glob.glob(f'{path}/**/*.png', recursive=True)]
 | 
			
		||||
    return file_manifest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import base64
 | 
			
		||||
def have_any_recent_upload_image_files(chatbot):
 | 
			
		||||
    _5min = 5 * 60
 | 
			
		||||
    if chatbot is None: return False, None    # chatbot is None
 | 
			
		||||
    most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
 | 
			
		||||
    if not most_recent_uploaded: return False, None   # most_recent_uploaded is None
 | 
			
		||||
    if time.time() - most_recent_uploaded["time"] < _5min:
 | 
			
		||||
        most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
 | 
			
		||||
        path = most_recent_uploaded['path']
 | 
			
		||||
        file_manifest = get_pictures_list(path)
 | 
			
		||||
        if len(file_manifest) == 0: return False, None
 | 
			
		||||
        return True, file_manifest # most_recent_uploaded is new
 | 
			
		||||
    else:
 | 
			
		||||
        return False, None  # most_recent_uploaded is too old
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Function to encode the image
 | 
			
		||||
def encode_image(image_path):
 | 
			
		||||
    with open(image_path, "rb") as image_file:
 | 
			
		||||
        return base64.b64encode(image_file.read()).decode('utf-8')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_max_token(llm_kwargs):
 | 
			
		||||
    from request_llms.bridge_all import model_info
 | 
			
		||||
    return model_info[llm_kwargs['llm_model']]['max_token']
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user