Merge pull request #1282 from Kilig947/image_understanding_spark

Image understanding spark
This commit is contained in:
Hao Ma 2023-11-22 16:19:22 +08:00 committed by GitHub
commit 203d5f7296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 120 additions and 34 deletions

View File

@ -15,29 +15,16 @@ import requests
import base64 import base64
import os import os
import glob import glob
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, is_the_upload_folder, \
update_ui_lastest_msg, get_max_token, encode_image, have_any_recent_upload_image_files
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, is_the_upload_folder, update_ui_lastest_msg, get_max_token
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \ proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \
get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG', 'AZURE_CFG_ARRAY') get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG', 'AZURE_CFG_ARRAY')
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \ timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。' '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
def have_any_recent_upload_image_files(chatbot):
_5min = 5 * 60
if chatbot is None: return False, None # chatbot is None
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
if not most_recent_uploaded: return False, None # most_recent_uploaded is None
if time.time() - most_recent_uploaded["time"] < _5min:
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
path = most_recent_uploaded['path']
file_manifest = [f for f in glob.glob(f'{path}/**/*.jpg', recursive=True)]
file_manifest += [f for f in glob.glob(f'{path}/**/*.jpeg', recursive=True)]
file_manifest += [f for f in glob.glob(f'{path}/**/*.png', recursive=True)]
if len(file_manifest) == 0: return False, None
return True, file_manifest # most_recent_uploaded is new
else:
return False, None # most_recent_uploaded is too old
def report_invalid_key(key): def report_invalid_key(key):
if get_conf("BLOCK_INVALID_APIKEY"): if get_conf("BLOCK_INVALID_APIKEY"):
@ -258,10 +245,6 @@ def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg,
chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk_decoded)}") chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk_decoded)}")
return chatbot, history return chatbot, history
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def generate_payload(inputs, llm_kwargs, history, system_prompt, image_paths): def generate_payload(inputs, llm_kwargs, history, system_prompt, image_paths):
""" """

View File

@ -1,4 +1,4 @@
from toolbox import get_conf from toolbox import get_conf, get_pictures_list, encode_image
import base64 import base64
import datetime import datetime
import hashlib import hashlib
@ -65,6 +65,7 @@ class SparkRequestInstance():
self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat" self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"
self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat" self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat"
self.gpt_url_v3 = "ws://spark-api.xf-yun.com/v3.1/chat" self.gpt_url_v3 = "ws://spark-api.xf-yun.com/v3.1/chat"
self.gpt_url_img = "wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image"
self.time_to_yield_event = threading.Event() self.time_to_yield_event = threading.Event()
self.time_to_exit_event = threading.Event() self.time_to_exit_event = threading.Event()
@ -92,7 +93,11 @@ class SparkRequestInstance():
gpt_url = self.gpt_url_v3 gpt_url = self.gpt_url_v3
else: else:
gpt_url = self.gpt_url gpt_url = self.gpt_url
file_manifest = []
if llm_kwargs.get('most_recent_uploaded'):
if llm_kwargs['most_recent_uploaded'].get('path'):
file_manifest = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
gpt_url = self.gpt_url_img
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url) wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url)
websocket.enableTrace(False) websocket.enableTrace(False)
wsUrl = wsParam.create_url() wsUrl = wsParam.create_url()
@ -101,9 +106,8 @@ class SparkRequestInstance():
def on_open(ws): def on_open(ws):
import _thread as thread import _thread as thread
thread.start_new_thread(run, (ws,)) thread.start_new_thread(run, (ws,))
def run(ws, *args): def run(ws, *args):
data = json.dumps(gen_params(ws.appid, *ws.all_args)) data = json.dumps(gen_params(ws.appid, *ws.all_args, file_manifest))
ws.send(data) ws.send(data)
# 收到websocket消息的处理 # 收到websocket消息的处理
@ -142,8 +146,17 @@ class SparkRequestInstance():
ws.all_args = (inputs, llm_kwargs, history, system_prompt) ws.all_args = (inputs, llm_kwargs, history, system_prompt)
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def generate_message_payload(inputs, llm_kwargs, history, system_prompt): def generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest):
conversation_cnt = len(history) // 2 conversation_cnt = len(history) // 2
messages = []
if file_manifest:
base64_images = []
for image_path in file_manifest:
base64_images.append(encode_image(image_path))
for img_s in base64_images:
if img_s not in str(messages):
messages.append({"role": "user", "content": img_s, "content_type": "image"})
else:
messages = [{"role": "system", "content": system_prompt}] messages = [{"role": "system", "content": system_prompt}]
if conversation_cnt: if conversation_cnt:
for index in range(0, 2*conversation_cnt, 2): for index in range(0, 2*conversation_cnt, 2):
@ -167,7 +180,7 @@ def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
return messages return messages
def gen_params(appid, inputs, llm_kwargs, history, system_prompt): def gen_params(appid, inputs, llm_kwargs, history, system_prompt, file_manifest):
""" """
通过appid和用户的提问来生成请参数 通过appid和用户的提问来生成请参数
""" """
@ -176,6 +189,8 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
"sparkv2": "generalv2", "sparkv2": "generalv2",
"sparkv3": "generalv3", "sparkv3": "generalv3",
} }
domains_select = domains[llm_kwargs['llm_model']]
if file_manifest: domains_select = 'image'
data = { data = {
"header": { "header": {
"app_id": appid, "app_id": appid,
@ -183,7 +198,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
}, },
"parameter": { "parameter": {
"chat": { "chat": {
"domain": domains[llm_kwargs['llm_model']], "domain": domains_select,
"temperature": llm_kwargs["temperature"], "temperature": llm_kwargs["temperature"],
"random_threshold": 0.5, "random_threshold": 0.5,
"max_tokens": 4096, "max_tokens": 4096,
@ -192,7 +207,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
}, },
"payload": { "payload": {
"message": { "message": {
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt) "text": generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest)
} }
} }
} }

View File

@ -4,6 +4,7 @@ import time
import inspect import inspect
import re import re
import os import os
import base64
import gradio import gradio
import shutil import shutil
import glob import glob
@ -79,6 +80,7 @@ def ArgsGeneralWrapper(f):
'max_length': max_length, 'max_length': max_length,
'temperature':temperature, 'temperature':temperature,
'client_ip': request.client.host, 'client_ip': request.client.host,
'most_recent_uploaded': cookies.get('most_recent_uploaded')
} }
plugin_kwargs = { plugin_kwargs = {
"advanced_arg": plugin_advanced_arg, "advanced_arg": plugin_advanced_arg,
@ -602,6 +604,64 @@ def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
except: pass except: pass
return return
def html_local_file(file):
base_path = os.path.dirname(__file__) # 项目目录
if os.path.exists(str(file)):
file = f'file={file.replace(base_path, ".")}'
return file
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
style = ''
if max_width is not None:
style += f"max-width: {max_width};"
if max_height is not None:
style += f"max-height: {max_height};"
__file = html_local_file(__file)
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
if md:
a = f'![{__file}]({__file})'
return a
def file_manifest_filter_type(file_list, filter_: list = None):
new_list = []
if not filter_: filter_ = ['png', 'jpg', 'jpeg']
for file in file_list:
if str(os.path.basename(file)).split('.')[-1] in filter_:
new_list.append(html_local_img(file, md=False))
else:
new_list.append(file)
return new_list
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
"""
Args:
head: 表头[]
tabs: 表值[[列1], [列2], [列3], [列4]]
alignment: :--- 左对齐 :---: 居中对齐 ---: 右对齐
column: True to keep data in columns, False to keep data in rows (default).
Returns:
A string representation of the markdown table.
"""
if column:
transposed_tabs = list(map(list, zip(*tabs)))
else:
transposed_tabs = tabs
# Find the maximum length among the columns
max_len = max(len(column) for column in transposed_tabs)
tab_format = "| %s "
tabs_list = "".join([tab_format % i for i in head]) + '|\n'
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
for i in range(max_len):
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
row_data = file_manifest_filter_type(row_data, filter_=None)
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
return tabs_list
def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkboxes, cookies): def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkboxes, cookies):
""" """
当文件被上传时的回调函数 当文件被上传时的回调函数
@ -627,15 +687,14 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
shutil.move(file.name, this_file_path) shutil.move(file.name, this_file_path)
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path+'.extract') upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path+'.extract')
# 整理文件集合
moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
if "浮动输入区" in checkboxes: if "浮动输入区" in checkboxes:
txt, txt2 = "", target_path_base txt, txt2 = "", target_path_base
else: else:
txt, txt2 = target_path_base, "" txt, txt2 = target_path_base, ""
# 输出消息 # 整理文件集合 输出消息
moved_files_str = '\t\n\n'.join(moved_files) moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
moved_files_str = to_markdown_tabs(head=['文件'], tabs=[moved_files])
chatbot.append(['我上传了文件,请查收', chatbot.append(['我上传了文件,请查收',
f'[Local Message] 收到以下文件: \n\n{moved_files_str}' + f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
f'\n\n调用路径参数已自动修正到: \n\n{txt}' + f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
@ -1198,6 +1257,35 @@ def get_chat_default_kwargs():
return default_chat_kwargs return default_chat_kwargs
def get_pictures_list(path):
file_manifest = [f for f in glob.glob(f'{path}/**/*.jpg', recursive=True)]
file_manifest += [f for f in glob.glob(f'{path}/**/*.jpeg', recursive=True)]
file_manifest += [f for f in glob.glob(f'{path}/**/*.png', recursive=True)]
return file_manifest
def have_any_recent_upload_image_files(chatbot):
_5min = 5 * 60
if chatbot is None: return False, None # chatbot is None
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
if not most_recent_uploaded: return False, None # most_recent_uploaded is None
if time.time() - most_recent_uploaded["time"] < _5min:
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
path = most_recent_uploaded['path']
file_manifest = get_pictures_list(path)
if len(file_manifest) == 0: return False, None
return True, file_manifest # most_recent_uploaded is new
else:
return False, None # most_recent_uploaded is too old
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def get_max_token(llm_kwargs): def get_max_token(llm_kwargs):
from request_llms.bridge_all import model_info from request_llms.bridge_all import model_info
return model_info[llm_kwargs['llm_model']]['max_token'] return model_info[llm_kwargs['llm_model']]['max_token']