mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 06:49:48 +08:00
修改模型生成的调用方式,兼容Chain调用
修改模型切换的bug
This commit is contained in:
parent
ca13ab8173
commit
c5bc21781c
7
api.py
7
api.py
@ -384,8 +384,10 @@ async def chat(
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
streaming=True):
|
{"prompt": question, "history": history, "streaming": True})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
pass
|
pass
|
||||||
@ -486,7 +488,6 @@ def api_start(host, port, **kwargs):
|
|||||||
global local_doc_qa
|
global local_doc_qa
|
||||||
|
|
||||||
llm_model_ins = shared.loaderLLM()
|
llm_model_ins = shared.loaderLLM()
|
||||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
# Add CORS middleware to allow all origins
|
# Add CORS middleware to allow all origins
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from agent import bing_search
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from textsplitter.zh_title_enhance import zh_title_enhance
|
from textsplitter.zh_title_enhance import zh_title_enhance
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
|
||||||
# patch HuggingFaceEmbeddings to make it hashable
|
# patch HuggingFaceEmbeddings to make it hashable
|
||||||
@ -119,7 +120,7 @@ def search_result2docs(search_results):
|
|||||||
|
|
||||||
|
|
||||||
class LocalDocQA:
|
class LocalDocQA:
|
||||||
llm: BaseAnswer = None
|
llm_model_chain: Chain = None
|
||||||
embeddings: object = None
|
embeddings: object = None
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K
|
top_k: int = VECTOR_SEARCH_TOP_K
|
||||||
chunk_size: int = CHUNK_SIZE
|
chunk_size: int = CHUNK_SIZE
|
||||||
@ -129,10 +130,10 @@ class LocalDocQA:
|
|||||||
def init_cfg(self,
|
def init_cfg(self,
|
||||||
embedding_model: str = EMBEDDING_MODEL,
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
embedding_device=EMBEDDING_DEVICE,
|
embedding_device=EMBEDDING_DEVICE,
|
||||||
llm_model: BaseAnswer = None,
|
llm_model: Chain = None,
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
):
|
):
|
||||||
self.llm = llm_model
|
self.llm_model_chain = llm_model
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||||
model_kwargs={'device': embedding_device})
|
model_kwargs={'device': embedding_device})
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
@ -236,8 +237,10 @@ class LocalDocQA:
|
|||||||
else:
|
else:
|
||||||
prompt = query
|
prompt = query
|
||||||
|
|
||||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
answer_result_stream_result = self.llm_model_chain(
|
||||||
streaming=streaming):
|
{"prompt": prompt, "history": chat_history, "streaming": streaming})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
history[-1][0] = query
|
history[-1][0] = query
|
||||||
@ -276,8 +279,10 @@ class LocalDocQA:
|
|||||||
result_docs = search_result2docs(results)
|
result_docs = search_result2docs(results)
|
||||||
prompt = generate_prompt(result_docs, query)
|
prompt = generate_prompt(result_docs, query)
|
||||||
|
|
||||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
answer_result_stream_result = self.llm_model_chain(
|
||||||
streaming=streaming):
|
{"prompt": prompt, "history": chat_history, "streaming": streaming})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
history[-1][0] = query
|
history[-1][0] = query
|
||||||
@ -296,7 +301,7 @@ class LocalDocQA:
|
|||||||
def update_file_from_vector_store(self,
|
def update_file_from_vector_store(self,
|
||||||
filepath: str or List[str],
|
filepath: str or List[str],
|
||||||
vs_path,
|
vs_path,
|
||||||
docs: List[Document],):
|
docs: List[Document], ):
|
||||||
vector_store = load_vector_store(vs_path, self.embeddings)
|
vector_store = load_vector_store(vs_path, self.embeddings)
|
||||||
status = vector_store.update_doc(filepath, docs)
|
status = vector_store.update_doc(filepath, docs)
|
||||||
return status
|
return status
|
||||||
@ -320,7 +325,6 @@ if __name__ == "__main__":
|
|||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
llm_model_ins = shared.loaderLLM()
|
llm_model_ins = shared.loaderLLM()
|
||||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
|
||||||
|
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||||
|
|||||||
@ -37,61 +37,67 @@ llm_model_dict = {
|
|||||||
"name": "chatglm-6b-int4-qe",
|
"name": "chatglm-6b-int4-qe",
|
||||||
"pretrained_model_name": "THUDM/chatglm-6b-int4-qe",
|
"pretrained_model_name": "THUDM/chatglm-6b-int4-qe",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "ChatGLM"
|
"provides": "ChatGLMLLMChain"
|
||||||
},
|
},
|
||||||
"chatglm-6b-int4": {
|
"chatglm-6b-int4": {
|
||||||
"name": "chatglm-6b-int4",
|
"name": "chatglm-6b-int4",
|
||||||
"pretrained_model_name": "THUDM/chatglm-6b-int4",
|
"pretrained_model_name": "THUDM/chatglm-6b-int4",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "ChatGLM"
|
"provides": "ChatGLMLLMChain"
|
||||||
},
|
},
|
||||||
"chatglm-6b-int8": {
|
"chatglm-6b-int8": {
|
||||||
"name": "chatglm-6b-int8",
|
"name": "chatglm-6b-int8",
|
||||||
"pretrained_model_name": "THUDM/chatglm-6b-int8",
|
"pretrained_model_name": "THUDM/chatglm-6b-int8",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "ChatGLM"
|
"provides": "ChatGLMLLMChain"
|
||||||
},
|
},
|
||||||
"chatglm-6b": {
|
"chatglm-6b": {
|
||||||
"name": "chatglm-6b",
|
"name": "chatglm-6b",
|
||||||
"pretrained_model_name": "THUDM/chatglm-6b",
|
"pretrained_model_name": "THUDM/chatglm-6b",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "ChatGLM"
|
"provides": "ChatGLMLLMChain"
|
||||||
},
|
},
|
||||||
"chatglm2-6b": {
|
"chatglm2-6b": {
|
||||||
"name": "chatglm2-6b",
|
"name": "chatglm2-6b",
|
||||||
"pretrained_model_name": "THUDM/chatglm2-6b",
|
"pretrained_model_name": "THUDM/chatglm2-6b",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "ChatGLM"
|
"provides": "ChatGLMLLMChain"
|
||||||
},
|
},
|
||||||
"chatglm2-6b-int4": {
|
"chatglm2-6b-int4": {
|
||||||
"name": "chatglm2-6b-int4",
|
"name": "chatglm2-6b-int4",
|
||||||
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
|
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "ChatGLM"
|
"provides": "ChatGLMLLMChain"
|
||||||
},
|
},
|
||||||
"chatglm2-6b-int8": {
|
"chatglm2-6b-int8": {
|
||||||
"name": "chatglm2-6b-int8",
|
"name": "chatglm2-6b-int8",
|
||||||
"pretrained_model_name": "THUDM/chatglm2-6b-int8",
|
"pretrained_model_name": "THUDM/chatglm2-6b-int8",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "ChatGLM"
|
"provides": "ChatGLMLLMChain"
|
||||||
},
|
},
|
||||||
"chatyuan": {
|
"chatyuan": {
|
||||||
"name": "chatyuan",
|
"name": "chatyuan",
|
||||||
"pretrained_model_name": "ClueAI/ChatYuan-large-v2",
|
"pretrained_model_name": "ClueAI/ChatYuan-large-v2",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "MOSSLLM"
|
"provides": "MOSSLLMChain"
|
||||||
},
|
},
|
||||||
"moss": {
|
"moss": {
|
||||||
"name": "moss",
|
"name": "moss",
|
||||||
"pretrained_model_name": "fnlp/moss-moon-003-sft",
|
"pretrained_model_name": "fnlp/moss-moon-003-sft",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "MOSSLLM"
|
"provides": "MOSSLLMChain"
|
||||||
},
|
},
|
||||||
"vicuna-13b-hf": {
|
"vicuna-13b-hf": {
|
||||||
"name": "vicuna-13b-hf",
|
"name": "vicuna-13b-hf",
|
||||||
"pretrained_model_name": "vicuna-13b-hf",
|
"pretrained_model_name": "vicuna-13b-hf",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "LLamaLLM"
|
"provides": "LLamaLLMChain"
|
||||||
|
},
|
||||||
|
"vicuna-7b-hf": {
|
||||||
|
"name": "vicuna-13b-hf",
|
||||||
|
"pretrained_model_name": "vicuna-13b-hf",
|
||||||
|
"local_model_path": None,
|
||||||
|
"provides": "LLamaLLMChain"
|
||||||
},
|
},
|
||||||
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
|
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
|
||||||
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
|
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
|
||||||
@ -101,7 +107,7 @@ llm_model_dict = {
|
|||||||
"name": "bloomz-7b1",
|
"name": "bloomz-7b1",
|
||||||
"pretrained_model_name": "bigscience/bloomz-7b1",
|
"pretrained_model_name": "bigscience/bloomz-7b1",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "MOSSLLM"
|
"provides": "MOSSLLMChain"
|
||||||
|
|
||||||
},
|
},
|
||||||
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
|
# 实测加载bigscience/bloom-3b需要170秒左右,暂不清楚为什么这么慢
|
||||||
@ -110,14 +116,14 @@ llm_model_dict = {
|
|||||||
"name": "bloom-3b",
|
"name": "bloom-3b",
|
||||||
"pretrained_model_name": "bigscience/bloom-3b",
|
"pretrained_model_name": "bigscience/bloom-3b",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "MOSSLLM"
|
"provides": "MOSSLLMChain"
|
||||||
|
|
||||||
},
|
},
|
||||||
"baichuan-7b": {
|
"baichuan-7b": {
|
||||||
"name": "baichuan-7b",
|
"name": "baichuan-7b",
|
||||||
"pretrained_model_name": "baichuan-inc/baichuan-7B",
|
"pretrained_model_name": "baichuan-inc/baichuan-7B",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "MOSSLLM"
|
"provides": "MOSSLLMChain"
|
||||||
},
|
},
|
||||||
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
|
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
|
||||||
"ggml-vicuna-13b-1.1-q5": {
|
"ggml-vicuna-13b-1.1-q5": {
|
||||||
@ -131,7 +137,7 @@ llm_model_dict = {
|
|||||||
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
|
# 需要手动从https://github.com/abetlen/llama-cpp-python/releases/tag/下载对应的wheel安装
|
||||||
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
|
# 实测v0.1.63与本模型的vicuna/ggml-vicuna-13b-1.1/ggml-vic13b-q5_1.bin可以兼容
|
||||||
"local_model_path": f'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/''',
|
"local_model_path": f'''{"/".join(os.path.abspath(__file__).split("/")[:3])}/.cache/huggingface/hub/models--vicuna--ggml-vicuna-13b-1.1/blobs/''',
|
||||||
"provides": "LLamaLLM"
|
"provides": "LLamaLLMChain"
|
||||||
},
|
},
|
||||||
|
|
||||||
# 通过 fastchat 调用的模型请参考如下格式
|
# 通过 fastchat 调用的模型请参考如下格式
|
||||||
@ -139,7 +145,7 @@ llm_model_dict = {
|
|||||||
"name": "chatglm-6b", # "name"修改为fastchat服务中的"model_name"
|
"name": "chatglm-6b", # "name"修改为fastchat服务中的"model_name"
|
||||||
"pretrained_model_name": "chatglm-6b",
|
"pretrained_model_name": "chatglm-6b",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
|
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
|
||||||
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||||
"api_key": "EMPTY"
|
"api_key": "EMPTY"
|
||||||
},
|
},
|
||||||
@ -147,7 +153,7 @@ llm_model_dict = {
|
|||||||
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
|
"name": "chatglm2-6b", # "name"修改为fastchat服务中的"model_name"
|
||||||
"pretrained_model_name": "chatglm2-6b",
|
"pretrained_model_name": "chatglm2-6b",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
|
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
|
||||||
"api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url"
|
"api_base_url": "http://localhost:8000/v1" # "name"修改为fastchat服务中的"api_base_url"
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -156,7 +162,7 @@ llm_model_dict = {
|
|||||||
"name": "vicuna-13b-hf", # "name"修改为fastchat服务中的"model_name"
|
"name": "vicuna-13b-hf", # "name"修改为fastchat服务中的"model_name"
|
||||||
"pretrained_model_name": "vicuna-13b-hf",
|
"pretrained_model_name": "vicuna-13b-hf",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "FastChatOpenAILLM", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLM"
|
"provides": "FastChatOpenAILLMChain", # 使用fastchat api时,需保证"provides"为"FastChatOpenAILLMChain"
|
||||||
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
"api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||||
"api_key": "EMPTY"
|
"api_key": "EMPTY"
|
||||||
},
|
},
|
||||||
@ -171,7 +177,7 @@ llm_model_dict = {
|
|||||||
"openai-chatgpt-3.5": {
|
"openai-chatgpt-3.5": {
|
||||||
"name": "gpt-3.5-turbo",
|
"name": "gpt-3.5-turbo",
|
||||||
"pretrained_model_name": "gpt-3.5-turbo",
|
"pretrained_model_name": "gpt-3.5-turbo",
|
||||||
"provides": "FastChatOpenAILLM",
|
"provides": "FastChatOpenAILLMChain",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"api_base_url": "https://api.openapi.com/v1",
|
"api_base_url": "https://api.openapi.com/v1",
|
||||||
"api_key": ""
|
"api_key": ""
|
||||||
@ -226,7 +232,7 @@ LLM_HISTORY_LEN = 3
|
|||||||
VECTOR_SEARCH_TOP_K = 5
|
VECTOR_SEARCH_TOP_K = 5
|
||||||
|
|
||||||
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
|
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
|
||||||
VECTOR_SEARCH_SCORE_THRESHOLD = 0
|
VECTOR_SEARCH_SCORE_THRESHOLD = 390
|
||||||
|
|
||||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from .chatglm_llm import ChatGLM
|
from .chatglm_llm import ChatGLMLLMChain
|
||||||
from .llama_llm import LLamaLLM
|
from .llama_llm import LLamaLLMChain
|
||||||
from .moss_llm import MOSSLLM
|
from .fastchat_openai_llm import FastChatOpenAILLMChain
|
||||||
from .fastchat_openai_llm import FastChatOpenAILLM
|
from .moss_llm import MOSSLLMChain
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
from models.base.base import (
|
from models.base.base import (
|
||||||
AnswerResult,
|
AnswerResult,
|
||||||
BaseAnswer
|
BaseAnswer,
|
||||||
)
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
from models.base.remote_rpc_model import (
|
from models.base.remote_rpc_model import (
|
||||||
RemoteRpcModel
|
RemoteRpcModel
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnswerResult",
|
"AnswerResult",
|
||||||
"BaseAnswer",
|
"BaseAnswer",
|
||||||
"RemoteRpcModel",
|
"RemoteRpcModel",
|
||||||
|
"AnswerResultStream",
|
||||||
|
"AnswerResultQueueSentinelTokenListenerQueue"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,13 +1,26 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, List
|
from typing import Any, Dict, List, Optional, Generator
|
||||||
import traceback
|
import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from models.loader import LoaderCheckPoint
|
|
||||||
|
|
||||||
|
class ListenerToken:
|
||||||
|
"""
|
||||||
|
观测结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_ids: torch.LongTensor
|
||||||
|
_scores: torch.FloatTensor
|
||||||
|
|
||||||
|
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self._scores = _scores
|
||||||
|
|
||||||
|
|
||||||
class AnswerResult:
|
class AnswerResult:
|
||||||
@ -16,6 +29,123 @@ class AnswerResult:
|
|||||||
"""
|
"""
|
||||||
history: List[List[str]] = []
|
history: List[List[str]] = []
|
||||||
llm_output: Optional[dict] = None
|
llm_output: Optional[dict] = None
|
||||||
|
listenerToken: ListenerToken = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerResultStream:
|
||||||
|
def __init__(self, callback_func=None):
|
||||||
|
self.callback_func = callback_func
|
||||||
|
|
||||||
|
def __call__(self, answerResult: AnswerResult):
|
||||||
|
if self.callback_func is not None:
|
||||||
|
self.callback_func(answerResult)
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
|
||||||
|
"""
|
||||||
|
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
|
||||||
|
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
|
||||||
|
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
|
||||||
|
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
|
||||||
|
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
|
||||||
|
"""
|
||||||
|
|
||||||
|
listenerQueue: deque = deque(maxlen=1)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
transformers.StoppingCriteria.__init__(self)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
"""
|
||||||
|
每次响应时将数据添加到响应队列
|
||||||
|
:param input_ids:
|
||||||
|
:param _scores:
|
||||||
|
:param kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Iteratorize:
|
||||||
|
"""
|
||||||
|
Transforms a function that takes a callback
|
||||||
|
into a lazy iterator (generator).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func, kwargs={}):
|
||||||
|
self.mfunc = func
|
||||||
|
self.q = Queue()
|
||||||
|
self.sentinel = object()
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.stop_now = False
|
||||||
|
|
||||||
|
def _callback(val):
|
||||||
|
"""
|
||||||
|
模型输出预测结果收集
|
||||||
|
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
|
||||||
|
结束条件包含如下
|
||||||
|
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
|
||||||
|
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
|
||||||
|
3、模型预测出错
|
||||||
|
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
|
||||||
|
迭代器收集的行为如下
|
||||||
|
创建Iteratorize迭代对象,
|
||||||
|
定义generate_with_callback收集器AnswerResultStream
|
||||||
|
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
|
||||||
|
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
|
||||||
|
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
|
||||||
|
这时generate_with_callback会被阻塞
|
||||||
|
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
|
||||||
|
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
|
||||||
|
2、消息为self.sentinel标识,抛出StopIteration异常
|
||||||
|
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
|
||||||
|
异步线程检测stop_now属性被更新,抛出异常结束预测行为
|
||||||
|
迭代行为结束
|
||||||
|
:param val:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.stop_now:
|
||||||
|
raise ValueError
|
||||||
|
self.q.put(val)
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
try:
|
||||||
|
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.q.put(self.sentinel)
|
||||||
|
|
||||||
|
self.thread = Thread(target=gen)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
obj = self.q.get(True, None)
|
||||||
|
if obj is self.sentinel:
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
暂无实现
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
""" break 后会执行 """
|
||||||
|
self.stop_now = True
|
||||||
|
|
||||||
|
|
||||||
class BaseAnswer(ABC):
|
class BaseAnswer(ABC):
|
||||||
@ -25,17 +155,25 @@ class BaseAnswer(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
"""Return _check_point of llm."""
|
"""Return _check_point of llm."""
|
||||||
|
def generatorAnswer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,) -> Generator[Any, str, bool]:
|
||||||
|
def generate_with_callback(callback=None, **kwargs):
|
||||||
|
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
|
||||||
|
self._generate_answer(**kwargs)
|
||||||
|
|
||||||
@property
|
def generate_with_streaming(**kwargs):
|
||||||
@abstractmethod
|
return Iteratorize(generate_with_callback, kwargs)
|
||||||
def _history_len(self) -> int:
|
|
||||||
"""Return _history_len of llm."""
|
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
|
||||||
|
for answerResult in generator:
|
||||||
|
if answerResult.listenerToken:
|
||||||
|
output = answerResult.listenerToken.input_ids
|
||||||
|
yield answerResult
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_history_len(self, history_len: int) -> None:
|
def _generate_answer(self,
|
||||||
"""Return _history_len of llm."""
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
def generatorAnswer(self, prompt: str,
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -1,70 +1,102 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from langchain.llms.base import LLM
|
from langchain.chains.base import Chain
|
||||||
from typing import Optional, List
|
from typing import Any, Dict, List, Optional, Generator
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult)
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
class ChatGLM(BaseAnswer, LLM, ABC):
|
class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
|
||||||
max_token: int = 10000
|
max_token: int = 10000
|
||||||
temperature: float = 0.01
|
temperature: float = 0.01
|
||||||
top_p = 0.9
|
# 相关度
|
||||||
|
top_p = 0.4
|
||||||
|
# 候选词数量
|
||||||
|
top_k = 10
|
||||||
checkPoint: LoaderCheckPoint = None
|
checkPoint: LoaderCheckPoint = None
|
||||||
# history = []
|
# history = []
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
|
streaming_key: str = "streaming" #: :meta private:
|
||||||
|
history_key: str = "history" #: :meta private:
|
||||||
|
prompt_key: str = "prompt" #: :meta private:
|
||||||
|
output_key: str = "answer_result_stream" #: :meta private:
|
||||||
|
|
||||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkPoint = checkPoint
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "ChatGLM"
|
return "ChatGLMLLMChain"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
return self.checkPoint
|
return self.checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _history_len(self) -> int:
|
def input_keys(self) -> List[str]:
|
||||||
return self.history_len
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
def set_history_len(self, history_len: int = 10) -> None:
|
:meta private:
|
||||||
self.history_len = history_len
|
"""
|
||||||
|
return [self.prompt_key]
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Generator]:
|
||||||
|
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||||
|
return {self.output_key: generator}
|
||||||
|
|
||||||
|
def _generate_answer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
history = inputs[self.history_key]
|
||||||
|
streaming = inputs[self.streaming_key]
|
||||||
|
prompt = inputs[self.prompt_key]
|
||||||
print(f"__call:{prompt}")
|
print(f"__call:{prompt}")
|
||||||
response, _ = self.checkPoint.model.chat(
|
# Create the StoppingCriteriaList with the stopping strings
|
||||||
self.checkPoint.tokenizer,
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
prompt,
|
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||||
history=[],
|
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||||
max_length=self.max_token,
|
stopping_criteria_list.append(listenerQueue)
|
||||||
temperature=self.temperature
|
|
||||||
)
|
|
||||||
print(f"response:{response}")
|
|
||||||
print(f"+++++++++++++++++++++++++++++++++++")
|
|
||||||
return response
|
|
||||||
|
|
||||||
def generatorAnswer(self, prompt: str,
|
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
history += [[]]
|
history += [[]]
|
||||||
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
|
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
|
||||||
self.checkPoint.tokenizer,
|
self.checkPoint.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
history=history[-self.history_len:-1] if self.history_len > 1 else [],
|
history=history[-self.history_len:-1] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
stopping_criteria=stopping_criteria_list
|
||||||
)):
|
)):
|
||||||
# self.checkPoint.clear_torch_cache()
|
# self.checkPoint.clear_torch_cache()
|
||||||
history[-1] = [prompt, stream_resp]
|
history[-1] = [prompt, stream_resp]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": stream_resp}
|
answer_result.llm_output = {"answer": stream_resp}
|
||||||
yield answer_result
|
if listenerQueue.listenerQueue.__len__() > 0:
|
||||||
|
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||||
|
generate_with_callback(answer_result)
|
||||||
self.checkPoint.clear_torch_cache()
|
self.checkPoint.clear_torch_cache()
|
||||||
else:
|
else:
|
||||||
response, _ = self.checkPoint.model.chat(
|
response, _ = self.checkPoint.model.chat(
|
||||||
@ -72,13 +104,18 @@ class ChatGLM(BaseAnswer, LLM, ABC):
|
|||||||
prompt,
|
prompt,
|
||||||
history=history[-self.history_len:] if self.history_len > 0 else [],
|
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
stopping_criteria=stopping_criteria_list
|
||||||
)
|
)
|
||||||
self.checkPoint.clear_torch_cache()
|
self.checkPoint.clear_torch_cache()
|
||||||
history += [[prompt, response]]
|
history += [[prompt, response]]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": response}
|
answer_result.llm_output = {"answer": response}
|
||||||
yield answer_result
|
if listenerQueue.listenerQueue.__len__() > 0:
|
||||||
|
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||||
|
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
import requests
|
from langchain.chains.base import Chain
|
||||||
from typing import Optional, List
|
from typing import Any, Dict, List, Optional, Generator, Collection
|
||||||
from langchain.llms.base import LLM
|
|
||||||
|
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (RemoteRpcModel,
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
AnswerResult)
|
from models.base import (BaseAnswer,
|
||||||
from typing import (
|
RemoteRpcModel,
|
||||||
Collection,
|
AnswerResult,
|
||||||
Dict
|
AnswerResultStream,
|
||||||
)
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
def _build_message_template() -> Dict[str, str]:
|
def _build_message_template() -> Dict[str, str]:
|
||||||
@ -22,41 +22,74 @@ def _build_message_template() -> Dict[str, str]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
# 将历史对话数组转换为文本格式
|
||||||
|
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
|
||||||
|
build_messages: Collection[Dict[str, str]] = []
|
||||||
|
for i, (old_query, response) in enumerate(history):
|
||||||
|
user_build_message = _build_message_template()
|
||||||
|
user_build_message['role'] = 'user'
|
||||||
|
user_build_message['content'] = old_query
|
||||||
|
system_build_message = _build_message_template()
|
||||||
|
system_build_message['role'] = 'system'
|
||||||
|
system_build_message['content'] = response
|
||||||
|
build_messages.append(user_build_message)
|
||||||
|
build_messages.append(system_build_message)
|
||||||
|
|
||||||
|
user_build_message = _build_message_template()
|
||||||
|
user_build_message['role'] = 'user'
|
||||||
|
user_build_message['content'] = query
|
||||||
|
build_messages.append(user_build_message)
|
||||||
|
return build_messages
|
||||||
|
|
||||||
|
|
||||||
|
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
api_base_url: str = "http://localhost:8000/v1"
|
api_base_url: str = "http://localhost:8000/v1"
|
||||||
model_name: str = "chatglm-6b"
|
model_name: str = "chatglm-6b"
|
||||||
max_token: int = 10000
|
max_token: int = 10000
|
||||||
temperature: float = 0.01
|
temperature: float = 0.01
|
||||||
top_p = 0.9
|
top_p = 0.9
|
||||||
checkPoint: LoaderCheckPoint = None
|
checkPoint: LoaderCheckPoint = None
|
||||||
history = []
|
# history = []
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
api_key: str = ""
|
api_key: str = ""
|
||||||
|
|
||||||
|
streaming_key: str = "streaming" #: :meta private:
|
||||||
|
history_key: str = "history" #: :meta private:
|
||||||
|
prompt_key: str = "prompt" #: :meta private:
|
||||||
|
output_key: str = "answer_result_stream" #: :meta private:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
checkPoint: LoaderCheckPoint = None,
|
checkPoint: LoaderCheckPoint = None,
|
||||||
# api_base_url:str="http://localhost:8000/v1",
|
# api_base_url:str="http://localhost:8000/v1",
|
||||||
# model_name:str="chatglm-6b",
|
# model_name:str="chatglm-6b",
|
||||||
# api_key:str=""
|
# api_key:str=""
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkPoint = checkPoint
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "FastChat"
|
return "LLamaLLMChain"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
return self.checkPoint
|
return self.checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _history_len(self) -> int:
|
def input_keys(self) -> List[str]:
|
||||||
return self.history_len
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
def set_history_len(self, history_len: int = 10) -> None:
|
:meta private:
|
||||||
self.history_len = history_len
|
"""
|
||||||
|
return [self.prompt_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _api_key(self) -> str:
|
def _api_key(self) -> str:
|
||||||
@ -75,53 +108,25 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
|||||||
def call_model_name(self, model_name):
|
def call_model_name(self, model_name):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Generator]:
|
||||||
|
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||||
|
return {self.output_key: generator}
|
||||||
|
|
||||||
|
def _generate_answer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
|
history = inputs[self.history_key]
|
||||||
|
streaming = inputs[self.streaming_key]
|
||||||
|
prompt = inputs[self.prompt_key]
|
||||||
print(f"__call:{prompt}")
|
print(f"__call:{prompt}")
|
||||||
try:
|
try:
|
||||||
import openai
|
|
||||||
# Not support yet
|
|
||||||
# openai.api_key = "EMPTY"
|
|
||||||
openai.key = self.api_key
|
|
||||||
openai.api_base = self.api_base_url
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import openai python package. "
|
|
||||||
"Please install it with `pip install openai`."
|
|
||||||
)
|
|
||||||
# create a chat completion
|
|
||||||
completion = openai.ChatCompletion.create(
|
|
||||||
model=self.model_name,
|
|
||||||
messages=self.build_message_list(prompt)
|
|
||||||
)
|
|
||||||
print(f"response:{completion.choices[0].message.content}")
|
|
||||||
print(f"+++++++++++++++++++++++++++++++++++")
|
|
||||||
return completion.choices[0].message.content
|
|
||||||
|
|
||||||
# 将历史对话数组转换为文本格式
|
|
||||||
def build_message_list(self, query) -> Collection[Dict[str, str]]:
|
|
||||||
build_message_list: Collection[Dict[str, str]] = []
|
|
||||||
history = self.history[-self.history_len:] if self.history_len > 0 else []
|
|
||||||
for i, (old_query, response) in enumerate(history):
|
|
||||||
user_build_message = _build_message_template()
|
|
||||||
user_build_message['role'] = 'user'
|
|
||||||
user_build_message['content'] = old_query
|
|
||||||
system_build_message = _build_message_template()
|
|
||||||
system_build_message['role'] = 'system'
|
|
||||||
system_build_message['content'] = response
|
|
||||||
build_message_list.append(user_build_message)
|
|
||||||
build_message_list.append(system_build_message)
|
|
||||||
|
|
||||||
user_build_message = _build_message_template()
|
|
||||||
user_build_message['role'] = 'user'
|
|
||||||
user_build_message['content'] = query
|
|
||||||
build_message_list.append(user_build_message)
|
|
||||||
return build_message_list
|
|
||||||
|
|
||||||
def generatorAnswer(self, prompt: str,
|
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
|
|
||||||
try:
|
|
||||||
import openai
|
import openai
|
||||||
# Not support yet
|
# Not support yet
|
||||||
# openai.api_key = "EMPTY"
|
# openai.api_key = "EMPTY"
|
||||||
@ -135,12 +140,13 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
|||||||
# create a chat completion
|
# create a chat completion
|
||||||
completion = openai.ChatCompletion.create(
|
completion = openai.ChatCompletion.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self.build_message_list(prompt)
|
messages=build_message_list(prompt)
|
||||||
)
|
)
|
||||||
|
print(f"response:{completion.choices[0].message.content}")
|
||||||
|
print(f"+++++++++++++++++++++++++++++++++++")
|
||||||
|
|
||||||
history += [[prompt, completion.choices[0].message.content]]
|
history += [[prompt, completion.choices[0].message.content]]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": completion.choices[0].message.content}
|
answer_result.llm_output = {"answer": completion.choices[0].message.content}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
yield answer_result
|
|
||||||
|
|||||||
@ -1,29 +1,32 @@
|
|||||||
from abc import ABC
|
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from abc import ABC
|
||||||
import random
|
from langchain.chains.base import Chain
|
||||||
import torch
|
from typing import Any, Dict, List, Optional, Generator, Union
|
||||||
import transformers
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
from typing import Optional, List, Dict, Any,Union
|
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult)
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
def __call__(self, input_ids: Union[torch.LongTensor,list], scores: Union[torch.FloatTensor,list]) -> torch.FloatTensor:
|
def __call__(self, input_ids: Union[torch.LongTensor, list],
|
||||||
|
scores: Union[torch.FloatTensor, list]) -> torch.FloatTensor:
|
||||||
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
|
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
|
||||||
input_ids = torch.tensor(input_ids) if isinstance(input_ids,list) else input_ids
|
input_ids = torch.tensor(input_ids) if isinstance(input_ids, list) else input_ids
|
||||||
scores = torch.tensor(scores) if isinstance(scores,list) else scores
|
scores = torch.tensor(scores) if isinstance(scores, list) else scores
|
||||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||||
scores.zero_()
|
scores.zero_()
|
||||||
scores[..., 5] = 5e4
|
scores[..., 5] = 5e4
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class LLamaLLM(BaseAnswer, LLM, ABC):
|
class LLamaLLMChain(BaseAnswer, Chain, ABC):
|
||||||
checkPoint: LoaderCheckPoint = None
|
checkPoint: LoaderCheckPoint = None
|
||||||
# history = []
|
# history = []
|
||||||
history_len: int = 3
|
history_len: int = 3
|
||||||
@ -37,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||||||
min_length: int = 0
|
min_length: int = 0
|
||||||
logits_processor: LogitsProcessorList = None
|
logits_processor: LogitsProcessorList = None
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None
|
stopping_criteria: Optional[StoppingCriteriaList] = None
|
||||||
eos_token_id: Optional[int] = [2]
|
streaming_key: str = "streaming" #: :meta private:
|
||||||
|
history_key: str = "history" #: :meta private:
|
||||||
state: object = {'max_new_tokens': 50,
|
prompt_key: str = "prompt" #: :meta private:
|
||||||
'seed': 1,
|
output_key: str = "answer_result_stream" #: :meta private:
|
||||||
'temperature': 0, 'top_p': 0.1,
|
|
||||||
'top_k': 40, 'typical_p': 1,
|
|
||||||
'repetition_penalty': 1.2,
|
|
||||||
'encoder_repetition_penalty': 1,
|
|
||||||
'no_repeat_ngram_size': 0,
|
|
||||||
'min_length': 0,
|
|
||||||
'penalty_alpha': 0,
|
|
||||||
'num_beams': 1,
|
|
||||||
'length_penalty': 1,
|
|
||||||
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
|
|
||||||
'truncation_length': 2048, 'custom_stopping_strings': '',
|
|
||||||
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
|
|
||||||
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
|
|
||||||
'pre_layer': 0, 'gpu_memory_0': 0}
|
|
||||||
|
|
||||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkPoint = checkPoint
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "LLamaLLM"
|
return "LLamaLLMChain"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.prompt_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
@ -107,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||||||
formatted_history += "### Human:{}\n### Assistant:".format(query)
|
formatted_history += "### Human:{}\n### Assistant:".format(query)
|
||||||
return formatted_history
|
return formatted_history
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self,
|
def _call(
|
||||||
input_ids: torch.LongTensor):
|
self,
|
||||||
"""
|
inputs: Dict[str, Any],
|
||||||
预生成注意力掩码和 输入序列中每个位置的索引的张量
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
# TODO 没有思路
|
) -> Dict[str, Generator]:
|
||||||
:return:
|
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||||
"""
|
return {self.output_key: generator}
|
||||||
|
|
||||||
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
|
def _generate_answer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
attention_mask = self.get_masks(input_ids, input_ids.device)
|
history = inputs[self.history_key]
|
||||||
|
streaming = inputs[self.streaming_key]
|
||||||
position_ids = self.get_position_ids(
|
prompt = inputs[self.prompt_key]
|
||||||
input_ids,
|
|
||||||
device=input_ids.device,
|
|
||||||
mask_positions=mask_positions
|
|
||||||
)
|
|
||||||
|
|
||||||
return input_ids, position_ids, attention_mask
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _history_len(self) -> int:
|
|
||||||
return self.history_len
|
|
||||||
|
|
||||||
def set_history_len(self, history_len: int = 10) -> None:
|
|
||||||
self.history_len = history_len
|
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
||||||
print(f"__call:{prompt}")
|
print(f"__call:{prompt}")
|
||||||
|
|
||||||
|
# Create the StoppingCriteriaList with the stopping strings
|
||||||
|
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||||
|
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||||
|
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||||
|
self.stopping_criteria.append(listenerQueue)
|
||||||
|
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
|
||||||
|
soft_prompt = self.history_to_text(query=prompt, history=history)
|
||||||
if self.logits_processor is None:
|
if self.logits_processor is None:
|
||||||
self.logits_processor = LogitsProcessorList()
|
self.logits_processor = LogitsProcessorList()
|
||||||
self.logits_processor.append(InvalidScoreLogitsProcessor())
|
self.logits_processor.append(InvalidScoreLogitsProcessor())
|
||||||
@ -154,16 +155,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||||||
"logits_processor": self.logits_processor}
|
"logits_processor": self.logits_processor}
|
||||||
|
|
||||||
# 向量转换
|
# 向量转换
|
||||||
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
|
input_ids = self.encode(soft_prompt, add_bos_token=self.checkPoint.tokenizer.add_bos_token,
|
||||||
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
|
truncation_length=self.max_new_tokens)
|
||||||
|
|
||||||
|
|
||||||
gen_kwargs.update({'inputs': input_ids})
|
gen_kwargs.update({'inputs': input_ids})
|
||||||
# 注意力掩码
|
|
||||||
# gen_kwargs.update({'attention_mask': attention_mask})
|
|
||||||
# gen_kwargs.update({'position_ids': position_ids})
|
|
||||||
if self.stopping_criteria is None:
|
|
||||||
self.stopping_criteria = transformers.StoppingCriteriaList()
|
|
||||||
# 观测输出
|
# 观测输出
|
||||||
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
||||||
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
|
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
|
||||||
@ -173,11 +168,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||||||
if "llama_cpp" in self.checkPoint.model.__str__():
|
if "llama_cpp" in self.checkPoint.model.__str__():
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args)&set(gen_kwargs.keys())
|
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args) & set(
|
||||||
common_kwargs = {key:gen_kwargs[key] for key in common_kwargs_keys}
|
gen_kwargs.keys())
|
||||||
#? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
|
common_kwargs = {key: gen_kwargs[key] for key in common_kwargs_keys}
|
||||||
#?为什么会不支持GPU呢,不应该啊?
|
# ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
|
||||||
output_ids = torch.tensor([list(self.checkPoint.model.generate(input_id_i.cpu(),**common_kwargs)) for input_id_i in input_ids])
|
# ?为什么会不支持GPU呢,不应该啊?
|
||||||
|
output_ids = torch.tensor(
|
||||||
|
[list(self.checkPoint.model.generate(input_id_i.cpu(), **common_kwargs)) for input_id_i in input_ids])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
output_ids = self.checkPoint.model.generate(**gen_kwargs)
|
output_ids = self.checkPoint.model.generate(**gen_kwargs)
|
||||||
@ -185,17 +182,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||||||
reply = self.decode(output_ids[0][-new_tokens:])
|
reply = self.decode(output_ids[0][-new_tokens:])
|
||||||
print(f"response:{reply}")
|
print(f"response:{reply}")
|
||||||
print(f"+++++++++++++++++++++++++++++++++++")
|
print(f"+++++++++++++++++++++++++++++++++++")
|
||||||
return reply
|
|
||||||
|
|
||||||
def generatorAnswer(self, prompt: str,
|
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
|
|
||||||
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
|
|
||||||
softprompt = self.history_to_text(prompt,history=history)
|
|
||||||
response = self._call(prompt=softprompt, stop=['\n###'])
|
|
||||||
|
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history + [[prompt, response]]
|
history += [[prompt, reply]]
|
||||||
answer_result.llm_output = {"answer": response}
|
answer_result.history = history
|
||||||
yield answer_result
|
if listenerQueue.listenerQueue.__len__() > 0:
|
||||||
|
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||||
|
answer_result.llm_output = {"answer": reply}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|||||||
@ -20,6 +20,7 @@ class LoaderCheckPoint:
|
|||||||
no_remote_model: bool = False
|
no_remote_model: bool = False
|
||||||
# 模型名称
|
# 模型名称
|
||||||
model_name: str = None
|
model_name: str = None
|
||||||
|
pretrained_model_name: str = None
|
||||||
tokenizer: object = None
|
tokenizer: object = None
|
||||||
# 模型全路径
|
# 模型全路径
|
||||||
model_path: str = None
|
model_path: str = None
|
||||||
@ -67,48 +68,49 @@ class LoaderCheckPoint:
|
|||||||
self.load_in_8bit = params.get('load_in_8bit', False)
|
self.load_in_8bit = params.get('load_in_8bit', False)
|
||||||
self.bf16 = params.get('bf16', False)
|
self.bf16 = params.get('bf16', False)
|
||||||
|
|
||||||
|
def _load_model_config(self):
|
||||||
def _load_model_config(self, model_name):
|
|
||||||
|
|
||||||
if self.model_path:
|
if self.model_path:
|
||||||
self.model_path = re.sub("\s","",self.model_path)
|
self.model_path = re.sub("\s", "", self.model_path)
|
||||||
checkpoint = Path(f'{self.model_path}')
|
checkpoint = Path(f'{self.model_path}')
|
||||||
else:
|
else:
|
||||||
if not self.no_remote_model:
|
if self.no_remote_model:
|
||||||
checkpoint = model_name
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"本地模型local_model_path未配置路径"
|
"本地模型local_model_path未配置路径"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
checkpoint = self.pretrained_model_name
|
||||||
|
|
||||||
|
print(f"load_model_config {checkpoint}...")
|
||||||
try:
|
try:
|
||||||
|
|
||||||
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
||||||
return model_config
|
return model_config
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
def _load_model(self, model_name):
|
def _load_model(self):
|
||||||
"""
|
"""
|
||||||
加载自定义位置的model
|
加载自定义位置的model
|
||||||
:param model_name:
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
print(f"Loading {model_name}...")
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
if self.model_path:
|
if self.model_path:
|
||||||
self.model_path = re.sub("\s","",self.model_path)
|
self.model_path = re.sub("\s", "", self.model_path)
|
||||||
checkpoint = Path(f'{self.model_path}')
|
checkpoint = Path(f'{self.model_path}')
|
||||||
else:
|
else:
|
||||||
if not self.no_remote_model:
|
if self.no_remote_model:
|
||||||
checkpoint = model_name
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"本地模型local_model_path未配置路径"
|
"本地模型local_model_path未配置路径"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
checkpoint = self.pretrained_model_name
|
||||||
|
|
||||||
|
print(f"Loading {checkpoint}...")
|
||||||
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
|
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
|
||||||
if 'chatglm' in model_name.lower() or "chatyuan" in model_name.lower():
|
if 'chatglm' in self.model_name.lower() or "chatyuan" in self.model_name.lower():
|
||||||
LoaderClass = AutoModel
|
LoaderClass = AutoModel
|
||||||
else:
|
else:
|
||||||
LoaderClass = AutoModelForCausalLM
|
LoaderClass = AutoModelForCausalLM
|
||||||
@ -134,11 +136,11 @@ class LoaderCheckPoint:
|
|||||||
# 支持自定义cuda设备
|
# 支持自定义cuda设备
|
||||||
elif ":" in self.llm_device:
|
elif ":" in self.llm_device:
|
||||||
model = LoaderClass.from_pretrained(checkpoint,
|
model = LoaderClass.from_pretrained(checkpoint,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||||
trust_remote_code=True).half().to(self.llm_device)
|
trust_remote_code=True).half().to(self.llm_device)
|
||||||
else:
|
else:
|
||||||
from accelerate import dispatch_model,infer_auto_device_map
|
from accelerate import dispatch_model, infer_auto_device_map
|
||||||
|
|
||||||
model = LoaderClass.from_pretrained(checkpoint,
|
model = LoaderClass.from_pretrained(checkpoint,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@ -146,29 +148,29 @@ class LoaderCheckPoint:
|
|||||||
trust_remote_code=True).half()
|
trust_remote_code=True).half()
|
||||||
# 可传入device_map自定义每张卡的部署情况
|
# 可传入device_map自定义每张卡的部署情况
|
||||||
if self.device_map is None:
|
if self.device_map is None:
|
||||||
if 'chatglm' in model_name.lower():
|
if 'chatglm' in self.model_name.lower():
|
||||||
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||||
elif 'moss' in model_name.lower():
|
elif 'moss' in self.model_name.lower():
|
||||||
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
|
self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
|
||||||
else:
|
else:
|
||||||
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
|
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
|
||||||
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
|
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
|
||||||
from accelerate.utils import get_balanced_memory
|
from accelerate.utils import get_balanced_memory
|
||||||
max_memory = get_balanced_memory(model,
|
max_memory = get_balanced_memory(model,
|
||||||
dtype=torch.int8 if self.load_in_8bit else None,
|
dtype=torch.int8 if self.load_in_8bit else None,
|
||||||
low_zero=False,
|
low_zero=False,
|
||||||
no_split_module_classes=model._no_split_modules)
|
no_split_module_classes=model._no_split_modules)
|
||||||
self.device_map = infer_auto_device_map(model,
|
self.device_map = infer_auto_device_map(model,
|
||||||
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
|
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
no_split_module_classes=model._no_split_modules)
|
no_split_module_classes=model._no_split_modules)
|
||||||
# 对于chaglm和moss意外的模型应使用自动指定,而非调用chatglm的配置方式
|
# 对于chaglm和moss意外的模型应使用自动指定,而非调用chatglm的配置方式
|
||||||
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
|
# 其他模型定义的层类几乎不可能与chatglm和moss一致,使用chatglm_auto_configure_device_map
|
||||||
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
|
# 百分百会报错,使用infer_auto_device_map虽然可能导致负载不均衡,但至少不会报错
|
||||||
# 实测在bloom模型上如此
|
# 实测在bloom模型上如此
|
||||||
# self.device_map = infer_auto_device_map(model,
|
# self.device_map = infer_auto_device_map(model,
|
||||||
# dtype=torch.int8,
|
# dtype=torch.int8,
|
||||||
# no_split_module_classes=model._no_split_modules)
|
# no_split_module_classes=model._no_split_modules)
|
||||||
|
|
||||||
model = dispatch_model(model, device_map=self.device_map)
|
model = dispatch_model(model, device_map=self.device_map)
|
||||||
else:
|
else:
|
||||||
@ -202,7 +204,7 @@ class LoaderCheckPoint:
|
|||||||
|
|
||||||
# tokenizer = model.tokenizer
|
# tokenizer = model.tokenizer
|
||||||
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
|
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
|
||||||
#* -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
|
# * -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
@ -231,7 +233,7 @@ class LoaderCheckPoint:
|
|||||||
llm_int8_enable_fp32_cpu_offload=False)
|
llm_int8_enable_fp32_cpu_offload=False)
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = LoaderClass.from_config(self.model_config,trust_remote_code = True)
|
model = LoaderClass.from_config(self.model_config, trust_remote_code=True)
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
if self.device_map is not None:
|
if self.device_map is not None:
|
||||||
params['device_map'] = self.device_map
|
params['device_map'] = self.device_map
|
||||||
@ -307,8 +309,8 @@ class LoaderCheckPoint:
|
|||||||
encode = ".encoder"
|
encode = ".encoder"
|
||||||
else:
|
else:
|
||||||
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
||||||
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
||||||
f'base_model.model.lm_head': 0, }
|
f'base_model.model.lm_head': 0, }
|
||||||
used = 2
|
used = 2
|
||||||
gpu_target = 0
|
gpu_target = 0
|
||||||
for i in range(num_trans_layers):
|
for i in range(num_trans_layers):
|
||||||
@ -321,7 +323,7 @@ class LoaderCheckPoint:
|
|||||||
|
|
||||||
return device_map
|
return device_map
|
||||||
|
|
||||||
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]:
|
def moss_auto_configure_device_map(self, num_gpus: int, checkpoint) -> Dict[str, int]:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
@ -336,16 +338,6 @@ class LoaderCheckPoint:
|
|||||||
"`pip install bitsandbytes``pip install accelerate`."
|
"`pip install bitsandbytes``pip install accelerate`."
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
if self.model_path:
|
|
||||||
checkpoint = Path(f'{self.model_path}')
|
|
||||||
else:
|
|
||||||
if not self.no_remote_model:
|
|
||||||
checkpoint = model_name
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"本地模型local_model_path未配置路径"
|
|
||||||
)
|
|
||||||
|
|
||||||
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
||||||
pretrained_model_name_or_path=checkpoint)
|
pretrained_model_name_or_path=checkpoint)
|
||||||
|
|
||||||
@ -452,7 +444,7 @@ class LoaderCheckPoint:
|
|||||||
|
|
||||||
def reload_model(self):
|
def reload_model(self):
|
||||||
self.unload_model()
|
self.unload_model()
|
||||||
self.model_config = self._load_model_config(self.model_name)
|
self.model_config = self._load_model_config()
|
||||||
|
|
||||||
if self.use_ptuning_v2:
|
if self.use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
@ -464,7 +456,7 @@ class LoaderCheckPoint:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("加载PrefixEncoder config.json失败")
|
print("加载PrefixEncoder config.json失败")
|
||||||
|
|
||||||
self.model, self.tokenizer = self._load_model(self.model_name)
|
self.model, self.tokenizer = self._load_model()
|
||||||
|
|
||||||
if self.lora:
|
if self.lora:
|
||||||
self._add_lora_to_model([self.lora])
|
self._add_lora_to_model([self.lora])
|
||||||
|
|||||||
@ -1,11 +1,19 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from langchain.llms.base import LLM
|
from langchain.chains.base import Chain
|
||||||
from typing import Optional, List
|
from typing import Any, Dict, List, Optional, Generator, Union
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult)
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
|
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
|
||||||
META_INSTRUCTION = \
|
META_INSTRUCTION = \
|
||||||
"""You are an AI assistant whose name is MOSS.
|
"""You are an AI assistant whose name is MOSS.
|
||||||
@ -20,41 +28,65 @@ META_INSTRUCTION = \
|
|||||||
Capabilities and tools that MOSS can possess.
|
Capabilities and tools that MOSS can possess.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
|
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
|
||||||
class MOSSLLM(BaseAnswer, LLM, ABC):
|
class MOSSLLMChain(BaseAnswer, Chain, ABC):
|
||||||
max_token: int = 2048
|
max_token: int = 2048
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
top_p = 0.8
|
top_p = 0.8
|
||||||
# history = []
|
# history = []
|
||||||
checkPoint: LoaderCheckPoint = None
|
checkPoint: LoaderCheckPoint = None
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
|
streaming_key: str = "streaming" #: :meta private:
|
||||||
|
history_key: str = "history" #: :meta private:
|
||||||
|
prompt_key: str = "prompt" #: :meta private:
|
||||||
|
output_key: str = "answer_result_stream" #: :meta private:
|
||||||
|
|
||||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkPoint = checkPoint
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "MOSS"
|
return "MOSSLLMChain"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.prompt_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
return self.checkPoint
|
return self.checkPoint
|
||||||
|
|
||||||
@property
|
def _call(
|
||||||
def _history_len(self) -> int:
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Generator]:
|
||||||
|
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||||
|
return {self.output_key: generator}
|
||||||
|
|
||||||
return self.history_len
|
def _generate_answer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
def set_history_len(self, history_len: int) -> None:
|
history = inputs[self.history_key]
|
||||||
self.history_len = history_len
|
streaming = inputs[self.streaming_key]
|
||||||
|
prompt = inputs[self.prompt_key]
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
print(f"__call:{prompt}")
|
||||||
pass
|
|
||||||
|
|
||||||
def generatorAnswer(self, prompt: str,
|
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
if len(history) > 0:
|
if len(history) > 0:
|
||||||
history = history[-self.history_len:] if self.history_len > 0 else []
|
history = history[-self.history_len:] if self.history_len > 0 else []
|
||||||
prompt_w_history = str(history)
|
prompt_w_history = str(history)
|
||||||
@ -79,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
|
|||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
eos_token_id=106068,
|
eos_token_id=106068,
|
||||||
pad_token_id=self.checkPoint.tokenizer.pad_token_id)
|
pad_token_id=self.checkPoint.tokenizer.pad_token_id)
|
||||||
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:],
|
||||||
|
skip_special_tokens=True)
|
||||||
self.checkPoint.clear_torch_cache()
|
self.checkPoint.clear_torch_cache()
|
||||||
history += [[prompt, response]]
|
history += [[prompt, response]]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": response}
|
answer_result.llm_output = {"answer": response}
|
||||||
|
|
||||||
yield answer_result
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
|
|||||||
if use_ptuning_v2:
|
if use_ptuning_v2:
|
||||||
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
|
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
|
||||||
|
|
||||||
|
# 如果指定了参数,则使用参数的配置
|
||||||
if llm_model:
|
if llm_model:
|
||||||
llm_model_info = llm_model_dict[llm_model]
|
llm_model_info = llm_model_dict[llm_model]
|
||||||
|
|
||||||
if loaderCheckPoint.no_remote_model:
|
loaderCheckPoint.model_name = llm_model_info['name']
|
||||||
loaderCheckPoint.model_name = llm_model_info['name']
|
loaderCheckPoint.pretrained_model_name = llm_model_info['pretrained_model_name']
|
||||||
else:
|
|
||||||
loaderCheckPoint.model_name = llm_model_info['pretrained_model_name']
|
|
||||||
|
|
||||||
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
|
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
|
||||||
|
|
||||||
|
|||||||
@ -1,39 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
|
|
||||||
import asyncio
|
|
||||||
from argparse import Namespace
|
|
||||||
from models.loader.args import parser
|
|
||||||
from models.loader import LoaderCheckPoint
|
|
||||||
|
|
||||||
|
|
||||||
import models.shared as shared
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch(args: Namespace):
|
|
||||||
args_dict = vars(args)
|
|
||||||
|
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
|
||||||
|
|
||||||
llm_model_ins = shared.loaderLLM()
|
|
||||||
|
|
||||||
history = [
|
|
||||||
("which city is this?", "tokyo"),
|
|
||||||
("why?", "she's japanese"),
|
|
||||||
|
|
||||||
]
|
|
||||||
for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history,
|
|
||||||
streaming=False):
|
|
||||||
resp = answer_result.llm_output["answer"]
|
|
||||||
|
|
||||||
print(resp)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = None
|
|
||||||
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'fastchat-chatglm-6b', '--no-remote-model'])
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_until_complete(dispatch(args))
|
|
||||||
47
webui.py
47
webui.py
@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||||||
yield history + [[query,
|
yield history + [[query,
|
||||||
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
||||||
else:
|
else:
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
|
||||||
streaming=streaming):
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
|
{"prompt": query, "history": history, "streaming": streaming})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
history[-1][-1] = resp
|
history[-1][-1] = resp
|
||||||
@ -101,11 +104,12 @@ def init_model():
|
|||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
llm_model_ins = shared.loaderLLM()
|
llm_model_ins = shared.loaderLLM()
|
||||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||||
generator = local_doc_qa.llm.generatorAnswer("你好")
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
for answer_result in generator:
|
{"prompt": "你好", "history": [], "streaming": False})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
print(answer_result.llm_output)
|
print(answer_result.llm_output)
|
||||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
logger.info(reply)
|
logger.info(reply)
|
||||||
@ -141,7 +145,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
|
|||||||
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||||
filelist = []
|
filelist = []
|
||||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
if local_doc_qa.llm_model_chain and local_doc_qa.embeddings:
|
||||||
if isinstance(files, list):
|
if isinstance(files, list):
|
||||||
for file in files:
|
for file in files:
|
||||||
filename = os.path.split(file.name)[-1]
|
filename = os.path.split(file.name)[-1]
|
||||||
@ -165,8 +169,8 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
|
|||||||
|
|
||||||
def change_vs_name_input(vs_id, history):
|
def change_vs_name_input(vs_id, history):
|
||||||
if vs_id == "新建知识库":
|
if vs_id == "新建知识库":
|
||||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history,\
|
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history, \
|
||||||
gr.update(choices=[]), gr.update(visible=False)
|
gr.update(choices=[]), gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||||
if "index.faiss" in os.listdir(vs_path):
|
if "index.faiss" in os.listdir(vs_path):
|
||||||
@ -218,7 +222,7 @@ def change_chunk_conent(mode, label_conent, history):
|
|||||||
|
|
||||||
|
|
||||||
def add_vs_name(vs_name, chatbot):
|
def add_vs_name(vs_name, chatbot):
|
||||||
if vs_name is None or vs_name.strip() == "" :
|
if vs_name is None or vs_name.strip() == "":
|
||||||
vs_status = "知识库名称不能为空,请重新填写知识库名称"
|
vs_status = "知识库名称不能为空,请重新填写知识库名称"
|
||||||
chatbot = chatbot + [[None, vs_status]]
|
chatbot = chatbot + [[None, vs_status]]
|
||||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
||||||
@ -262,6 +266,7 @@ def reinit_vector_store(vs_id, history):
|
|||||||
def refresh_vs_list():
|
def refresh_vs_list():
|
||||||
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
|
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
|
||||||
|
|
||||||
|
|
||||||
def delete_file(vs_id, files_to_delete, chatbot):
|
def delete_file(vs_id, files_to_delete, chatbot):
|
||||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||||
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
|
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
|
||||||
@ -275,11 +280,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
|
|||||||
rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
|
rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
|
||||||
if "fail" in status:
|
if "fail" in status:
|
||||||
vs_status = "文件删除失败。"
|
vs_status = "文件删除失败。"
|
||||||
elif len(rested_files)>0:
|
elif len(rested_files) > 0:
|
||||||
vs_status = "文件删除成功。"
|
vs_status = "文件删除成功。"
|
||||||
else:
|
else:
|
||||||
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
|
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
|
||||||
logger.info(",".join(files_to_delete)+vs_status)
|
logger.info(",".join(files_to_delete) + vs_status)
|
||||||
chatbot = chatbot + [[None, vs_status]]
|
chatbot = chatbot + [[None, vs_status]]
|
||||||
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot
|
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot
|
||||||
|
|
||||||
@ -290,7 +295,8 @@ def delete_vs(vs_id, chatbot):
|
|||||||
status = f"成功删除知识库{vs_id}"
|
status = f"成功删除知识库{vs_id}"
|
||||||
logger.info(status)
|
logger.info(status)
|
||||||
chatbot = chatbot + [[None, status]]
|
chatbot = chatbot + [[None, status]]
|
||||||
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \
|
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(
|
||||||
|
visible=True), \
|
||||||
gr.update(visible=False), chatbot, gr.update(visible=False)
|
gr.update(visible=False), chatbot, gr.update(visible=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
@ -333,7 +339,8 @@ default_theme_args = dict(
|
|||||||
|
|
||||||
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
|
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
|
||||||
vs_path, file_status, model_status = gr.State(
|
vs_path, file_status, model_status = gr.State(
|
||||||
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
|
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(
|
||||||
|
""), gr.State(
|
||||||
model_status)
|
model_status)
|
||||||
gr.Markdown(webui_title)
|
gr.Markdown(webui_title)
|
||||||
with gr.Tab("对话"):
|
with gr.Tab("对话"):
|
||||||
@ -386,8 +393,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||||||
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
||||||
with gr.Tab("删除文件"):
|
with gr.Tab("删除文件"):
|
||||||
files_to_delete = gr.CheckboxGroup(choices=[],
|
files_to_delete = gr.CheckboxGroup(choices=[],
|
||||||
label="请从知识库已有文件中选择要删除的文件",
|
label="请从知识库已有文件中选择要删除的文件",
|
||||||
interactive=True)
|
interactive=True)
|
||||||
delete_file_button = gr.Button("从知识库中删除选中文件")
|
delete_file_button = gr.Button("从知识库中删除选中文件")
|
||||||
vs_refresh.click(fn=refresh_vs_list,
|
vs_refresh.click(fn=refresh_vs_list,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
@ -455,9 +462,9 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||||||
with vs_setting:
|
with vs_setting:
|
||||||
vs_refresh = gr.Button("更新已有知识库选项")
|
vs_refresh = gr.Button("更新已有知识库选项")
|
||||||
select_vs_test = gr.Dropdown(get_vs_list(),
|
select_vs_test = gr.Dropdown(get_vs_list(),
|
||||||
label="请选择要加载的知识库",
|
label="请选择要加载的知识库",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
|
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
|
||||||
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
||||||
lines=1,
|
lines=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
@ -497,8 +504,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||||||
inputs=[vs_name, chatbot],
|
inputs=[vs_name, chatbot],
|
||||||
outputs=[select_vs_test, vs_name, vs_add, file2vs, chatbot])
|
outputs=[select_vs_test, vs_name, vs_add, file2vs, chatbot])
|
||||||
select_vs_test.change(fn=change_vs_name_input,
|
select_vs_test.change(fn=change_vs_name_input,
|
||||||
inputs=[select_vs_test, chatbot],
|
inputs=[select_vs_test, chatbot],
|
||||||
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
||||||
load_file_button.click(get_vector_store,
|
load_file_button.click(get_vector_store,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
inputs=[select_vs_test, files, sentence_size, chatbot, vs_add, vs_add],
|
inputs=[select_vs_test, files, sentence_size, chatbot, vs_add, vs_add],
|
||||||
|
|||||||
14
webui_st.py
14
webui_st.py
@ -85,9 +85,10 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||||||
yield history + [[query,
|
yield history + [[query,
|
||||||
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
||||||
else:
|
else:
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
streaming=streaming):
|
{"prompt": query, "history": history, "streaming": streaming})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
history[-1][-1] = resp + (
|
history[-1][-1] = resp + (
|
||||||
@ -105,13 +106,14 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
|
|||||||
args_dict.update(model=llm_model)
|
args_dict.update(model=llm_model)
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
llm_model_ins = shared.loaderLLM()
|
llm_model_ins = shared.loaderLLM()
|
||||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
||||||
embedding_model=embedding_model)
|
embedding_model=embedding_model)
|
||||||
generator = local_doc_qa.llm.generatorAnswer("你好")
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
for answer_result in generator:
|
{"prompt": "你好", "history": [], "streaming": False})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
print(answer_result.llm_output)
|
print(answer_result.llm_output)
|
||||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
logger.info(reply)
|
logger.info(reply)
|
||||||
@ -468,7 +470,7 @@ with st.sidebar:
|
|||||||
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
|
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
|
||||||
history_len = st.slider(
|
history_len = st.slider(
|
||||||
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
|
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
|
||||||
local_doc_qa.llm.set_history_len(history_len)
|
# local_doc_qa.llm.set_history_len(history_len)
|
||||||
chunk_conent = st.checkbox('启用上下文关联', False)
|
chunk_conent = st.checkbox('启用上下文关联', False)
|
||||||
st.text('')
|
st.text('')
|
||||||
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
|
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user