Merge pull request #253 from RongkangXiong/dev
add crazy_functions 解析一个Java项目
This commit is contained in:
		
						commit
						ab57f4bfb0
					
				
							
								
								
									
										0
									
								
								crazy_functions/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								crazy_functions/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -148,3 +148,66 @@ def 解析一个C项目(txt, top_p, temperature, chatbot, history, systemPromptT
 | 
			
		||||
        return
 | 
			
		||||
    yield from 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CatchException
 | 
			
		||||
def 解析一个Java项目(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
 | 
			
		||||
    history = []  # 清空历史,以免输入溢出
 | 
			
		||||
    import glob, os
 | 
			
		||||
    if os.path.exists(txt):
 | 
			
		||||
        project_folder = txt
 | 
			
		||||
    else:
 | 
			
		||||
        if txt == "": txt = '空空如也的输入栏'
 | 
			
		||||
        report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
 | 
			
		||||
        yield chatbot, history, '正常'
 | 
			
		||||
        return
 | 
			
		||||
    file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.java', recursive=True)] + \
 | 
			
		||||
                    [f for f in glob.glob(f'{project_folder}/**/*.jar', recursive=True)] + \
 | 
			
		||||
                    [f for f in glob.glob(f'{project_folder}/**/*.xml', recursive=True)] + \
 | 
			
		||||
                    [f for f in glob.glob(f'{project_folder}/**/*.sh', recursive=True)]
 | 
			
		||||
    if len(file_manifest) == 0:
 | 
			
		||||
        report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何java文件: {txt}")
 | 
			
		||||
        yield chatbot, history, '正常'
 | 
			
		||||
        return
 | 
			
		||||
    yield from 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CatchException
 | 
			
		||||
def 解析一个Rect项目(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
 | 
			
		||||
    history = []  # 清空历史,以免输入溢出
 | 
			
		||||
    import glob, os
 | 
			
		||||
    if os.path.exists(txt):
 | 
			
		||||
        project_folder = txt
 | 
			
		||||
    else:
 | 
			
		||||
        if txt == "": txt = '空空如也的输入栏'
 | 
			
		||||
        report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
 | 
			
		||||
        yield chatbot, history, '正常'
 | 
			
		||||
        return
 | 
			
		||||
    file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.ts', recursive=True)] + \
 | 
			
		||||
                    [f for f in glob.glob(f'{project_folder}/**/*.tsx', recursive=True)] + \
 | 
			
		||||
                    [f for f in glob.glob(f'{project_folder}/**/*.json', recursive=True)] + \
 | 
			
		||||
                    [f for f in glob.glob(f'{project_folder}/**/*.js', recursive=True)] + \
 | 
			
		||||
                    [f for f in glob.glob(f'{project_folder}/**/*.jsx', recursive=True)]
 | 
			
		||||
    if len(file_manifest) == 0:
 | 
			
		||||
        report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何Rect文件: {txt}")
 | 
			
		||||
        yield chatbot, history, '正常'
 | 
			
		||||
        return
 | 
			
		||||
    yield from 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CatchException
 | 
			
		||||
def 解析一个Golang项目(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
 | 
			
		||||
    history = []  # 清空历史,以免输入溢出
 | 
			
		||||
    import glob, os
 | 
			
		||||
    if os.path.exists(txt):
 | 
			
		||||
        project_folder = txt
 | 
			
		||||
    else:
 | 
			
		||||
        if txt == "": txt = '空空如也的输入栏'
 | 
			
		||||
        report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
 | 
			
		||||
        yield chatbot, history, '正常'
 | 
			
		||||
        return
 | 
			
		||||
    file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.go', recursive=True)]
 | 
			
		||||
    if len(file_manifest) == 0:
 | 
			
		||||
        report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何golang文件: {txt}")
 | 
			
		||||
        yield chatbot, history, '正常'
 | 
			
		||||
        return
 | 
			
		||||
    yield from 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,9 @@ def get_crazy_functionals():
 | 
			
		||||
    from crazy_functions.解析项目源代码 import 解析一个Python项目
 | 
			
		||||
    from crazy_functions.解析项目源代码 import 解析一个C项目的头文件
 | 
			
		||||
    from crazy_functions.解析项目源代码 import 解析一个C项目
 | 
			
		||||
    from crazy_functions.解析项目源代码 import 解析一个Golang项目
 | 
			
		||||
    from crazy_functions.解析项目源代码 import 解析一个Java项目
 | 
			
		||||
    from crazy_functions.解析项目源代码 import 解析一个Rect项目
 | 
			
		||||
    from crazy_functions.高级功能函数模板 import 高阶功能模板函数
 | 
			
		||||
    from crazy_functions.代码重写为全英文_多线程 import 全项目切换英文
 | 
			
		||||
 | 
			
		||||
@ -30,6 +33,21 @@ def get_crazy_functionals():
 | 
			
		||||
            "AsButton": False,  # 加入下拉菜单中
 | 
			
		||||
            "Function": 解析一个C项目
 | 
			
		||||
        },
 | 
			
		||||
        "解析整个Go项目": {
 | 
			
		||||
            "Color": "stop",    # 按钮颜色
 | 
			
		||||
            "AsButton": False,  # 加入下拉菜单中
 | 
			
		||||
            "Function": 解析一个Golang项目
 | 
			
		||||
        },
 | 
			
		||||
        "解析整个Java项目": {
 | 
			
		||||
            "Color": "stop",  # 按钮颜色
 | 
			
		||||
            "AsButton": False,  # 加入下拉菜单中
 | 
			
		||||
            "Function": 解析一个Java项目
 | 
			
		||||
        },
 | 
			
		||||
        "解析整个Java项目": {
 | 
			
		||||
            "Color": "stop",  # 按钮颜色
 | 
			
		||||
            "AsButton": False,  # 加入下拉菜单中
 | 
			
		||||
            "Function": 解析一个Rect项目
 | 
			
		||||
        },
 | 
			
		||||
        "读Tex论文写摘要": {
 | 
			
		||||
            "Color": "stop",    # 按钮颜色
 | 
			
		||||
            "Function": 读文章写摘要
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										36
									
								
								request_llm/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								request_llm/README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,36 @@
 | 
			
		||||
# 如何使用其他大语言模型
 | 
			
		||||
 | 
			
		||||
## 1. 先运行text-generation
 | 
			
		||||
``` sh
 | 
			
		||||
# 下载模型( text-generation 这么牛的项目,别忘了给人家star )
 | 
			
		||||
git clone https://github.com/oobabooga/text-generation-webui.git
 | 
			
		||||
 | 
			
		||||
# 安装text-generation的额外依赖
 | 
			
		||||
pip install accelerate bitsandbytes flexgen gradio llamacpp markdown numpy peft requests rwkv safetensors sentencepiece tqdm datasets git+https://github.com/huggingface/transformers
 | 
			
		||||
 | 
			
		||||
# 切换路径
 | 
			
		||||
cd text-generation-webui
 | 
			
		||||
 | 
			
		||||
# 下载模型
 | 
			
		||||
python download-model.py facebook/galactica-1.3b
 | 
			
		||||
# 其他可选如 facebook/opt-1.3b
 | 
			
		||||
#           facebook/galactica-6.7b
 | 
			
		||||
#           facebook/galactica-120b
 | 
			
		||||
#           facebook/pygmalion-1.3b 等
 | 
			
		||||
# 详情见 https://github.com/oobabooga/text-generation-webui
 | 
			
		||||
 | 
			
		||||
# 启动text-generation,注意把模型的斜杠改成下划线
 | 
			
		||||
python server.py --cpu --listen --listen-port 7860 --model facebook_galactica-1.3b
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## 2. 修改config.py
 | 
			
		||||
``` sh
 | 
			
		||||
# LLM_MODEL格式较复杂   TGUI:[模型]@[ws地址]:[ws端口] ,   端口要和上面给定的端口一致
 | 
			
		||||
LLM_MODEL = "TGUI:galactica-1.3b@localhost:7860"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## 3. 运行!
 | 
			
		||||
``` sh
 | 
			
		||||
cd chatgpt-academic
 | 
			
		||||
python main.py
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										167
									
								
								request_llm/bridge_tgui.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										167
									
								
								request_llm/bridge_tgui.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,167 @@
 | 
			
		||||
'''
 | 
			
		||||
Contributed by SagsMug. Modified by binary-husky
 | 
			
		||||
https://github.com/oobabooga/text-generation-webui/pull/175
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import json
 | 
			
		||||
import random
 | 
			
		||||
import string
 | 
			
		||||
import websockets
 | 
			
		||||
import logging
 | 
			
		||||
import time
 | 
			
		||||
import threading
 | 
			
		||||
import importlib
 | 
			
		||||
from toolbox import get_conf
 | 
			
		||||
LLM_MODEL, = get_conf('LLM_MODEL')
 | 
			
		||||
 | 
			
		||||
# "TGUI:galactica-1.3b@localhost:7860"
 | 
			
		||||
model_name, addr_port = LLM_MODEL.split('@')
 | 
			
		||||
assert ':' in addr_port, "LLM_MODEL 格式不正确!" + LLM_MODEL
 | 
			
		||||
addr, port = addr_port.split(':')
 | 
			
		||||
 | 
			
		||||
def random_hash():
 | 
			
		||||
    letters = string.ascii_lowercase + string.digits
 | 
			
		||||
    return ''.join(random.choice(letters) for i in range(9))
 | 
			
		||||
 | 
			
		||||
async def run(context, max_token=512):
 | 
			
		||||
    params = {
 | 
			
		||||
        'max_new_tokens': max_token,
 | 
			
		||||
        'do_sample': True,
 | 
			
		||||
        'temperature': 0.5,
 | 
			
		||||
        'top_p': 0.9,
 | 
			
		||||
        'typical_p': 1,
 | 
			
		||||
        'repetition_penalty': 1.05,
 | 
			
		||||
        'encoder_repetition_penalty': 1.0,
 | 
			
		||||
        'top_k': 0,
 | 
			
		||||
        'min_length': 0,
 | 
			
		||||
        'no_repeat_ngram_size': 0,
 | 
			
		||||
        'num_beams': 1,
 | 
			
		||||
        'penalty_alpha': 0,
 | 
			
		||||
        'length_penalty': 1,
 | 
			
		||||
        'early_stopping': True,
 | 
			
		||||
        'seed': -1,
 | 
			
		||||
    }
 | 
			
		||||
    session = random_hash()
 | 
			
		||||
 | 
			
		||||
    async with websockets.connect(f"ws://{addr}:{port}/queue/join") as websocket:
 | 
			
		||||
        while content := json.loads(await websocket.recv()):
 | 
			
		||||
            #Python3.10 syntax, replace with if elif on older
 | 
			
		||||
            if content["msg"] ==  "send_hash":
 | 
			
		||||
                await websocket.send(json.dumps({
 | 
			
		||||
                    "session_hash": session,
 | 
			
		||||
                    "fn_index": 12
 | 
			
		||||
                }))
 | 
			
		||||
            elif content["msg"] ==  "estimation":
 | 
			
		||||
                pass
 | 
			
		||||
            elif content["msg"] ==  "send_data":
 | 
			
		||||
                await websocket.send(json.dumps({
 | 
			
		||||
                    "session_hash": session,
 | 
			
		||||
                    "fn_index": 12,
 | 
			
		||||
                    "data": [
 | 
			
		||||
                        context,
 | 
			
		||||
                        params['max_new_tokens'],
 | 
			
		||||
                        params['do_sample'],
 | 
			
		||||
                        params['temperature'],
 | 
			
		||||
                        params['top_p'],
 | 
			
		||||
                        params['typical_p'],
 | 
			
		||||
                        params['repetition_penalty'],
 | 
			
		||||
                        params['encoder_repetition_penalty'],
 | 
			
		||||
                        params['top_k'],
 | 
			
		||||
                        params['min_length'],
 | 
			
		||||
                        params['no_repeat_ngram_size'],
 | 
			
		||||
                        params['num_beams'],
 | 
			
		||||
                        params['penalty_alpha'],
 | 
			
		||||
                        params['length_penalty'],
 | 
			
		||||
                        params['early_stopping'],
 | 
			
		||||
                        params['seed'],
 | 
			
		||||
                    ]
 | 
			
		||||
                }))
 | 
			
		||||
            elif content["msg"] ==  "process_starts":
 | 
			
		||||
                pass
 | 
			
		||||
            elif content["msg"] in ["process_generating", "process_completed"]:
 | 
			
		||||
                yield content["output"]["data"][0]
 | 
			
		||||
                # You can search for your desired end indicator and 
 | 
			
		||||
                #  stop generation by closing the websocket here
 | 
			
		||||
                if (content["msg"] == "process_completed"):
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict_tgui(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', stream = True, additional_fn=None):
 | 
			
		||||
    """
 | 
			
		||||
        发送至chatGPT,流式获取输出。
 | 
			
		||||
        用于基础的对话功能。
 | 
			
		||||
        inputs 是本次问询的输入
 | 
			
		||||
        top_p, temperature是chatGPT的内部调优参数
 | 
			
		||||
        history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
 | 
			
		||||
        chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
 | 
			
		||||
        additional_fn代表点击的哪个按钮,按钮见functional.py
 | 
			
		||||
    """
 | 
			
		||||
    if additional_fn is not None:
 | 
			
		||||
        import functional
 | 
			
		||||
        importlib.reload(functional)    # 热更新prompt
 | 
			
		||||
        functional = functional.get_functionals()
 | 
			
		||||
        if "PreProcess" in functional[additional_fn]: inputs = functional[additional_fn]["PreProcess"](inputs)  # 获取预处理函数(如果有的话)
 | 
			
		||||
        inputs = functional[additional_fn]["Prefix"] + inputs + functional[additional_fn]["Suffix"]
 | 
			
		||||
 | 
			
		||||
    raw_input = "What I would like to say is the following: " + inputs
 | 
			
		||||
    logging.info(f'[raw_input] {raw_input}')
 | 
			
		||||
    history.extend([inputs, ""])
 | 
			
		||||
    chatbot.append([inputs, ""])
 | 
			
		||||
    yield chatbot, history, "等待响应"
 | 
			
		||||
 | 
			
		||||
    prompt = inputs
 | 
			
		||||
    tgui_say = ""
 | 
			
		||||
 | 
			
		||||
    mutable = ["", time.time()]
 | 
			
		||||
    def run_coorotine(mutable):
 | 
			
		||||
        async def get_result(mutable):
 | 
			
		||||
            async for response in run(prompt):
 | 
			
		||||
                print(response[len(mutable[0]):])
 | 
			
		||||
                mutable[0] = response
 | 
			
		||||
                if (time.time() - mutable[1]) > 3: 
 | 
			
		||||
                    print('exit when no listener')
 | 
			
		||||
                    break
 | 
			
		||||
        asyncio.run(get_result(mutable))
 | 
			
		||||
 | 
			
		||||
    thread_listen = threading.Thread(target=run_coorotine, args=(mutable,), daemon=True)
 | 
			
		||||
    thread_listen.start()
 | 
			
		||||
 | 
			
		||||
    while thread_listen.is_alive():
 | 
			
		||||
        time.sleep(1)
 | 
			
		||||
        mutable[1] = time.time()
 | 
			
		||||
        # Print intermediate steps
 | 
			
		||||
        if tgui_say != mutable[0]:
 | 
			
		||||
            tgui_say = mutable[0]
 | 
			
		||||
            history[-1] = tgui_say
 | 
			
		||||
            chatbot[-1] = (history[-2], history[-1])
 | 
			
		||||
            yield chatbot, history, "status_text"
 | 
			
		||||
 | 
			
		||||
    logging.info(f'[response] {tgui_say}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict_tgui_no_ui(inputs, top_p, temperature, history=[], sys_prompt=""):
 | 
			
		||||
    raw_input = "What I would like to say is the following: " + inputs
 | 
			
		||||
    prompt = inputs
 | 
			
		||||
    tgui_say = ""
 | 
			
		||||
    mutable = ["", time.time()]
 | 
			
		||||
    def run_coorotine(mutable):
 | 
			
		||||
        async def get_result(mutable):
 | 
			
		||||
            async for response in run(prompt, max_token=20):
 | 
			
		||||
                print(response[len(mutable[0]):])
 | 
			
		||||
                mutable[0] = response
 | 
			
		||||
                if (time.time() - mutable[1]) > 3: 
 | 
			
		||||
                    print('exit when no listener')
 | 
			
		||||
                    break
 | 
			
		||||
        asyncio.run(get_result(mutable))
 | 
			
		||||
    thread_listen = threading.Thread(target=run_coorotine, args=(mutable,))
 | 
			
		||||
    thread_listen.start()
 | 
			
		||||
    while thread_listen.is_alive():
 | 
			
		||||
        time.sleep(1)
 | 
			
		||||
        mutable[1] = time.time()
 | 
			
		||||
    tgui_say = mutable[0]
 | 
			
		||||
    return tgui_say
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user