Add GLM INT8

This commit is contained in:
binary-husky 2023-07-24 18:19:57 +08:00
parent dd4ba0ea22
commit e93b6fa3a6
3 changed files with 14 additions and 9 deletions

View File

@ -80,7 +80,7 @@ ChatGLM_PTUNING_CHECKPOINT = "" # 例如"/home/hmp/ChatGLM2-6B/ptuning/output/6b
# 本地LLM模型如ChatGLM的执行方式 CPU/GPU # 本地LLM模型如ChatGLM的执行方式 CPU/GPU
LOCAL_MODEL_DEVICE = "cpu" # 可选 "cuda" LOCAL_MODEL_DEVICE = "cpu" # 可选 "cuda"
LOCAL_MODEL_QUANT = "INT4" # 默认 "" "INT4" 启用量化INT4版本 "INT8" 启用量化INT8版本 LOCAL_MODEL_QUANT = "FP16" # 默认 "FP16" "INT4" 启用量化INT4版本 "INT8" 启用量化INT8版本
# 设置gradio的并行线程数不需要修改 # 设置gradio的并行线程数不需要修改

View File

@ -37,19 +37,23 @@ class GetGLMHandle(Process):
# 子进程执行 # 子进程执行
# 第一次运行,加载参数 # 第一次运行,加载参数
retry = 0 retry = 0
pretrained_model_name_or_path = "THUDM/chatglm2-6b" LOCAL_MODEL_QUANT, device = get_conf('LOCAL_MODEL_QUANT', 'LOCAL_MODEL_DEVICE')
LOCAL_MODEL_QUANT = get_conf('LOCAL_MODEL_QUANT')
if LOCAL_MODEL_QUANT and len(LOCAL_MODEL_QUANT) > 0 and LOCAL_MODEL_QUANT[0] == "INT4": if LOCAL_MODEL_QUANT == "INT4": # INT4
pretrained_model_name_or_path = "THUDM/chatglm2-6b-int4" _model_name_ = "THUDM/chatglm2-6b-int4"
elif LOCAL_MODEL_QUANT == "INT8": # INT8
_model_name_ = "THUDM/chatglm2-6b-int8"
else:
_model_name_ = "THUDM/chatglm2-6b" # FP16
while True: while True:
try: try:
if self.chatglm_model is None: if self.chatglm_model is None:
self.chatglm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) self.chatglm_tokenizer = AutoTokenizer.from_pretrained(_model_name_, trust_remote_code=True)
device, = get_conf('LOCAL_MODEL_DEVICE')
if device=='cpu': if device=='cpu':
self.chatglm_model = AutoModel.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True).float() self.chatglm_model = AutoModel.from_pretrained(_model_name_, trust_remote_code=True).float()
else: else:
self.chatglm_model = AutoModel.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True).half().cuda() self.chatglm_model = AutoModel.from_pretrained(_model_name_, trust_remote_code=True).half().cuda()
self.chatglm_model = self.chatglm_model.eval() self.chatglm_model = self.chatglm_model.eval()
break break
else: else:

View File

@ -681,6 +681,7 @@ def read_single_conf_with_lru_cache(arg):
else: else:
print亮红( "[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式请在config文件中修改API密钥之后再运行。") print亮红( "[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式请在config文件中修改API密钥之后再运行。")
if arg == 'proxies': if arg == 'proxies':
if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY防止proxies单独起作用
if r is None: if r is None:
print亮红('[PROXY] 网络代理状态未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议检查USE_PROXY选项是否修改。') print亮红('[PROXY] 网络代理状态未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议检查USE_PROXY选项是否修改。')
else: else: