jpeg type align for gemini
This commit is contained in:
parent
480516380d
commit
37744a9cb1
@ -9,7 +9,7 @@ import requests
|
||||
from typing import List, Dict, Tuple
|
||||
from toolbox import get_conf, encode_image, get_pictures_list
|
||||
|
||||
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
|
||||
proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
|
||||
|
||||
"""
|
||||
========================================================================
|
||||
@ -26,33 +26,56 @@ to_markdown_tabs 文件list 转换为 md tab
|
||||
|
||||
def files_filter_handler(file_list):
|
||||
new_list = []
|
||||
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
|
||||
filter_ = [
|
||||
"png",
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"bmp",
|
||||
"svg",
|
||||
"webp",
|
||||
"ico",
|
||||
"tif",
|
||||
"tiff",
|
||||
"raw",
|
||||
"eps",
|
||||
]
|
||||
for file in file_list:
|
||||
file = str(file).replace('file=', '')
|
||||
file = str(file).replace("file=", "")
|
||||
if os.path.exists(file):
|
||||
if str(os.path.basename(file)).split('.')[-1] in filter_:
|
||||
if str(os.path.basename(file)).split(".")[-1] in filter_:
|
||||
new_list.append(file)
|
||||
return new_list
|
||||
|
||||
|
||||
def input_encode_handler(inputs, llm_kwargs):
|
||||
if llm_kwargs['most_recent_uploaded'].get('path'):
|
||||
image_paths = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
|
||||
if llm_kwargs["most_recent_uploaded"].get("path"):
|
||||
image_paths = get_pictures_list(llm_kwargs["most_recent_uploaded"]["path"])
|
||||
md_encode = []
|
||||
for md_path in image_paths:
|
||||
md_encode.append({
|
||||
"data": encode_image(md_path),
|
||||
"type": os.path.splitext(md_path)[1].replace('.', '')
|
||||
})
|
||||
type_ = os.path.splitext(md_path)[1].replace(".", "")
|
||||
type_ = "jpeg" if type_ == "jpg" else type_
|
||||
md_encode.append({"data": encode_image(md_path), "type": type_})
|
||||
return inputs, md_encode
|
||||
|
||||
|
||||
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
|
||||
new_list = []
|
||||
if not filter_:
|
||||
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
|
||||
filter_ = [
|
||||
"png",
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"bmp",
|
||||
"svg",
|
||||
"webp",
|
||||
"ico",
|
||||
"tif",
|
||||
"tiff",
|
||||
"raw",
|
||||
"eps",
|
||||
]
|
||||
for file in file_list:
|
||||
if str(os.path.basename(file)).split('.')[-1] in filter_:
|
||||
if str(os.path.basename(file)).split(".")[-1] in filter_:
|
||||
new_list.append(html_local_img(file, md=md_type))
|
||||
elif os.path.exists(file):
|
||||
new_list.append(link_mtime_to_md(file))
|
||||
@ -75,8 +98,8 @@ def html_local_file(file):
|
||||
return file
|
||||
|
||||
|
||||
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
|
||||
style = ''
|
||||
def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
|
||||
style = ""
|
||||
if max_width is not None:
|
||||
style += f"max-width: {max_width};"
|
||||
if max_height is not None:
|
||||
@ -84,11 +107,11 @@ def html_local_img(__file, layout='left', max_width=None, max_height=None, md=Tr
|
||||
__file = html_local_file(__file)
|
||||
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
|
||||
if md:
|
||||
a = f''
|
||||
a = f""
|
||||
return a
|
||||
|
||||
|
||||
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
|
||||
def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False):
|
||||
"""
|
||||
Args:
|
||||
head: 表头:[]
|
||||
@ -106,43 +129,53 @@ def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
|
||||
max_len = max(len(column) for column in transposed_tabs)
|
||||
|
||||
tab_format = "| %s "
|
||||
tabs_list = "".join([tab_format % i for i in head]) + '|\n'
|
||||
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
|
||||
tabs_list = "".join([tab_format % i for i in head]) + "|\n"
|
||||
tabs_list += "".join([tab_format % alignment for i in head]) + "|\n"
|
||||
|
||||
for i in range(max_len):
|
||||
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
|
||||
row_data = [tab[i] if i < len(tab) else "" for tab in transposed_tabs]
|
||||
row_data = file_manifest_filter_html(row_data, filter_=None)
|
||||
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
|
||||
tabs_list += "".join([tab_format % i for i in row_data]) + "|\n"
|
||||
|
||||
return tabs_list
|
||||
|
||||
|
||||
class GoogleChatInit:
|
||||
|
||||
def __init__(self):
|
||||
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
|
||||
self.url_gemini = "https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k"
|
||||
|
||||
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
|
||||
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
|
||||
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
|
||||
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
|
||||
headers, payload = self.generate_message_payload(
|
||||
inputs, llm_kwargs, history, system_prompt
|
||||
)
|
||||
response = requests.post(
|
||||
url=self.url_gemini,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
proxies=proxies,
|
||||
timeout=TIMEOUT_SECONDS,
|
||||
)
|
||||
return response.iter_lines()
|
||||
|
||||
def __conversation_user(self, user_input, llm_kwargs):
|
||||
what_i_have_asked = {"role": "user", "parts": []}
|
||||
if 'vision' not in self.url_gemini:
|
||||
if "vision" not in self.url_gemini:
|
||||
input_ = user_input
|
||||
encode_img = []
|
||||
else:
|
||||
input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
|
||||
what_i_have_asked['parts'].append({'text': input_})
|
||||
what_i_have_asked["parts"].append({"text": input_})
|
||||
if encode_img:
|
||||
for data in encode_img:
|
||||
what_i_have_asked['parts'].append(
|
||||
{'inline_data': {
|
||||
"mime_type": f"image/{data['type']}",
|
||||
"data": data['data']
|
||||
}})
|
||||
what_i_have_asked["parts"].append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": f"image/{data['type']}",
|
||||
"data": data["data"],
|
||||
}
|
||||
}
|
||||
)
|
||||
return what_i_have_asked
|
||||
|
||||
def __conversation_history(self, history, llm_kwargs):
|
||||
@ -153,40 +186,43 @@ class GoogleChatInit:
|
||||
what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
|
||||
what_gpt_answer = {
|
||||
"role": "model",
|
||||
"parts": [{"text": history[index + 1]}]
|
||||
"parts": [{"text": history[index + 1]}],
|
||||
}
|
||||
messages.append(what_i_have_asked)
|
||||
messages.append(what_gpt_answer)
|
||||
return messages
|
||||
|
||||
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
|
||||
def generate_message_payload(
|
||||
self, inputs, llm_kwargs, history, system_prompt
|
||||
) -> Tuple[Dict, Dict]:
|
||||
messages = [
|
||||
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
|
||||
# {"role": "user", "parts": [{"text": ""}]},
|
||||
# {"role": "model", "parts": [{"text": ""}]}
|
||||
]
|
||||
self.url_gemini = self.url_gemini.replace(
|
||||
'%m', llm_kwargs['llm_model']).replace(
|
||||
'%k', get_conf('GEMINI_API_KEY')
|
||||
)
|
||||
header = {'Content-Type': 'application/json'}
|
||||
if 'vision' not in self.url_gemini: # 不是vision 才处理history
|
||||
messages.extend(self.__conversation_history(history, llm_kwargs)) # 处理 history
|
||||
"%m", llm_kwargs["llm_model"]
|
||||
).replace("%k", get_conf("GEMINI_API_KEY"))
|
||||
header = {"Content-Type": "application/json"}
|
||||
if "vision" not in self.url_gemini: # 不是vision 才处理history
|
||||
messages.extend(
|
||||
self.__conversation_history(history, llm_kwargs)
|
||||
) # 处理 history
|
||||
messages.append(self.__conversation_user(inputs, llm_kwargs)) # 处理用户对话
|
||||
payload = {
|
||||
"contents": messages,
|
||||
"generationConfig": {
|
||||
# "maxOutputTokens": 800,
|
||||
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
|
||||
"temperature": llm_kwargs.get('temperature', 1),
|
||||
"topP": llm_kwargs.get('top_p', 0.8),
|
||||
"topK": 10
|
||||
}
|
||||
"stopSequences": str(llm_kwargs.get("stop", "")).split(" "),
|
||||
"temperature": llm_kwargs.get("temperature", 1),
|
||||
"topP": llm_kwargs.get("top_p", 0.8),
|
||||
"topK": 10,
|
||||
},
|
||||
}
|
||||
return header, payload
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
google = GoogleChatInit()
|
||||
# print(gootle.generate_message_payload('你好呀', {}, ['123123', '3123123'], ''))
|
||||
# gootle.input_encode_handle('123123[123123](./123123), ')
|
||||
# gootle.input_encode_handle('123123[123123](./123123), ')
|
||||
|
Loading…
x
Reference in New Issue
Block a user