From da7c03e868b89f71b52444a0565ae4d08e50293a Mon Sep 17 00:00:00 2001 From: qingxu fu <505030475@qq.com> Date: Fri, 10 Nov 2023 22:54:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=BE=E5=83=8F=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crazy_functional.py | 17 ++- .../multi_stage/multi_stage_utils.py | 45 +++++++ crazy_functions/图片生成.py | 125 ++++++++++++++++-- toolbox.py | 7 +- 4 files changed, 176 insertions(+), 18 deletions(-) create mode 100644 crazy_functions/multi_stage/multi_stage_utils.py diff --git a/crazy_functional.py b/crazy_functional.py index e82f399..2e94570 100644 --- a/crazy_functional.py +++ b/crazy_functional.py @@ -349,16 +349,16 @@ def get_crazy_functions(): print('Load function plugin failed') try: - from crazy_functions.图片生成 import 图片生成, 图片生成_DALLE3 + from crazy_functions.图片生成 import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2 function_plugins.update({ - "图片生成(先切换模型到openai或api2d)": { + "图片生成_DALLE2(先切换模型到openai或api2d)": { "Group": "对话", "Color": "stop", "AsButton": False, "AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False) "ArgsReminder": "在这里输入分辨率, 如1024x1024(默认),支持 256x256, 512x512, 1024x1024", # 高级参数输入区的显示提示 "Info": "使用DALLE2生成图片 | 输入参数字符串,提供图像的内容", - "Function": HotReload(图片生成) + "Function": HotReload(图片生成_DALLE2) }, }) function_plugins.update({ @@ -372,6 +372,17 @@ def get_crazy_functions(): "Function": HotReload(图片生成_DALLE3) }, }) + # function_plugins.update({ + # "图片修改_DALLE2(启动DALLE2图像修改向导程序)": { + # "Group": "对话", + # "Color": "stop", + # "AsButton": False, + # "AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False) + # "ArgsReminder": "在这里输入分辨率, 如1024x1024(默认),支持 1024x1024, 1792x1024, 1024x1792", # 高级参数输入区的显示提示 + # # "Info": "使用DALLE2修改图片 | 输入参数字符串,提供图像的内容", + # "Function": HotReload(图片修改_DALLE2) + # }, + # }) except: print('Load function plugin failed') diff --git a/crazy_functions/multi_stage/multi_stage_utils.py b/crazy_functions/multi_stage/multi_stage_utils.py new file mode 100644 index 0000000..60f0778 --- /dev/null +++ b/crazy_functions/multi_stage/multi_stage_utils.py @@ -0,0 +1,45 @@ +from pydantic import BaseModel, Field +from typing import List +from toolbox import update_ui_lastest_msg, disable_auto_promotion +from request_llms.bridge_all import predict_no_ui_long_connection +from crazy_functions.json_fns.pydantic_io import GptJsonIO, JsonStringError +import time +import pickle + +def have_any_recent_upload_files(chatbot): + _5min = 5 * 60 + if not chatbot: return False # chatbot is None + most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None) + if not most_recent_uploaded: return False # most_recent_uploaded is None + if time.time() - most_recent_uploaded["time"] < _5min: return True # most_recent_uploaded is new + else: return False # most_recent_uploaded is too old + +class GptAcademicState(): + def __init__(self): + self.reset() + + def reset(self): + pass + + def lock_plugin(self, chatbot): + chatbot._cookies['plugin_state'] = pickle.dumps(self) + + def unlock_plugin(self, chatbot): + self.reset() + chatbot._cookies['plugin_state'] = pickle.dumps(self) + + def set_state(self, chatbot, key, value): + setattr(self, key, value) + chatbot._cookies['plugin_state'] = pickle.dumps(self) + + def get_state(chatbot, cls=None): + state = chatbot._cookies.get('plugin_state', None) + if state is not None: state = pickle.loads(state) + elif cls is not None: state = cls() + else: state = GptAcademicState() + state.chatbot = chatbot + return state + +class GatherMaterials(): + def __init__(self, materials) -> None: + materials = ['image', 'prompt'] \ No newline at end of file diff --git a/crazy_functions/图片生成.py b/crazy_functions/图片生成.py index 95b4481..4968361 100644 --- a/crazy_functions/图片生成.py +++ b/crazy_functions/图片生成.py @@ -1,6 +1,5 @@ from toolbox import CatchException, update_ui, get_conf, select_api_key, get_log_folder -from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive -import datetime +from crazy_functions.multi_stage.multi_stage_utils import GptAcademicState def gen_image(llm_kwargs, prompt, resolution="1024x1024", model="dall-e-2"): @@ -43,9 +42,48 @@ def gen_image(llm_kwargs, prompt, resolution="1024x1024", model="dall-e-2"): return image_url, file_path+file_name +def edit_image(llm_kwargs, prompt, image_path, resolution="1024x1024", model="dall-e-2"): + import requests, json, time, os + from request_llms.bridge_all import model_info + + proxies = get_conf('proxies') + api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model']) + chat_endpoint = model_info[llm_kwargs['llm_model']]['endpoint'] + # 'https://api.openai.com/v1/chat/completions' + img_endpoint = chat_endpoint.replace('chat/completions','images/edits') + # # Generate the image + url = img_endpoint + headers = { + 'Authorization': f"Bearer {api_key}", + 'Content-Type': 'application/json' + } + data = { + 'image': open(image_path, 'rb'), + 'prompt': prompt, + 'n': 1, + 'size': resolution, + 'model': model, + 'response_format': 'url' + } + response = requests.post(url, headers=headers, json=data, proxies=proxies) + print(response.content) + try: + image_url = json.loads(response.content.decode('utf8'))['data'][0]['url'] + except: + raise RuntimeError(response.content.decode()) + # 文件保存到本地 + r = requests.get(image_url, proxies=proxies) + file_path = f'{get_log_folder()}/image_gen/' + os.makedirs(file_path, exist_ok=True) + file_name = 'Image' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.png' + with open(file_path+file_name, 'wb+') as f: f.write(r.content) + + + return image_url, file_path+file_name + @CatchException -def 图片生成(prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): +def 图片生成_DALLE2(prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): """ txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径 llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行 @@ -69,17 +107,9 @@ def 图片生成(prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro ]) yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新 + @CatchException def 图片生成_DALLE3(prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): - """ - txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径 - llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行 - plugin_kwargs 插件模型的参数,暂时没有用武之地 - chatbot 聊天显示框的句柄,用于显示给用户 - history 聊天历史,前情提要 - system_prompt 给gpt的静默提醒 - web_port 当前软件运行的端口号 - """ history = [] # 清空历史,以免输入溢出 chatbot.append(("这是什么功能?", "[Local Message] 生成图像, 请先把模型切换至gpt-*或者api2d-*。如果中文效果不理想, 请尝试英文Prompt。正在处理中 .....")) yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新 @@ -94,3 +124,74 @@ def 图片生成_DALLE3(prompt, llm_kwargs, plugin_kwargs, chatbot, history, sys ]) yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新 + +class ImageEditState(GptAcademicState): + def get_image_file(self, x): + import os, glob + if len(x) == 0: return False, None + if not os.path.exists(x): return False, None + if x.endswith('.png'): return True, x + file_manifest = [f for f in glob.glob(f'{x}/**/*.png', recursive=True)] + confirm = (len(file_manifest) >= 1 and file_manifest[0].endswith('.png') and os.path.exists(file_manifest[0])) + file = None if not confirm else file_manifest[0] + return confirm, file + + def get_resolution(self, x): + return (x in ['256x256', '512x512', '1024x1024']), x + + def get_prompt(self, x): + confirm = (len(x)>=5) and (not self.get_resolution(x)[0]) and (not self.get_image_file(x)[0]) + return confirm, x + + def reset(self): + self.req = [ + {'value':None, 'description': '请先上传图像(必须是.png格式), 然后再次点击本插件', 'verify_fn': self.get_image_file}, + {'value':None, 'description': '请输入分辨率,可选:256x256, 512x512 或 1024x1024', 'verify_fn': self.get_resolution}, + {'value':None, 'description': '请输入修改需求,建议您使用英文提示词', 'verify_fn': self.get_prompt}, + ] + self.info = "" + + def feed(self, prompt, chatbot): + for r in self.req: + if r['value'] is None: + confirm, res = r['verify_fn'](prompt) + if confirm: + r['value'] = res + self.set_state(chatbot, 'dummy_key', 'dummy_value') + break + return self + + def next_req(self): + for r in self.req: + if r['value'] is None: + return r['description'] + return "已经收集到所有信息" + + def already_obtained_all_materials(self): + return all([x['value'] is not None for x in self.req]) + +@CatchException +def 图片修改_DALLE2(prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): + history = [] # 清空历史 + state = ImageEditState.get_state(chatbot, ImageEditState) + state = state.feed(prompt, chatbot) + if not state.already_obtained_all_materials(): + chatbot.append(["图片修改(先上传图片,再输入修改需求,最后输入分辨率)", state.next_req()]) + yield from update_ui(chatbot=chatbot, history=history) + return + + image_path = state.req[0] + resolution = state.req[1] + prompt = state.req[2] + chatbot.append(["图片修改, 执行中", f"图片:`{image_path}`
分辨率:`{resolution}`
修改需求:`{prompt}`"]) + yield from update_ui(chatbot=chatbot, history=history) + + image_url, image_path = edit_image(llm_kwargs, prompt, image_path, resolution) + chatbot.append([state.prompt, + f'图像中转网址:
`{image_url}`
'+ + f'中转网址预览:
' + f'本地文件地址:
`{image_path}`
'+ + f'本地文件预览:
' + ]) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新 + diff --git a/toolbox.py b/toolbox.py index 8c6e7fa..b1e1ce7 100644 --- a/toolbox.py +++ b/toolbox.py @@ -625,13 +625,14 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo def on_report_generated(cookies, files, chatbot): - from toolbox import find_recent_files - PATH_LOGGING = get_conf('PATH_LOGGING') + # from toolbox import find_recent_files + # PATH_LOGGING = get_conf('PATH_LOGGING') if 'files_to_promote' in cookies: report_files = cookies['files_to_promote'] cookies.pop('files_to_promote') else: - report_files = find_recent_files(PATH_LOGGING) + report_files = [] + # report_files = find_recent_files(PATH_LOGGING) if len(report_files) == 0: return cookies, None, chatbot # files.extend(report_files)