use legacy image io for gemini
This commit is contained in:
parent
a7c960dcb0
commit
60ba712131
@ -4,9 +4,10 @@
|
|||||||
# @Descr :
|
# @Descr :
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from request_llms.com_google import GoogleChatInit
|
from request_llms.com_google import GoogleChatInit
|
||||||
from toolbox import get_conf, update_ui, update_ui_lastest_msg
|
from toolbox import get_conf, update_ui, update_ui_lastest_msg, have_any_recent_upload_image_files, trimmed_format_exc
|
||||||
|
|
||||||
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
||||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
||||||
@ -48,7 +49,16 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
if get_conf("GEMINI_API_KEY") == "":
|
if get_conf("GEMINI_API_KEY") == "":
|
||||||
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
|
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if "vision" in llm_kwargs["llm_model"]:
|
||||||
|
have_recent_file, image_paths = have_any_recent_upload_image_files(chatbot)
|
||||||
|
def make_media_input(inputs, image_paths):
|
||||||
|
for image_path in image_paths:
|
||||||
|
inputs = inputs + f'<br/><br/><div align="center"><img src="file={os.path.abspath(image_path)}"></div>'
|
||||||
|
return inputs
|
||||||
|
if have_recent_file:
|
||||||
|
inputs = make_media_input(inputs, image_paths)
|
||||||
|
|
||||||
chatbot.append((inputs, ""))
|
chatbot.append((inputs, ""))
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
genai = GoogleChatInit()
|
genai = GoogleChatInit()
|
||||||
@ -59,10 +69,9 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
retry += 1
|
retry += 1
|
||||||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
chatbot[-1] = ((chatbot[-1][0], trimmed_format_exc()))
|
||||||
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
|
yield from update_ui(chatbot=chatbot, history=history, msg="请求失败") # 刷新界面
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
|
return
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
|
||||||
gpt_replying_buffer = ""
|
gpt_replying_buffer = ""
|
||||||
gpt_security_policy = ""
|
gpt_security_policy = ""
|
||||||
history.extend([inputs, ''])
|
history.extend([inputs, ''])
|
||||||
@ -94,7 +103,6 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
llm_kwargs = {'llm_model': 'gemini-pro'}
|
llm_kwargs = {'llm_model': 'gemini-pro'}
|
||||||
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
|
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
|
||||||
for i in result:
|
for i in result:
|
||||||
|
@ -7,7 +7,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import requests
|
import requests
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
from toolbox import get_conf, encode_image
|
from toolbox import get_conf, encode_image, get_pictures_list
|
||||||
|
|
||||||
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
|
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
|
||||||
|
|
||||||
@ -35,20 +35,15 @@ def files_filter_handler(file_list):
|
|||||||
return new_list
|
return new_list
|
||||||
|
|
||||||
|
|
||||||
def input_encode_handler(inputs):
|
def input_encode_handler(inputs, llm_kwargs):
|
||||||
|
if llm_kwargs['most_recent_uploaded'].get('path'):
|
||||||
|
image_paths = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
|
||||||
md_encode = []
|
md_encode = []
|
||||||
pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
|
for md_path in image_paths:
|
||||||
matches_path = re.findall(pattern_md_file, inputs)
|
md_encode.append({
|
||||||
for md_path in matches_path:
|
"data": encode_image(md_path),
|
||||||
pattern_file = r"\((file=.*)\)"
|
"type": os.path.splitext(md_path)[1].replace('.', '')
|
||||||
matches_path = re.findall(pattern_file, md_path)
|
})
|
||||||
encode_file = files_filter_handler(file_list=matches_path)
|
|
||||||
if encode_file:
|
|
||||||
md_encode.extend([{
|
|
||||||
"data": encode_image(i),
|
|
||||||
"type": os.path.splitext(i)[1].replace('.', '')
|
|
||||||
} for i in encode_file])
|
|
||||||
inputs = inputs.replace(md_path, '')
|
|
||||||
return inputs, md_encode
|
return inputs, md_encode
|
||||||
|
|
||||||
|
|
||||||
@ -127,13 +122,19 @@ class GoogleChatInit:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
|
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
|
||||||
|
|
||||||
def __conversation_user(self, user_input):
|
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
|
||||||
|
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
|
||||||
|
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
|
||||||
|
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
|
||||||
|
return response.iter_lines()
|
||||||
|
|
||||||
|
def __conversation_user(self, user_input, llm_kwargs):
|
||||||
what_i_have_asked = {"role": "user", "parts": []}
|
what_i_have_asked = {"role": "user", "parts": []}
|
||||||
if 'vision' not in self.url_gemini:
|
if 'vision' not in self.url_gemini:
|
||||||
input_ = user_input
|
input_ = user_input
|
||||||
encode_img = []
|
encode_img = []
|
||||||
else:
|
else:
|
||||||
input_, encode_img = input_encode_handler(user_input)
|
input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
|
||||||
what_i_have_asked['parts'].append({'text': input_})
|
what_i_have_asked['parts'].append({'text': input_})
|
||||||
if encode_img:
|
if encode_img:
|
||||||
for data in encode_img:
|
for data in encode_img:
|
||||||
@ -144,12 +145,12 @@ class GoogleChatInit:
|
|||||||
}})
|
}})
|
||||||
return what_i_have_asked
|
return what_i_have_asked
|
||||||
|
|
||||||
def __conversation_history(self, history):
|
def __conversation_history(self, history, llm_kwargs):
|
||||||
messages = []
|
messages = []
|
||||||
conversation_cnt = len(history) // 2
|
conversation_cnt = len(history) // 2
|
||||||
if conversation_cnt:
|
if conversation_cnt:
|
||||||
for index in range(0, 2 * conversation_cnt, 2):
|
for index in range(0, 2 * conversation_cnt, 2):
|
||||||
what_i_have_asked = self.__conversation_user(history[index])
|
what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
|
||||||
what_gpt_answer = {
|
what_gpt_answer = {
|
||||||
"role": "model",
|
"role": "model",
|
||||||
"parts": [{"text": history[index + 1]}]
|
"parts": [{"text": history[index + 1]}]
|
||||||
@ -158,12 +159,6 @@ class GoogleChatInit:
|
|||||||
messages.append(what_gpt_answer)
|
messages.append(what_gpt_answer)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
|
|
||||||
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
|
|
||||||
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
|
|
||||||
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
|
|
||||||
return response.iter_lines()
|
|
||||||
|
|
||||||
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
|
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
|
||||||
messages = [
|
messages = [
|
||||||
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
|
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
|
||||||
@ -176,14 +171,14 @@ class GoogleChatInit:
|
|||||||
)
|
)
|
||||||
header = {'Content-Type': 'application/json'}
|
header = {'Content-Type': 'application/json'}
|
||||||
if 'vision' not in self.url_gemini: # 不是vision 才处理history
|
if 'vision' not in self.url_gemini: # 不是vision 才处理history
|
||||||
messages.extend(self.__conversation_history(history)) # 处理 history
|
messages.extend(self.__conversation_history(history, llm_kwargs)) # 处理 history
|
||||||
messages.append(self.__conversation_user(inputs)) # 处理用户对话
|
messages.append(self.__conversation_user(inputs, llm_kwargs)) # 处理用户对话
|
||||||
payload = {
|
payload = {
|
||||||
"contents": messages,
|
"contents": messages,
|
||||||
"generationConfig": {
|
"generationConfig": {
|
||||||
|
# "maxOutputTokens": 800,
|
||||||
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
|
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
|
||||||
"temperature": llm_kwargs.get('temperature', 1),
|
"temperature": llm_kwargs.get('temperature', 1),
|
||||||
# "maxOutputTokens": 800,
|
|
||||||
"topP": llm_kwargs.get('top_p', 0.8),
|
"topP": llm_kwargs.get('top_p', 0.8),
|
||||||
"topK": 10
|
"topK": 10
|
||||||
}
|
}
|
||||||
@ -193,6 +188,5 @@ class GoogleChatInit:
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
google = GoogleChatInit()
|
google = GoogleChatInit()
|
||||||
# print(gootle.generate_message_payload('你好呀', {},
|
# print(gootle.generate_message_payload('你好呀', {}, ['123123', '3123123'], ''))
|
||||||
# ['123123', '3123123'], ''))
|
|
||||||
# gootle.input_encode_handle('123123[123123](./123123), ')
|
# gootle.input_encode_handle('123123[123123](./123123), ')
|
Loading…
x
Reference in New Issue
Block a user