jpeg type align for gemini
This commit is contained in:
		
							parent
							
								
									480516380d
								
							
						
					
					
						commit
						37744a9cb1
					
				@ -9,7 +9,7 @@ import requests
 | 
				
			|||||||
from typing import List, Dict, Tuple
 | 
					from typing import List, Dict, Tuple
 | 
				
			||||||
from toolbox import get_conf, encode_image, get_pictures_list
 | 
					from toolbox import get_conf, encode_image, get_pictures_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
 | 
					proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
========================================================================
 | 
					========================================================================
 | 
				
			||||||
@ -26,33 +26,56 @@ to_markdown_tabs 文件list 转换为 md tab
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def files_filter_handler(file_list):
 | 
					def files_filter_handler(file_list):
 | 
				
			||||||
    new_list = []
 | 
					    new_list = []
 | 
				
			||||||
    filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
 | 
					    filter_ = [
 | 
				
			||||||
 | 
					        "png",
 | 
				
			||||||
 | 
					        "jpg",
 | 
				
			||||||
 | 
					        "jpeg",
 | 
				
			||||||
 | 
					        "bmp",
 | 
				
			||||||
 | 
					        "svg",
 | 
				
			||||||
 | 
					        "webp",
 | 
				
			||||||
 | 
					        "ico",
 | 
				
			||||||
 | 
					        "tif",
 | 
				
			||||||
 | 
					        "tiff",
 | 
				
			||||||
 | 
					        "raw",
 | 
				
			||||||
 | 
					        "eps",
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
    for file in file_list:
 | 
					    for file in file_list:
 | 
				
			||||||
        file = str(file).replace('file=', '')
 | 
					        file = str(file).replace("file=", "")
 | 
				
			||||||
        if os.path.exists(file):
 | 
					        if os.path.exists(file):
 | 
				
			||||||
            if str(os.path.basename(file)).split('.')[-1] in filter_:
 | 
					            if str(os.path.basename(file)).split(".")[-1] in filter_:
 | 
				
			||||||
                new_list.append(file)
 | 
					                new_list.append(file)
 | 
				
			||||||
    return new_list
 | 
					    return new_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def input_encode_handler(inputs, llm_kwargs):
 | 
					def input_encode_handler(inputs, llm_kwargs):
 | 
				
			||||||
    if llm_kwargs['most_recent_uploaded'].get('path'):
 | 
					    if llm_kwargs["most_recent_uploaded"].get("path"):
 | 
				
			||||||
        image_paths = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
 | 
					        image_paths = get_pictures_list(llm_kwargs["most_recent_uploaded"]["path"])
 | 
				
			||||||
    md_encode = []
 | 
					    md_encode = []
 | 
				
			||||||
    for md_path in image_paths:
 | 
					    for md_path in image_paths:
 | 
				
			||||||
        md_encode.append({
 | 
					        type_ = os.path.splitext(md_path)[1].replace(".", "")
 | 
				
			||||||
            "data": encode_image(md_path),
 | 
					        type_ = "jpeg" if type_ == "jpg" else type_
 | 
				
			||||||
            "type": os.path.splitext(md_path)[1].replace('.', '')
 | 
					        md_encode.append({"data": encode_image(md_path), "type": type_})
 | 
				
			||||||
        })
 | 
					 | 
				
			||||||
    return inputs, md_encode
 | 
					    return inputs, md_encode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
 | 
					def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
 | 
				
			||||||
    new_list = []
 | 
					    new_list = []
 | 
				
			||||||
    if not filter_:
 | 
					    if not filter_:
 | 
				
			||||||
        filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
 | 
					        filter_ = [
 | 
				
			||||||
 | 
					            "png",
 | 
				
			||||||
 | 
					            "jpg",
 | 
				
			||||||
 | 
					            "jpeg",
 | 
				
			||||||
 | 
					            "bmp",
 | 
				
			||||||
 | 
					            "svg",
 | 
				
			||||||
 | 
					            "webp",
 | 
				
			||||||
 | 
					            "ico",
 | 
				
			||||||
 | 
					            "tif",
 | 
				
			||||||
 | 
					            "tiff",
 | 
				
			||||||
 | 
					            "raw",
 | 
				
			||||||
 | 
					            "eps",
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
    for file in file_list:
 | 
					    for file in file_list:
 | 
				
			||||||
        if str(os.path.basename(file)).split('.')[-1] in filter_:
 | 
					        if str(os.path.basename(file)).split(".")[-1] in filter_:
 | 
				
			||||||
            new_list.append(html_local_img(file, md=md_type))
 | 
					            new_list.append(html_local_img(file, md=md_type))
 | 
				
			||||||
        elif os.path.exists(file):
 | 
					        elif os.path.exists(file):
 | 
				
			||||||
            new_list.append(link_mtime_to_md(file))
 | 
					            new_list.append(link_mtime_to_md(file))
 | 
				
			||||||
@ -75,8 +98,8 @@ def html_local_file(file):
 | 
				
			|||||||
    return file
 | 
					    return file
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
 | 
					def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
 | 
				
			||||||
    style = ''
 | 
					    style = ""
 | 
				
			||||||
    if max_width is not None:
 | 
					    if max_width is not None:
 | 
				
			||||||
        style += f"max-width: {max_width};"
 | 
					        style += f"max-width: {max_width};"
 | 
				
			||||||
    if max_height is not None:
 | 
					    if max_height is not None:
 | 
				
			||||||
@ -84,11 +107,11 @@ def html_local_img(__file, layout='left', max_width=None, max_height=None, md=Tr
 | 
				
			|||||||
    __file = html_local_file(__file)
 | 
					    __file = html_local_file(__file)
 | 
				
			||||||
    a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
 | 
					    a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
 | 
				
			||||||
    if md:
 | 
					    if md:
 | 
				
			||||||
        a = f''
 | 
					        a = f""
 | 
				
			||||||
    return a
 | 
					    return a
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
 | 
					def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
        head: 表头:[]
 | 
					        head: 表头:[]
 | 
				
			||||||
@ -106,43 +129,53 @@ def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
 | 
				
			|||||||
    max_len = max(len(column) for column in transposed_tabs)
 | 
					    max_len = max(len(column) for column in transposed_tabs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tab_format = "| %s "
 | 
					    tab_format = "| %s "
 | 
				
			||||||
    tabs_list = "".join([tab_format % i for i in head]) + '|\n'
 | 
					    tabs_list = "".join([tab_format % i for i in head]) + "|\n"
 | 
				
			||||||
    tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
 | 
					    tabs_list += "".join([tab_format % alignment for i in head]) + "|\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for i in range(max_len):
 | 
					    for i in range(max_len):
 | 
				
			||||||
        row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
 | 
					        row_data = [tab[i] if i < len(tab) else "" for tab in transposed_tabs]
 | 
				
			||||||
        row_data = file_manifest_filter_html(row_data, filter_=None)
 | 
					        row_data = file_manifest_filter_html(row_data, filter_=None)
 | 
				
			||||||
        tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
 | 
					        tabs_list += "".join([tab_format % i for i in row_data]) + "|\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return tabs_list
 | 
					    return tabs_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GoogleChatInit:
 | 
					class GoogleChatInit:
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
 | 
					        self.url_gemini = "https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
 | 
					    def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
 | 
				
			||||||
        headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
 | 
					        headers, payload = self.generate_message_payload(
 | 
				
			||||||
        response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
 | 
					            inputs, llm_kwargs, history, system_prompt
 | 
				
			||||||
                                 stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
 | 
					        )
 | 
				
			||||||
 | 
					        response = requests.post(
 | 
				
			||||||
 | 
					            url=self.url_gemini,
 | 
				
			||||||
 | 
					            headers=headers,
 | 
				
			||||||
 | 
					            data=json.dumps(payload),
 | 
				
			||||||
 | 
					            stream=True,
 | 
				
			||||||
 | 
					            proxies=proxies,
 | 
				
			||||||
 | 
					            timeout=TIMEOUT_SECONDS,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        return response.iter_lines()
 | 
					        return response.iter_lines()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __conversation_user(self, user_input, llm_kwargs):
 | 
					    def __conversation_user(self, user_input, llm_kwargs):
 | 
				
			||||||
        what_i_have_asked = {"role": "user", "parts": []}
 | 
					        what_i_have_asked = {"role": "user", "parts": []}
 | 
				
			||||||
        if 'vision' not in self.url_gemini:
 | 
					        if "vision" not in self.url_gemini:
 | 
				
			||||||
            input_ = user_input
 | 
					            input_ = user_input
 | 
				
			||||||
            encode_img = []
 | 
					            encode_img = []
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
 | 
					            input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
 | 
				
			||||||
        what_i_have_asked['parts'].append({'text': input_})
 | 
					        what_i_have_asked["parts"].append({"text": input_})
 | 
				
			||||||
        if encode_img:
 | 
					        if encode_img:
 | 
				
			||||||
            for data in encode_img:
 | 
					            for data in encode_img:
 | 
				
			||||||
                what_i_have_asked['parts'].append(
 | 
					                what_i_have_asked["parts"].append(
 | 
				
			||||||
                    {'inline_data': {
 | 
					                    {
 | 
				
			||||||
 | 
					                        "inline_data": {
 | 
				
			||||||
                            "mime_type": f"image/{data['type']}",
 | 
					                            "mime_type": f"image/{data['type']}",
 | 
				
			||||||
                        "data": data['data']
 | 
					                            "data": data["data"],
 | 
				
			||||||
                    }})
 | 
					                        }
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
        return what_i_have_asked
 | 
					        return what_i_have_asked
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __conversation_history(self, history, llm_kwargs):
 | 
					    def __conversation_history(self, history, llm_kwargs):
 | 
				
			||||||
@ -153,40 +186,43 @@ class GoogleChatInit:
 | 
				
			|||||||
                what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
 | 
					                what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
 | 
				
			||||||
                what_gpt_answer = {
 | 
					                what_gpt_answer = {
 | 
				
			||||||
                    "role": "model",
 | 
					                    "role": "model",
 | 
				
			||||||
                    "parts": [{"text": history[index + 1]}]
 | 
					                    "parts": [{"text": history[index + 1]}],
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                messages.append(what_i_have_asked)
 | 
					                messages.append(what_i_have_asked)
 | 
				
			||||||
                messages.append(what_gpt_answer)
 | 
					                messages.append(what_gpt_answer)
 | 
				
			||||||
        return messages
 | 
					        return messages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
 | 
					    def generate_message_payload(
 | 
				
			||||||
 | 
					        self, inputs, llm_kwargs, history, system_prompt
 | 
				
			||||||
 | 
					    ) -> Tuple[Dict, Dict]:
 | 
				
			||||||
        messages = [
 | 
					        messages = [
 | 
				
			||||||
            # {"role": "system", "parts": [{"text": system_prompt}]},  # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
 | 
					            # {"role": "system", "parts": [{"text": system_prompt}]},  # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
 | 
				
			||||||
            # {"role": "user", "parts": [{"text": ""}]},
 | 
					            # {"role": "user", "parts": [{"text": ""}]},
 | 
				
			||||||
            # {"role": "model", "parts": [{"text": ""}]}
 | 
					            # {"role": "model", "parts": [{"text": ""}]}
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
        self.url_gemini = self.url_gemini.replace(
 | 
					        self.url_gemini = self.url_gemini.replace(
 | 
				
			||||||
            '%m', llm_kwargs['llm_model']).replace(
 | 
					            "%m", llm_kwargs["llm_model"]
 | 
				
			||||||
            '%k', get_conf('GEMINI_API_KEY')
 | 
					        ).replace("%k", get_conf("GEMINI_API_KEY"))
 | 
				
			||||||
        )
 | 
					        header = {"Content-Type": "application/json"}
 | 
				
			||||||
        header = {'Content-Type': 'application/json'}
 | 
					        if "vision" not in self.url_gemini:  # 不是vision 才处理history
 | 
				
			||||||
        if 'vision' not in self.url_gemini:  # 不是vision 才处理history
 | 
					            messages.extend(
 | 
				
			||||||
            messages.extend(self.__conversation_history(history, llm_kwargs))  # 处理 history
 | 
					                self.__conversation_history(history, llm_kwargs)
 | 
				
			||||||
 | 
					            )  # 处理 history
 | 
				
			||||||
        messages.append(self.__conversation_user(inputs, llm_kwargs))  # 处理用户对话
 | 
					        messages.append(self.__conversation_user(inputs, llm_kwargs))  # 处理用户对话
 | 
				
			||||||
        payload = {
 | 
					        payload = {
 | 
				
			||||||
            "contents": messages,
 | 
					            "contents": messages,
 | 
				
			||||||
            "generationConfig": {
 | 
					            "generationConfig": {
 | 
				
			||||||
                # "maxOutputTokens": 800,
 | 
					                # "maxOutputTokens": 800,
 | 
				
			||||||
                "stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
 | 
					                "stopSequences": str(llm_kwargs.get("stop", "")).split(" "),
 | 
				
			||||||
                "temperature": llm_kwargs.get('temperature', 1),
 | 
					                "temperature": llm_kwargs.get("temperature", 1),
 | 
				
			||||||
                "topP": llm_kwargs.get('top_p', 0.8),
 | 
					                "topP": llm_kwargs.get("top_p", 0.8),
 | 
				
			||||||
                "topK": 10
 | 
					                "topK": 10,
 | 
				
			||||||
            }
 | 
					            },
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        return header, payload
 | 
					        return header, payload
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    google = GoogleChatInit()
 | 
					    google = GoogleChatInit()
 | 
				
			||||||
    # print(gootle.generate_message_payload('你好呀', {},  ['123123', '3123123'], ''))
 | 
					    # print(gootle.generate_message_payload('你好呀', {},  ['123123', '3123123'], ''))
 | 
				
			||||||
    # gootle.input_encode_handle('123123[123123](./123123), ')
 | 
					    # gootle.input_encode_handle('123123[123123](./123123), ')
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user