jpeg type align for gemini

This commit is contained in:
qingxu fu 2023-12-31 20:28:39 +08:00
parent 480516380d
commit 37744a9cb1

View File

@ -9,7 +9,7 @@ import requests
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
from toolbox import get_conf, encode_image, get_pictures_list from toolbox import get_conf, encode_image, get_pictures_list
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS') proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
""" """
======================================================================== ========================================================================
@ -26,33 +26,56 @@ to_markdown_tabs 文件list 转换为 md tab
def files_filter_handler(file_list): def files_filter_handler(file_list):
new_list = [] new_list = []
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps'] filter_ = [
"png",
"jpg",
"jpeg",
"bmp",
"svg",
"webp",
"ico",
"tif",
"tiff",
"raw",
"eps",
]
for file in file_list: for file in file_list:
file = str(file).replace('file=', '') file = str(file).replace("file=", "")
if os.path.exists(file): if os.path.exists(file):
if str(os.path.basename(file)).split('.')[-1] in filter_: if str(os.path.basename(file)).split(".")[-1] in filter_:
new_list.append(file) new_list.append(file)
return new_list return new_list
def input_encode_handler(inputs, llm_kwargs): def input_encode_handler(inputs, llm_kwargs):
if llm_kwargs['most_recent_uploaded'].get('path'): if llm_kwargs["most_recent_uploaded"].get("path"):
image_paths = get_pictures_list(llm_kwargs['most_recent_uploaded']['path']) image_paths = get_pictures_list(llm_kwargs["most_recent_uploaded"]["path"])
md_encode = [] md_encode = []
for md_path in image_paths: for md_path in image_paths:
md_encode.append({ type_ = os.path.splitext(md_path)[1].replace(".", "")
"data": encode_image(md_path), type_ = "jpeg" if type_ == "jpg" else type_
"type": os.path.splitext(md_path)[1].replace('.', '') md_encode.append({"data": encode_image(md_path), "type": type_})
})
return inputs, md_encode return inputs, md_encode
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False): def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
new_list = [] new_list = []
if not filter_: if not filter_:
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps'] filter_ = [
"png",
"jpg",
"jpeg",
"bmp",
"svg",
"webp",
"ico",
"tif",
"tiff",
"raw",
"eps",
]
for file in file_list: for file in file_list:
if str(os.path.basename(file)).split('.')[-1] in filter_: if str(os.path.basename(file)).split(".")[-1] in filter_:
new_list.append(html_local_img(file, md=md_type)) new_list.append(html_local_img(file, md=md_type))
elif os.path.exists(file): elif os.path.exists(file):
new_list.append(link_mtime_to_md(file)) new_list.append(link_mtime_to_md(file))
@ -75,8 +98,8 @@ def html_local_file(file):
return file return file
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True): def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
style = '' style = ""
if max_width is not None: if max_width is not None:
style += f"max-width: {max_width};" style += f"max-width: {max_width};"
if max_height is not None: if max_height is not None:
@ -84,11 +107,11 @@ def html_local_img(__file, layout='left', max_width=None, max_height=None, md=Tr
__file = html_local_file(__file) __file = html_local_file(__file)
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>' a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
if md: if md:
a = f'![{__file}]({__file})' a = f"![{__file}]({__file})"
return a return a
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False): def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False):
""" """
Args: Args:
head: 表头[] head: 表头[]
@ -106,43 +129,53 @@ def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
max_len = max(len(column) for column in transposed_tabs) max_len = max(len(column) for column in transposed_tabs)
tab_format = "| %s " tab_format = "| %s "
tabs_list = "".join([tab_format % i for i in head]) + '|\n' tabs_list = "".join([tab_format % i for i in head]) + "|\n"
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n' tabs_list += "".join([tab_format % alignment for i in head]) + "|\n"
for i in range(max_len): for i in range(max_len):
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs] row_data = [tab[i] if i < len(tab) else "" for tab in transposed_tabs]
row_data = file_manifest_filter_html(row_data, filter_=None) row_data = file_manifest_filter_html(row_data, filter_=None)
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n' tabs_list += "".join([tab_format % i for i in row_data]) + "|\n"
return tabs_list return tabs_list
class GoogleChatInit: class GoogleChatInit:
def __init__(self): def __init__(self):
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k' self.url_gemini = "https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k"
def generate_chat(self, inputs, llm_kwargs, history, system_prompt): def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt) headers, payload = self.generate_message_payload(
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload), inputs, llm_kwargs, history, system_prompt
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS) )
response = requests.post(
url=self.url_gemini,
headers=headers,
data=json.dumps(payload),
stream=True,
proxies=proxies,
timeout=TIMEOUT_SECONDS,
)
return response.iter_lines() return response.iter_lines()
def __conversation_user(self, user_input, llm_kwargs): def __conversation_user(self, user_input, llm_kwargs):
what_i_have_asked = {"role": "user", "parts": []} what_i_have_asked = {"role": "user", "parts": []}
if 'vision' not in self.url_gemini: if "vision" not in self.url_gemini:
input_ = user_input input_ = user_input
encode_img = [] encode_img = []
else: else:
input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs) input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
what_i_have_asked['parts'].append({'text': input_}) what_i_have_asked["parts"].append({"text": input_})
if encode_img: if encode_img:
for data in encode_img: for data in encode_img:
what_i_have_asked['parts'].append( what_i_have_asked["parts"].append(
{'inline_data': { {
"mime_type": f"image/{data['type']}", "inline_data": {
"data": data['data'] "mime_type": f"image/{data['type']}",
}}) "data": data["data"],
}
}
)
return what_i_have_asked return what_i_have_asked
def __conversation_history(self, history, llm_kwargs): def __conversation_history(self, history, llm_kwargs):
@ -153,40 +186,43 @@ class GoogleChatInit:
what_i_have_asked = self.__conversation_user(history[index], llm_kwargs) what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
what_gpt_answer = { what_gpt_answer = {
"role": "model", "role": "model",
"parts": [{"text": history[index + 1]}] "parts": [{"text": history[index + 1]}],
} }
messages.append(what_i_have_asked) messages.append(what_i_have_asked)
messages.append(what_gpt_answer) messages.append(what_gpt_answer)
return messages return messages
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]: def generate_message_payload(
self, inputs, llm_kwargs, history, system_prompt
) -> Tuple[Dict, Dict]:
messages = [ messages = [
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。 # {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
# {"role": "user", "parts": [{"text": ""}]}, # {"role": "user", "parts": [{"text": ""}]},
# {"role": "model", "parts": [{"text": ""}]} # {"role": "model", "parts": [{"text": ""}]}
] ]
self.url_gemini = self.url_gemini.replace( self.url_gemini = self.url_gemini.replace(
'%m', llm_kwargs['llm_model']).replace( "%m", llm_kwargs["llm_model"]
'%k', get_conf('GEMINI_API_KEY') ).replace("%k", get_conf("GEMINI_API_KEY"))
) header = {"Content-Type": "application/json"}
header = {'Content-Type': 'application/json'} if "vision" not in self.url_gemini: # 不是vision 才处理history
if 'vision' not in self.url_gemini: # 不是vision 才处理history messages.extend(
messages.extend(self.__conversation_history(history, llm_kwargs)) # 处理 history self.__conversation_history(history, llm_kwargs)
) # 处理 history
messages.append(self.__conversation_user(inputs, llm_kwargs)) # 处理用户对话 messages.append(self.__conversation_user(inputs, llm_kwargs)) # 处理用户对话
payload = { payload = {
"contents": messages, "contents": messages,
"generationConfig": { "generationConfig": {
# "maxOutputTokens": 800, # "maxOutputTokens": 800,
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '), "stopSequences": str(llm_kwargs.get("stop", "")).split(" "),
"temperature": llm_kwargs.get('temperature', 1), "temperature": llm_kwargs.get("temperature", 1),
"topP": llm_kwargs.get('top_p', 0.8), "topP": llm_kwargs.get("top_p", 0.8),
"topK": 10 "topK": 10,
} },
} }
return header, payload return header, payload
if __name__ == '__main__': if __name__ == "__main__":
google = GoogleChatInit() google = GoogleChatInit()
# print(gootle.generate_message_payload('你好呀', {}, ['123123', '3123123'], '')) # print(gootle.generate_message_payload('你好呀', {}, ['123123', '3123123'], ''))
# gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)') # gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)')