mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-04 05:33:12 +08:00
llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例, 定义checkpoint名称和远程路径
loader.py: 模型重载 定义 generatorAnswer 增加 AnswerResultStream 定义generate_with_callback收集器,在每次响应时将队列数据同步到AnswerResult requirements.txt 变更项目依赖
This commit is contained in:
parent
c3924b2ece
commit
33bbb4779e
@ -12,10 +12,14 @@ from tqdm import tqdm
|
|||||||
from pypinyin import lazy_pinyin
|
from pypinyin import lazy_pinyin
|
||||||
from loader import UnstructuredPaddleImageLoader
|
from loader import UnstructuredPaddleImageLoader
|
||||||
from loader import UnstructuredPaddlePDFLoader
|
from loader import UnstructuredPaddlePDFLoader
|
||||||
|
from models.base import (BaseAnswer,
|
||||||
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
from models.loader.args import parser
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
|
import models.shared as shared
|
||||||
|
|
||||||
DEVICE_ = EMBEDDING_DEVICE
|
|
||||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
|
||||||
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
|
||||||
|
|
||||||
|
|
||||||
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||||
@ -132,7 +136,7 @@ def similarity_search_with_score_by_vector(
|
|||||||
|
|
||||||
|
|
||||||
class LocalDocQA:
|
class LocalDocQA:
|
||||||
llm: object = None
|
llm: BaseAnswer = None
|
||||||
embeddings: object = None
|
embeddings: object = None
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K
|
top_k: int = VECTOR_SEARCH_TOP_K
|
||||||
chunk_size: int = CHUNK_SIZE
|
chunk_size: int = CHUNK_SIZE
|
||||||
@ -142,23 +146,10 @@ class LocalDocQA:
|
|||||||
def init_cfg(self,
|
def init_cfg(self,
|
||||||
embedding_model: str = EMBEDDING_MODEL,
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
embedding_device=EMBEDDING_DEVICE,
|
embedding_device=EMBEDDING_DEVICE,
|
||||||
llm_history_len: int = LLM_HISTORY_LEN,
|
llm_model: BaseAnswer = None,
|
||||||
llm_model: str = LLM_MODEL,
|
|
||||||
llm_device=LLM_DEVICE,
|
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
use_ptuning_v2: bool = USE_PTUNING_V2,
|
|
||||||
use_lora: bool = USE_LORA,
|
|
||||||
):
|
):
|
||||||
if llm_model.startswith('moss'):
|
self.llm = llm_model
|
||||||
from models.moss_llm import MOSS
|
|
||||||
self.llm = MOSS()
|
|
||||||
else:
|
|
||||||
from models.chatglm_llm import ChatGLM
|
|
||||||
self.llm = ChatGLM()
|
|
||||||
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
|
||||||
llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora)
|
|
||||||
self.llm.history_len = llm_history_len
|
|
||||||
|
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||||
model_kwargs={'device': embedding_device})
|
model_kwargs={'device': embedding_device})
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
@ -259,16 +250,16 @@ class LocalDocQA:
|
|||||||
torch_gc()
|
torch_gc()
|
||||||
prompt = generate_prompt(related_docs_with_score, query)
|
prompt = generate_prompt(related_docs_with_score, query)
|
||||||
|
|
||||||
for result, history in self.llm._call(prompt=prompt,
|
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||||
history=chat_history,
|
streaming=streaming):
|
||||||
streaming=streaming):
|
resp = answer_result.llm_output["answer"]
|
||||||
torch_gc()
|
history = answer_result.history
|
||||||
history[-1][0] = query
|
history[-1][0] = query
|
||||||
response = {"query": query,
|
response = {"query": query,
|
||||||
"result": result,
|
"result": resp,
|
||||||
"source_documents": related_docs_with_score}
|
"source_documents": related_docs_with_score}
|
||||||
yield response, history
|
yield response, history
|
||||||
torch_gc()
|
|
||||||
|
|
||||||
# query 查询内容
|
# query 查询内容
|
||||||
# vs_path 知识库路径
|
# vs_path 知识库路径
|
||||||
@ -297,10 +288,19 @@ class LocalDocQA:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# 初始化消息
|
||||||
|
args = None
|
||||||
|
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
|
||||||
|
|
||||||
|
args_dict = vars(args)
|
||||||
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
|
llm_model_ins = shared.loaderLLM()
|
||||||
|
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||||||
|
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
local_doc_qa.init_cfg()
|
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||||
query = "本项目使用的embedding模型是什么,消耗多少显存"
|
query = "本项目使用的embedding模型是什么,消耗多少显存"
|
||||||
vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa"
|
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
|
||||||
last_print_len = 0
|
last_print_len = 0
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||||
vs_path=vs_path,
|
vs_path=vs_path,
|
||||||
|
|||||||
@ -22,14 +22,54 @@ EMBEDDING_MODEL = "text2vec"
|
|||||||
# Embedding running device
|
# Embedding running device
|
||||||
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
# supported LLM models
|
# supported LLM models
|
||||||
|
"""
|
||||||
|
llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例
|
||||||
|
"""
|
||||||
llm_model_dict = {
|
llm_model_dict = {
|
||||||
"chatyuan": "ClueAI/ChatYuan-large-v2",
|
"chatglm-6b-int4-qe": {
|
||||||
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
|
"name": "chatglm-6b-int4-qe",
|
||||||
"chatglm-6b-int4": "THUDM/chatglm-6b-int4",
|
"remote-checkpoint": "THUDM/chatglm-6b-int4-qe",
|
||||||
"chatglm-6b-int8": "THUDM/chatglm-6b-int8",
|
"path": None,
|
||||||
"chatglm-6b": "THUDM/chatglm-6b",
|
"provides": "ChatGLM"
|
||||||
"moss": "fnlp/moss-moon-003-sft",
|
},
|
||||||
|
"chatglm-6b-int4": {
|
||||||
|
"name": "chatglm-6b-int4",
|
||||||
|
"remote-checkpoint": "THUDM/chatglm-6b-int4",
|
||||||
|
"path": None,
|
||||||
|
"provides": "ChatGLM"
|
||||||
|
},
|
||||||
|
"chatglm-6b": {
|
||||||
|
"name": "chatglm-6b",
|
||||||
|
"remote-checkpoint": "THUDM/chatglm-6b-int4",
|
||||||
|
"path": None,
|
||||||
|
"provides": "ChatGLM"
|
||||||
|
},
|
||||||
|
"llama-7b-hf": {
|
||||||
|
"name": "llama-7b-hf",
|
||||||
|
"remote-checkpoint": "llama-7b-hf",
|
||||||
|
"path": None,
|
||||||
|
"provides": "LLamaLLM"
|
||||||
|
},
|
||||||
|
"vicuna-13b-hf": {
|
||||||
|
"name": "vicuna-13b-hf",
|
||||||
|
"remote-checkpoint": "vicuna-13b-hf",
|
||||||
|
"path": None,
|
||||||
|
"provides": "LLamaLLM"
|
||||||
|
},
|
||||||
|
"chatyuan": {
|
||||||
|
"name": "chatyuan",
|
||||||
|
"remote-checkpoint": "ClueAI/ChatYuan-large-v2",
|
||||||
|
"path": None,
|
||||||
|
"provides": None
|
||||||
|
},
|
||||||
|
"chatglm-6b-int8":{
|
||||||
|
"name": "chatglm-6b-int8",
|
||||||
|
"remote-checkpoint": "THUDM/chatglm-6b-int8",
|
||||||
|
"path": None,
|
||||||
|
"provides": "ChatGLM"
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# LLM model name
|
# LLM model name
|
||||||
|
|||||||
0
fastchat/__init__.py
Normal file
0
fastchat/__init__.py
Normal file
1
fastchat/api/__init__.py
Normal file
1
fastchat/api/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .fastchat_api import *
|
||||||
261
fastchat/api/conversation.py
Normal file
261
fastchat/api/conversation.py
Normal file
@ -0,0 +1,261 @@
|
|||||||
|
"""
|
||||||
|
Conversation prompt template.
|
||||||
|
|
||||||
|
Now we support
|
||||||
|
- Vicuna
|
||||||
|
- Koala
|
||||||
|
- OpenAssistant/oasst-sft-1-pythia-12b
|
||||||
|
- StabilityAI/stablelm-tuned-alpha-7b
|
||||||
|
- databricks/dolly-v2-12b
|
||||||
|
- THUDM/chatglm-6b
|
||||||
|
- Alpaca/LLaMa
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from enum import auto, Enum
|
||||||
|
from typing import List, Tuple, Any
|
||||||
|
|
||||||
|
|
||||||
|
class SeparatorStyle(Enum):
|
||||||
|
"""Different separator style."""
|
||||||
|
|
||||||
|
SINGLE = auto()
|
||||||
|
TWO = auto()
|
||||||
|
DOLLY = auto()
|
||||||
|
OASST_PYTHIA = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Conversation:
|
||||||
|
"""A class that keeps all conversation history."""
|
||||||
|
|
||||||
|
system: str
|
||||||
|
roles: List[str]
|
||||||
|
messages: List[List[str]]
|
||||||
|
offset: int
|
||||||
|
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||||
|
sep: str = "###"
|
||||||
|
sep2: str = None
|
||||||
|
|
||||||
|
# Used for gradio server
|
||||||
|
skip_next: bool = False
|
||||||
|
conv_id: Any = None
|
||||||
|
|
||||||
|
def get_prompt(self):
|
||||||
|
if self.sep_style == SeparatorStyle.SINGLE:
|
||||||
|
ret = self.system
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
ret += self.sep + " " + role + ": " + message
|
||||||
|
else:
|
||||||
|
ret += self.sep + " " + role + ":"
|
||||||
|
return ret
|
||||||
|
elif self.sep_style == SeparatorStyle.TWO:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
ret = self.system + seps[0]
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
ret += role + ":"
|
||||||
|
return ret
|
||||||
|
elif self.sep_style == SeparatorStyle.DOLLY:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
ret = self.system
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
ret += role + ":\n" + message + seps[i % 2]
|
||||||
|
if i % 2 == 1:
|
||||||
|
ret += "\n\n"
|
||||||
|
else:
|
||||||
|
ret += role + ":\n"
|
||||||
|
return ret
|
||||||
|
elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
|
||||||
|
ret = self.system
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
ret += role + message + self.sep
|
||||||
|
else:
|
||||||
|
ret += role
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||||
|
|
||||||
|
def append_message(self, role, message):
|
||||||
|
self.messages.append([role, message])
|
||||||
|
|
||||||
|
def to_gradio_chatbot(self):
|
||||||
|
ret = []
|
||||||
|
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||||
|
if i % 2 == 0:
|
||||||
|
ret.append([msg, None])
|
||||||
|
else:
|
||||||
|
ret[-1][-1] = msg
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return Conversation(
|
||||||
|
system=self.system,
|
||||||
|
roles=self.roles,
|
||||||
|
messages=[[x, y] for x, y in self.messages],
|
||||||
|
offset=self.offset,
|
||||||
|
sep_style=self.sep_style,
|
||||||
|
sep=self.sep,
|
||||||
|
sep2=self.sep2,
|
||||||
|
conv_id=self.conv_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
return {
|
||||||
|
"system": self.system,
|
||||||
|
"roles": self.roles,
|
||||||
|
"messages": self.messages,
|
||||||
|
"offset": self.offset,
|
||||||
|
"sep": self.sep,
|
||||||
|
"sep2": self.sep2,
|
||||||
|
"conv_id": self.conv_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
conv_one_shot = Conversation(
|
||||||
|
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||||
|
roles=("Human", "Assistant"),
|
||||||
|
messages=(
|
||||||
|
(
|
||||||
|
"Human",
|
||||||
|
"What are the key differences between renewable and non-renewable energy sources?",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Assistant",
|
||||||
|
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
||||||
|
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
||||||
|
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
||||||
|
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
||||||
|
"renewable and non-renewable energy sources:\n"
|
||||||
|
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
||||||
|
"energy sources are finite and will eventually run out.\n"
|
||||||
|
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
||||||
|
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
||||||
|
"and other negative effects.\n"
|
||||||
|
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
||||||
|
"have lower operational costs than non-renewable sources.\n"
|
||||||
|
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
||||||
|
"locations than non-renewable sources.\n"
|
||||||
|
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
||||||
|
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
||||||
|
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
||||||
|
"non-renewable sources are not, and their depletion can lead to economic and social instability.",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
offset=2,
|
||||||
|
sep_style=SeparatorStyle.SINGLE,
|
||||||
|
sep="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
conv_vicuna_v1_1 = Conversation(
|
||||||
|
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||||
|
roles=("USER", "ASSISTANT"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.TWO,
|
||||||
|
sep=" ",
|
||||||
|
sep2="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
conv_koala_v1 = Conversation(
|
||||||
|
system="BEGINNING OF CONVERSATION:",
|
||||||
|
roles=("USER", "GPT"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.TWO,
|
||||||
|
sep=" ",
|
||||||
|
sep2="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_dolly = Conversation(
|
||||||
|
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
|
||||||
|
roles=("### Instruction", "### Response"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.DOLLY,
|
||||||
|
sep="\n\n",
|
||||||
|
sep2="### End",
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_oasst = Conversation(
|
||||||
|
system="",
|
||||||
|
roles=("<|prompter|>", "<|assistant|>"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.OASST_PYTHIA,
|
||||||
|
sep="<|endoftext|>",
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_stablelm = Conversation(
|
||||||
|
system="""<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||||
|
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||||
|
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||||
|
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||||
|
- StableLM will refuse to participate in anything that could harm a human.
|
||||||
|
""",
|
||||||
|
roles=("<|USER|>", "<|ASSISTANT|>"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.OASST_PYTHIA,
|
||||||
|
sep="",
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_templates = {
|
||||||
|
"conv_one_shot": conv_one_shot,
|
||||||
|
"vicuna_v1.1": conv_vicuna_v1_1,
|
||||||
|
"koala_v1": conv_koala_v1,
|
||||||
|
"dolly": conv_dolly,
|
||||||
|
"oasst": conv_oasst,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_conv_template(model_name):
|
||||||
|
model_name = model_name.lower()
|
||||||
|
if "vicuna" in model_name or "output" in model_name:
|
||||||
|
return conv_vicuna_v1_1
|
||||||
|
elif "koala" in model_name:
|
||||||
|
return conv_koala_v1
|
||||||
|
elif "dolly-v2" in model_name:
|
||||||
|
return conv_dolly
|
||||||
|
elif "oasst" in model_name and "pythia" in model_name:
|
||||||
|
return conv_oasst
|
||||||
|
elif "stablelm" in model_name:
|
||||||
|
return conv_stablelm
|
||||||
|
return conv_one_shot
|
||||||
|
|
||||||
|
|
||||||
|
def compute_skip_echo_len(model_name, conv, prompt):
|
||||||
|
model_name = model_name.lower()
|
||||||
|
if "chatglm" in model_name:
|
||||||
|
skip_echo_len = len(conv.messages[-2][1]) + 1
|
||||||
|
elif "dolly-v2" in model_name:
|
||||||
|
special_toks = ["### Instruction:", "### Response:", "### End"]
|
||||||
|
skip_echo_len = len(prompt)
|
||||||
|
for tok in special_toks:
|
||||||
|
skip_echo_len -= prompt.count(tok) * len(tok)
|
||||||
|
elif "oasst" in model_name and "pythia" in model_name:
|
||||||
|
special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
|
||||||
|
skip_echo_len = len(prompt)
|
||||||
|
for tok in special_toks:
|
||||||
|
skip_echo_len -= prompt.count(tok) * len(tok)
|
||||||
|
elif "stablelm" in model_name:
|
||||||
|
special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
|
||||||
|
skip_echo_len = len(prompt)
|
||||||
|
for tok in special_toks:
|
||||||
|
skip_echo_len -= prompt.count(tok) * len(tok)
|
||||||
|
else:
|
||||||
|
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
|
||||||
|
return skip_echo_len
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(default_conversation.get_prompt())
|
||||||
459
fastchat/api/fastchat_api.py
Normal file
459
fastchat/api/fastchat_api.py
Normal file
@ -0,0 +1,459 @@
|
|||||||
|
"""Wrapper around FastChat APIs."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from typing import (
|
||||||
|
AbstractSet,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import Extra, Field, root_validator
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
FAST_CHAT_API = "http://localhost:21002/worker_generate_stream"
|
||||||
|
|
||||||
|
|
||||||
|
def _streaming_response_template() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
:return: 响应结构
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"text": "",
|
||||||
|
"error_code": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
||||||
|
"""Update response from the stream response."""
|
||||||
|
response["text"] += stream_response["text"]
|
||||||
|
response["error_code"] += stream_response["error_code"]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFastChat(BaseLLM):
|
||||||
|
"""Wrapper around FastChat large language models."""
|
||||||
|
|
||||||
|
model_name: str = "text-davinci-003"
|
||||||
|
"""Model name to use."""
|
||||||
|
temperature: float = 0.7
|
||||||
|
"""What sampling temperature to use."""
|
||||||
|
max_new_tokens: int = 200
|
||||||
|
stop: int = 20
|
||||||
|
batch_size: int = 20
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
streaming: bool = False
|
||||||
|
"""Penalizes repeated tokens."""
|
||||||
|
n: int = 1
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
||||||
|
"""Set of special tokens that are allowed。"""
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||||
|
"""Set of special tokens that are not allowed。"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.ignore
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||||
|
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
logger.warning(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transfered to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling FastChat API."""
|
||||||
|
normal_params = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"prompt": '',
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {**normal_params}
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Call out to FastChat's endpoint with k unique prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: The prompts to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The full LLM output.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = fastchat.generate(["Tell me a joke."])
|
||||||
|
"""
|
||||||
|
# TODO: write a unit test for this
|
||||||
|
params = self._invocation_params
|
||||||
|
sub_prompts = self.get_sub_prompts(params, prompts)
|
||||||
|
choices = []
|
||||||
|
token_usage: Dict[str, int] = {}
|
||||||
|
headers = {"User-Agent": "fastchat Client"}
|
||||||
|
for _prompts in sub_prompts:
|
||||||
|
|
||||||
|
params["prompt"] = _prompts[0]
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
params["stop"] = stop
|
||||||
|
|
||||||
|
if self.streaming:
|
||||||
|
if len(_prompts) > 1:
|
||||||
|
raise ValueError("Cannot stream results with multiple prompts.")
|
||||||
|
|
||||||
|
response_template = _streaming_response_template()
|
||||||
|
response = requests.post(
|
||||||
|
FAST_CHAT_API,
|
||||||
|
headers=headers,
|
||||||
|
json=params,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for stream_resp in response.iter_lines(
|
||||||
|
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||||
|
):
|
||||||
|
if stream_resp:
|
||||||
|
data = json.loads(stream_resp.decode("utf-8"))
|
||||||
|
skip_echo_len = len(_prompts[0])
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
data["text"] = output
|
||||||
|
self.callback_manager.on_llm_new_token(
|
||||||
|
output,
|
||||||
|
verbose=self.verbose,
|
||||||
|
logprobs=data["error_code"],
|
||||||
|
)
|
||||||
|
_update_response(response_template, data)
|
||||||
|
choices.append(response_template)
|
||||||
|
else:
|
||||||
|
response_template = _streaming_response_template()
|
||||||
|
response = requests.post(
|
||||||
|
FAST_CHAT_API,
|
||||||
|
headers=headers,
|
||||||
|
json=params,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for stream_resp in response.iter_lines(
|
||||||
|
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||||
|
):
|
||||||
|
if stream_resp:
|
||||||
|
data = json.loads(stream_resp.decode("utf-8"))
|
||||||
|
skip_echo_len = len(_prompts[0])
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
data["text"] = output
|
||||||
|
_update_response(response_template, data)
|
||||||
|
|
||||||
|
choices.append(response_template)
|
||||||
|
|
||||||
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Call out to FastChat's endpoint async with k unique prompts."""
|
||||||
|
params = self._invocation_params
|
||||||
|
sub_prompts = self.get_sub_prompts(params, prompts)
|
||||||
|
choices = []
|
||||||
|
token_usage: Dict[str, int] = {}
|
||||||
|
|
||||||
|
headers = {"User-Agent": "fastchat Client"}
|
||||||
|
for _prompts in sub_prompts:
|
||||||
|
|
||||||
|
params["prompt"] = _prompts[0]
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
params["stop"] = stop
|
||||||
|
|
||||||
|
if self.streaming:
|
||||||
|
if len(_prompts) > 1:
|
||||||
|
raise ValueError("Cannot stream results with multiple prompts.")
|
||||||
|
|
||||||
|
response_template = _streaming_response_template()
|
||||||
|
response = requests.post(
|
||||||
|
FAST_CHAT_API,
|
||||||
|
headers=headers,
|
||||||
|
json=params,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for stream_resp in response.iter_lines(
|
||||||
|
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||||
|
):
|
||||||
|
if stream_resp:
|
||||||
|
data = json.loads(stream_resp.decode("utf-8"))
|
||||||
|
skip_echo_len = len(_prompts[0])
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
data["text"] = output
|
||||||
|
self.callback_manager.on_llm_new_token(
|
||||||
|
output,
|
||||||
|
verbose=self.verbose,
|
||||||
|
logprobs=data["error_code"],
|
||||||
|
)
|
||||||
|
_update_response(response_template, data)
|
||||||
|
choices.append(response_template)
|
||||||
|
else:
|
||||||
|
response_template = _streaming_response_template()
|
||||||
|
response = requests.post(
|
||||||
|
FAST_CHAT_API,
|
||||||
|
headers=headers,
|
||||||
|
json=params,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for stream_resp in response.iter_lines(
|
||||||
|
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||||
|
):
|
||||||
|
if stream_resp:
|
||||||
|
data = json.loads(stream_resp.decode("utf-8"))
|
||||||
|
skip_echo_len = len(_prompts[0])
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
data["text"] = output
|
||||||
|
_update_response(response_template, data)
|
||||||
|
|
||||||
|
choices.append(response_template)
|
||||||
|
|
||||||
|
return self.create_llm_result(choices, prompts, token_usage)
|
||||||
|
|
||||||
|
def get_sub_prompts(
|
||||||
|
self,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
prompts: List[str],
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""Get the sub prompts for llm call."""
|
||||||
|
if params["max_new_tokens"] == -1:
|
||||||
|
if len(prompts) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"max_new_tokens set to -1 not supported for multiple inputs."
|
||||||
|
)
|
||||||
|
params["max_new_tokens"] = self.max_new_tokens_for_prompt(prompts[0])
|
||||||
|
# append pload
|
||||||
|
sub_prompts = [
|
||||||
|
prompts[i: i + self.batch_size]
|
||||||
|
for i in range(0, len(prompts), self.batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
return sub_prompts
|
||||||
|
|
||||||
|
def create_llm_result(
|
||||||
|
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Create the LLMResult from the choices and prompts."""
|
||||||
|
generations = []
|
||||||
|
for i, _ in enumerate(prompts):
|
||||||
|
sub_choices = choices[i * self.n: (i + 1) * self.n]
|
||||||
|
generations.append(
|
||||||
|
[
|
||||||
|
Generation(
|
||||||
|
text=choice["text"],
|
||||||
|
generation_info=dict(
|
||||||
|
finish_reason='over',
|
||||||
|
logprobs=choice["text"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for choice in sub_choices
|
||||||
|
]
|
||||||
|
)
|
||||||
|
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||||
|
return LLMResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
||||||
|
"""Call FastChat with streaming flag and return the resulting generator.
|
||||||
|
|
||||||
|
BETA: this is a beta feature while we figure out the right abstraction.
|
||||||
|
Once that happens, this interface could change.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompts to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A generator representing the stream of tokens from OpenAI.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
generator = fastChat.stream("Tell me a joke.")
|
||||||
|
for token in generator:
|
||||||
|
yield token
|
||||||
|
"""
|
||||||
|
params = self._invocation_params
|
||||||
|
params["prompt"] = prompt
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
params["stop"] = stop
|
||||||
|
|
||||||
|
headers = {"User-Agent": "fastchat Client"}
|
||||||
|
response = requests.post(
|
||||||
|
FAST_CHAT_API,
|
||||||
|
headers=headers,
|
||||||
|
json=params,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for stream_resp in response.iter_lines(
|
||||||
|
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||||
|
):
|
||||||
|
if stream_resp:
|
||||||
|
data = json.loads(stream_resp.decode("utf-8"))
|
||||||
|
skip_echo_len = len(_prompts[0])
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
data["text"] = output
|
||||||
|
yield data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model."""
|
||||||
|
return self._default_params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {**{"model_name": self.model_name}, **self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "fastChat"
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Calculate num tokens with tiktoken package."""
|
||||||
|
# tiktoken NOT supported for Python < 3.8
|
||||||
|
if sys.version_info[1] < 8:
|
||||||
|
return super().get_num_tokens(text)
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import tiktoken python package. "
|
||||||
|
"This is needed in order to calculate get_num_tokens. "
|
||||||
|
"Please install it with `pip install tiktoken`."
|
||||||
|
)
|
||||||
|
|
||||||
|
enc = tiktoken.encoding_for_model(self.model_name)
|
||||||
|
|
||||||
|
tokenized_text = enc.encode(
|
||||||
|
text,
|
||||||
|
allowed_special=self.allowed_special,
|
||||||
|
disallowed_special=self.disallowed_special,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate the number of tokens in the encoded text
|
||||||
|
return len(tokenized_text)
|
||||||
|
|
||||||
|
def modelname_to_contextsize(self, modelname: str) -> int:
|
||||||
|
"""Calculate the maximum number of tokens possible to generate for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modelname: The modelname we want to know the context size for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The maximum context size
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
max_new_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
||||||
|
"""
|
||||||
|
model_token_mapping = {
|
||||||
|
"vicuna-13b": 2049,
|
||||||
|
"koala": 2049,
|
||||||
|
"dolly-v2": 2049,
|
||||||
|
"oasst": 2049,
|
||||||
|
"stablelm": 2049,
|
||||||
|
}
|
||||||
|
|
||||||
|
context_size = model_token_mapping.get(modelname, None)
|
||||||
|
|
||||||
|
if context_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
|
||||||
|
"Known models are: " + ", ".join(model_token_mapping.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
return context_size
|
||||||
|
|
||||||
|
def max_new_tokens_for_prompt(self, prompt: str) -> int:
|
||||||
|
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The maximum number of tokens to generate for a prompt.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
max_new_tokens = openai.max_token_for_prompt("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
num_tokens = self.get_num_tokens(prompt)
|
||||||
|
|
||||||
|
# get max context size for model by name
|
||||||
|
max_size = self.modelname_to_contextsize(self.model_name)
|
||||||
|
return max_size - num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class FastChat(BaseFastChat):
|
||||||
|
"""Wrapper around OpenAI large language models.
|
||||||
|
|
||||||
|
To use, you should have the ``openai`` python package installed, and the
|
||||||
|
environment variable ``OPENAI_API_KEY`` set with your API key.
|
||||||
|
|
||||||
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||||
|
in, even if not explicitly saved on this class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
openai = FastChat(model_name="vicuna")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Dict[str, Any]:
|
||||||
|
return {**{"model": self.model_name}, **super()._invocation_params}
|
||||||
@ -1,2 +1,4 @@
|
|||||||
|
|
||||||
from .chatglm_llm import *
|
from .chatglm_llm import ChatGLM
|
||||||
|
from .llama_llm import LLamaLLM
|
||||||
|
from .moss_llm import MOSSLLM
|
||||||
|
|||||||
97
models/__main__.py
Normal file
97
models/__main__.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
|
||||||
|
import asyncio
|
||||||
|
from argparse import Namespace
|
||||||
|
from models.loader.args import parser
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
|
|
||||||
|
from langchain.agents import initialize_agent, Tool
|
||||||
|
from langchain.agents import AgentType
|
||||||
|
|
||||||
|
import models.shared as shared
|
||||||
|
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
|
||||||
|
from typing import List, Set
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLMSingleActionAgent(ZeroShotAgent):
|
||||||
|
allowed_tools: List[str]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(CustomLLMSingleActionAgent, self).__init__(*args, **kwargs)
|
||||||
|
self.allowed_tools = kwargs['allowed_tools']
|
||||||
|
|
||||||
|
def get_allowed_tools(self) -> Set[str]:
|
||||||
|
return set(self.allowed_tools)
|
||||||
|
|
||||||
|
|
||||||
|
async def dispatch(args: Namespace):
|
||||||
|
args_dict = vars(args)
|
||||||
|
|
||||||
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
|
llm_model_ins = shared.loaderLLM()
|
||||||
|
|
||||||
|
template = """This is a conversation between a human and a bot:
|
||||||
|
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
Write a summary of the conversation for {input}:
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["input", "chat_history"],
|
||||||
|
template=template
|
||||||
|
)
|
||||||
|
memory = ConversationBufferMemory(memory_key="chat_history")
|
||||||
|
readonlymemory = ReadOnlySharedMemory(memory=memory)
|
||||||
|
summry_chain = LLMChain(
|
||||||
|
llm=llm_model_ins,
|
||||||
|
prompt=prompt,
|
||||||
|
verbose=True,
|
||||||
|
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
Tool(
|
||||||
|
name="Summary",
|
||||||
|
func=summry_chain.run,
|
||||||
|
description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary."
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
|
||||||
|
suffix = """Begin!
|
||||||
|
|
||||||
|
Question: {input}
|
||||||
|
{agent_scratchpad}"""
|
||||||
|
|
||||||
|
|
||||||
|
prompt = CustomLLMSingleActionAgent.create_prompt(
|
||||||
|
tools,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
input_variables=["input", "agent_scratchpad"]
|
||||||
|
)
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
llm_chain = LLMChain(llm=llm_model_ins, prompt=prompt)
|
||||||
|
agent = CustomLLMSingleActionAgent(llm_chain=llm_chain, tools=tools, allowed_tools=tool_names)
|
||||||
|
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools)
|
||||||
|
|
||||||
|
agent_chain.run(input="你好")
|
||||||
|
agent_chain.run(input="你是谁?")
|
||||||
|
agent_chain.run(input="我们之前聊了什么?")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = None
|
||||||
|
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'vicuna-13b-hf', '--no-remote-model', '--load-in-8bit'])
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(dispatch(args))
|
||||||
198
models/base.py
Normal file
198
models/base.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, List
|
||||||
|
import traceback
|
||||||
|
from collections import deque
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
|
|
||||||
|
|
||||||
|
class ListenerToken:
|
||||||
|
"""
|
||||||
|
观测结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_ids: torch.LongTensor
|
||||||
|
_scores: torch.FloatTensor
|
||||||
|
|
||||||
|
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self._scores = _scores
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerResult:
|
||||||
|
"""
|
||||||
|
消息实体
|
||||||
|
"""
|
||||||
|
history: List[List[str]] = []
|
||||||
|
llm_output: Optional[dict] = None
|
||||||
|
listenerToken: ListenerToken = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerResultStream:
|
||||||
|
def __init__(self, callback_func=None):
|
||||||
|
self.callback_func = callback_func
|
||||||
|
|
||||||
|
def __call__(self, answerResult: AnswerResult):
|
||||||
|
if self.callback_func is not None:
|
||||||
|
self.callback_func(answerResult)
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
|
||||||
|
"""
|
||||||
|
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
|
||||||
|
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
|
||||||
|
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
|
||||||
|
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
|
||||||
|
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
|
||||||
|
"""
|
||||||
|
|
||||||
|
listenerQueue: deque = deque(maxlen=1)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
transformers.StoppingCriteria.__init__(self)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
"""
|
||||||
|
每次响应时将数据添加到响应队列
|
||||||
|
:param input_ids:
|
||||||
|
:param _scores:
|
||||||
|
:param kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Iteratorize:
|
||||||
|
"""
|
||||||
|
Transforms a function that takes a callback
|
||||||
|
into a lazy iterator (generator).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func, kwargs={}):
|
||||||
|
self.mfunc = func
|
||||||
|
self.q = Queue()
|
||||||
|
self.sentinel = object()
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.stop_now = False
|
||||||
|
|
||||||
|
def _callback(val):
|
||||||
|
"""
|
||||||
|
模型输出预测结果收集
|
||||||
|
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
|
||||||
|
结束条件包含如下
|
||||||
|
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
|
||||||
|
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
|
||||||
|
3、模型预测出错
|
||||||
|
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
|
||||||
|
迭代器收集的行为如下
|
||||||
|
创建Iteratorize迭代对象,
|
||||||
|
定义generate_with_callback收集器AnswerResultStream
|
||||||
|
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
|
||||||
|
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
|
||||||
|
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
|
||||||
|
这时generate_with_callback会被阻塞
|
||||||
|
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
|
||||||
|
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
|
||||||
|
2、消息为self.sentinel标识,抛出StopIteration异常
|
||||||
|
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
|
||||||
|
异步线程检测stop_now属性被更新,抛出异常结束预测行为
|
||||||
|
迭代行为结束
|
||||||
|
:param val:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.stop_now:
|
||||||
|
raise ValueError
|
||||||
|
self.q.put(val)
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
try:
|
||||||
|
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.q.put(self.sentinel)
|
||||||
|
|
||||||
|
self.thread = Thread(target=gen)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
obj = self.q.get(True, None)
|
||||||
|
if obj is self.sentinel:
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
暂无实现
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
""" break 后会执行 """
|
||||||
|
self.stop_now = True
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAnswer(ABC):
|
||||||
|
"""上层业务包装器.用于结果生成统一api调用"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
|
"""Return _check_point of llm."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def _history_len(self) -> int:
|
||||||
|
"""Return _history_len of llm."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_history_len(self, history_len: int) -> None:
|
||||||
|
"""Return _history_len of llm."""
|
||||||
|
|
||||||
|
def generatorAnswer(self, prompt: str,
|
||||||
|
history: List[List[str]] = [],
|
||||||
|
streaming: bool = False):
|
||||||
|
def generate_with_callback(callback=None, **kwargs):
|
||||||
|
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
|
||||||
|
self._generate_answer(**kwargs)
|
||||||
|
|
||||||
|
def generate_with_streaming(**kwargs):
|
||||||
|
return Iteratorize(generate_with_callback, kwargs)
|
||||||
|
|
||||||
|
"""
|
||||||
|
eos_token_id是指定token(例如,"</s>"),
|
||||||
|
用于表示序列的结束。在生成文本任务中,生成器在生成序列时,将不断地生成token,直到生成此特殊的eos_token_id,表示序列生成已经完成。
|
||||||
|
在Hugging Face Transformer模型中,eos_token_id是由tokenizer自动添加到输入中的。
|
||||||
|
在模型生成输出时,如果模型生成了eos_token_id,则生成过程将停止并返回生成的序列。
|
||||||
|
"""
|
||||||
|
eos_token_ids = [
|
||||||
|
self._check_point.tokenizer.eos_token_id] if self._check_point.tokenizer.eos_token_id is not None else []
|
||||||
|
|
||||||
|
with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator:
|
||||||
|
for answerResult in generator:
|
||||||
|
if answerResult.listenerToken:
|
||||||
|
output = answerResult.listenerToken.input_ids
|
||||||
|
yield answerResult
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generate_answer(self, prompt: str,
|
||||||
|
history: List[List[str]] = [],
|
||||||
|
streaming: bool = False,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
pass
|
||||||
@ -1,189 +1,96 @@
|
|||||||
import json
|
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from typing import List, Dict, Optional
|
from typing import Optional, List
|
||||||
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
from models.loader import LoaderCheckPoint
|
||||||
import torch
|
from models.base import (BaseAnswer,
|
||||||
from configs.model_config import *
|
AnswerResult,
|
||||||
from utils import torch_gc
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
DEVICE_ = LLM_DEVICE
|
|
||||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
|
||||||
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
|
||||||
|
|
||||||
|
|
||||||
def auto_configure_device_map(num_gpus: int, use_lora: bool) -> Dict[str, int]:
|
import transformers
|
||||||
# transformer.word_embeddings 占用1层
|
|
||||||
# transformer.final_layernorm 和 lm_head 占用1层
|
|
||||||
# transformer.layers 占用 28 层
|
|
||||||
# 总共30层分配到num_gpus张卡上
|
|
||||||
num_trans_layers = 28
|
|
||||||
per_gpu_layers = 30 / num_gpus
|
|
||||||
|
|
||||||
# bugfix: PEFT加载lora模型出现的层命名不同
|
|
||||||
if LLM_LORA_PATH and use_lora:
|
|
||||||
layer_prefix = 'base_model.model.transformer'
|
|
||||||
else:
|
|
||||||
layer_prefix = 'transformer'
|
|
||||||
|
|
||||||
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
|
||||||
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
|
||||||
# linux下 model.device 会被设置成 lm_head.device
|
|
||||||
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
|
||||||
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
|
||||||
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
|
||||||
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
|
||||||
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
|
||||||
f'base_model.model.lm_head': 0, }
|
|
||||||
|
|
||||||
used = 2
|
|
||||||
gpu_target = 0
|
|
||||||
for i in range(num_trans_layers):
|
|
||||||
if used >= per_gpu_layers:
|
|
||||||
gpu_target += 1
|
|
||||||
used = 0
|
|
||||||
assert gpu_target < num_gpus
|
|
||||||
device_map[f'{layer_prefix}.layers.{i}'] = gpu_target
|
|
||||||
used += 1
|
|
||||||
|
|
||||||
return device_map
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGLM(LLM):
|
class ChatGLM(BaseAnswer, LLM, ABC):
|
||||||
max_token: int = 10000
|
max_token: int = 10000
|
||||||
temperature: float = 0.8
|
temperature: float = 0.01
|
||||||
top_p = 0.9
|
top_p = 0.9
|
||||||
|
checkPoint: LoaderCheckPoint = None
|
||||||
# history = []
|
# history = []
|
||||||
tokenizer: object = None
|
|
||||||
model: object = None
|
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "ChatGLM"
|
return "ChatGLM"
|
||||||
|
|
||||||
def _call(self,
|
@property
|
||||||
prompt: str,
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
history: List[List[str]] = [],
|
return self.checkPoint
|
||||||
streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
|
|
||||||
|
@property
|
||||||
|
def _history_len(self) -> int:
|
||||||
|
return self.history_len
|
||||||
|
|
||||||
|
def set_history_len(self, history_len: int = 10) -> None:
|
||||||
|
self.history_len = history_len
|
||||||
|
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _generate_answer(self, prompt: str,
|
||||||
|
history: List[List[str]] = [],
|
||||||
|
streaming: bool = False,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
# Create the StoppingCriteriaList with the stopping strings
|
||||||
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
|
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||||
|
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||||
|
stopping_criteria_list.append(listenerQueue)
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
|
|
||||||
self.tokenizer,
|
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
|
||||||
|
self.checkPoint.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
history=history[-self.history_len:-1] if self.history_len > 0 else [],
|
history=history[-self.history_len:-1] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
top_p=self.top_p,
|
stopping_criteria=stopping_criteria_list
|
||||||
)):
|
)):
|
||||||
torch_gc()
|
self.checkPoint.clear_torch_cache()
|
||||||
if inum == 0:
|
if inum == 0:
|
||||||
history += [[prompt, stream_resp]]
|
history += [[prompt, stream_resp]]
|
||||||
else:
|
else:
|
||||||
history[-1] = [prompt, stream_resp]
|
history[-1] = [prompt, stream_resp]
|
||||||
yield stream_resp, history
|
answer_result = AnswerResult()
|
||||||
torch_gc()
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": stream_resp}
|
||||||
|
if listenerQueue.listenerQueue.__len__() > 0:
|
||||||
|
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||||
|
generate_with_callback(answer_result)
|
||||||
else:
|
else:
|
||||||
response, _ = self.model.chat(
|
response, _ = self.checkPoint.model.chat(
|
||||||
self.tokenizer,
|
self.checkPoint.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
history=history[-self.history_len:] if self.history_len > 0 else [],
|
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
top_p=self.top_p,
|
stopping_criteria=stopping_criteria_list
|
||||||
)
|
)
|
||||||
torch_gc()
|
self.checkPoint.clear_torch_cache()
|
||||||
history += [[prompt, response]]
|
history += [[prompt, response]]
|
||||||
yield response, history
|
answer_result = AnswerResult()
|
||||||
torch_gc()
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": response}
|
||||||
|
if listenerQueue.listenerQueue.__len__() > 0:
|
||||||
|
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||||
|
|
||||||
# def chat(self,
|
generate_with_callback(answer_result)
|
||||||
# prompt: str) -> str:
|
|
||||||
# response, _ = self.model.chat(
|
|
||||||
# self.tokenizer,
|
|
||||||
# prompt,
|
|
||||||
# history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
|
||||||
# max_length=self.max_token,
|
|
||||||
# temperature=self.temperature,
|
|
||||||
# )
|
|
||||||
# torch_gc()
|
|
||||||
# self.history = self.history + [[None, response]]
|
|
||||||
# return response
|
|
||||||
|
|
||||||
def load_model(self,
|
|
||||||
model_name_or_path: str = "THUDM/chatglm-6b",
|
|
||||||
llm_device=LLM_DEVICE,
|
|
||||||
use_ptuning_v2=False,
|
|
||||||
use_lora=False,
|
|
||||||
device_map: Optional[Dict[str, int]] = None,
|
|
||||||
**kwargs):
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_name_or_path,
|
|
||||||
trust_remote_code=True
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
if use_ptuning_v2:
|
|
||||||
try:
|
|
||||||
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
|
|
||||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
|
||||||
prefix_encoder_file.close()
|
|
||||||
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
|
||||||
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载PrefixEncoder config.json失败: {e}")
|
|
||||||
self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True,
|
|
||||||
**kwargs)
|
|
||||||
if LLM_LORA_PATH and use_lora:
|
|
||||||
from peft import PeftModel
|
|
||||||
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
|
||||||
|
|
||||||
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
|
||||||
# 根据当前设备GPU数量决定是否进行多卡部署
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
if num_gpus < 2 and device_map is None:
|
|
||||||
self.model = self.model.half().cuda()
|
|
||||||
else:
|
|
||||||
from accelerate import dispatch_model
|
|
||||||
|
|
||||||
# model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
|
|
||||||
# config=model_config, **kwargs)
|
|
||||||
if LLM_LORA_PATH and use_lora:
|
|
||||||
from peft import PeftModel
|
|
||||||
model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
|
||||||
# 可传入device_map自定义每张卡的部署情况
|
|
||||||
if device_map is None:
|
|
||||||
device_map = auto_configure_device_map(num_gpus, use_lora)
|
|
||||||
|
|
||||||
self.model = dispatch_model(self.model.half(), device_map=device_map)
|
|
||||||
else:
|
|
||||||
self.model = self.model.float().to(llm_device)
|
|
||||||
|
|
||||||
if use_ptuning_v2:
|
|
||||||
try:
|
|
||||||
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
|
|
||||||
new_prefix_state_dict = {}
|
|
||||||
for k, v in prefix_state_dict.items():
|
|
||||||
if k.startswith("transformer.prefix_encoder."):
|
|
||||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
|
||||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
|
||||||
self.model.transformer.prefix_encoder.float()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载PrefixEncoder模型参数失败:{e}")
|
|
||||||
|
|
||||||
self.model = self.model.eval()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
llm = ChatGLM()
|
|
||||||
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
|
|
||||||
llm_device=LLM_DEVICE, )
|
|
||||||
last_print_len = 0
|
|
||||||
for resp, history in llm._call("你好", streaming=True):
|
|
||||||
logger.info(resp[last_print_len:], end="", flush=True)
|
|
||||||
last_print_len = len(resp)
|
|
||||||
for resp, history in llm._call("你好", streaming=False):
|
|
||||||
logger.info(resp)
|
|
||||||
pass
|
|
||||||
|
|||||||
222
models/extensions/callback.py
Normal file
222
models/extensions/callback.py
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
import gc
|
||||||
|
import traceback
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread
|
||||||
|
import threading
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from collections import deque
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from models.extensions.thread_with_exception import ThreadWithException
|
||||||
|
import models.shared as shared
|
||||||
|
|
||||||
|
|
||||||
|
class LimitedLengthDict(dict):
|
||||||
|
def __init__(self, maxlen=None, *args, **kwargs):
|
||||||
|
self.maxlen = maxlen
|
||||||
|
self._keys = deque()
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
if key not in self:
|
||||||
|
if self.maxlen is not None and len(self) >= self.maxlen:
|
||||||
|
oldest_key = self._keys.popleft()
|
||||||
|
if oldest_key in self:
|
||||||
|
del self[oldest_key]
|
||||||
|
self._keys.append(key)
|
||||||
|
super().__setitem__(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class FixedLengthQueue:
|
||||||
|
|
||||||
|
# 停止符号列表
|
||||||
|
stop_sequence: Optional[str] = []
|
||||||
|
# 缓冲区
|
||||||
|
max_length: int = 0
|
||||||
|
# 缓冲区容器
|
||||||
|
queue: deque = None
|
||||||
|
# 输入区容器
|
||||||
|
queue_in: LimitedLengthDict[int, str] = {}
|
||||||
|
# 输出区容器
|
||||||
|
queue_out: Dict[int, str] = {}
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# 创建新的实例
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
# 在这里可以对实例进行额外的设置
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def __init__(self, stop_sequence):
|
||||||
|
if stop_sequence is None:
|
||||||
|
self.stop_sequence = []
|
||||||
|
self.max_length = 0
|
||||||
|
elif isinstance(stop_sequence, str):
|
||||||
|
self.stop_sequence = [stop_sequence]
|
||||||
|
self.max_length = 1
|
||||||
|
else:
|
||||||
|
self.stop_sequence = stop_sequence
|
||||||
|
self.max_length = len(''.join(stop_sequence))
|
||||||
|
|
||||||
|
self.queue = deque(maxlen=self.max_length)
|
||||||
|
self.queue.clear()
|
||||||
|
self.queue_in.clear()
|
||||||
|
self.queue_out.clear()
|
||||||
|
|
||||||
|
def add(self, index, item):
|
||||||
|
self.queue_in[index] = item
|
||||||
|
|
||||||
|
def _add_out(self, index, item):
|
||||||
|
self.queue_out[index] = item
|
||||||
|
|
||||||
|
def put_replace_out(self, index):
|
||||||
|
return self.queue_out[index]
|
||||||
|
|
||||||
|
def contains_replace_sequence(self):
|
||||||
|
"""
|
||||||
|
替换字符
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
for key, value in self.queue_in.items():
|
||||||
|
|
||||||
|
word_index = value.rfind(":")
|
||||||
|
if word_index != -1:
|
||||||
|
value = value.replace(":", ":")
|
||||||
|
|
||||||
|
word_index = value.rfind("[")
|
||||||
|
if word_index != -1:
|
||||||
|
value = value.replace("[", "")
|
||||||
|
|
||||||
|
word_index = value.rfind("]")
|
||||||
|
if word_index != -1:
|
||||||
|
value = value.replace("]", "")
|
||||||
|
|
||||||
|
self._add_out(key, value)
|
||||||
|
|
||||||
|
def contains_stop_sequence(self):
|
||||||
|
# 截取固定大小的数据判断
|
||||||
|
self.queue.clear()
|
||||||
|
last_three_keys = list(self.queue_out.keys())[-self.max_length:]
|
||||||
|
joined_queue = ''.join([self.queue_out[key] for key in last_three_keys])
|
||||||
|
for char in joined_queue:
|
||||||
|
self.queue.append(char)
|
||||||
|
|
||||||
|
joined_queue = ''.join(self.queue)
|
||||||
|
# Initialize a variable to store the index of the last found stop string
|
||||||
|
last_stop_str_index = -1
|
||||||
|
|
||||||
|
# Iterate through the stop string list
|
||||||
|
for stop_word in self.stop_sequence:
|
||||||
|
# Find the last occurrence of the stop string in the output
|
||||||
|
stop_word_index = joined_queue.rfind(stop_word)
|
||||||
|
|
||||||
|
# If the stop string is found, compare the index with the previously found index
|
||||||
|
if stop_word_index != -1 and stop_word_index > last_stop_str_index:
|
||||||
|
last_stop_str_index = stop_word_index
|
||||||
|
|
||||||
|
# Handle the last found stop string index here
|
||||||
|
return last_stop_str_index
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.queue)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
||||||
|
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||||
|
|
||||||
|
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
||||||
|
transformers.StoppingCriteria.__init__(self)
|
||||||
|
self.sentinel_token_ids = sentinel_token_ids
|
||||||
|
self.starting_idx = starting_idx
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||||
|
for sample in input_ids:
|
||||||
|
trimmed_sample = sample[self.starting_idx:]
|
||||||
|
|
||||||
|
for i in range(len(self.sentinel_token_ids)):
|
||||||
|
# Can't unfold, output is still too tiny. Skip.
|
||||||
|
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
|
||||||
|
continue
|
||||||
|
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
|
||||||
|
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Stream(transformers.StoppingCriteria):
|
||||||
|
def __init__(self, callback_func=None):
|
||||||
|
self.callback_func = callback_func
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
if shared.stop_everything:
|
||||||
|
raise ValueError
|
||||||
|
if self.callback_func is not None:
|
||||||
|
self.callback_func(input_ids[0])
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Iteratorize:
|
||||||
|
"""
|
||||||
|
Transforms a function that takes a callback
|
||||||
|
into a lazy iterator (generator).
|
||||||
|
"""
|
||||||
|
|
||||||
|
thread: ThreadWithException = None
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# 创建新的实例
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
# 在这里可以对实例进行额外的设置
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, func, kwargs={}, callback=None):
|
||||||
|
self.mfunc = func
|
||||||
|
self.c_callback = callback
|
||||||
|
self.q = Queue()
|
||||||
|
self.sentinel = object()
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def _callback(val):
|
||||||
|
if shared.stop_everything:
|
||||||
|
raise ValueError
|
||||||
|
self.q.put(val)
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
try:
|
||||||
|
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||||
|
except ValueError:
|
||||||
|
print("print(ValueError)")
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
print("traceback.print_exc()")
|
||||||
|
self.q.put(self.sentinel)
|
||||||
|
|
||||||
|
self.thread = ThreadWithException(target=gen)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
shared.stop_everything = False
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
obj = self.q.get(True, None)
|
||||||
|
if obj is self.sentinel:
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
shared.stop_everything = False
|
||||||
|
self.q.empty()
|
||||||
|
shared.loaderCheckPoint.clear_torch_cache()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
shared.stop_everything = False
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
shared.stop_everything = True
|
||||||
|
shared.loaderCheckPoint.clear_torch_cache()
|
||||||
|
self.thread.raise_exception()
|
||||||
10
models/extensions/extensions.py
Normal file
10
models/extensions/extensions.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import gc
|
||||||
|
import traceback
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# This iterator returns the extensions in the order specified in the command-line
|
||||||
|
def iterator():
|
||||||
|
state_extensions = {}
|
||||||
|
for name in sorted(state_extensions, key=lambda x: state_extensions[x][1]):
|
||||||
|
if state_extensions[name][0]:
|
||||||
|
yield getattr(extensions, name).script, name
|
||||||
64
models/extensions/llamacpp_model_alternative.py
Normal file
64
models/extensions/llamacpp_model_alternative.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
'''
|
||||||
|
Based on
|
||||||
|
https://github.com/abetlen/llama-cpp-python
|
||||||
|
|
||||||
|
Documentation:
|
||||||
|
https://abetlen.github.io/llama-cpp-python/
|
||||||
|
'''
|
||||||
|
|
||||||
|
from llama_cpp import Llama, LlamaCache
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.callbacks import Iteratorize
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCppModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, path):
|
||||||
|
result = self()
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'model_path': str(path),
|
||||||
|
'n_ctx': 2048,
|
||||||
|
'seed': 0,
|
||||||
|
'n_threads': shared.args.threads or None
|
||||||
|
}
|
||||||
|
self.model = Llama(**params)
|
||||||
|
self.model.set_cache(LlamaCache)
|
||||||
|
|
||||||
|
# This is ugly, but the model and the tokenizer are the same object in this library.
|
||||||
|
return result, result
|
||||||
|
|
||||||
|
def encode(self, string):
|
||||||
|
if type(string) is str:
|
||||||
|
string = string.encode()
|
||||||
|
return self.model.tokenize(string)
|
||||||
|
|
||||||
|
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
|
||||||
|
if type(context) is str:
|
||||||
|
context = context.encode()
|
||||||
|
tokens = self.model.tokenize(context)
|
||||||
|
|
||||||
|
output = b""
|
||||||
|
count = 0
|
||||||
|
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
|
||||||
|
text = self.model.detokenize([token])
|
||||||
|
output += text
|
||||||
|
if callback:
|
||||||
|
callback(text.decode())
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
if count >= token_count or (token == self.model.token_eos()):
|
||||||
|
break
|
||||||
|
|
||||||
|
return output.decode()
|
||||||
|
|
||||||
|
def generate_with_streaming(self, **kwargs):
|
||||||
|
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||||
|
reply = ''
|
||||||
|
for token in generator:
|
||||||
|
reply += token
|
||||||
|
yield reply
|
||||||
30
models/extensions/thread_with_exception.py
Normal file
30
models/extensions/thread_with_exception.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Python program raising
|
||||||
|
# exceptions in a python
|
||||||
|
# thread
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import ctypes
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadWithException(threading.Thread):
|
||||||
|
|
||||||
|
def get_id(self):
|
||||||
|
return self.ident
|
||||||
|
|
||||||
|
def raise_exception(self):
|
||||||
|
"""raises the exception, performs cleanup if needed"""
|
||||||
|
try:
|
||||||
|
thread_id = self.get_id()
|
||||||
|
tid = ctypes.c_long(thread_id)
|
||||||
|
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(SystemExit))
|
||||||
|
if res == 0:
|
||||||
|
# pass
|
||||||
|
raise ValueError("invalid thread id")
|
||||||
|
elif res != 1:
|
||||||
|
# """if it returns a number greater than one, you're in trouble,
|
||||||
|
# and you should call it again with exc=NULL to revert the effect"""
|
||||||
|
ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
|
||||||
|
raise SystemError("PyThreadState_SetAsyncExc failed")
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
297
models/llama_llm.py
Normal file
297
models/llama_llm.py
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
|
from models.extensions.callback import (Iteratorize, Stream, FixedLengthQueue)
|
||||||
|
import models.shared as shared
|
||||||
|
from models.base import (BaseAnswer,
|
||||||
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
|
||||||
|
|
||||||
|
def _streaming_response_template() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
:return: 响应结构
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"text": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _update_response(response: Dict[str, Any], stream_response: str) -> None:
|
||||||
|
"""Update response from the stream response."""
|
||||||
|
response["text"] += stream_response
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||||
|
scores.zero_()
|
||||||
|
scores[..., 5] = 5e4
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||||
|
checkPoint: LoaderCheckPoint = None
|
||||||
|
history = []
|
||||||
|
history_len: int = 3
|
||||||
|
max_new_tokens: int = 500
|
||||||
|
num_beams: int = 1
|
||||||
|
temperature: float = 0.5
|
||||||
|
top_p: float = 0.4
|
||||||
|
top_k: int = 10
|
||||||
|
repetition_penalty: float = 1.2
|
||||||
|
encoder_repetition_penalty: int = 1
|
||||||
|
min_length: int = 0
|
||||||
|
logits_processor: LogitsProcessorList = None
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None
|
||||||
|
eos_token_id: Optional[int] = [2]
|
||||||
|
|
||||||
|
state: object = {'max_new_tokens': 50,
|
||||||
|
'seed': 1,
|
||||||
|
'temperature': 0, 'top_p': 0.1,
|
||||||
|
'top_k': 40, 'typical_p': 1,
|
||||||
|
'repetition_penalty': 1.2,
|
||||||
|
'encoder_repetition_penalty': 1,
|
||||||
|
'no_repeat_ngram_size': 0,
|
||||||
|
'min_length': 0,
|
||||||
|
'penalty_alpha': 0,
|
||||||
|
'num_beams': 1,
|
||||||
|
'length_penalty': 1,
|
||||||
|
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
|
||||||
|
'truncation_length': 2048, 'custom_stopping_strings': '',
|
||||||
|
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
|
||||||
|
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
|
||||||
|
'pre_layer': 0, 'gpu_memory_0': 0}
|
||||||
|
|
||||||
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
|
super().__init__()
|
||||||
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "LLamaLLM"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
|
return self.checkPoint
|
||||||
|
|
||||||
|
def encode(self, prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||||
|
input_ids = self.checkPoint.tokenizer.encode(str(prompt), return_tensors='pt',
|
||||||
|
add_special_tokens=add_special_tokens)
|
||||||
|
# This is a hack for making replies more creative.
|
||||||
|
if not add_bos_token and input_ids[0][0] == self.checkPoint.tokenizer.bos_token_id:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
|
||||||
|
# Llama adds this extra token when the first character is '\n', and this
|
||||||
|
# compromises the stopping criteria, so we just remove it
|
||||||
|
if type(self.checkPoint.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
|
||||||
|
# Handling truncation
|
||||||
|
if truncation_length is not None:
|
||||||
|
input_ids = input_ids[:, -truncation_length:]
|
||||||
|
|
||||||
|
return input_ids.cuda()
|
||||||
|
|
||||||
|
def decode(self, output_ids):
|
||||||
|
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||||
|
return reply
|
||||||
|
|
||||||
|
def generate_with_callback(self, callback=None, **kwargs):
|
||||||
|
self.checkPoint.clear_torch_cache()
|
||||||
|
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||||
|
with torch.no_grad():
|
||||||
|
self.checkPoint.model.generate(**kwargs)
|
||||||
|
print("方法结束")
|
||||||
|
|
||||||
|
def generate_with_streaming(self, **kwargs):
|
||||||
|
return Iteratorize(self.generate_with_callback, kwargs)
|
||||||
|
|
||||||
|
# 将历史对话数组转换为文本格式
|
||||||
|
def history_to_text(self, query):
|
||||||
|
formatted_history = ''
|
||||||
|
history = self.history[-self.history_len:] if self.history_len > 0 else []
|
||||||
|
for i, (old_query, response) in enumerate(history):
|
||||||
|
formatted_history += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
||||||
|
formatted_history += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||||
|
return formatted_history
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self,
|
||||||
|
input_ids: torch.LongTensor):
|
||||||
|
"""
|
||||||
|
预生成注意力掩码和 输入序列中每个位置的索引的张量
|
||||||
|
# TODO 没有思路
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
|
||||||
|
|
||||||
|
attention_mask = self.get_masks(input_ids, input_ids.device)
|
||||||
|
|
||||||
|
position_ids = self.get_position_ids(
|
||||||
|
input_ids,
|
||||||
|
device=input_ids.device,
|
||||||
|
mask_positions=mask_positions
|
||||||
|
)
|
||||||
|
|
||||||
|
return input_ids, position_ids, attention_mask
|
||||||
|
|
||||||
|
def get_position_ids(self, input_ids: torch.LongTensor, mask_positions, device):
|
||||||
|
"""
|
||||||
|
注意力偏移量
|
||||||
|
:param input_ids:
|
||||||
|
:param mask_positions:
|
||||||
|
:param device:
|
||||||
|
:param use_gmasks:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids]
|
||||||
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||||
|
for i, context_length in enumerate(context_lengths):
|
||||||
|
position_ids[i, context_length:] = mask_positions[i]
|
||||||
|
block_position_ids = [torch.cat((
|
||||||
|
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||||
|
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||||
|
)) for context_length in context_lengths]
|
||||||
|
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||||
|
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def get_masks(self, input_ids, device):
|
||||||
|
"""
|
||||||
|
获取注意力掩码
|
||||||
|
:param input_ids:
|
||||||
|
:param device:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids]
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
||||||
|
attention_mask.tril_()
|
||||||
|
for i, context_length in enumerate(context_lengths):
|
||||||
|
attention_mask[i, :, :context_length] = 1
|
||||||
|
attention_mask.unsqueeze_(1)
|
||||||
|
attention_mask = (attention_mask < 0.5).bool()
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
def generate_softprompt_history_tensors(self, query):
|
||||||
|
"""
|
||||||
|
历史对话软提示
|
||||||
|
这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history
|
||||||
|
数组转换为所需的文本格式。然后,我们将格式化后的历史文本
|
||||||
|
再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 对话内容
|
||||||
|
# 处理历史对话
|
||||||
|
formatted_history = self.history_to_text(query)
|
||||||
|
return formatted_history
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _history_len(self) -> int:
|
||||||
|
return self.history_len
|
||||||
|
|
||||||
|
def set_history_len(self, history_len: int = 10) -> None:
|
||||||
|
self.history_len = history_len
|
||||||
|
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
print(f"__call:{prompt}")
|
||||||
|
if self.logits_processor is None:
|
||||||
|
self.logits_processor = LogitsProcessorList()
|
||||||
|
self.logits_processor.append(InvalidScoreLogitsProcessor())
|
||||||
|
|
||||||
|
gen_kwargs = {
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"num_beams": self.num_beams,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"repetition_penalty": self.repetition_penalty,
|
||||||
|
"encoder_repetition_penalty": self.encoder_repetition_penalty,
|
||||||
|
"min_length": self.min_length,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"eos_token_id": self.eos_token_id,
|
||||||
|
"logits_processor": self.logits_processor}
|
||||||
|
|
||||||
|
# 向量拼接
|
||||||
|
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
|
||||||
|
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
|
||||||
|
|
||||||
|
# 对话模型prompt
|
||||||
|
gen_kwargs.update({'inputs': input_ids})
|
||||||
|
# 注意力掩码
|
||||||
|
# gen_kwargs.update({'attention_mask': attention_mask})
|
||||||
|
# gen_kwargs.update({'position_ids': position_ids})
|
||||||
|
if self.stopping_criteria is None:
|
||||||
|
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||||
|
# 观测输出
|
||||||
|
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
||||||
|
shared.stop_everything = False
|
||||||
|
stopped = False
|
||||||
|
response_template = _streaming_response_template()
|
||||||
|
|
||||||
|
# TODO 此流输出方法需要重写!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||||
|
# stopping_criteria方法不可控制 迭代器的变量无法共享
|
||||||
|
with self.generate_with_streaming(**gen_kwargs) as generator:
|
||||||
|
last_reply_len = 0
|
||||||
|
reply_index = 0
|
||||||
|
# Create a FixedLengthQueue with the desired stop sequence and a maximum length.
|
||||||
|
queue = FixedLengthQueue(stop)
|
||||||
|
for output in generator:
|
||||||
|
new_tokens = len(output) - len(input_ids[0])
|
||||||
|
reply = self.decode(output[-new_tokens:])
|
||||||
|
|
||||||
|
new_reply = len(reply) - last_reply_len
|
||||||
|
output_reply = reply[-new_reply:]
|
||||||
|
queue.add(reply_index, output_reply)
|
||||||
|
queue.contains_replace_sequence()
|
||||||
|
if stop:
|
||||||
|
pos = queue.contains_stop_sequence()
|
||||||
|
if pos != -1:
|
||||||
|
shared.stop_everything = True
|
||||||
|
stopped = True
|
||||||
|
|
||||||
|
#print(f"{reply_index}:reply {output_reply}")
|
||||||
|
english_reply = queue.put_replace_out(reply_index)
|
||||||
|
#print(f"{reply_index}:english_reply {english_reply}")
|
||||||
|
_update_response(response_template, english_reply)
|
||||||
|
last_reply_len = len(reply)
|
||||||
|
|
||||||
|
reply_index += 1
|
||||||
|
if new_tokens == self.max_new_tokens - 1 or stopped:
|
||||||
|
break
|
||||||
|
|
||||||
|
response = response_template['text']
|
||||||
|
print(f"response:{response}")
|
||||||
|
self.history = self.history + [[None, response]]
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _generate_answer(self, prompt: str,
|
||||||
|
history: List[List[str]] = [],
|
||||||
|
streaming: bool = False,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
if history:
|
||||||
|
self.history = history
|
||||||
|
# Create the StoppingCriteriaList with the stopping strings
|
||||||
|
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||||
|
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||||
|
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||||
|
self.stopping_criteria.append(listenerQueue)
|
||||||
|
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
|
||||||
|
softprompt = self.generate_softprompt_history_tensors(prompt)
|
||||||
|
response = self._call(prompt=softprompt, stop=['\n###'])
|
||||||
|
answer_result = AnswerResult()
|
||||||
|
answer_result.history = self.history
|
||||||
|
if listenerQueue.listenerQueue.__len__() > 0:
|
||||||
|
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||||
|
answer_result.llm_output = {"answer": response}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
2
models/loader/__init__.py
Normal file
2
models/loader/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
|
||||||
|
from .loader import *
|
||||||
59
models/loader/args.py
Normal file
59
models/loader/args.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Additional argparse types
|
||||||
|
def path(string):
|
||||||
|
if not string:
|
||||||
|
return ''
|
||||||
|
s = os.path.expanduser(string)
|
||||||
|
if not os.path.exists(s):
|
||||||
|
raise argparse.ArgumentTypeError(f'No such file or directory: "{string}"')
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def file_path(string):
|
||||||
|
if not string:
|
||||||
|
return ''
|
||||||
|
s = os.path.expanduser(string)
|
||||||
|
if not os.path.isfile(s):
|
||||||
|
raise argparse.ArgumentTypeError(f'No such file: "{string}"')
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def dir_path(string):
|
||||||
|
if not string:
|
||||||
|
return ''
|
||||||
|
s = os.path.expanduser(string)
|
||||||
|
if not os.path.isdir(s):
|
||||||
|
raise argparse.ArgumentTypeError(f'No such directory: "{string}"')
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(prog='langchina-ChatGLM',
|
||||||
|
description='基于langchain和chatGML的LLM文档阅读器')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument('--no-remote-model', action='store_true', default=False, help='remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model`')
|
||||||
|
parser.add_argument('--model', type=str, default='chatglm-6b', help='Name of the model to load by default.')
|
||||||
|
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||||
|
parser.add_argument("--model-dir", type=str, default='model/', help="Path to directory with all the models")
|
||||||
|
parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
|
||||||
|
|
||||||
|
# Accelerate/transformers
|
||||||
|
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||||
|
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||||
|
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
|
||||||
|
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
|
||||||
|
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||||
|
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args([])
|
||||||
|
# Generares dict with a default value for each argument
|
||||||
|
DEFAULT_ARGS = vars(args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
405
models/loader/loader.py
Normal file
405
models/loader/loader.py
Normal file
@ -0,0 +1,405 @@
|
|||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from peft import PeftModel
|
||||||
|
from typing import Optional, List, Dict, Tuple, Union
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||||
|
AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer)
|
||||||
|
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||||
|
from transformers.modeling_utils import no_init_weights
|
||||||
|
from transformers.utils import ContextManagers
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderCheckPoint:
|
||||||
|
"""
|
||||||
|
加载自定义 model CheckPoint
|
||||||
|
"""
|
||||||
|
# remote in the model on loader checkpoint
|
||||||
|
no_remote_model: bool = False
|
||||||
|
# 模型名称
|
||||||
|
model_name: str = None
|
||||||
|
tokenizer: object = None
|
||||||
|
# 模型全路径
|
||||||
|
model_path: str = None
|
||||||
|
model: object = None
|
||||||
|
model_config: object = None
|
||||||
|
lora_names: set = []
|
||||||
|
model_dir: str = None
|
||||||
|
lora_dir: str = None
|
||||||
|
ptuning_dir: str = None
|
||||||
|
use_ptuning_v2: bool = False
|
||||||
|
cpu: bool = False
|
||||||
|
gpu_memory: object = None
|
||||||
|
cpu_memory: object = None
|
||||||
|
auto_devices: object = True
|
||||||
|
# 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156
|
||||||
|
load_in_8bit: bool = False
|
||||||
|
is_llamacpp: bool = False
|
||||||
|
bf16: bool = False
|
||||||
|
params: object = None
|
||||||
|
# 自定义设备网络
|
||||||
|
device_map: Optional[Dict[str, int]] = None
|
||||||
|
# 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu
|
||||||
|
llm_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
|
||||||
|
def __init__(self, params: dict = None):
|
||||||
|
"""
|
||||||
|
模型初始化
|
||||||
|
:param params:
|
||||||
|
"""
|
||||||
|
self.model_path = None
|
||||||
|
self.params = params or {}
|
||||||
|
self.no_remote_model = params.get('no_remote_model', False)
|
||||||
|
self.model_name = params.get('model', '')
|
||||||
|
self.lora = params.get('lora', '')
|
||||||
|
self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
self.model_dir = params.get('model_dir', '')
|
||||||
|
self.lora_dir = params.get('lora_dir', '')
|
||||||
|
self.ptuning_dir = params.get('ptuning_dir', '')
|
||||||
|
self.cpu = params.get('cpu', False)
|
||||||
|
self.gpu_memory = params.get('gpu_memory', None)
|
||||||
|
self.cpu_memory = params.get('cpu_memory', None)
|
||||||
|
self.auto_devices = params.get('auto_devices', True)
|
||||||
|
self.load_in_8bit = params.get('load_in_8bit', False)
|
||||||
|
self.bf16 = params.get('bf16', False)
|
||||||
|
|
||||||
|
def _load_model_config(self, model_name):
|
||||||
|
checkpoint = Path(f'{self.model_dir}/{model_name}')
|
||||||
|
|
||||||
|
if self.model_path:
|
||||||
|
checkpoint = Path(f'{self.model_path}')
|
||||||
|
else:
|
||||||
|
if not self.no_remote_model:
|
||||||
|
checkpoint = model_name
|
||||||
|
|
||||||
|
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
||||||
|
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
def _load_model(self, model_name):
|
||||||
|
"""
|
||||||
|
加载自定义位置的model
|
||||||
|
:param model_name:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
print(f"Loading {model_name}...")
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
checkpoint = Path(f'{self.model_dir}/{model_name}')
|
||||||
|
|
||||||
|
self.is_llamacpp = len(list(checkpoint.glob('ggml*.bin'))) > 0
|
||||||
|
|
||||||
|
if self.model_path:
|
||||||
|
checkpoint = Path(f'{self.model_path}')
|
||||||
|
else:
|
||||||
|
if not self.no_remote_model:
|
||||||
|
checkpoint = model_name
|
||||||
|
|
||||||
|
if 'chatglm' in model_name.lower():
|
||||||
|
LoaderClass = AutoModel
|
||||||
|
else:
|
||||||
|
LoaderClass = AutoModelForCausalLM
|
||||||
|
|
||||||
|
# Load the model in simple 16-bit mode by default
|
||||||
|
if not any([self.cpu, self.load_in_8bit, self.auto_devices, self.gpu_memory is not None,
|
||||||
|
self.cpu_memory is not None, self.is_llamacpp]):
|
||||||
|
|
||||||
|
if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
|
||||||
|
# 根据当前设备GPU数量决定是否进行多卡部署
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
if num_gpus < 2 and self.device_map is None:
|
||||||
|
model = (
|
||||||
|
LoaderClass.from_pretrained(checkpoint,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
config=self.model_config,
|
||||||
|
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||||
|
trust_remote_code=True)
|
||||||
|
.half()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
|
||||||
|
model = LoaderClass.from_pretrained(checkpoint,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
config=self.model_config,
|
||||||
|
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||||
|
trust_remote_code=True).half()
|
||||||
|
# 可传入device_map自定义每张卡的部署情况
|
||||||
|
if self.device_map is None:
|
||||||
|
if 'chatglm' in model_name.lower():
|
||||||
|
device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||||
|
elif 'moss' in model_name.lower():
|
||||||
|
device_map = self.moss_auto_configure_device_map(num_gpus,model_name)
|
||||||
|
else:
|
||||||
|
device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||||
|
|
||||||
|
model = dispatch_model(model, device_map=device_map)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
||||||
|
model = (
|
||||||
|
AutoModel.from_pretrained(
|
||||||
|
checkpoint,
|
||||||
|
config=self.model_config,
|
||||||
|
trust_remote_code=True)
|
||||||
|
.float()
|
||||||
|
.to(self.llm_device)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.is_llamacpp:
|
||||||
|
from models.extensions.llamacpp_model_alternative import LlamaCppModel
|
||||||
|
|
||||||
|
model_file = list(checkpoint.glob('ggml*.bin'))[0]
|
||||||
|
print(f"llama.cpp weights detected: {model_file}\n")
|
||||||
|
|
||||||
|
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
# Custom
|
||||||
|
else:
|
||||||
|
params = {"low_cpu_mem_usage": True}
|
||||||
|
if not any((self.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||||
|
print(
|
||||||
|
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
||||||
|
self.cpu = True
|
||||||
|
|
||||||
|
if self.cpu:
|
||||||
|
params["torch_dtype"] = torch.float32
|
||||||
|
else:
|
||||||
|
params["device_map"] = 'auto'
|
||||||
|
params["trust_remote_code"] = True
|
||||||
|
if self.load_in_8bit and any((self.auto_devices, self.gpu_memory)):
|
||||||
|
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True,
|
||||||
|
llm_int8_enable_fp32_cpu_offload=True)
|
||||||
|
elif self.load_in_8bit:
|
||||||
|
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
elif self.bf16:
|
||||||
|
params["torch_dtype"] = torch.bfloat16
|
||||||
|
else:
|
||||||
|
params["torch_dtype"] = torch.float16
|
||||||
|
|
||||||
|
if self.gpu_memory:
|
||||||
|
memory_map = list(map(lambda x: x.strip(), self.gpu_memory))
|
||||||
|
max_cpu_memory = self.cpu_memory.strip() if self.cpu_memory is not None else '99GiB'
|
||||||
|
max_memory = {}
|
||||||
|
for i in range(len(memory_map)):
|
||||||
|
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else \
|
||||||
|
memory_map[i]
|
||||||
|
max_memory['cpu'] = max_cpu_memory
|
||||||
|
params['max_memory'] = max_memory
|
||||||
|
elif self.auto_devices:
|
||||||
|
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||||
|
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||||
|
if total_mem - suggestion < 800:
|
||||||
|
suggestion -= 1000
|
||||||
|
suggestion = int(round(suggestion / 1000))
|
||||||
|
print(
|
||||||
|
f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||||
|
|
||||||
|
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{self.cpu_memory or 99}GiB'}
|
||||||
|
params['max_memory'] = max_memory
|
||||||
|
|
||||||
|
if self.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||||
|
config = AutoConfig.from_pretrained(checkpoint)
|
||||||
|
with init_empty_weights():
|
||||||
|
model = LoaderClass.from_config(config)
|
||||||
|
model.tie_weights()
|
||||||
|
if self.device_map is not None:
|
||||||
|
params['device_map'] = self.device_map
|
||||||
|
else:
|
||||||
|
params['device_map'] = infer_auto_device_map(
|
||||||
|
model,
|
||||||
|
dtype=torch.int8,
|
||||||
|
max_memory=params['max_memory'],
|
||||||
|
no_split_module_classes=model._no_split_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
model = LoaderClass.from_pretrained(checkpoint, **params)
|
||||||
|
|
||||||
|
# Loading the tokenizer
|
||||||
|
if type(model) is transformers.LlamaForCausalLM:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(checkpoint, clean_up_tokenization_spaces=True)
|
||||||
|
# Leaving this here until the LLaMA tokenizer gets figured out.
|
||||||
|
# For some people this fixes things, for others it causes an error.
|
||||||
|
try:
|
||||||
|
tokenizer.eos_token_id = 2
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
tokenizer.pad_token_id = 0
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
||||||
|
|
||||||
|
print(f"Loaded the model in {(time.time() - t0):.2f} seconds.")
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def chatglm_auto_configure_device_map(self, num_gpus: int) -> Dict[str, int]:
|
||||||
|
# transformer.word_embeddings 占用1层
|
||||||
|
# transformer.final_layernorm 和 lm_head 占用1层
|
||||||
|
# transformer.layers 占用 28 层
|
||||||
|
# 总共30层分配到num_gpus张卡上
|
||||||
|
num_trans_layers = 28
|
||||||
|
per_gpu_layers = 30 / num_gpus
|
||||||
|
|
||||||
|
# bugfix: PEFT加载lora模型出现的层命名不同
|
||||||
|
if self.lora:
|
||||||
|
layer_prefix = 'base_model.model.transformer'
|
||||||
|
else:
|
||||||
|
layer_prefix = 'transformer'
|
||||||
|
|
||||||
|
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
||||||
|
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
||||||
|
# linux下 model.device 会被设置成 lm_head.device
|
||||||
|
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
||||||
|
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
||||||
|
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
||||||
|
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
||||||
|
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
||||||
|
f'base_model.model.lm_head': 0, }
|
||||||
|
|
||||||
|
used = 2
|
||||||
|
gpu_target = 0
|
||||||
|
for i in range(num_trans_layers):
|
||||||
|
if used >= per_gpu_layers:
|
||||||
|
gpu_target += 1
|
||||||
|
used = 0
|
||||||
|
assert gpu_target < num_gpus
|
||||||
|
device_map[f'{layer_prefix}.layers.{i}'] = gpu_target
|
||||||
|
used += 1
|
||||||
|
|
||||||
|
return device_map
|
||||||
|
|
||||||
|
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]:
|
||||||
|
checkpoint = Path(f'{self.model_dir}/{model_name}')
|
||||||
|
|
||||||
|
if self.model_path:
|
||||||
|
checkpoint = Path(f'{self.model_path}')
|
||||||
|
else:
|
||||||
|
if not self.no_remote_model:
|
||||||
|
checkpoint = model_name
|
||||||
|
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
||||||
|
pretrained_model_name_or_path=checkpoint)
|
||||||
|
|
||||||
|
with ContextManagers([no_init_weights(_enable=True), init_empty_weights()]):
|
||||||
|
model = cls(self.model_config)
|
||||||
|
max_memory = get_balanced_memory(model, dtype=torch.int8 if self.load_in_8bit else None,
|
||||||
|
low_zero=False, no_split_module_classes=model._no_split_modules)
|
||||||
|
device_map = infer_auto_device_map(
|
||||||
|
model, dtype=torch.float16 if not self.load_in_8bit else torch.int8, max_memory=max_memory,
|
||||||
|
no_split_module_classes=model._no_split_modules)
|
||||||
|
device_map["transformer.wte"] = 0
|
||||||
|
device_map["transformer.drop"] = 0
|
||||||
|
device_map["transformer.ln_f"] = 0
|
||||||
|
device_map["lm_head"] = 0
|
||||||
|
return device_map
|
||||||
|
|
||||||
|
def _add_lora_to_model(self, lora_names):
|
||||||
|
# 目前加载的lora
|
||||||
|
prior_set = set(self.lora_names)
|
||||||
|
# 需要加载的
|
||||||
|
added_set = set(lora_names) - prior_set
|
||||||
|
# 删除的lora
|
||||||
|
removed_set = prior_set - set(lora_names)
|
||||||
|
self.lora_names = list(lora_names)
|
||||||
|
|
||||||
|
# Nothing to do = skip.
|
||||||
|
if len(added_set) == 0 and len(removed_set) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only adding, and already peft? Do it the easy way.
|
||||||
|
if len(removed_set) == 0 and len(prior_set) > 0:
|
||||||
|
print(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||||
|
for lora in added_set:
|
||||||
|
self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If removing anything, disable all and re-add.
|
||||||
|
if len(removed_set) > 0:
|
||||||
|
self.model.disable_adapter()
|
||||||
|
|
||||||
|
if len(lora_names) > 0:
|
||||||
|
print("Applying the following LoRAs to {}: {}".format(self.model_name, ', '.join(lora_names)))
|
||||||
|
params = {}
|
||||||
|
if not self.cpu:
|
||||||
|
params['dtype'] = self.model.dtype
|
||||||
|
if hasattr(self.model, "hf_device_map"):
|
||||||
|
params['device_map'] = {"base_model.model." + k: v for k, v in self.model.hf_device_map.items()}
|
||||||
|
elif self.load_in_8bit:
|
||||||
|
params['device_map'] = {'': 0}
|
||||||
|
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||||
|
|
||||||
|
self.model = PeftModel.from_pretrained(self.model, Path(f"{self.lora_dir}/{lora_names[0]}"), **params)
|
||||||
|
|
||||||
|
for lora in lora_names[1:]:
|
||||||
|
self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
|
||||||
|
|
||||||
|
if not self.load_in_8bit and not self.cpu:
|
||||||
|
|
||||||
|
if not hasattr(self.model, "hf_device_map"):
|
||||||
|
if torch.has_mps:
|
||||||
|
device = torch.device('mps')
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
else:
|
||||||
|
self.model = self.model.cuda()
|
||||||
|
|
||||||
|
def clear_torch_cache(self):
|
||||||
|
gc.collect()
|
||||||
|
if not self.cpu:
|
||||||
|
device_id = "0" if torch.cuda.is_available() else None
|
||||||
|
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
|
||||||
|
with torch.cuda.device(CUDA_DEVICE):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
def unload_model(self):
|
||||||
|
del self.model
|
||||||
|
del self.tokenizer
|
||||||
|
self.model = self.tokenizer = None
|
||||||
|
self.clear_torch_cache()
|
||||||
|
|
||||||
|
def set_model_path(self, model_path):
|
||||||
|
self.model_path = model_path
|
||||||
|
|
||||||
|
def reload_model(self):
|
||||||
|
self.unload_model()
|
||||||
|
self.model_config = self._load_model_config(self.model_name)
|
||||||
|
|
||||||
|
if self.use_ptuning_v2:
|
||||||
|
try:
|
||||||
|
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
|
||||||
|
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||||
|
prefix_encoder_file.close()
|
||||||
|
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||||
|
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
||||||
|
except Exception:
|
||||||
|
print("加载PrefixEncoder config.json失败")
|
||||||
|
|
||||||
|
self.model, self.tokenizer = self._load_model(self.model_name)
|
||||||
|
|
||||||
|
if self.lora:
|
||||||
|
self._add_lora_to_model([self.lora])
|
||||||
|
|
||||||
|
if self.use_ptuning_v2:
|
||||||
|
try:
|
||||||
|
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
|
||||||
|
new_prefix_state_dict = {}
|
||||||
|
for k, v in prefix_state_dict.items():
|
||||||
|
if k.startswith("transformer.prefix_encoder."):
|
||||||
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||||
|
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
|
self.model.transformer.prefix_encoder.float()
|
||||||
|
except Exception:
|
||||||
|
print("加载PrefixEncoder模型参数失败")
|
||||||
|
|
||||||
|
self.model = self.model.eval()
|
||||||
@ -1,20 +1,14 @@
|
|||||||
import json
|
from abc import ABC
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from typing import List, Dict, Optional
|
from typing import Optional, List
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
from models.loader import LoaderCheckPoint
|
||||||
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
from models.base import (BaseAnswer,
|
||||||
from transformers.modeling_utils import no_init_weights
|
AnswerResult,
|
||||||
from transformers.utils import ContextManagers
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from configs.model_config import *
|
|
||||||
from utils import torch_gc
|
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
|
||||||
|
|
||||||
DEVICE_ = LLM_DEVICE
|
|
||||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
|
||||||
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
|
||||||
|
|
||||||
META_INSTRUCTION = \
|
META_INSTRUCTION = \
|
||||||
"""You are an AI assistant whose name is MOSS.
|
"""You are an AI assistant whose name is MOSS.
|
||||||
@ -30,45 +24,40 @@ META_INSTRUCTION = \
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def auto_configure_device_map() -> Dict[str, int]:
|
class MOSSLLM(BaseAnswer, LLM, ABC):
|
||||||
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
|
||||||
pretrained_model_name_or_path=llm_model_dict['moss'])
|
|
||||||
|
|
||||||
with ContextManagers([no_init_weights(_enable=True), init_empty_weights()]):
|
|
||||||
model_config = AutoConfig.from_pretrained(llm_model_dict['moss'], trust_remote_code=True)
|
|
||||||
model = cls(model_config)
|
|
||||||
max_memory = get_balanced_memory(model, dtype=torch.int8 if LOAD_IN_8BIT else None,
|
|
||||||
low_zero=False, no_split_module_classes=model._no_split_modules)
|
|
||||||
device_map = infer_auto_device_map(
|
|
||||||
model, dtype=torch.float16 if not LOAD_IN_8BIT else torch.int8, max_memory=max_memory,
|
|
||||||
no_split_module_classes=model._no_split_modules)
|
|
||||||
device_map["transformer.wte"] = 0
|
|
||||||
device_map["transformer.drop"] = 0
|
|
||||||
device_map["transformer.ln_f"] = 0
|
|
||||||
device_map["lm_head"] = 0
|
|
||||||
return device_map
|
|
||||||
|
|
||||||
|
|
||||||
class MOSS(LLM):
|
|
||||||
max_token: int = 2048
|
max_token: int = 2048
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
top_p = 0.8
|
top_p = 0.8
|
||||||
# history = []
|
# history = []
|
||||||
tokenizer: object = None
|
checkPoint: LoaderCheckPoint = None
|
||||||
model: object = None
|
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "MOSS"
|
return "MOSS"
|
||||||
|
|
||||||
def _call(self,
|
@property
|
||||||
prompt: str,
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
history: List[List[str]] = [],
|
return self.checkPoint
|
||||||
streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
|
|
||||||
|
@property
|
||||||
|
def set_history_len(self) -> int:
|
||||||
|
return self.history_len
|
||||||
|
|
||||||
|
def _set_history_len(self, history_len: int) -> None:
|
||||||
|
self.history_len = history_len
|
||||||
|
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _generate_answer(self, prompt: str,
|
||||||
|
history: List[List[str]] = [],
|
||||||
|
streaming: bool = False,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
if len(history) > 0:
|
if len(history) > 0:
|
||||||
history = history[-self.history_len:-1] if self.history_len > 0 else []
|
history = history[-self.history_len:-1] if self.history_len > 0 else []
|
||||||
prompt_w_history = str(history)
|
prompt_w_history = str(history)
|
||||||
@ -77,9 +66,9 @@ class MOSS(LLM):
|
|||||||
prompt_w_history = META_INSTRUCTION
|
prompt_w_history = META_INSTRUCTION
|
||||||
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
|
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
|
||||||
|
|
||||||
inputs = self.tokenizer(prompt_w_history, return_tensors="pt")
|
inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model.generate(
|
outputs = self.checkPoint.model.generate(
|
||||||
inputs.input_ids.cuda(),
|
inputs.input_ids.cuda(),
|
||||||
attention_mask=inputs.attention_mask.cuda(),
|
attention_mask=inputs.attention_mask.cuda(),
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
@ -92,78 +81,8 @@ class MOSS(LLM):
|
|||||||
eos_token_id=106068,
|
eos_token_id=106068,
|
||||||
pad_token_id=self.tokenizer.pad_token_id)
|
pad_token_id=self.tokenizer.pad_token_id)
|
||||||
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||||
torch_gc()
|
self.checkPoint.clear_torch_cache()
|
||||||
history += [[prompt, response]]
|
history += [[prompt, response]]
|
||||||
yield response, history
|
yield response, history
|
||||||
torch_gc()
|
|
||||||
|
|
||||||
def load_model(self,
|
|
||||||
model_name_or_path: str = "fnlp/moss-moon-003-sft",
|
|
||||||
llm_device=LLM_DEVICE,
|
|
||||||
use_ptuning_v2=False,
|
|
||||||
use_lora=False,
|
|
||||||
device_map: Optional[Dict[str, int]] = None,
|
|
||||||
**kwargs):
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_name_or_path,
|
|
||||||
trust_remote_code=True
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
|
||||||
|
|
||||||
if use_ptuning_v2:
|
|
||||||
try:
|
|
||||||
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
|
|
||||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
|
||||||
prefix_encoder_file.close()
|
|
||||||
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
|
||||||
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
print("加载PrefixEncoder config.json失败")
|
|
||||||
|
|
||||||
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
|
||||||
# accelerate自动多卡部署
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=model_config,
|
|
||||||
load_in_8bit=LOAD_IN_8BIT, trust_remote_code=True,
|
|
||||||
device_map=auto_configure_device_map(), **kwargs)
|
|
||||||
|
|
||||||
if LLM_LORA_PATH and use_lora:
|
|
||||||
from peft import PeftModel
|
|
||||||
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.model = self.model.float().to(llm_device)
|
|
||||||
if LLM_LORA_PATH and use_lora:
|
|
||||||
from peft import PeftModel
|
|
||||||
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
|
||||||
|
|
||||||
if use_ptuning_v2:
|
|
||||||
try:
|
|
||||||
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
|
|
||||||
new_prefix_state_dict = {}
|
|
||||||
for k, v in prefix_state_dict.items():
|
|
||||||
if k.startswith("transformer.prefix_encoder."):
|
|
||||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
|
||||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
|
||||||
self.model.transformer.prefix_encoder.float()
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
print("加载PrefixEncoder模型参数失败")
|
|
||||||
|
|
||||||
self.model = self.model.eval()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
llm = MOSS()
|
|
||||||
llm.load_model(model_name_or_path=llm_model_dict['moss'],
|
|
||||||
llm_device=LLM_DEVICE, )
|
|
||||||
last_print_len = 0
|
|
||||||
# for resp, history in llm._call("你好", streaming=True):
|
|
||||||
# print(resp[last_print_len:], end="", flush=True)
|
|
||||||
# last_print_len = len(resp)
|
|
||||||
for resp, history in llm._call("你好", streaming=False):
|
|
||||||
print(resp)
|
|
||||||
import time
|
|
||||||
time.sleep(10)
|
|
||||||
pass
|
|
||||||
|
|||||||
43
models/shared.py
Normal file
43
models/shared.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
from models.loader.args import parser
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
|
from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||||
|
from models.base import BaseAnswer
|
||||||
|
"""迭代器是否停止状态"""
|
||||||
|
stop_everything = False
|
||||||
|
|
||||||
|
loaderCheckPoint: LoaderCheckPoint = None
|
||||||
|
|
||||||
|
|
||||||
|
def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False) -> BaseAnswer:
|
||||||
|
"""
|
||||||
|
init llm_model_ins LLM
|
||||||
|
:param llm_model: model_name
|
||||||
|
:param no_remote_model: remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model
|
||||||
|
:param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pre_model_name = loaderCheckPoint.model_name
|
||||||
|
llm_model_info = llm_model_dict[pre_model_name]
|
||||||
|
|
||||||
|
if no_remote_model:
|
||||||
|
loaderCheckPoint.no_remote_model = no_remote_model
|
||||||
|
if use_ptuning_v2:
|
||||||
|
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
|
||||||
|
|
||||||
|
if llm_model:
|
||||||
|
llm_model_info = llm_model_dict[llm_model]
|
||||||
|
|
||||||
|
if loaderCheckPoint.no_remote_model:
|
||||||
|
loaderCheckPoint.model_name = llm_model_info['name']
|
||||||
|
else:
|
||||||
|
loaderCheckPoint.model_name = llm_model_info['remote-checkpoint']
|
||||||
|
|
||||||
|
loaderCheckPoint.model_path = llm_model_info['path']
|
||||||
|
|
||||||
|
loaderCheckPoint.reload_model()
|
||||||
|
|
||||||
|
provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
|
||||||
|
modelInsLLM = provides_class(checkPoint=loaderCheckPoint)
|
||||||
|
return modelInsLLM
|
||||||
@ -17,6 +17,8 @@ fastapi
|
|||||||
uvicorn
|
uvicorn
|
||||||
peft
|
peft
|
||||||
pypinyin
|
pypinyin
|
||||||
bitsandbytes
|
|
||||||
click~=8.1.3
|
click~=8.1.3
|
||||||
tabulate
|
tabulate
|
||||||
|
bitsandbytes; platform_system != "Windows"
|
||||||
|
llama-cpp-python==0.1.34; platform_system != "Windows"
|
||||||
|
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
|
|||||||
51
webui.py
51
webui.py
@ -1,9 +1,17 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from chains.local_doc_qa import LocalDocQA
|
from chains.local_doc_qa import LocalDocQA
|
||||||
from configs.model_config import *
|
from configs.model_config import *
|
||||||
import nltk
|
import nltk
|
||||||
|
from models.base import (BaseAnswer,
|
||||||
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
import models.shared as shared
|
||||||
|
from models.loader.args import parser
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
@ -69,7 +77,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||||||
yield history + [[query,
|
yield history + [[query,
|
||||||
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
||||||
else:
|
else:
|
||||||
for resp, history in local_doc_qa.llm._call(query, history, streaming=streaming):
|
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
||||||
|
streaming=streaming):
|
||||||
|
|
||||||
|
resp = answer_result.llm_output["answer"]
|
||||||
|
history = answer_result.history
|
||||||
history[-1][-1] = resp + (
|
history[-1][-1] = resp + (
|
||||||
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||||
yield history, ""
|
yield history, ""
|
||||||
@ -77,10 +89,12 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||||||
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
||||||
|
|
||||||
|
|
||||||
def init_model():
|
def init_model(llm_model: BaseAnswer = None):
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg()
|
local_doc_qa.init_cfg(llm_model=llm_model)
|
||||||
local_doc_qa.llm._call("你好")
|
generator = local_doc_qa.llm.generatorAnswer("你好")
|
||||||
|
for answer_result in generator:
|
||||||
|
print(answer_result.llm_output)
|
||||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
logger.info(reply)
|
logger.info(reply)
|
||||||
return reply
|
return reply
|
||||||
@ -95,14 +109,13 @@ def init_model():
|
|||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history):
|
def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model,
|
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
||||||
|
llm_model_ins.history_len = llm_history_len
|
||||||
|
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
llm_history_len=llm_history_len,
|
top_k=top_k)
|
||||||
use_ptuning_v2=use_ptuning_v2,
|
|
||||||
use_lora=use_lora,
|
|
||||||
top_k=top_k, )
|
|
||||||
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
logger.info(model_status)
|
logger.info(model_status)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -219,7 +232,17 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
|
|||||||
知识库暂不支持文件删除,该功能将在后续版本中推出。
|
知识库暂不支持文件删除,该功能将在后续版本中推出。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_status = init_model()
|
# 初始化消息
|
||||||
|
args = None
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args_dict = vars(args)
|
||||||
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
|
llm_model_ins = shared.loaderLLM()
|
||||||
|
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||||||
|
|
||||||
|
model_status = init_model(llm_model=llm_model_ins)
|
||||||
|
|
||||||
|
|
||||||
default_theme_args = dict(
|
default_theme_args = dict(
|
||||||
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
||||||
@ -399,6 +422,10 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||||||
label="LLM 模型",
|
label="LLM 模型",
|
||||||
value=LLM_MODEL,
|
value=LLM_MODEL,
|
||||||
interactive=True)
|
interactive=True)
|
||||||
|
no_remote_model = gr.Checkbox(shared.LoaderCheckPoint.no_remote_model,
|
||||||
|
label="加载本地模型",
|
||||||
|
interactive=True)
|
||||||
|
|
||||||
llm_history_len = gr.Slider(0, 10,
|
llm_history_len = gr.Slider(0, 10,
|
||||||
value=LLM_HISTORY_LEN,
|
value=LLM_HISTORY_LEN,
|
||||||
step=1,
|
step=1,
|
||||||
@ -418,7 +445,7 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||||||
label="向量匹配 top k", interactive=True)
|
label="向量匹配 top k", interactive=True)
|
||||||
load_model_button = gr.Button("重新加载模型")
|
load_model_button = gr.Button("重新加载模型")
|
||||||
load_model_button.click(reinit_model, show_progress=True,
|
load_model_button.click(reinit_model, show_progress=True,
|
||||||
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora,
|
inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora,
|
||||||
top_k, chatbot], outputs=chatbot)
|
top_k, chatbot], outputs=chatbot)
|
||||||
|
|
||||||
(demo
|
(demo
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user