mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-24 15:53:21 +08:00
llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例, 定义checkpoint名称和远程路径
loader.py: 模型重载 定义 generatorAnswer 增加 AnswerResultStream 定义generate_with_callback收集器,在每次响应时将队列数据同步到AnswerResult requirements.txt 变更项目依赖
This commit is contained in:
parent
c3924b2ece
commit
33bbb4779e
@ -12,10 +12,14 @@ from tqdm import tqdm
|
||||
from pypinyin import lazy_pinyin
|
||||
from loader import UnstructuredPaddleImageLoader
|
||||
from loader import UnstructuredPaddlePDFLoader
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
from models.loader.args import parser
|
||||
from models.loader import LoaderCheckPoint
|
||||
import models.shared as shared
|
||||
|
||||
DEVICE_ = EMBEDDING_DEVICE
|
||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
||||
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
||||
|
||||
|
||||
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||
@ -132,7 +136,7 @@ def similarity_search_with_score_by_vector(
|
||||
|
||||
|
||||
class LocalDocQA:
|
||||
llm: object = None
|
||||
llm: BaseAnswer = None
|
||||
embeddings: object = None
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
chunk_size: int = CHUNK_SIZE
|
||||
@ -142,23 +146,10 @@ class LocalDocQA:
|
||||
def init_cfg(self,
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
embedding_device=EMBEDDING_DEVICE,
|
||||
llm_history_len: int = LLM_HISTORY_LEN,
|
||||
llm_model: str = LLM_MODEL,
|
||||
llm_device=LLM_DEVICE,
|
||||
llm_model: BaseAnswer = None,
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
use_ptuning_v2: bool = USE_PTUNING_V2,
|
||||
use_lora: bool = USE_LORA,
|
||||
):
|
||||
if llm_model.startswith('moss'):
|
||||
from models.moss_llm import MOSS
|
||||
self.llm = MOSS()
|
||||
else:
|
||||
from models.chatglm_llm import ChatGLM
|
||||
self.llm = ChatGLM()
|
||||
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
||||
llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora)
|
||||
self.llm.history_len = llm_history_len
|
||||
|
||||
self.llm = llm_model
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||
model_kwargs={'device': embedding_device})
|
||||
self.top_k = top_k
|
||||
@ -259,16 +250,16 @@ class LocalDocQA:
|
||||
torch_gc()
|
||||
prompt = generate_prompt(related_docs_with_score, query)
|
||||
|
||||
for result, history in self.llm._call(prompt=prompt,
|
||||
history=chat_history,
|
||||
streaming=streaming):
|
||||
torch_gc()
|
||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||
streaming=streaming):
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
history[-1][0] = query
|
||||
response = {"query": query,
|
||||
"result": result,
|
||||
"result": resp,
|
||||
"source_documents": related_docs_with_score}
|
||||
yield response, history
|
||||
torch_gc()
|
||||
|
||||
|
||||
# query 查询内容
|
||||
# vs_path 知识库路径
|
||||
@ -297,10 +288,19 @@ class LocalDocQA:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 初始化消息
|
||||
args = None
|
||||
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model'])
|
||||
|
||||
args_dict = vars(args)
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
llm_model_ins = shared.loaderLLM()
|
||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||||
|
||||
local_doc_qa = LocalDocQA()
|
||||
local_doc_qa.init_cfg()
|
||||
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||
query = "本项目使用的embedding模型是什么,消耗多少显存"
|
||||
vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa"
|
||||
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test"
|
||||
last_print_len = 0
|
||||
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||
vs_path=vs_path,
|
||||
|
||||
@ -22,14 +22,54 @@ EMBEDDING_MODEL = "text2vec"
|
||||
# Embedding running device
|
||||
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
|
||||
# supported LLM models
|
||||
"""
|
||||
llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例
|
||||
"""
|
||||
llm_model_dict = {
|
||||
"chatyuan": "ClueAI/ChatYuan-large-v2",
|
||||
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
|
||||
"chatglm-6b-int4": "THUDM/chatglm-6b-int4",
|
||||
"chatglm-6b-int8": "THUDM/chatglm-6b-int8",
|
||||
"chatglm-6b": "THUDM/chatglm-6b",
|
||||
"moss": "fnlp/moss-moon-003-sft",
|
||||
"chatglm-6b-int4-qe": {
|
||||
"name": "chatglm-6b-int4-qe",
|
||||
"remote-checkpoint": "THUDM/chatglm-6b-int4-qe",
|
||||
"path": None,
|
||||
"provides": "ChatGLM"
|
||||
},
|
||||
"chatglm-6b-int4": {
|
||||
"name": "chatglm-6b-int4",
|
||||
"remote-checkpoint": "THUDM/chatglm-6b-int4",
|
||||
"path": None,
|
||||
"provides": "ChatGLM"
|
||||
},
|
||||
"chatglm-6b": {
|
||||
"name": "chatglm-6b",
|
||||
"remote-checkpoint": "THUDM/chatglm-6b-int4",
|
||||
"path": None,
|
||||
"provides": "ChatGLM"
|
||||
},
|
||||
"llama-7b-hf": {
|
||||
"name": "llama-7b-hf",
|
||||
"remote-checkpoint": "llama-7b-hf",
|
||||
"path": None,
|
||||
"provides": "LLamaLLM"
|
||||
},
|
||||
"vicuna-13b-hf": {
|
||||
"name": "vicuna-13b-hf",
|
||||
"remote-checkpoint": "vicuna-13b-hf",
|
||||
"path": None,
|
||||
"provides": "LLamaLLM"
|
||||
},
|
||||
"chatyuan": {
|
||||
"name": "chatyuan",
|
||||
"remote-checkpoint": "ClueAI/ChatYuan-large-v2",
|
||||
"path": None,
|
||||
"provides": None
|
||||
},
|
||||
"chatglm-6b-int8":{
|
||||
"name": "chatglm-6b-int8",
|
||||
"remote-checkpoint": "THUDM/chatglm-6b-int8",
|
||||
"path": None,
|
||||
"provides": "ChatGLM"
|
||||
},
|
||||
}
|
||||
|
||||
# LLM model name
|
||||
|
||||
0
fastchat/__init__.py
Normal file
0
fastchat/__init__.py
Normal file
1
fastchat/api/__init__.py
Normal file
1
fastchat/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .fastchat_api import *
|
||||
261
fastchat/api/conversation.py
Normal file
261
fastchat/api/conversation.py
Normal file
@ -0,0 +1,261 @@
|
||||
"""
|
||||
Conversation prompt template.
|
||||
|
||||
Now we support
|
||||
- Vicuna
|
||||
- Koala
|
||||
- OpenAssistant/oasst-sft-1-pythia-12b
|
||||
- StabilityAI/stablelm-tuned-alpha-7b
|
||||
- databricks/dolly-v2-12b
|
||||
- THUDM/chatglm-6b
|
||||
- Alpaca/LLaMa
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
DOLLY = auto()
|
||||
OASST_PYTHIA = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
|
||||
# Used for gradio server
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += self.sep + " " + role + ": " + message
|
||||
else:
|
||||
ret += self.sep + " " + role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ":\n" + message + seps[i % 2]
|
||||
if i % 2 == 1:
|
||||
ret += "\n\n"
|
||||
else:
|
||||
ret += role + ":\n"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + message + self.sep
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
conv_id=self.conv_id,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id,
|
||||
}
|
||||
|
||||
|
||||
conv_one_shot = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(
|
||||
(
|
||||
"Human",
|
||||
"What are the key differences between renewable and non-renewable energy sources?",
|
||||
),
|
||||
(
|
||||
"Assistant",
|
||||
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
||||
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
||||
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
||||
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
||||
"renewable and non-renewable energy sources:\n"
|
||||
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
||||
"energy sources are finite and will eventually run out.\n"
|
||||
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
||||
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
||||
"and other negative effects.\n"
|
||||
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
||||
"have lower operational costs than non-renewable sources.\n"
|
||||
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
||||
"locations than non-renewable sources.\n"
|
||||
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
||||
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
||||
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
||||
"non-renewable sources are not, and their depletion can lead to economic and social instability.",
|
||||
),
|
||||
),
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
|
||||
conv_vicuna_v1_1 = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
|
||||
conv_koala_v1 = Conversation(
|
||||
system="BEGINNING OF CONVERSATION:",
|
||||
roles=("USER", "GPT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
conv_dolly = Conversation(
|
||||
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
|
||||
roles=("### Instruction", "### Response"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.DOLLY,
|
||||
sep="\n\n",
|
||||
sep2="### End",
|
||||
)
|
||||
|
||||
conv_oasst = Conversation(
|
||||
system="",
|
||||
roles=("<|prompter|>", "<|assistant|>"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.OASST_PYTHIA,
|
||||
sep="<|endoftext|>",
|
||||
)
|
||||
|
||||
conv_stablelm = Conversation(
|
||||
system="""<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
""",
|
||||
roles=("<|USER|>", "<|ASSISTANT|>"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.OASST_PYTHIA,
|
||||
sep="",
|
||||
)
|
||||
|
||||
conv_templates = {
|
||||
"conv_one_shot": conv_one_shot,
|
||||
"vicuna_v1.1": conv_vicuna_v1_1,
|
||||
"koala_v1": conv_koala_v1,
|
||||
"dolly": conv_dolly,
|
||||
"oasst": conv_oasst,
|
||||
}
|
||||
|
||||
|
||||
def get_default_conv_template(model_name):
|
||||
model_name = model_name.lower()
|
||||
if "vicuna" in model_name or "output" in model_name:
|
||||
return conv_vicuna_v1_1
|
||||
elif "koala" in model_name:
|
||||
return conv_koala_v1
|
||||
elif "dolly-v2" in model_name:
|
||||
return conv_dolly
|
||||
elif "oasst" in model_name and "pythia" in model_name:
|
||||
return conv_oasst
|
||||
elif "stablelm" in model_name:
|
||||
return conv_stablelm
|
||||
return conv_one_shot
|
||||
|
||||
|
||||
def compute_skip_echo_len(model_name, conv, prompt):
|
||||
model_name = model_name.lower()
|
||||
if "chatglm" in model_name:
|
||||
skip_echo_len = len(conv.messages[-2][1]) + 1
|
||||
elif "dolly-v2" in model_name:
|
||||
special_toks = ["### Instruction:", "### Response:", "### End"]
|
||||
skip_echo_len = len(prompt)
|
||||
for tok in special_toks:
|
||||
skip_echo_len -= prompt.count(tok) * len(tok)
|
||||
elif "oasst" in model_name and "pythia" in model_name:
|
||||
special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
|
||||
skip_echo_len = len(prompt)
|
||||
for tok in special_toks:
|
||||
skip_echo_len -= prompt.count(tok) * len(tok)
|
||||
elif "stablelm" in model_name:
|
||||
special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
|
||||
skip_echo_len = len(prompt)
|
||||
for tok in special_toks:
|
||||
skip_echo_len -= prompt.count(tok) * len(tok)
|
||||
else:
|
||||
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
|
||||
return skip_echo_len
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(default_conversation.get_prompt())
|
||||
459
fastchat/api/fastchat_api.py
Normal file
459
fastchat/api/fastchat_api.py
Normal file
@ -0,0 +1,459 @@
|
||||
"""Wrapper around FastChat APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
FAST_CHAT_API = "http://localhost:21002/worker_generate_stream"
|
||||
|
||||
|
||||
def _streaming_response_template() -> Dict[str, Any]:
|
||||
"""
|
||||
:return: 响应结构
|
||||
"""
|
||||
return {
|
||||
"text": "",
|
||||
"error_code": 0,
|
||||
}
|
||||
|
||||
|
||||
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
||||
"""Update response from the stream response."""
|
||||
response["text"] += stream_response["text"]
|
||||
response["error_code"] += stream_response["error_code"]
|
||||
|
||||
|
||||
class BaseFastChat(BaseLLM):
|
||||
"""Wrapper around FastChat large language models."""
|
||||
|
||||
model_name: str = "text-davinci-003"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
max_new_tokens: int = 200
|
||||
stop: int = 20
|
||||
batch_size: int = 20
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Penalizes repeated tokens."""
|
||||
n: int = 1
|
||||
"""Whether to stream the results or not."""
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.ignore
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name not in all_required_field_names:
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transfered to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling FastChat API."""
|
||||
normal_params = {
|
||||
"model": self.model_name,
|
||||
"prompt": '',
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
return {**normal_params}
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Call out to FastChat's endpoint with k unique prompts.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The full LLM output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = fastchat.generate(["Tell me a joke."])
|
||||
"""
|
||||
# TODO: write a unit test for this
|
||||
params = self._invocation_params
|
||||
sub_prompts = self.get_sub_prompts(params, prompts)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
headers = {"User-Agent": "fastchat Client"}
|
||||
for _prompts in sub_prompts:
|
||||
|
||||
params["prompt"] = _prompts[0]
|
||||
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
|
||||
response_template = _streaming_response_template()
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
self.callback_manager.on_llm_new_token(
|
||||
output,
|
||||
verbose=self.verbose,
|
||||
logprobs=data["error_code"],
|
||||
)
|
||||
_update_response(response_template, data)
|
||||
choices.append(response_template)
|
||||
else:
|
||||
response_template = _streaming_response_template()
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
_update_response(response_template, data)
|
||||
|
||||
choices.append(response_template)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Call out to FastChat's endpoint async with k unique prompts."""
|
||||
params = self._invocation_params
|
||||
sub_prompts = self.get_sub_prompts(params, prompts)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
|
||||
headers = {"User-Agent": "fastchat Client"}
|
||||
for _prompts in sub_prompts:
|
||||
|
||||
params["prompt"] = _prompts[0]
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
|
||||
response_template = _streaming_response_template()
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
self.callback_manager.on_llm_new_token(
|
||||
output,
|
||||
verbose=self.verbose,
|
||||
logprobs=data["error_code"],
|
||||
)
|
||||
_update_response(response_template, data)
|
||||
choices.append(response_template)
|
||||
else:
|
||||
response_template = _streaming_response_template()
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
_update_response(response_template, data)
|
||||
|
||||
choices.append(response_template)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
def get_sub_prompts(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
) -> List[List[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
if params["max_new_tokens"] == -1:
|
||||
if len(prompts) != 1:
|
||||
raise ValueError(
|
||||
"max_new_tokens set to -1 not supported for multiple inputs."
|
||||
)
|
||||
params["max_new_tokens"] = self.max_new_tokens_for_prompt(prompts[0])
|
||||
# append pload
|
||||
sub_prompts = [
|
||||
prompts[i: i + self.batch_size]
|
||||
for i in range(0, len(prompts), self.batch_size)
|
||||
]
|
||||
|
||||
return sub_prompts
|
||||
|
||||
def create_llm_result(
|
||||
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
||||
) -> LLMResult:
|
||||
"""Create the LLMResult from the choices and prompts."""
|
||||
generations = []
|
||||
for i, _ in enumerate(prompts):
|
||||
sub_choices = choices[i * self.n: (i + 1) * self.n]
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=choice["text"],
|
||||
generation_info=dict(
|
||||
finish_reason='over',
|
||||
logprobs=choice["text"],
|
||||
),
|
||||
)
|
||||
for choice in sub_choices
|
||||
]
|
||||
)
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
||||
"""Call FastChat with streaming flag and return the resulting generator.
|
||||
|
||||
BETA: this is a beta feature while we figure out the right abstraction.
|
||||
Once that happens, this interface could change.
|
||||
|
||||
Args:
|
||||
prompt: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
A generator representing the stream of tokens from OpenAI.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
generator = fastChat.stream("Tell me a joke.")
|
||||
for token in generator:
|
||||
yield token
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["prompt"] = prompt
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
|
||||
headers = {"User-Agent": "fastchat Client"}
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
yield data
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fastChat"
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate num tokens with tiktoken package."""
|
||||
# tiktoken NOT supported for Python < 3.8
|
||||
if sys.version_info[1] < 8:
|
||||
return super().get_num_tokens(text)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.model_name)
|
||||
|
||||
tokenized_text = enc.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def modelname_to_contextsize(self, modelname: str) -> int:
|
||||
"""Calculate the maximum number of tokens possible to generate for a model.
|
||||
|
||||
Args:
|
||||
modelname: The modelname we want to know the context size for.
|
||||
|
||||
Returns:
|
||||
The maximum context size
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
max_new_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
||||
"""
|
||||
model_token_mapping = {
|
||||
"vicuna-13b": 2049,
|
||||
"koala": 2049,
|
||||
"dolly-v2": 2049,
|
||||
"oasst": 2049,
|
||||
"stablelm": 2049,
|
||||
}
|
||||
|
||||
context_size = model_token_mapping.get(modelname, None)
|
||||
|
||||
if context_size is None:
|
||||
raise ValueError(
|
||||
f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
|
||||
"Known models are: " + ", ".join(model_token_mapping.keys())
|
||||
)
|
||||
|
||||
return context_size
|
||||
|
||||
def max_new_tokens_for_prompt(self, prompt: str) -> int:
|
||||
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
The maximum number of tokens to generate for a prompt.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
max_new_tokens = openai.max_token_for_prompt("Tell me a joke.")
|
||||
"""
|
||||
num_tokens = self.get_num_tokens(prompt)
|
||||
|
||||
# get max context size for model by name
|
||||
max_size = self.modelname_to_contextsize(self.model_name)
|
||||
return max_size - num_tokens
|
||||
|
||||
|
||||
class FastChat(BaseFastChat):
|
||||
"""Wrapper around OpenAI large language models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``OPENAI_API_KEY`` set with your API key.
|
||||
|
||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OpenAI
|
||||
openai = FastChat(model_name="vicuna")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**{"model": self.model_name}, **super()._invocation_params}
|
||||
@ -1,2 +1,4 @@
|
||||
|
||||
from .chatglm_llm import *
|
||||
from .chatglm_llm import ChatGLM
|
||||
from .llama_llm import LLamaLLM
|
||||
from .moss_llm import MOSSLLM
|
||||
|
||||
97
models/__main__.py
Normal file
97
models/__main__.py
Normal file
@ -0,0 +1,97 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
|
||||
import asyncio
|
||||
from argparse import Namespace
|
||||
from models.loader.args import parser
|
||||
from models.loader import LoaderCheckPoint
|
||||
|
||||
from langchain.agents import initialize_agent, Tool
|
||||
from langchain.agents import AgentType
|
||||
|
||||
import models.shared as shared
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
|
||||
from typing import List, Set
|
||||
|
||||
|
||||
|
||||
class CustomLLMSingleActionAgent(ZeroShotAgent):
|
||||
allowed_tools: List[str]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CustomLLMSingleActionAgent, self).__init__(*args, **kwargs)
|
||||
self.allowed_tools = kwargs['allowed_tools']
|
||||
|
||||
def get_allowed_tools(self) -> Set[str]:
|
||||
return set(self.allowed_tools)
|
||||
|
||||
|
||||
async def dispatch(args: Namespace):
|
||||
args_dict = vars(args)
|
||||
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
llm_model_ins = shared.loaderLLM()
|
||||
|
||||
template = """This is a conversation between a human and a bot:
|
||||
|
||||
{chat_history}
|
||||
|
||||
Write a summary of the conversation for {input}:
|
||||
"""
|
||||
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["input", "chat_history"],
|
||||
template=template
|
||||
)
|
||||
memory = ConversationBufferMemory(memory_key="chat_history")
|
||||
readonlymemory = ReadOnlySharedMemory(memory=memory)
|
||||
summry_chain = LLMChain(
|
||||
llm=llm_model_ins,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
|
||||
)
|
||||
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="Summary",
|
||||
func=summry_chain.run,
|
||||
description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary."
|
||||
)
|
||||
]
|
||||
|
||||
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
|
||||
suffix = """Begin!
|
||||
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
|
||||
|
||||
prompt = CustomLLMSingleActionAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=["input", "agent_scratchpad"]
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
llm_chain = LLMChain(llm=llm_model_ins, prompt=prompt)
|
||||
agent = CustomLLMSingleActionAgent(llm_chain=llm_chain, tools=tools, allowed_tools=tool_names)
|
||||
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools)
|
||||
|
||||
agent_chain.run(input="你好")
|
||||
agent_chain.run(input="你是谁?")
|
||||
agent_chain.run(input="我们之前聊了什么?")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = None
|
||||
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'vicuna-13b-hf', '--no-remote-model', '--load-in-8bit'])
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(dispatch(args))
|
||||
198
models/base.py
Normal file
198
models/base.py
Normal file
@ -0,0 +1,198 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List
|
||||
import traceback
|
||||
from collections import deque
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from models.loader import LoaderCheckPoint
|
||||
|
||||
|
||||
class ListenerToken:
|
||||
"""
|
||||
观测结果
|
||||
"""
|
||||
|
||||
input_ids: torch.LongTensor
|
||||
_scores: torch.FloatTensor
|
||||
|
||||
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
|
||||
self.input_ids = input_ids
|
||||
self._scores = _scores
|
||||
|
||||
|
||||
class AnswerResult:
|
||||
"""
|
||||
消息实体
|
||||
"""
|
||||
history: List[List[str]] = []
|
||||
llm_output: Optional[dict] = None
|
||||
listenerToken: ListenerToken = None
|
||||
|
||||
|
||||
class AnswerResultStream:
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
||||
def __call__(self, answerResult: AnswerResult):
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(answerResult)
|
||||
|
||||
|
||||
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
|
||||
"""
|
||||
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
|
||||
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
|
||||
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
|
||||
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
|
||||
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
|
||||
"""
|
||||
|
||||
listenerQueue: deque = deque(maxlen=1)
|
||||
|
||||
def __init__(self):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
"""
|
||||
每次响应时将数据添加到响应队列
|
||||
:param input_ids:
|
||||
:param _scores:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
|
||||
return False
|
||||
|
||||
|
||||
class Iteratorize:
|
||||
"""
|
||||
Transforms a function that takes a callback
|
||||
into a lazy iterator (generator).
|
||||
"""
|
||||
|
||||
def __init__(self, func, kwargs={}):
|
||||
self.mfunc = func
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
self.stop_now = False
|
||||
|
||||
def _callback(val):
|
||||
"""
|
||||
模型输出预测结果收集
|
||||
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
|
||||
结束条件包含如下
|
||||
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
|
||||
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
|
||||
3、模型预测出错
|
||||
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
|
||||
迭代器收集的行为如下
|
||||
创建Iteratorize迭代对象,
|
||||
定义generate_with_callback收集器AnswerResultStream
|
||||
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
|
||||
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
|
||||
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
|
||||
这时generate_with_callback会被阻塞
|
||||
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
|
||||
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
|
||||
2、消息为self.sentinel标识,抛出StopIteration异常
|
||||
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
|
||||
异步线程检测stop_now属性被更新,抛出异常结束预测行为
|
||||
迭代行为结束
|
||||
:param val:
|
||||
:return:
|
||||
"""
|
||||
if self.stop_now:
|
||||
raise ValueError
|
||||
self.q.put(val)
|
||||
|
||||
def gen():
|
||||
try:
|
||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||
except ValueError:
|
||||
pass
|
||||
except:
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
self.q.put(self.sentinel)
|
||||
|
||||
self.thread = Thread(target=gen)
|
||||
self.thread.start()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True, None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
return obj
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
暂无实现
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
""" break 后会执行 """
|
||||
self.stop_now = True
|
||||
|
||||
|
||||
class BaseAnswer(ABC):
|
||||
"""上层业务包装器.用于结果生成统一api调用"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
"""Return _check_point of llm."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _history_len(self) -> int:
|
||||
"""Return _history_len of llm."""
|
||||
|
||||
@abstractmethod
|
||||
def set_history_len(self, history_len: int) -> None:
|
||||
"""Return _history_len of llm."""
|
||||
|
||||
def generatorAnswer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False):
|
||||
def generate_with_callback(callback=None, **kwargs):
|
||||
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
|
||||
self._generate_answer(**kwargs)
|
||||
|
||||
def generate_with_streaming(**kwargs):
|
||||
return Iteratorize(generate_with_callback, kwargs)
|
||||
|
||||
"""
|
||||
eos_token_id是指定token(例如,"</s>"),
|
||||
用于表示序列的结束。在生成文本任务中,生成器在生成序列时,将不断地生成token,直到生成此特殊的eos_token_id,表示序列生成已经完成。
|
||||
在Hugging Face Transformer模型中,eos_token_id是由tokenizer自动添加到输入中的。
|
||||
在模型生成输出时,如果模型生成了eos_token_id,则生成过程将停止并返回生成的序列。
|
||||
"""
|
||||
eos_token_ids = [
|
||||
self._check_point.tokenizer.eos_token_id] if self._check_point.tokenizer.eos_token_id is not None else []
|
||||
|
||||
with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator:
|
||||
for answerResult in generator:
|
||||
if answerResult.listenerToken:
|
||||
output = answerResult.listenerToken.input_ids
|
||||
yield answerResult
|
||||
|
||||
@abstractmethod
|
||||
def _generate_answer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
pass
|
||||
@ -1,189 +1,96 @@
|
||||
import json
|
||||
|
||||
from abc import ABC
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from typing import List, Dict, Optional
|
||||
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
||||
import torch
|
||||
from configs.model_config import *
|
||||
from utils import torch_gc
|
||||
|
||||
DEVICE_ = LLM_DEVICE
|
||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
||||
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
||||
from typing import Optional, List
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
|
||||
|
||||
def auto_configure_device_map(num_gpus: int, use_lora: bool) -> Dict[str, int]:
|
||||
# transformer.word_embeddings 占用1层
|
||||
# transformer.final_layernorm 和 lm_head 占用1层
|
||||
# transformer.layers 占用 28 层
|
||||
# 总共30层分配到num_gpus张卡上
|
||||
num_trans_layers = 28
|
||||
per_gpu_layers = 30 / num_gpus
|
||||
|
||||
# bugfix: PEFT加载lora模型出现的层命名不同
|
||||
if LLM_LORA_PATH and use_lora:
|
||||
layer_prefix = 'base_model.model.transformer'
|
||||
else:
|
||||
layer_prefix = 'transformer'
|
||||
|
||||
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
||||
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
||||
# linux下 model.device 会被设置成 lm_head.device
|
||||
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
||||
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
||||
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
||||
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
||||
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
||||
f'base_model.model.lm_head': 0, }
|
||||
|
||||
used = 2
|
||||
gpu_target = 0
|
||||
for i in range(num_trans_layers):
|
||||
if used >= per_gpu_layers:
|
||||
gpu_target += 1
|
||||
used = 0
|
||||
assert gpu_target < num_gpus
|
||||
device_map[f'{layer_prefix}.layers.{i}'] = gpu_target
|
||||
used += 1
|
||||
|
||||
return device_map
|
||||
import transformers
|
||||
|
||||
|
||||
class ChatGLM(LLM):
|
||||
class ChatGLM(BaseAnswer, LLM, ABC):
|
||||
max_token: int = 10000
|
||||
temperature: float = 0.8
|
||||
temperature: float = 0.01
|
||||
top_p = 0.9
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
# history = []
|
||||
tokenizer: object = None
|
||||
model: object = None
|
||||
history_len: int = 10
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "ChatGLM"
|
||||
|
||||
def _call(self,
|
||||
prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
return self.checkPoint
|
||||
|
||||
@property
|
||||
def _history_len(self) -> int:
|
||||
return self.history_len
|
||||
|
||||
def set_history_len(self, history_len: int = 10) -> None:
|
||||
self.history_len = history_len
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def _generate_answer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
# Create the StoppingCriteriaList with the stopping strings
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||
stopping_criteria_list.append(listenerQueue)
|
||||
|
||||
if streaming:
|
||||
for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
|
||||
self.tokenizer,
|
||||
|
||||
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
|
||||
self.checkPoint.tokenizer,
|
||||
prompt,
|
||||
history=history[-self.history_len:-1] if self.history_len > 0 else [],
|
||||
max_length=self.max_token,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
stopping_criteria=stopping_criteria_list
|
||||
)):
|
||||
torch_gc()
|
||||
self.checkPoint.clear_torch_cache()
|
||||
if inum == 0:
|
||||
history += [[prompt, stream_resp]]
|
||||
else:
|
||||
history[-1] = [prompt, stream_resp]
|
||||
yield stream_resp, history
|
||||
torch_gc()
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": stream_resp}
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
generate_with_callback(answer_result)
|
||||
else:
|
||||
response, _ = self.model.chat(
|
||||
self.tokenizer,
|
||||
response, _ = self.checkPoint.model.chat(
|
||||
self.checkPoint.tokenizer,
|
||||
prompt,
|
||||
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||
max_length=self.max_token,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
stopping_criteria=stopping_criteria_list
|
||||
)
|
||||
torch_gc()
|
||||
self.checkPoint.clear_torch_cache()
|
||||
history += [[prompt, response]]
|
||||
yield response, history
|
||||
torch_gc()
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": response}
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
|
||||
# def chat(self,
|
||||
# prompt: str) -> str:
|
||||
# response, _ = self.model.chat(
|
||||
# self.tokenizer,
|
||||
# prompt,
|
||||
# history=self.history[-self.history_len:] if self.history_len > 0 else [],
|
||||
# max_length=self.max_token,
|
||||
# temperature=self.temperature,
|
||||
# )
|
||||
# torch_gc()
|
||||
# self.history = self.history + [[None, response]]
|
||||
# return response
|
||||
|
||||
def load_model(self,
|
||||
model_name_or_path: str = "THUDM/chatglm-6b",
|
||||
llm_device=LLM_DEVICE,
|
||||
use_ptuning_v2=False,
|
||||
use_lora=False,
|
||||
device_map: Optional[Dict[str, int]] = None,
|
||||
**kwargs):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
|
||||
if use_ptuning_v2:
|
||||
try:
|
||||
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
|
||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||
prefix_encoder_file.close()
|
||||
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
||||
except Exception as e:
|
||||
logger.error(f"加载PrefixEncoder config.json失败: {e}")
|
||||
self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True,
|
||||
**kwargs)
|
||||
if LLM_LORA_PATH and use_lora:
|
||||
from peft import PeftModel
|
||||
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
||||
|
||||
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
||||
# 根据当前设备GPU数量决定是否进行多卡部署
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus < 2 and device_map is None:
|
||||
self.model = self.model.half().cuda()
|
||||
else:
|
||||
from accelerate import dispatch_model
|
||||
|
||||
# model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
|
||||
# config=model_config, **kwargs)
|
||||
if LLM_LORA_PATH and use_lora:
|
||||
from peft import PeftModel
|
||||
model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
||||
# 可传入device_map自定义每张卡的部署情况
|
||||
if device_map is None:
|
||||
device_map = auto_configure_device_map(num_gpus, use_lora)
|
||||
|
||||
self.model = dispatch_model(self.model.half(), device_map=device_map)
|
||||
else:
|
||||
self.model = self.model.float().to(llm_device)
|
||||
|
||||
if use_ptuning_v2:
|
||||
try:
|
||||
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
|
||||
new_prefix_state_dict = {}
|
||||
for k, v in prefix_state_dict.items():
|
||||
if k.startswith("transformer.prefix_encoder."):
|
||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
self.model.transformer.prefix_encoder.float()
|
||||
except Exception as e:
|
||||
logger.error(f"加载PrefixEncoder模型参数失败:{e}")
|
||||
|
||||
self.model = self.model.eval()
|
||||
generate_with_callback(answer_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
llm = ChatGLM()
|
||||
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
|
||||
llm_device=LLM_DEVICE, )
|
||||
last_print_len = 0
|
||||
for resp, history in llm._call("你好", streaming=True):
|
||||
logger.info(resp[last_print_len:], end="", flush=True)
|
||||
last_print_len = len(resp)
|
||||
for resp, history in llm._call("你好", streaming=False):
|
||||
logger.info(resp)
|
||||
pass
|
||||
|
||||
222
models/extensions/callback.py
Normal file
222
models/extensions/callback.py
Normal file
@ -0,0 +1,222 @@
|
||||
import gc
|
||||
import traceback
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
import threading
|
||||
from typing import Optional, List, Dict, Any
|
||||
from collections import deque
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from models.extensions.thread_with_exception import ThreadWithException
|
||||
import models.shared as shared
|
||||
|
||||
|
||||
class LimitedLengthDict(dict):
|
||||
def __init__(self, maxlen=None, *args, **kwargs):
|
||||
self.maxlen = maxlen
|
||||
self._keys = deque()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key not in self:
|
||||
if self.maxlen is not None and len(self) >= self.maxlen:
|
||||
oldest_key = self._keys.popleft()
|
||||
if oldest_key in self:
|
||||
del self[oldest_key]
|
||||
self._keys.append(key)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class FixedLengthQueue:
|
||||
|
||||
# 停止符号列表
|
||||
stop_sequence: Optional[str] = []
|
||||
# 缓冲区
|
||||
max_length: int = 0
|
||||
# 缓冲区容器
|
||||
queue: deque = None
|
||||
# 输入区容器
|
||||
queue_in: LimitedLengthDict[int, str] = {}
|
||||
# 输出区容器
|
||||
queue_out: Dict[int, str] = {}
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# 创建新的实例
|
||||
instance = super().__new__(cls)
|
||||
# 在这里可以对实例进行额外的设置
|
||||
return instance
|
||||
|
||||
def __init__(self, stop_sequence):
|
||||
if stop_sequence is None:
|
||||
self.stop_sequence = []
|
||||
self.max_length = 0
|
||||
elif isinstance(stop_sequence, str):
|
||||
self.stop_sequence = [stop_sequence]
|
||||
self.max_length = 1
|
||||
else:
|
||||
self.stop_sequence = stop_sequence
|
||||
self.max_length = len(''.join(stop_sequence))
|
||||
|
||||
self.queue = deque(maxlen=self.max_length)
|
||||
self.queue.clear()
|
||||
self.queue_in.clear()
|
||||
self.queue_out.clear()
|
||||
|
||||
def add(self, index, item):
|
||||
self.queue_in[index] = item
|
||||
|
||||
def _add_out(self, index, item):
|
||||
self.queue_out[index] = item
|
||||
|
||||
def put_replace_out(self, index):
|
||||
return self.queue_out[index]
|
||||
|
||||
def contains_replace_sequence(self):
|
||||
"""
|
||||
替换字符
|
||||
:return:
|
||||
"""
|
||||
|
||||
for key, value in self.queue_in.items():
|
||||
|
||||
word_index = value.rfind(":")
|
||||
if word_index != -1:
|
||||
value = value.replace(":", ":")
|
||||
|
||||
word_index = value.rfind("[")
|
||||
if word_index != -1:
|
||||
value = value.replace("[", "")
|
||||
|
||||
word_index = value.rfind("]")
|
||||
if word_index != -1:
|
||||
value = value.replace("]", "")
|
||||
|
||||
self._add_out(key, value)
|
||||
|
||||
def contains_stop_sequence(self):
|
||||
# 截取固定大小的数据判断
|
||||
self.queue.clear()
|
||||
last_three_keys = list(self.queue_out.keys())[-self.max_length:]
|
||||
joined_queue = ''.join([self.queue_out[key] for key in last_three_keys])
|
||||
for char in joined_queue:
|
||||
self.queue.append(char)
|
||||
|
||||
joined_queue = ''.join(self.queue)
|
||||
# Initialize a variable to store the index of the last found stop string
|
||||
last_stop_str_index = -1
|
||||
|
||||
# Iterate through the stop string list
|
||||
for stop_word in self.stop_sequence:
|
||||
# Find the last occurrence of the stop string in the output
|
||||
stop_word_index = joined_queue.rfind(stop_word)
|
||||
|
||||
# If the stop string is found, compare the index with the previously found index
|
||||
if stop_word_index != -1 and stop_word_index > last_stop_str_index:
|
||||
last_stop_str_index = stop_word_index
|
||||
|
||||
# Handle the last found stop string index here
|
||||
return last_stop_str_index
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.queue)
|
||||
|
||||
|
||||
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
|
||||
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
self.sentinel_token_ids = sentinel_token_ids
|
||||
self.starting_idx = starting_idx
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||
for sample in input_ids:
|
||||
trimmed_sample = sample[self.starting_idx:]
|
||||
|
||||
for i in range(len(self.sentinel_token_ids)):
|
||||
# Can't unfold, output is still too tiny. Skip.
|
||||
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
|
||||
continue
|
||||
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
|
||||
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
if shared.stop_everything:
|
||||
raise ValueError
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(input_ids[0])
|
||||
return False
|
||||
|
||||
|
||||
class Iteratorize:
|
||||
"""
|
||||
Transforms a function that takes a callback
|
||||
into a lazy iterator (generator).
|
||||
"""
|
||||
|
||||
thread: ThreadWithException = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# 创建新的实例
|
||||
instance = super().__new__(cls)
|
||||
# 在这里可以对实例进行额外的设置
|
||||
return instance
|
||||
|
||||
|
||||
def __init__(self, func, kwargs={}, callback=None):
|
||||
self.mfunc = func
|
||||
self.c_callback = callback
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _callback(val):
|
||||
if shared.stop_everything:
|
||||
raise ValueError
|
||||
self.q.put(val)
|
||||
|
||||
def gen():
|
||||
try:
|
||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||
except ValueError:
|
||||
print("print(ValueError)")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print("traceback.print_exc()")
|
||||
self.q.put(self.sentinel)
|
||||
|
||||
self.thread = ThreadWithException(target=gen)
|
||||
self.thread.start()
|
||||
|
||||
def __iter__(self):
|
||||
shared.stop_everything = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True, None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
return obj
|
||||
|
||||
def __del__(self):
|
||||
shared.stop_everything = False
|
||||
self.q.empty()
|
||||
shared.loaderCheckPoint.clear_torch_cache()
|
||||
|
||||
def __enter__(self):
|
||||
shared.stop_everything = False
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
shared.stop_everything = True
|
||||
shared.loaderCheckPoint.clear_torch_cache()
|
||||
self.thread.raise_exception()
|
||||
10
models/extensions/extensions.py
Normal file
10
models/extensions/extensions.py
Normal file
@ -0,0 +1,10 @@
|
||||
import gc
|
||||
import traceback
|
||||
import torch
|
||||
|
||||
# This iterator returns the extensions in the order specified in the command-line
|
||||
def iterator():
|
||||
state_extensions = {}
|
||||
for name in sorted(state_extensions, key=lambda x: state_extensions[x][1]):
|
||||
if state_extensions[name][0]:
|
||||
yield getattr(extensions, name).script, name
|
||||
64
models/extensions/llamacpp_model_alternative.py
Normal file
64
models/extensions/llamacpp_model_alternative.py
Normal file
@ -0,0 +1,64 @@
|
||||
'''
|
||||
Based on
|
||||
https://github.com/abetlen/llama-cpp-python
|
||||
|
||||
Documentation:
|
||||
https://abetlen.github.io/llama-cpp-python/
|
||||
'''
|
||||
|
||||
from llama_cpp import Llama, LlamaCache
|
||||
|
||||
from modules import shared
|
||||
from modules.callbacks import Iteratorize
|
||||
|
||||
|
||||
class LlamaCppModel:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, path):
|
||||
result = self()
|
||||
|
||||
params = {
|
||||
'model_path': str(path),
|
||||
'n_ctx': 2048,
|
||||
'seed': 0,
|
||||
'n_threads': shared.args.threads or None
|
||||
}
|
||||
self.model = Llama(**params)
|
||||
self.model.set_cache(LlamaCache)
|
||||
|
||||
# This is ugly, but the model and the tokenizer are the same object in this library.
|
||||
return result, result
|
||||
|
||||
def encode(self, string):
|
||||
if type(string) is str:
|
||||
string = string.encode()
|
||||
return self.model.tokenize(string)
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
|
||||
if type(context) is str:
|
||||
context = context.encode()
|
||||
tokens = self.model.tokenize(context)
|
||||
|
||||
output = b""
|
||||
count = 0
|
||||
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
|
||||
text = self.model.detokenize([token])
|
||||
output += text
|
||||
if callback:
|
||||
callback(text.decode())
|
||||
|
||||
count += 1
|
||||
if count >= token_count or (token == self.model.token_eos()):
|
||||
break
|
||||
|
||||
return output.decode()
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||
reply = ''
|
||||
for token in generator:
|
||||
reply += token
|
||||
yield reply
|
||||
30
models/extensions/thread_with_exception.py
Normal file
30
models/extensions/thread_with_exception.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Python program raising
|
||||
# exceptions in a python
|
||||
# thread
|
||||
|
||||
import threading
|
||||
import ctypes
|
||||
import time
|
||||
|
||||
|
||||
class ThreadWithException(threading.Thread):
|
||||
|
||||
def get_id(self):
|
||||
return self.ident
|
||||
|
||||
def raise_exception(self):
|
||||
"""raises the exception, performs cleanup if needed"""
|
||||
try:
|
||||
thread_id = self.get_id()
|
||||
tid = ctypes.c_long(thread_id)
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(SystemExit))
|
||||
if res == 0:
|
||||
# pass
|
||||
raise ValueError("invalid thread id")
|
||||
elif res != 1:
|
||||
# """if it returns a number greater than one, you're in trouble,
|
||||
# and you should call it again with exc=NULL to revert the effect"""
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
|
||||
raise SystemError("PyThreadState_SetAsyncExc failed")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
297
models/llama_llm.py
Normal file
297
models/llama_llm.py
Normal file
@ -0,0 +1,297 @@
|
||||
from abc import ABC
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
import random
|
||||
import torch
|
||||
import transformers
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||
from typing import Optional, List, Dict, Any
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.extensions.callback import (Iteratorize, Stream, FixedLengthQueue)
|
||||
import models.shared as shared
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
|
||||
|
||||
def _streaming_response_template() -> Dict[str, Any]:
|
||||
"""
|
||||
:return: 响应结构
|
||||
"""
|
||||
return {
|
||||
"text": ""
|
||||
}
|
||||
|
||||
|
||||
def _update_response(response: Dict[str, Any], stream_response: str) -> None:
|
||||
"""Update response from the stream response."""
|
||||
response["text"] += stream_response
|
||||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 5] = 5e4
|
||||
return scores
|
||||
|
||||
|
||||
class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
history = []
|
||||
history_len: int = 3
|
||||
max_new_tokens: int = 500
|
||||
num_beams: int = 1
|
||||
temperature: float = 0.5
|
||||
top_p: float = 0.4
|
||||
top_k: int = 10
|
||||
repetition_penalty: float = 1.2
|
||||
encoder_repetition_penalty: int = 1
|
||||
min_length: int = 0
|
||||
logits_processor: LogitsProcessorList = None
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None
|
||||
eos_token_id: Optional[int] = [2]
|
||||
|
||||
state: object = {'max_new_tokens': 50,
|
||||
'seed': 1,
|
||||
'temperature': 0, 'top_p': 0.1,
|
||||
'top_k': 40, 'typical_p': 1,
|
||||
'repetition_penalty': 1.2,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'min_length': 0,
|
||||
'penalty_alpha': 0,
|
||||
'num_beams': 1,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
|
||||
'truncation_length': 2048, 'custom_stopping_strings': '',
|
||||
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
|
||||
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
|
||||
'pre_layer': 0, 'gpu_memory_0': 0}
|
||||
|
||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "LLamaLLM"
|
||||
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
return self.checkPoint
|
||||
|
||||
def encode(self, prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||
input_ids = self.checkPoint.tokenizer.encode(str(prompt), return_tensors='pt',
|
||||
add_special_tokens=add_special_tokens)
|
||||
# This is a hack for making replies more creative.
|
||||
if not add_bos_token and input_ids[0][0] == self.checkPoint.tokenizer.bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
# Llama adds this extra token when the first character is '\n', and this
|
||||
# compromises the stopping criteria, so we just remove it
|
||||
if type(self.checkPoint.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
# Handling truncation
|
||||
if truncation_length is not None:
|
||||
input_ids = input_ids[:, -truncation_length:]
|
||||
|
||||
return input_ids.cuda()
|
||||
|
||||
def decode(self, output_ids):
|
||||
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
return reply
|
||||
|
||||
def generate_with_callback(self, callback=None, **kwargs):
|
||||
self.checkPoint.clear_torch_cache()
|
||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||
with torch.no_grad():
|
||||
self.checkPoint.model.generate(**kwargs)
|
||||
print("方法结束")
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
return Iteratorize(self.generate_with_callback, kwargs)
|
||||
|
||||
# 将历史对话数组转换为文本格式
|
||||
def history_to_text(self, query):
|
||||
formatted_history = ''
|
||||
history = self.history[-self.history_len:] if self.history_len > 0 else []
|
||||
for i, (old_query, response) in enumerate(history):
|
||||
formatted_history += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
||||
formatted_history += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||
return formatted_history
|
||||
|
||||
def prepare_inputs_for_generation(self,
|
||||
input_ids: torch.LongTensor):
|
||||
"""
|
||||
预生成注意力掩码和 输入序列中每个位置的索引的张量
|
||||
# TODO 没有思路
|
||||
:return:
|
||||
"""
|
||||
|
||||
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
|
||||
|
||||
attention_mask = self.get_masks(input_ids, input_ids.device)
|
||||
|
||||
position_ids = self.get_position_ids(
|
||||
input_ids,
|
||||
device=input_ids.device,
|
||||
mask_positions=mask_positions
|
||||
)
|
||||
|
||||
return input_ids, position_ids, attention_mask
|
||||
|
||||
def get_position_ids(self, input_ids: torch.LongTensor, mask_positions, device):
|
||||
"""
|
||||
注意力偏移量
|
||||
:param input_ids:
|
||||
:param mask_positions:
|
||||
:param device:
|
||||
:param use_gmasks:
|
||||
:return:
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids]
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[i, context_length:] = mask_positions[i]
|
||||
block_position_ids = [torch.cat((
|
||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||
)) for context_length in context_lengths]
|
||||
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||
return position_ids
|
||||
|
||||
def get_masks(self, input_ids, device):
|
||||
"""
|
||||
获取注意力掩码
|
||||
:param input_ids:
|
||||
:param device:
|
||||
:return:
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids]
|
||||
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
||||
attention_mask.tril_()
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
attention_mask[i, :, :context_length] = 1
|
||||
attention_mask.unsqueeze_(1)
|
||||
attention_mask = (attention_mask < 0.5).bool()
|
||||
return attention_mask
|
||||
|
||||
def generate_softprompt_history_tensors(self, query):
|
||||
"""
|
||||
历史对话软提示
|
||||
这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history
|
||||
数组转换为所需的文本格式。然后,我们将格式化后的历史文本
|
||||
再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。
|
||||
:return:
|
||||
"""
|
||||
|
||||
# 对话内容
|
||||
# 处理历史对话
|
||||
formatted_history = self.history_to_text(query)
|
||||
return formatted_history
|
||||
|
||||
@property
|
||||
def _history_len(self) -> int:
|
||||
return self.history_len
|
||||
|
||||
def set_history_len(self, history_len: int = 10) -> None:
|
||||
self.history_len = history_len
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
print(f"__call:{prompt}")
|
||||
if self.logits_processor is None:
|
||||
self.logits_processor = LogitsProcessorList()
|
||||
self.logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
|
||||
gen_kwargs = {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"num_beams": self.num_beams,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"encoder_repetition_penalty": self.encoder_repetition_penalty,
|
||||
"min_length": self.min_length,
|
||||
"temperature": self.temperature,
|
||||
"eos_token_id": self.eos_token_id,
|
||||
"logits_processor": self.logits_processor}
|
||||
|
||||
# 向量拼接
|
||||
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
|
||||
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
|
||||
|
||||
# 对话模型prompt
|
||||
gen_kwargs.update({'inputs': input_ids})
|
||||
# 注意力掩码
|
||||
# gen_kwargs.update({'attention_mask': attention_mask})
|
||||
# gen_kwargs.update({'position_ids': position_ids})
|
||||
if self.stopping_criteria is None:
|
||||
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||
# 观测输出
|
||||
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
||||
shared.stop_everything = False
|
||||
stopped = False
|
||||
response_template = _streaming_response_template()
|
||||
|
||||
# TODO 此流输出方法需要重写!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
# stopping_criteria方法不可控制 迭代器的变量无法共享
|
||||
with self.generate_with_streaming(**gen_kwargs) as generator:
|
||||
last_reply_len = 0
|
||||
reply_index = 0
|
||||
# Create a FixedLengthQueue with the desired stop sequence and a maximum length.
|
||||
queue = FixedLengthQueue(stop)
|
||||
for output in generator:
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = self.decode(output[-new_tokens:])
|
||||
|
||||
new_reply = len(reply) - last_reply_len
|
||||
output_reply = reply[-new_reply:]
|
||||
queue.add(reply_index, output_reply)
|
||||
queue.contains_replace_sequence()
|
||||
if stop:
|
||||
pos = queue.contains_stop_sequence()
|
||||
if pos != -1:
|
||||
shared.stop_everything = True
|
||||
stopped = True
|
||||
|
||||
#print(f"{reply_index}:reply {output_reply}")
|
||||
english_reply = queue.put_replace_out(reply_index)
|
||||
#print(f"{reply_index}:english_reply {english_reply}")
|
||||
_update_response(response_template, english_reply)
|
||||
last_reply_len = len(reply)
|
||||
|
||||
reply_index += 1
|
||||
if new_tokens == self.max_new_tokens - 1 or stopped:
|
||||
break
|
||||
|
||||
response = response_template['text']
|
||||
print(f"response:{response}")
|
||||
self.history = self.history + [[None, response]]
|
||||
return response
|
||||
|
||||
def _generate_answer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
if history:
|
||||
self.history = history
|
||||
# Create the StoppingCriteriaList with the stopping strings
|
||||
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||
self.stopping_criteria.append(listenerQueue)
|
||||
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
|
||||
softprompt = self.generate_softprompt_history_tensors(prompt)
|
||||
response = self._call(prompt=softprompt, stop=['\n###'])
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = self.history
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
answer_result.llm_output = {"answer": response}
|
||||
generate_with_callback(answer_result)
|
||||
2
models/loader/__init__.py
Normal file
2
models/loader/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
|
||||
from .loader import *
|
||||
59
models/loader/args.py
Normal file
59
models/loader/args.py
Normal file
@ -0,0 +1,59 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
|
||||
# Additional argparse types
|
||||
def path(string):
|
||||
if not string:
|
||||
return ''
|
||||
s = os.path.expanduser(string)
|
||||
if not os.path.exists(s):
|
||||
raise argparse.ArgumentTypeError(f'No such file or directory: "{string}"')
|
||||
return s
|
||||
|
||||
|
||||
def file_path(string):
|
||||
if not string:
|
||||
return ''
|
||||
s = os.path.expanduser(string)
|
||||
if not os.path.isfile(s):
|
||||
raise argparse.ArgumentTypeError(f'No such file: "{string}"')
|
||||
return s
|
||||
|
||||
|
||||
def dir_path(string):
|
||||
if not string:
|
||||
return ''
|
||||
s = os.path.expanduser(string)
|
||||
if not os.path.isdir(s):
|
||||
raise argparse.ArgumentTypeError(f'No such directory: "{string}"')
|
||||
return s
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(prog='langchina-ChatGLM',
|
||||
description='基于langchain和chatGML的LLM文档阅读器')
|
||||
|
||||
|
||||
|
||||
parser.add_argument('--no-remote-model', action='store_true', default=False, help='remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model`')
|
||||
parser.add_argument('--model', type=str, default='chatglm-6b', help='Name of the model to load by default.')
|
||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||
parser.add_argument("--model-dir", type=str, default='model/', help="Path to directory with all the models")
|
||||
parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
|
||||
|
||||
# Accelerate/transformers
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
|
||||
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
|
||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||
|
||||
|
||||
args = parser.parse_args([])
|
||||
# Generares dict with a default value for each argument
|
||||
DEFAULT_ARGS = vars(args)
|
||||
|
||||
|
||||
|
||||
405
models/loader/loader.py
Normal file
405
models/loader/loader.py
Normal file
@ -0,0 +1,405 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from peft import PeftModel
|
||||
from typing import Optional, List, Dict, Tuple, Union
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer)
|
||||
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
from transformers.utils import ContextManagers
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||
|
||||
|
||||
class LoaderCheckPoint:
|
||||
"""
|
||||
加载自定义 model CheckPoint
|
||||
"""
|
||||
# remote in the model on loader checkpoint
|
||||
no_remote_model: bool = False
|
||||
# 模型名称
|
||||
model_name: str = None
|
||||
tokenizer: object = None
|
||||
# 模型全路径
|
||||
model_path: str = None
|
||||
model: object = None
|
||||
model_config: object = None
|
||||
lora_names: set = []
|
||||
model_dir: str = None
|
||||
lora_dir: str = None
|
||||
ptuning_dir: str = None
|
||||
use_ptuning_v2: bool = False
|
||||
cpu: bool = False
|
||||
gpu_memory: object = None
|
||||
cpu_memory: object = None
|
||||
auto_devices: object = True
|
||||
# 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156
|
||||
load_in_8bit: bool = False
|
||||
is_llamacpp: bool = False
|
||||
bf16: bool = False
|
||||
params: object = None
|
||||
# 自定义设备网络
|
||||
device_map: Optional[Dict[str, int]] = None
|
||||
# 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu
|
||||
llm_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
def __init__(self, params: dict = None):
|
||||
"""
|
||||
模型初始化
|
||||
:param params:
|
||||
"""
|
||||
self.model_path = None
|
||||
self.params = params or {}
|
||||
self.no_remote_model = params.get('no_remote_model', False)
|
||||
self.model_name = params.get('model', '')
|
||||
self.lora = params.get('lora', '')
|
||||
self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.model_dir = params.get('model_dir', '')
|
||||
self.lora_dir = params.get('lora_dir', '')
|
||||
self.ptuning_dir = params.get('ptuning_dir', '')
|
||||
self.cpu = params.get('cpu', False)
|
||||
self.gpu_memory = params.get('gpu_memory', None)
|
||||
self.cpu_memory = params.get('cpu_memory', None)
|
||||
self.auto_devices = params.get('auto_devices', True)
|
||||
self.load_in_8bit = params.get('load_in_8bit', False)
|
||||
self.bf16 = params.get('bf16', False)
|
||||
|
||||
def _load_model_config(self, model_name):
|
||||
checkpoint = Path(f'{self.model_dir}/{model_name}')
|
||||
|
||||
if self.model_path:
|
||||
checkpoint = Path(f'{self.model_path}')
|
||||
else:
|
||||
if not self.no_remote_model:
|
||||
checkpoint = model_name
|
||||
|
||||
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
|
||||
return model_config
|
||||
|
||||
def _load_model(self, model_name):
|
||||
"""
|
||||
加载自定义位置的model
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
print(f"Loading {model_name}...")
|
||||
t0 = time.time()
|
||||
|
||||
checkpoint = Path(f'{self.model_dir}/{model_name}')
|
||||
|
||||
self.is_llamacpp = len(list(checkpoint.glob('ggml*.bin'))) > 0
|
||||
|
||||
if self.model_path:
|
||||
checkpoint = Path(f'{self.model_path}')
|
||||
else:
|
||||
if not self.no_remote_model:
|
||||
checkpoint = model_name
|
||||
|
||||
if 'chatglm' in model_name.lower():
|
||||
LoaderClass = AutoModel
|
||||
else:
|
||||
LoaderClass = AutoModelForCausalLM
|
||||
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([self.cpu, self.load_in_8bit, self.auto_devices, self.gpu_memory is not None,
|
||||
self.cpu_memory is not None, self.is_llamacpp]):
|
||||
|
||||
if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
|
||||
# 根据当前设备GPU数量决定是否进行多卡部署
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus < 2 and self.device_map is None:
|
||||
model = (
|
||||
LoaderClass.from_pretrained(checkpoint,
|
||||
low_cpu_mem_usage=True,
|
||||
config=self.model_config,
|
||||
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||
trust_remote_code=True)
|
||||
.half()
|
||||
.cuda()
|
||||
)
|
||||
else:
|
||||
from accelerate import dispatch_model
|
||||
|
||||
model = LoaderClass.from_pretrained(checkpoint,
|
||||
low_cpu_mem_usage=True,
|
||||
config=self.model_config,
|
||||
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||
trust_remote_code=True).half()
|
||||
# 可传入device_map自定义每张卡的部署情况
|
||||
if self.device_map is None:
|
||||
if 'chatglm' in model_name.lower():
|
||||
device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||
elif 'moss' in model_name.lower():
|
||||
device_map = self.moss_auto_configure_device_map(num_gpus,model_name)
|
||||
else:
|
||||
device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||
|
||||
model = dispatch_model(model, device_map=device_map)
|
||||
else:
|
||||
print(
|
||||
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
||||
model = (
|
||||
AutoModel.from_pretrained(
|
||||
checkpoint,
|
||||
config=self.model_config,
|
||||
trust_remote_code=True)
|
||||
.float()
|
||||
.to(self.llm_device)
|
||||
)
|
||||
|
||||
elif self.is_llamacpp:
|
||||
from models.extensions.llamacpp_model_alternative import LlamaCppModel
|
||||
|
||||
model_file = list(checkpoint.glob('ggml*.bin'))[0]
|
||||
print(f"llama.cpp weights detected: {model_file}\n")
|
||||
|
||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||
return model, tokenizer
|
||||
|
||||
# Custom
|
||||
else:
|
||||
params = {"low_cpu_mem_usage": True}
|
||||
if not any((self.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||
print(
|
||||
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
||||
self.cpu = True
|
||||
|
||||
if self.cpu:
|
||||
params["torch_dtype"] = torch.float32
|
||||
else:
|
||||
params["device_map"] = 'auto'
|
||||
params["trust_remote_code"] = True
|
||||
if self.load_in_8bit and any((self.auto_devices, self.gpu_memory)):
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=True)
|
||||
elif self.load_in_8bit:
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif self.bf16:
|
||||
params["torch_dtype"] = torch.bfloat16
|
||||
else:
|
||||
params["torch_dtype"] = torch.float16
|
||||
|
||||
if self.gpu_memory:
|
||||
memory_map = list(map(lambda x: x.strip(), self.gpu_memory))
|
||||
max_cpu_memory = self.cpu_memory.strip() if self.cpu_memory is not None else '99GiB'
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else \
|
||||
memory_map[i]
|
||||
max_memory['cpu'] = max_cpu_memory
|
||||
params['max_memory'] = max_memory
|
||||
elif self.auto_devices:
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
if total_mem - suggestion < 800:
|
||||
suggestion -= 1000
|
||||
suggestion = int(round(suggestion / 1000))
|
||||
print(
|
||||
f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||
|
||||
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{self.cpu_memory or 99}GiB'}
|
||||
params['max_memory'] = max_memory
|
||||
|
||||
if self.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||
config = AutoConfig.from_pretrained(checkpoint)
|
||||
with init_empty_weights():
|
||||
model = LoaderClass.from_config(config)
|
||||
model.tie_weights()
|
||||
if self.device_map is not None:
|
||||
params['device_map'] = self.device_map
|
||||
else:
|
||||
params['device_map'] = infer_auto_device_map(
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
max_memory=params['max_memory'],
|
||||
no_split_module_classes=model._no_split_modules
|
||||
)
|
||||
|
||||
model = LoaderClass.from_pretrained(checkpoint, **params)
|
||||
|
||||
# Loading the tokenizer
|
||||
if type(model) is transformers.LlamaForCausalLM:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(checkpoint, clean_up_tokenization_spaces=True)
|
||||
# Leaving this here until the LLaMA tokenizer gets figured out.
|
||||
# For some people this fixes things, for others it causes an error.
|
||||
try:
|
||||
tokenizer.eos_token_id = 2
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.pad_token_id = 0
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
|
||||
print(f"Loaded the model in {(time.time() - t0):.2f} seconds.")
|
||||
return model, tokenizer
|
||||
|
||||
def chatglm_auto_configure_device_map(self, num_gpus: int) -> Dict[str, int]:
|
||||
# transformer.word_embeddings 占用1层
|
||||
# transformer.final_layernorm 和 lm_head 占用1层
|
||||
# transformer.layers 占用 28 层
|
||||
# 总共30层分配到num_gpus张卡上
|
||||
num_trans_layers = 28
|
||||
per_gpu_layers = 30 / num_gpus
|
||||
|
||||
# bugfix: PEFT加载lora模型出现的层命名不同
|
||||
if self.lora:
|
||||
layer_prefix = 'base_model.model.transformer'
|
||||
else:
|
||||
layer_prefix = 'transformer'
|
||||
|
||||
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
||||
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
||||
# linux下 model.device 会被设置成 lm_head.device
|
||||
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
||||
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
||||
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
||||
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
||||
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
||||
f'base_model.model.lm_head': 0, }
|
||||
|
||||
used = 2
|
||||
gpu_target = 0
|
||||
for i in range(num_trans_layers):
|
||||
if used >= per_gpu_layers:
|
||||
gpu_target += 1
|
||||
used = 0
|
||||
assert gpu_target < num_gpus
|
||||
device_map[f'{layer_prefix}.layers.{i}'] = gpu_target
|
||||
used += 1
|
||||
|
||||
return device_map
|
||||
|
||||
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]:
|
||||
checkpoint = Path(f'{self.model_dir}/{model_name}')
|
||||
|
||||
if self.model_path:
|
||||
checkpoint = Path(f'{self.model_path}')
|
||||
else:
|
||||
if not self.no_remote_model:
|
||||
checkpoint = model_name
|
||||
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
||||
pretrained_model_name_or_path=checkpoint)
|
||||
|
||||
with ContextManagers([no_init_weights(_enable=True), init_empty_weights()]):
|
||||
model = cls(self.model_config)
|
||||
max_memory = get_balanced_memory(model, dtype=torch.int8 if self.load_in_8bit else None,
|
||||
low_zero=False, no_split_module_classes=model._no_split_modules)
|
||||
device_map = infer_auto_device_map(
|
||||
model, dtype=torch.float16 if not self.load_in_8bit else torch.int8, max_memory=max_memory,
|
||||
no_split_module_classes=model._no_split_modules)
|
||||
device_map["transformer.wte"] = 0
|
||||
device_map["transformer.drop"] = 0
|
||||
device_map["transformer.ln_f"] = 0
|
||||
device_map["lm_head"] = 0
|
||||
return device_map
|
||||
|
||||
def _add_lora_to_model(self, lora_names):
|
||||
# 目前加载的lora
|
||||
prior_set = set(self.lora_names)
|
||||
# 需要加载的
|
||||
added_set = set(lora_names) - prior_set
|
||||
# 删除的lora
|
||||
removed_set = prior_set - set(lora_names)
|
||||
self.lora_names = list(lora_names)
|
||||
|
||||
# Nothing to do = skip.
|
||||
if len(added_set) == 0 and len(removed_set) == 0:
|
||||
return
|
||||
|
||||
# Only adding, and already peft? Do it the easy way.
|
||||
if len(removed_set) == 0 and len(prior_set) > 0:
|
||||
print(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||
for lora in added_set:
|
||||
self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
|
||||
return
|
||||
|
||||
# If removing anything, disable all and re-add.
|
||||
if len(removed_set) > 0:
|
||||
self.model.disable_adapter()
|
||||
|
||||
if len(lora_names) > 0:
|
||||
print("Applying the following LoRAs to {}: {}".format(self.model_name, ', '.join(lora_names)))
|
||||
params = {}
|
||||
if not self.cpu:
|
||||
params['dtype'] = self.model.dtype
|
||||
if hasattr(self.model, "hf_device_map"):
|
||||
params['device_map'] = {"base_model.model." + k: v for k, v in self.model.hf_device_map.items()}
|
||||
elif self.load_in_8bit:
|
||||
params['device_map'] = {'': 0}
|
||||
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||
|
||||
self.model = PeftModel.from_pretrained(self.model, Path(f"{self.lora_dir}/{lora_names[0]}"), **params)
|
||||
|
||||
for lora in lora_names[1:]:
|
||||
self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
|
||||
|
||||
if not self.load_in_8bit and not self.cpu:
|
||||
|
||||
if not hasattr(self.model, "hf_device_map"):
|
||||
if torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
self.model = self.model.to(device)
|
||||
else:
|
||||
self.model = self.model.cuda()
|
||||
|
||||
def clear_torch_cache(self):
|
||||
gc.collect()
|
||||
if not self.cpu:
|
||||
device_id = "0" if torch.cuda.is_available() else None
|
||||
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
|
||||
with torch.cuda.device(CUDA_DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_model(self):
|
||||
del self.model
|
||||
del self.tokenizer
|
||||
self.model = self.tokenizer = None
|
||||
self.clear_torch_cache()
|
||||
|
||||
def set_model_path(self, model_path):
|
||||
self.model_path = model_path
|
||||
|
||||
def reload_model(self):
|
||||
self.unload_model()
|
||||
self.model_config = self._load_model_config(self.model_name)
|
||||
|
||||
if self.use_ptuning_v2:
|
||||
try:
|
||||
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
|
||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||
prefix_encoder_file.close()
|
||||
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
||||
except Exception:
|
||||
print("加载PrefixEncoder config.json失败")
|
||||
|
||||
self.model, self.tokenizer = self._load_model(self.model_name)
|
||||
|
||||
if self.lora:
|
||||
self._add_lora_to_model([self.lora])
|
||||
|
||||
if self.use_ptuning_v2:
|
||||
try:
|
||||
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
|
||||
new_prefix_state_dict = {}
|
||||
for k, v in prefix_state_dict.items():
|
||||
if k.startswith("transformer.prefix_encoder."):
|
||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
self.model.transformer.prefix_encoder.float()
|
||||
except Exception:
|
||||
print("加载PrefixEncoder模型参数失败")
|
||||
|
||||
self.model = self.model.eval()
|
||||
@ -1,20 +1,14 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from typing import List, Dict, Optional
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
from transformers.utils import ContextManagers
|
||||
from typing import Optional, List
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
|
||||
import torch
|
||||
from configs.model_config import *
|
||||
from utils import torch_gc
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||
|
||||
DEVICE_ = LLM_DEVICE
|
||||
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
||||
DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_
|
||||
|
||||
META_INSTRUCTION = \
|
||||
"""You are an AI assistant whose name is MOSS.
|
||||
@ -30,45 +24,40 @@ META_INSTRUCTION = \
|
||||
"""
|
||||
|
||||
|
||||
def auto_configure_device_map() -> Dict[str, int]:
|
||||
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
||||
pretrained_model_name_or_path=llm_model_dict['moss'])
|
||||
|
||||
with ContextManagers([no_init_weights(_enable=True), init_empty_weights()]):
|
||||
model_config = AutoConfig.from_pretrained(llm_model_dict['moss'], trust_remote_code=True)
|
||||
model = cls(model_config)
|
||||
max_memory = get_balanced_memory(model, dtype=torch.int8 if LOAD_IN_8BIT else None,
|
||||
low_zero=False, no_split_module_classes=model._no_split_modules)
|
||||
device_map = infer_auto_device_map(
|
||||
model, dtype=torch.float16 if not LOAD_IN_8BIT else torch.int8, max_memory=max_memory,
|
||||
no_split_module_classes=model._no_split_modules)
|
||||
device_map["transformer.wte"] = 0
|
||||
device_map["transformer.drop"] = 0
|
||||
device_map["transformer.ln_f"] = 0
|
||||
device_map["lm_head"] = 0
|
||||
return device_map
|
||||
|
||||
|
||||
class MOSS(LLM):
|
||||
class MOSSLLM(BaseAnswer, LLM, ABC):
|
||||
max_token: int = 2048
|
||||
temperature: float = 0.7
|
||||
top_p = 0.8
|
||||
# history = []
|
||||
tokenizer: object = None
|
||||
model: object = None
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
history_len: int = 10
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "MOSS"
|
||||
|
||||
def _call(self,
|
||||
prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]:
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
return self.checkPoint
|
||||
|
||||
@property
|
||||
def set_history_len(self) -> int:
|
||||
return self.history_len
|
||||
|
||||
def _set_history_len(self, history_len: int) -> None:
|
||||
self.history_len = history_len
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def _generate_answer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
if len(history) > 0:
|
||||
history = history[-self.history_len:-1] if self.history_len > 0 else []
|
||||
prompt_w_history = str(history)
|
||||
@ -77,9 +66,9 @@ class MOSS(LLM):
|
||||
prompt_w_history = META_INSTRUCTION
|
||||
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
|
||||
|
||||
inputs = self.tokenizer(prompt_w_history, return_tensors="pt")
|
||||
inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
outputs = self.checkPoint.model.generate(
|
||||
inputs.input_ids.cuda(),
|
||||
attention_mask=inputs.attention_mask.cuda(),
|
||||
max_length=self.max_token,
|
||||
@ -92,78 +81,8 @@ class MOSS(LLM):
|
||||
eos_token_id=106068,
|
||||
pad_token_id=self.tokenizer.pad_token_id)
|
||||
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
torch_gc()
|
||||
self.checkPoint.clear_torch_cache()
|
||||
history += [[prompt, response]]
|
||||
yield response, history
|
||||
torch_gc()
|
||||
|
||||
def load_model(self,
|
||||
model_name_or_path: str = "fnlp/moss-moon-003-sft",
|
||||
llm_device=LLM_DEVICE,
|
||||
use_ptuning_v2=False,
|
||||
use_lora=False,
|
||||
device_map: Optional[Dict[str, int]] = None,
|
||||
**kwargs):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
|
||||
if use_ptuning_v2:
|
||||
try:
|
||||
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
|
||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||
prefix_encoder_file.close()
|
||||
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("加载PrefixEncoder config.json失败")
|
||||
|
||||
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
||||
# accelerate自动多卡部署
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=model_config,
|
||||
load_in_8bit=LOAD_IN_8BIT, trust_remote_code=True,
|
||||
device_map=auto_configure_device_map(), **kwargs)
|
||||
|
||||
if LLM_LORA_PATH and use_lora:
|
||||
from peft import PeftModel
|
||||
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
||||
|
||||
else:
|
||||
self.model = self.model.float().to(llm_device)
|
||||
if LLM_LORA_PATH and use_lora:
|
||||
from peft import PeftModel
|
||||
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
||||
|
||||
if use_ptuning_v2:
|
||||
try:
|
||||
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
|
||||
new_prefix_state_dict = {}
|
||||
for k, v in prefix_state_dict.items():
|
||||
if k.startswith("transformer.prefix_encoder."):
|
||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
self.model.transformer.prefix_encoder.float()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("加载PrefixEncoder模型参数失败")
|
||||
|
||||
self.model = self.model.eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
llm = MOSS()
|
||||
llm.load_model(model_name_or_path=llm_model_dict['moss'],
|
||||
llm_device=LLM_DEVICE, )
|
||||
last_print_len = 0
|
||||
# for resp, history in llm._call("你好", streaming=True):
|
||||
# print(resp[last_print_len:], end="", flush=True)
|
||||
# last_print_len = len(resp)
|
||||
for resp, history in llm._call("你好", streaming=False):
|
||||
print(resp)
|
||||
import time
|
||||
time.sleep(10)
|
||||
pass
|
||||
|
||||
43
models/shared.py
Normal file
43
models/shared.py
Normal file
@ -0,0 +1,43 @@
|
||||
import sys
|
||||
|
||||
from models.loader.args import parser
|
||||
from models.loader import LoaderCheckPoint
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||
from models.base import BaseAnswer
|
||||
"""迭代器是否停止状态"""
|
||||
stop_everything = False
|
||||
|
||||
loaderCheckPoint: LoaderCheckPoint = None
|
||||
|
||||
|
||||
def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False) -> BaseAnswer:
|
||||
"""
|
||||
init llm_model_ins LLM
|
||||
:param llm_model: model_name
|
||||
:param no_remote_model: remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model
|
||||
:param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder
|
||||
:return:
|
||||
"""
|
||||
pre_model_name = loaderCheckPoint.model_name
|
||||
llm_model_info = llm_model_dict[pre_model_name]
|
||||
|
||||
if no_remote_model:
|
||||
loaderCheckPoint.no_remote_model = no_remote_model
|
||||
if use_ptuning_v2:
|
||||
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
|
||||
|
||||
if llm_model:
|
||||
llm_model_info = llm_model_dict[llm_model]
|
||||
|
||||
if loaderCheckPoint.no_remote_model:
|
||||
loaderCheckPoint.model_name = llm_model_info['name']
|
||||
else:
|
||||
loaderCheckPoint.model_name = llm_model_info['remote-checkpoint']
|
||||
|
||||
loaderCheckPoint.model_path = llm_model_info['path']
|
||||
|
||||
loaderCheckPoint.reload_model()
|
||||
|
||||
provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
|
||||
modelInsLLM = provides_class(checkPoint=loaderCheckPoint)
|
||||
return modelInsLLM
|
||||
@ -17,6 +17,8 @@ fastapi
|
||||
uvicorn
|
||||
peft
|
||||
pypinyin
|
||||
bitsandbytes
|
||||
click~=8.1.3
|
||||
tabulate
|
||||
bitsandbytes; platform_system != "Windows"
|
||||
llama-cpp-python==0.1.34; platform_system != "Windows"
|
||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||
|
||||
51
webui.py
51
webui.py
@ -1,9 +1,17 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from chains.local_doc_qa import LocalDocQA
|
||||
from configs.model_config import *
|
||||
import nltk
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
import models.shared as shared
|
||||
from models.loader.args import parser
|
||||
from models.loader import LoaderCheckPoint
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
@ -69,7 +77,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
||||
yield history + [[query,
|
||||
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
||||
else:
|
||||
for resp, history in local_doc_qa.llm._call(query, history, streaming=streaming):
|
||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
||||
streaming=streaming):
|
||||
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
history[-1][-1] = resp + (
|
||||
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||
yield history, ""
|
||||
@ -77,10 +89,12 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
||||
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
||||
|
||||
|
||||
def init_model():
|
||||
def init_model(llm_model: BaseAnswer = None):
|
||||
try:
|
||||
local_doc_qa.init_cfg()
|
||||
local_doc_qa.llm._call("你好")
|
||||
local_doc_qa.init_cfg(llm_model=llm_model)
|
||||
generator = local_doc_qa.llm.generatorAnswer("你好")
|
||||
for answer_result in generator:
|
||||
print(answer_result.llm_output)
|
||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||
logger.info(reply)
|
||||
return reply
|
||||
@ -95,14 +109,13 @@ def init_model():
|
||||
return reply
|
||||
|
||||
|
||||
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history):
|
||||
def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
|
||||
try:
|
||||
local_doc_qa.init_cfg(llm_model=llm_model,
|
||||
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
||||
llm_model_ins.history_len = llm_history_len
|
||||
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
||||
embedding_model=embedding_model,
|
||||
llm_history_len=llm_history_len,
|
||||
use_ptuning_v2=use_ptuning_v2,
|
||||
use_lora=use_lora,
|
||||
top_k=top_k, )
|
||||
top_k=top_k)
|
||||
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||
logger.info(model_status)
|
||||
except Exception as e:
|
||||
@ -219,7 +232,17 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
|
||||
知识库暂不支持文件删除,该功能将在后续版本中推出。
|
||||
"""
|
||||
|
||||
model_status = init_model()
|
||||
# 初始化消息
|
||||
args = None
|
||||
args = parser.parse_args()
|
||||
|
||||
args_dict = vars(args)
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
llm_model_ins = shared.loaderLLM()
|
||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||||
|
||||
model_status = init_model(llm_model=llm_model_ins)
|
||||
|
||||
|
||||
default_theme_args = dict(
|
||||
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
||||
@ -399,6 +422,10 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
||||
label="LLM 模型",
|
||||
value=LLM_MODEL,
|
||||
interactive=True)
|
||||
no_remote_model = gr.Checkbox(shared.LoaderCheckPoint.no_remote_model,
|
||||
label="加载本地模型",
|
||||
interactive=True)
|
||||
|
||||
llm_history_len = gr.Slider(0, 10,
|
||||
value=LLM_HISTORY_LEN,
|
||||
step=1,
|
||||
@ -418,7 +445,7 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
||||
label="向量匹配 top k", interactive=True)
|
||||
load_model_button = gr.Button("重新加载模型")
|
||||
load_model_button.click(reinit_model, show_progress=True,
|
||||
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora,
|
||||
inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora,
|
||||
top_k, chatbot], outputs=chatbot)
|
||||
|
||||
(demo
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user