Merge pull request #1410 from binary-husky/frontier

fix spark image understanding api
This commit is contained in:
binary-husky 2023-12-23 17:49:35 +08:00 committed by GitHub
commit 6ca0dd2f9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 9 deletions

View File

@ -139,6 +139,8 @@ def can_multi_process(llm):
if llm.startswith('gpt-'): return True if llm.startswith('gpt-'): return True
if llm.startswith('api2d-'): return True if llm.startswith('api2d-'): return True
if llm.startswith('azure-'): return True if llm.startswith('azure-'): return True
if llm.startswith('spark'): return True
if llm.startswith('zhipuai'): return True
return False return False
def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency( def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(

View File

@ -26,7 +26,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
from .com_sparkapi import SparkRequestInstance from .com_sparkapi import SparkRequestInstance
sri = SparkRequestInstance() sri = SparkRequestInstance()
for response in sri.generate(inputs, llm_kwargs, history, sys_prompt): for response in sri.generate(inputs, llm_kwargs, history, sys_prompt, use_image_api=False):
if len(observe_window) >= 1: if len(observe_window) >= 1:
observe_window[0] = response observe_window[0] = response
if len(observe_window) >= 2: if len(observe_window) >= 2:
@ -52,7 +52,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
# 开始接收回复 # 开始接收回复
from .com_sparkapi import SparkRequestInstance from .com_sparkapi import SparkRequestInstance
sri = SparkRequestInstance() sri = SparkRequestInstance()
for response in sri.generate(inputs, llm_kwargs, history, system_prompt): for response in sri.generate(inputs, llm_kwargs, history, system_prompt, use_image_api=True):
chatbot[-1] = (inputs, response) chatbot[-1] = (inputs, response)
yield from update_ui(chatbot=chatbot, history=history) yield from update_ui(chatbot=chatbot, history=history)

View File

@ -72,12 +72,12 @@ class SparkRequestInstance():
self.result_buf = "" self.result_buf = ""
def generate(self, inputs, llm_kwargs, history, system_prompt): def generate(self, inputs, llm_kwargs, history, system_prompt, use_image_api=False):
llm_kwargs = llm_kwargs llm_kwargs = llm_kwargs
history = history history = history
system_prompt = system_prompt system_prompt = system_prompt
import _thread as thread import _thread as thread
thread.start_new_thread(self.create_blocking_request, (inputs, llm_kwargs, history, system_prompt)) thread.start_new_thread(self.create_blocking_request, (inputs, llm_kwargs, history, system_prompt, use_image_api))
while True: while True:
self.time_to_yield_event.wait(timeout=1) self.time_to_yield_event.wait(timeout=1)
if self.time_to_yield_event.is_set(): if self.time_to_yield_event.is_set():
@ -86,7 +86,7 @@ class SparkRequestInstance():
return self.result_buf return self.result_buf
def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt): def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt, use_image_api):
if llm_kwargs['llm_model'] == 'sparkv2': if llm_kwargs['llm_model'] == 'sparkv2':
gpt_url = self.gpt_url_v2 gpt_url = self.gpt_url_v2
elif llm_kwargs['llm_model'] == 'sparkv3': elif llm_kwargs['llm_model'] == 'sparkv3':
@ -94,10 +94,12 @@ class SparkRequestInstance():
else: else:
gpt_url = self.gpt_url gpt_url = self.gpt_url
file_manifest = [] file_manifest = []
if llm_kwargs.get('most_recent_uploaded'): if use_image_api and llm_kwargs.get('most_recent_uploaded'):
if llm_kwargs['most_recent_uploaded'].get('path'): if llm_kwargs['most_recent_uploaded'].get('path'):
file_manifest = get_pictures_list(llm_kwargs['most_recent_uploaded']['path']) file_manifest = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
gpt_url = self.gpt_url_img if len(file_manifest) > 0:
print('正在使用讯飞图片理解API')
gpt_url = self.gpt_url_img
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url) wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url)
websocket.enableTrace(False) websocket.enableTrace(False)
wsUrl = wsParam.create_url() wsUrl = wsParam.create_url()

View File

@ -256,13 +256,13 @@ textarea.svelte-1pie7s6 {
max-height: 95% !important; max-height: 95% !important;
overflow-y: auto !important; overflow-y: auto !important;
}*/ }*/
.app.svelte-1mya07g.svelte-1mya07g { /* .app.svelte-1mya07g.svelte-1mya07g {
max-width: 100%; max-width: 100%;
position: relative; position: relative;
padding: var(--size-4); padding: var(--size-4);
width: 100%; width: 100%;
height: 100%; height: 100%;
} } */
.gradio-container-3-32-2 h1 { .gradio-container-3-32-2 h1 {
font-weight: 700 !important; font-weight: 700 !important;