From 60ba712131cf8b7fe9efb81ea9ece16bebd7d6b9 Mon Sep 17 00:00:00 2001
From: qingxu fu <505030475@qq.com>
Date: Sun, 31 Dec 2023 19:02:40 +0800
Subject: [PATCH] use legacy image io for gemini
---
request_llms/bridge_google_gemini.py | 22 ++++++++----
request_llms/com_google.py | 52 ++++++++++++----------------
2 files changed, 38 insertions(+), 36 deletions(-)
diff --git a/request_llms/bridge_google_gemini.py b/request_llms/bridge_google_gemini.py
index 2438e09..49d8211 100644
--- a/request_llms/bridge_google_gemini.py
+++ b/request_llms/bridge_google_gemini.py
@@ -4,9 +4,10 @@
# @Descr :
import json
import re
+import os
import time
from request_llms.com_google import GoogleChatInit
-from toolbox import get_conf, update_ui, update_ui_lastest_msg
+from toolbox import get_conf, update_ui, update_ui_lastest_msg, have_any_recent_upload_image_files, trimmed_format_exc
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
@@ -48,7 +49,16 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
if get_conf("GEMINI_API_KEY") == "":
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
return
-
+
+ if "vision" in llm_kwargs["llm_model"]:
+ have_recent_file, image_paths = have_any_recent_upload_image_files(chatbot)
+ def make_media_input(inputs, image_paths):
+ for image_path in image_paths:
+ inputs = inputs + f'
})
'
+ return inputs
+ if have_recent_file:
+ inputs = make_media_input(inputs, image_paths)
+
chatbot.append((inputs, ""))
yield from update_ui(chatbot=chatbot, history=history)
genai = GoogleChatInit()
@@ -59,10 +69,9 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
break
except Exception as e:
retry += 1
- chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
- retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
- yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
- if retry > MAX_RETRY: raise TimeoutError
+ chatbot[-1] = ((chatbot[-1][0], trimmed_format_exc()))
+ yield from update_ui(chatbot=chatbot, history=history, msg="请求失败") # 刷新界面
+ return
gpt_replying_buffer = ""
gpt_security_policy = ""
history.extend([inputs, ''])
@@ -94,7 +103,6 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
if __name__ == '__main__':
import sys
-
llm_kwargs = {'llm_model': 'gemini-pro'}
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
for i in result:
diff --git a/request_llms/com_google.py b/request_llms/com_google.py
index 7981908..5d44796 100644
--- a/request_llms/com_google.py
+++ b/request_llms/com_google.py
@@ -7,7 +7,7 @@ import os
import re
import requests
from typing import List, Dict, Tuple
-from toolbox import get_conf, encode_image
+from toolbox import get_conf, encode_image, get_pictures_list
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
@@ -35,20 +35,15 @@ def files_filter_handler(file_list):
return new_list
-def input_encode_handler(inputs):
+def input_encode_handler(inputs, llm_kwargs):
+ if llm_kwargs['most_recent_uploaded'].get('path'):
+ image_paths = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
md_encode = []
- pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
- matches_path = re.findall(pattern_md_file, inputs)
- for md_path in matches_path:
- pattern_file = r"\((file=.*)\)"
- matches_path = re.findall(pattern_file, md_path)
- encode_file = files_filter_handler(file_list=matches_path)
- if encode_file:
- md_encode.extend([{
- "data": encode_image(i),
- "type": os.path.splitext(i)[1].replace('.', '')
- } for i in encode_file])
- inputs = inputs.replace(md_path, '')
+ for md_path in image_paths:
+ md_encode.append({
+ "data": encode_image(md_path),
+ "type": os.path.splitext(md_path)[1].replace('.', '')
+ })
return inputs, md_encode
@@ -127,13 +122,19 @@ class GoogleChatInit:
def __init__(self):
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
- def __conversation_user(self, user_input):
+ def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
+ headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
+ response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
+ stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
+ return response.iter_lines()
+
+ def __conversation_user(self, user_input, llm_kwargs):
what_i_have_asked = {"role": "user", "parts": []}
if 'vision' not in self.url_gemini:
input_ = user_input
encode_img = []
else:
- input_, encode_img = input_encode_handler(user_input)
+ input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
what_i_have_asked['parts'].append({'text': input_})
if encode_img:
for data in encode_img:
@@ -144,12 +145,12 @@ class GoogleChatInit:
}})
return what_i_have_asked
- def __conversation_history(self, history):
+ def __conversation_history(self, history, llm_kwargs):
messages = []
conversation_cnt = len(history) // 2
if conversation_cnt:
for index in range(0, 2 * conversation_cnt, 2):
- what_i_have_asked = self.__conversation_user(history[index])
+ what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
what_gpt_answer = {
"role": "model",
"parts": [{"text": history[index + 1]}]
@@ -158,12 +159,6 @@ class GoogleChatInit:
messages.append(what_gpt_answer)
return messages
- def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
- headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
- response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
- stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
- return response.iter_lines()
-
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
messages = [
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
@@ -176,14 +171,14 @@ class GoogleChatInit:
)
header = {'Content-Type': 'application/json'}
if 'vision' not in self.url_gemini: # 不是vision 才处理history
- messages.extend(self.__conversation_history(history)) # 处理 history
- messages.append(self.__conversation_user(inputs)) # 处理用户对话
+ messages.extend(self.__conversation_history(history, llm_kwargs)) # 处理 history
+ messages.append(self.__conversation_user(inputs, llm_kwargs)) # 处理用户对话
payload = {
"contents": messages,
"generationConfig": {
+ # "maxOutputTokens": 800,
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
"temperature": llm_kwargs.get('temperature', 1),
- # "maxOutputTokens": 800,
"topP": llm_kwargs.get('top_p', 0.8),
"topK": 10
}
@@ -193,6 +188,5 @@ class GoogleChatInit:
if __name__ == '__main__':
google = GoogleChatInit()
- # print(gootle.generate_message_payload('你好呀', {},
- # ['123123', '3123123'], ''))
+ # print(gootle.generate_message_payload('你好呀', {}, ['123123', '3123123'], ''))
# gootle.input_encode_handle('123123[123123](./123123), ')
\ No newline at end of file