From 33bbb4779e373602cb8cfa984b2e4e8342f62a75 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Thu, 18 May 2023 22:54:41 +0800 Subject: [PATCH] =?UTF-8?q?llm=5Fmodel=5Fdict=20=E5=A4=84=E7=90=86?= =?UTF-8?q?=E4=BA=86loader=E7=9A=84=E4=B8=80=E4=BA=9B=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E8=A1=8C=E4=B8=BA=EF=BC=8C=E5=A6=82=E5=8A=A0=E8=BD=BD=E4=BD=8D?= =?UTF-8?q?=E7=BD=AE=EF=BC=8C=E6=A8=A1=E5=9E=8B=E5=90=8D=E7=A7=B0=EF=BC=8C?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=A4=84=E7=90=86=E5=99=A8=E5=AE=9E=E4=BE=8B?= =?UTF-8?q?,=20=E5=AE=9A=E4=B9=89checkpoint=E5=90=8D=E7=A7=B0=E5=92=8C?= =?UTF-8?q?=E8=BF=9C=E7=A8=8B=E8=B7=AF=E5=BE=84=20loader.py:=20=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E9=87=8D=E8=BD=BD=20=E5=AE=9A=E4=B9=89=20generatorAns?= =?UTF-8?q?wer=20=E5=A2=9E=E5=8A=A0=20AnswerResultStream=20=20=20=20?= =?UTF-8?q?=E5=AE=9A=E4=B9=89generate=5Fwith=5Fcallback=E6=94=B6=E9=9B=86?= =?UTF-8?q?=E5=99=A8=EF=BC=8C=E5=9C=A8=E6=AF=8F=E6=AC=A1=E5=93=8D=E5=BA=94?= =?UTF-8?q?=E6=97=B6=E5=B0=86=E9=98=9F=E5=88=97=E6=95=B0=E6=8D=AE=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E5=88=B0AnswerResult=20requirements.txt=20=E5=8F=98?= =?UTF-8?q?=E6=9B=B4=E9=A1=B9=E7=9B=AE=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chains/local_doc_qa.py | 54 +-- configs/model_config.py | 52 +- fastchat/__init__.py | 0 fastchat/api/__init__.py | 1 + fastchat/api/conversation.py | 261 ++++++++++ fastchat/api/fastchat_api.py | 459 ++++++++++++++++++ models/__init__.py | 4 +- models/__main__.py | 97 ++++ models/base.py | 198 ++++++++ models/chatglm_llm.py | 213 +++----- models/extensions/callback.py | 222 +++++++++ models/extensions/extensions.py | 10 + .../extensions/llamacpp_model_alternative.py | 64 +++ models/extensions/thread_with_exception.py | 30 ++ models/llama_llm.py | 297 ++++++++++++ models/loader/__init__.py | 2 + models/loader/args.py | 59 +++ models/loader/loader.py | 405 ++++++++++++++++ models/moss_llm.py | 149 ++---- models/shared.py | 43 ++ requirements.txt | 4 +- webui.py | 51 +- 22 files changed, 2360 insertions(+), 315 deletions(-) create mode 100644 fastchat/__init__.py create mode 100644 fastchat/api/__init__.py create mode 100644 fastchat/api/conversation.py create mode 100644 fastchat/api/fastchat_api.py create mode 100644 models/__main__.py create mode 100644 models/base.py create mode 100644 models/extensions/callback.py create mode 100644 models/extensions/extensions.py create mode 100644 models/extensions/llamacpp_model_alternative.py create mode 100644 models/extensions/thread_with_exception.py create mode 100644 models/llama_llm.py create mode 100644 models/loader/__init__.py create mode 100644 models/loader/args.py create mode 100644 models/loader/loader.py create mode 100644 models/shared.py diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index d4cdcb04..5ba2f7ea 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -12,10 +12,14 @@ from tqdm import tqdm from pypinyin import lazy_pinyin from loader import UnstructuredPaddleImageLoader from loader import UnstructuredPaddlePDFLoader +from models.base import (BaseAnswer, + AnswerResult, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) +from models.loader.args import parser +from models.loader import LoaderCheckPoint +import models.shared as shared -DEVICE_ = EMBEDDING_DEVICE -DEVICE_ID = "0" if torch.cuda.is_available() else None -DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_ def load_file(filepath, sentence_size=SENTENCE_SIZE): @@ -132,7 +136,7 @@ def similarity_search_with_score_by_vector( class LocalDocQA: - llm: object = None + llm: BaseAnswer = None embeddings: object = None top_k: int = VECTOR_SEARCH_TOP_K chunk_size: int = CHUNK_SIZE @@ -142,23 +146,10 @@ class LocalDocQA: def init_cfg(self, embedding_model: str = EMBEDDING_MODEL, embedding_device=EMBEDDING_DEVICE, - llm_history_len: int = LLM_HISTORY_LEN, - llm_model: str = LLM_MODEL, - llm_device=LLM_DEVICE, + llm_model: BaseAnswer = None, top_k=VECTOR_SEARCH_TOP_K, - use_ptuning_v2: bool = USE_PTUNING_V2, - use_lora: bool = USE_LORA, ): - if llm_model.startswith('moss'): - from models.moss_llm import MOSS - self.llm = MOSS() - else: - from models.chatglm_llm import ChatGLM - self.llm = ChatGLM() - self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], - llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora) - self.llm.history_len = llm_history_len - + self.llm = llm_model self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], model_kwargs={'device': embedding_device}) self.top_k = top_k @@ -259,16 +250,16 @@ class LocalDocQA: torch_gc() prompt = generate_prompt(related_docs_with_score, query) - for result, history in self.llm._call(prompt=prompt, - history=chat_history, - streaming=streaming): - torch_gc() + for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, + streaming=streaming): + resp = answer_result.llm_output["answer"] + history = answer_result.history history[-1][0] = query response = {"query": query, - "result": result, + "result": resp, "source_documents": related_docs_with_score} yield response, history - torch_gc() + # query 查询内容 # vs_path 知识库路径 @@ -297,10 +288,19 @@ class LocalDocQA: if __name__ == "__main__": + # 初始化消息 + args = None + args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model']) + + args_dict = vars(args) + shared.loaderCheckPoint = LoaderCheckPoint(args_dict) + llm_model_ins = shared.loaderLLM() + llm_model_ins.set_history_len(LLM_HISTORY_LEN) + local_doc_qa = LocalDocQA() - local_doc_qa.init_cfg() + local_doc_qa.init_cfg(llm_model=llm_model_ins) query = "本项目使用的embedding模型是什么,消耗多少显存" - vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa" + vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test" last_print_len = 0 for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, vs_path=vs_path, diff --git a/configs/model_config.py b/configs/model_config.py index b04ec3af..3a9b4691 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -22,14 +22,54 @@ EMBEDDING_MODEL = "text2vec" # Embedding running device EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + # supported LLM models +""" +llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例 +""" llm_model_dict = { - "chatyuan": "ClueAI/ChatYuan-large-v2", - "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe", - "chatglm-6b-int4": "THUDM/chatglm-6b-int4", - "chatglm-6b-int8": "THUDM/chatglm-6b-int8", - "chatglm-6b": "THUDM/chatglm-6b", - "moss": "fnlp/moss-moon-003-sft", + "chatglm-6b-int4-qe": { + "name": "chatglm-6b-int4-qe", + "remote-checkpoint": "THUDM/chatglm-6b-int4-qe", + "path": None, + "provides": "ChatGLM" + }, + "chatglm-6b-int4": { + "name": "chatglm-6b-int4", + "remote-checkpoint": "THUDM/chatglm-6b-int4", + "path": None, + "provides": "ChatGLM" + }, + "chatglm-6b": { + "name": "chatglm-6b", + "remote-checkpoint": "THUDM/chatglm-6b-int4", + "path": None, + "provides": "ChatGLM" + }, + "llama-7b-hf": { + "name": "llama-7b-hf", + "remote-checkpoint": "llama-7b-hf", + "path": None, + "provides": "LLamaLLM" + }, + "vicuna-13b-hf": { + "name": "vicuna-13b-hf", + "remote-checkpoint": "vicuna-13b-hf", + "path": None, + "provides": "LLamaLLM" + }, + "chatyuan": { + "name": "chatyuan", + "remote-checkpoint": "ClueAI/ChatYuan-large-v2", + "path": None, + "provides": None + }, + "chatglm-6b-int8":{ + "name": "chatglm-6b-int8", + "remote-checkpoint": "THUDM/chatglm-6b-int8", + "path": None, + "provides": "ChatGLM" + }, } # LLM model name diff --git a/fastchat/__init__.py b/fastchat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastchat/api/__init__.py b/fastchat/api/__init__.py new file mode 100644 index 00000000..7d4f6df3 --- /dev/null +++ b/fastchat/api/__init__.py @@ -0,0 +1 @@ +from .fastchat_api import * \ No newline at end of file diff --git a/fastchat/api/conversation.py b/fastchat/api/conversation.py new file mode 100644 index 00000000..e349e83c --- /dev/null +++ b/fastchat/api/conversation.py @@ -0,0 +1,261 @@ +""" +Conversation prompt template. + +Now we support +- Vicuna +- Koala +- OpenAssistant/oasst-sft-1-pythia-12b +- StabilityAI/stablelm-tuned-alpha-7b +- databricks/dolly-v2-12b +- THUDM/chatglm-6b +- Alpaca/LLaMa +""" + +import dataclasses +from enum import auto, Enum +from typing import List, Tuple, Any + + +class SeparatorStyle(Enum): + """Different separator style.""" + + SINGLE = auto() + TWO = auto() + DOLLY = auto() + OASST_PYTHIA = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + # Used for gradio server + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + for role, message in self.messages: + if message: + ret += self.sep + " " + role + ": " + message + else: + ret += self.sep + " " + role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":\n" + message + seps[i % 2] + if i % 2 == 1: + ret += "\n\n" + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.OASST_PYTHIA: + ret = self.system + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + conv_id=self.conv_id, + ) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id, + } + + +conv_one_shot = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ( + "Human", + "What are the key differences between renewable and non-renewable energy sources?", + ), + ( + "Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.", + ), + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + + +conv_vicuna_v1_1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + + +conv_koala_v1 = Conversation( + system="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_dolly = Conversation( + system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", + roles=("### Instruction", "### Response"), + messages=(), + offset=0, + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="### End", +) + +conv_oasst = Conversation( + system="", + roles=("<|prompter|>", "<|assistant|>"), + messages=(), + offset=0, + sep_style=SeparatorStyle.OASST_PYTHIA, + sep="<|endoftext|>", +) + +conv_stablelm = Conversation( + system="""<|SYSTEM|># StableLM Tuned (Alpha version) +- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. +- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. +- StableLM will refuse to participate in anything that could harm a human. +""", + roles=("<|USER|>", "<|ASSISTANT|>"), + messages=(), + offset=0, + sep_style=SeparatorStyle.OASST_PYTHIA, + sep="", +) + +conv_templates = { + "conv_one_shot": conv_one_shot, + "vicuna_v1.1": conv_vicuna_v1_1, + "koala_v1": conv_koala_v1, + "dolly": conv_dolly, + "oasst": conv_oasst, +} + + +def get_default_conv_template(model_name): + model_name = model_name.lower() + if "vicuna" in model_name or "output" in model_name: + return conv_vicuna_v1_1 + elif "koala" in model_name: + return conv_koala_v1 + elif "dolly-v2" in model_name: + return conv_dolly + elif "oasst" in model_name and "pythia" in model_name: + return conv_oasst + elif "stablelm" in model_name: + return conv_stablelm + return conv_one_shot + + +def compute_skip_echo_len(model_name, conv, prompt): + model_name = model_name.lower() + if "chatglm" in model_name: + skip_echo_len = len(conv.messages[-2][1]) + 1 + elif "dolly-v2" in model_name: + special_toks = ["### Instruction:", "### Response:", "### End"] + skip_echo_len = len(prompt) + for tok in special_toks: + skip_echo_len -= prompt.count(tok) * len(tok) + elif "oasst" in model_name and "pythia" in model_name: + special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"] + skip_echo_len = len(prompt) + for tok in special_toks: + skip_echo_len -= prompt.count(tok) * len(tok) + elif "stablelm" in model_name: + special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"] + skip_echo_len = len(prompt) + for tok in special_toks: + skip_echo_len -= prompt.count(tok) * len(tok) + else: + skip_echo_len = len(prompt) + 1 - prompt.count("") * 3 + return skip_echo_len + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/fastchat/api/fastchat_api.py b/fastchat/api/fastchat_api.py new file mode 100644 index 00000000..11d28e36 --- /dev/null +++ b/fastchat/api/fastchat_api.py @@ -0,0 +1,459 @@ +"""Wrapper around FastChat APIs.""" +from __future__ import annotations + +import logging +import sys +import warnings +from typing import ( + AbstractSet, + Any, + Callable, + Collection, + Dict, + Generator, + List, + Literal, + Mapping, + Optional, + Set, + Tuple, + Union, +) + +from pydantic import Extra, Field, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.llms.base import BaseLLM +from langchain.schema import Generation, LLMResult +from langchain.utils import get_from_dict_or_env + +import requests +import json + + +logger = logging.getLogger(__name__) +FAST_CHAT_API = "http://localhost:21002/worker_generate_stream" + + +def _streaming_response_template() -> Dict[str, Any]: + """ + :return: 响应结构 + """ + return { + "text": "", + "error_code": 0, + } + + +def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None: + """Update response from the stream response.""" + response["text"] += stream_response["text"] + response["error_code"] += stream_response["error_code"] + + +class BaseFastChat(BaseLLM): + """Wrapper around FastChat large language models.""" + + model_name: str = "text-davinci-003" + """Model name to use.""" + temperature: float = 0.7 + """What sampling temperature to use.""" + max_new_tokens: int = 200 + stop: int = 20 + batch_size: int = 20 + """Maximum number of retries to make when generating.""" + streaming: bool = False + """Penalizes repeated tokens.""" + n: int = 1 + """Whether to stream the results or not.""" + allowed_special: Union[Literal["all"], AbstractSet[str]] = set() + """Set of special tokens that are allowed。""" + disallowed_special: Union[Literal["all"], Collection[str]] = "all" + """Set of special tokens that are not allowed。""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.ignore + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = {field.alias for field in cls.__fields__.values()} + + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name not in all_required_field_names: + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transfered to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + values["model_kwargs"] = extra + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling FastChat API.""" + normal_params = { + "model": self.model_name, + "prompt": '', + "max_new_tokens": self.max_new_tokens, + "temperature": self.temperature, + } + + return {**normal_params} + + def _generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + """Call out to FastChat's endpoint with k unique prompts. + + Args: + prompts: The prompts to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The full LLM output. + + Example: + .. code-block:: python + + response = fastchat.generate(["Tell me a joke."]) + """ + # TODO: write a unit test for this + params = self._invocation_params + sub_prompts = self.get_sub_prompts(params, prompts) + choices = [] + token_usage: Dict[str, int] = {} + headers = {"User-Agent": "fastchat Client"} + for _prompts in sub_prompts: + + params["prompt"] = _prompts[0] + + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + if self.streaming: + if len(_prompts) > 1: + raise ValueError("Cannot stream results with multiple prompts.") + + response_template = _streaming_response_template() + response = requests.post( + FAST_CHAT_API, + headers=headers, + json=params, + stream=True, + ) + for stream_resp in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\0" + ): + if stream_resp: + data = json.loads(stream_resp.decode("utf-8")) + skip_echo_len = len(_prompts[0]) + output = data["text"][skip_echo_len:].strip() + data["text"] = output + self.callback_manager.on_llm_new_token( + output, + verbose=self.verbose, + logprobs=data["error_code"], + ) + _update_response(response_template, data) + choices.append(response_template) + else: + response_template = _streaming_response_template() + response = requests.post( + FAST_CHAT_API, + headers=headers, + json=params, + stream=True, + ) + for stream_resp in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\0" + ): + if stream_resp: + data = json.loads(stream_resp.decode("utf-8")) + skip_echo_len = len(_prompts[0]) + output = data["text"][skip_echo_len:].strip() + data["text"] = output + _update_response(response_template, data) + + choices.append(response_template) + + return self.create_llm_result(choices, prompts, token_usage) + + async def _agenerate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + """Call out to FastChat's endpoint async with k unique prompts.""" + params = self._invocation_params + sub_prompts = self.get_sub_prompts(params, prompts) + choices = [] + token_usage: Dict[str, int] = {} + + headers = {"User-Agent": "fastchat Client"} + for _prompts in sub_prompts: + + params["prompt"] = _prompts[0] + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + if self.streaming: + if len(_prompts) > 1: + raise ValueError("Cannot stream results with multiple prompts.") + + response_template = _streaming_response_template() + response = requests.post( + FAST_CHAT_API, + headers=headers, + json=params, + stream=True, + ) + for stream_resp in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\0" + ): + if stream_resp: + data = json.loads(stream_resp.decode("utf-8")) + skip_echo_len = len(_prompts[0]) + output = data["text"][skip_echo_len:].strip() + data["text"] = output + self.callback_manager.on_llm_new_token( + output, + verbose=self.verbose, + logprobs=data["error_code"], + ) + _update_response(response_template, data) + choices.append(response_template) + else: + response_template = _streaming_response_template() + response = requests.post( + FAST_CHAT_API, + headers=headers, + json=params, + stream=True, + ) + for stream_resp in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\0" + ): + if stream_resp: + data = json.loads(stream_resp.decode("utf-8")) + skip_echo_len = len(_prompts[0]) + output = data["text"][skip_echo_len:].strip() + data["text"] = output + _update_response(response_template, data) + + choices.append(response_template) + + return self.create_llm_result(choices, prompts, token_usage) + + def get_sub_prompts( + self, + params: Dict[str, Any], + prompts: List[str], + ) -> List[List[str]]: + """Get the sub prompts for llm call.""" + if params["max_new_tokens"] == -1: + if len(prompts) != 1: + raise ValueError( + "max_new_tokens set to -1 not supported for multiple inputs." + ) + params["max_new_tokens"] = self.max_new_tokens_for_prompt(prompts[0]) + # append pload + sub_prompts = [ + prompts[i: i + self.batch_size] + for i in range(0, len(prompts), self.batch_size) + ] + + return sub_prompts + + def create_llm_result( + self, choices: Any, prompts: List[str], token_usage: Dict[str, int] + ) -> LLMResult: + """Create the LLMResult from the choices and prompts.""" + generations = [] + for i, _ in enumerate(prompts): + sub_choices = choices[i * self.n: (i + 1) * self.n] + generations.append( + [ + Generation( + text=choice["text"], + generation_info=dict( + finish_reason='over', + logprobs=choice["text"], + ), + ) + for choice in sub_choices + ] + ) + llm_output = {"token_usage": token_usage, "model_name": self.model_name} + return LLMResult(generations=generations, llm_output=llm_output) + + def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator: + """Call FastChat with streaming flag and return the resulting generator. + + BETA: this is a beta feature while we figure out the right abstraction. + Once that happens, this interface could change. + + Args: + prompt: The prompts to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + A generator representing the stream of tokens from OpenAI. + + Example: + .. code-block:: python + + generator = fastChat.stream("Tell me a joke.") + for token in generator: + yield token + """ + params = self._invocation_params + params["prompt"] = prompt + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + headers = {"User-Agent": "fastchat Client"} + response = requests.post( + FAST_CHAT_API, + headers=headers, + json=params, + stream=True, + ) + for stream_resp in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\0" + ): + if stream_resp: + data = json.loads(stream_resp.decode("utf-8")) + skip_echo_len = len(_prompts[0]) + output = data["text"][skip_echo_len:].strip() + data["text"] = output + yield data + + @property + def _invocation_params(self) -> Dict[str, Any]: + """Get the parameters used to invoke the model.""" + return self._default_params + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "fastChat" + + def get_num_tokens(self, text: str) -> int: + """Calculate num tokens with tiktoken package.""" + # tiktoken NOT supported for Python < 3.8 + if sys.version_info[1] < 8: + return super().get_num_tokens(text) + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_num_tokens. " + "Please install it with `pip install tiktoken`." + ) + + enc = tiktoken.encoding_for_model(self.model_name) + + tokenized_text = enc.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) + + # calculate the number of tokens in the encoded text + return len(tokenized_text) + + def modelname_to_contextsize(self, modelname: str) -> int: + """Calculate the maximum number of tokens possible to generate for a model. + + Args: + modelname: The modelname we want to know the context size for. + + Returns: + The maximum context size + + Example: + .. code-block:: python + + max_new_tokens = openai.modelname_to_contextsize("text-davinci-003") + """ + model_token_mapping = { + "vicuna-13b": 2049, + "koala": 2049, + "dolly-v2": 2049, + "oasst": 2049, + "stablelm": 2049, + } + + context_size = model_token_mapping.get(modelname, None) + + if context_size is None: + raise ValueError( + f"Unknown model: {modelname}. Please provide a valid OpenAI model name." + "Known models are: " + ", ".join(model_token_mapping.keys()) + ) + + return context_size + + def max_new_tokens_for_prompt(self, prompt: str) -> int: + """Calculate the maximum number of tokens possible to generate for a prompt. + + Args: + prompt: The prompt to pass into the model. + + Returns: + The maximum number of tokens to generate for a prompt. + + Example: + .. code-block:: python + + max_new_tokens = openai.max_token_for_prompt("Tell me a joke.") + """ + num_tokens = self.get_num_tokens(prompt) + + # get max context size for model by name + max_size = self.modelname_to_contextsize(self.model_name) + return max_size - num_tokens + + +class FastChat(BaseFastChat): + """Wrapper around OpenAI large language models. + + To use, you should have the ``openai`` python package installed, and the + environment variable ``OPENAI_API_KEY`` set with your API key. + + Any parameters that are valid to be passed to the openai.create call can be passed + in, even if not explicitly saved on this class. + + Example: + .. code-block:: python + + from langchain.llms import OpenAI + openai = FastChat(model_name="vicuna") + """ + + @property + def _invocation_params(self) -> Dict[str, Any]: + return {**{"model": self.model_name}, **super()._invocation_params} diff --git a/models/__init__.py b/models/__init__.py index 153a78bb..2a58a8ff 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,2 +1,4 @@ -from .chatglm_llm import * \ No newline at end of file +from .chatglm_llm import ChatGLM +from .llama_llm import LLamaLLM +from .moss_llm import MOSSLLM diff --git a/models/__main__.py b/models/__main__.py new file mode 100644 index 00000000..11495bee --- /dev/null +++ b/models/__main__.py @@ -0,0 +1,97 @@ +import sys +import os + +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') +import asyncio +from argparse import Namespace +from models.loader.args import parser +from models.loader import LoaderCheckPoint + +from langchain.agents import initialize_agent, Tool +from langchain.agents import AgentType + +import models.shared as shared + +from langchain.chains import LLMChain +from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory +from langchain.prompts import PromptTemplate +from langchain.agents import ZeroShotAgent, Tool, AgentExecutor +from typing import List, Set + + + +class CustomLLMSingleActionAgent(ZeroShotAgent): + allowed_tools: List[str] + + def __init__(self, *args, **kwargs): + super(CustomLLMSingleActionAgent, self).__init__(*args, **kwargs) + self.allowed_tools = kwargs['allowed_tools'] + + def get_allowed_tools(self) -> Set[str]: + return set(self.allowed_tools) + + +async def dispatch(args: Namespace): + args_dict = vars(args) + + shared.loaderCheckPoint = LoaderCheckPoint(args_dict) + llm_model_ins = shared.loaderLLM() + + template = """This is a conversation between a human and a bot: + +{chat_history} + +Write a summary of the conversation for {input}: +""" + + prompt = PromptTemplate( + input_variables=["input", "chat_history"], + template=template + ) + memory = ConversationBufferMemory(memory_key="chat_history") + readonlymemory = ReadOnlySharedMemory(memory=memory) + summry_chain = LLMChain( + llm=llm_model_ins, + prompt=prompt, + verbose=True, + memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory + ) + + + tools = [ + Tool( + name="Summary", + func=summry_chain.run, + description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary." + ) + ] + + prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:""" + suffix = """Begin! + +Question: {input} +{agent_scratchpad}""" + + + prompt = CustomLLMSingleActionAgent.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + input_variables=["input", "agent_scratchpad"] + ) + tool_names = [tool.name for tool in tools] + llm_chain = LLMChain(llm=llm_model_ins, prompt=prompt) + agent = CustomLLMSingleActionAgent(llm_chain=llm_chain, tools=tools, allowed_tools=tool_names) + agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools) + + agent_chain.run(input="你好") + agent_chain.run(input="你是谁?") + agent_chain.run(input="我们之前聊了什么?") + +if __name__ == '__main__': + args = None + args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'vicuna-13b-hf', '--no-remote-model', '--load-in-8bit']) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(dispatch(args)) diff --git a/models/base.py b/models/base.py new file mode 100644 index 00000000..c8f1c1b4 --- /dev/null +++ b/models/base.py @@ -0,0 +1,198 @@ +from abc import ABC, abstractmethod +from typing import Optional, List +import traceback +from collections import deque +from queue import Queue +from threading import Thread + +import torch +import transformers +from models.loader import LoaderCheckPoint + + +class ListenerToken: + """ + 观测结果 + """ + + input_ids: torch.LongTensor + _scores: torch.FloatTensor + + def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor): + self.input_ids = input_ids + self._scores = _scores + + +class AnswerResult: + """ + 消息实体 + """ + history: List[List[str]] = [] + llm_output: Optional[dict] = None + listenerToken: ListenerToken = None + + +class AnswerResultStream: + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, answerResult: AnswerResult): + if self.callback_func is not None: + self.callback_func(answerResult) + + +class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria): + """ + 定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult + 实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数, + 通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件 + 当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束 + 输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制 + """ + + listenerQueue: deque = deque(maxlen=1) + + def __init__(self): + transformers.StoppingCriteria.__init__(self) + + def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool: + """ + 每次响应时将数据添加到响应队列 + :param input_ids: + :param _scores: + :param kwargs: + :return: + """ + self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores)) + return False + + +class Iteratorize: + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + def __init__(self, func, kwargs={}): + self.mfunc = func + self.q = Queue() + self.sentinel = object() + self.kwargs = kwargs + self.stop_now = False + + def _callback(val): + """ + 模型输出预测结果收集 + 通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束 + 结束条件包含如下 + 1、模型预测结束、收集器self.q队列收到 self.sentinel标识 + 2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件 + 3、模型预测出错 + 因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为 + 迭代器收集的行为如下 + 创建Iteratorize迭代对象, + 定义generate_with_callback收集器AnswerResultStream + 启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer + _generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体 + 由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测 + 这时generate_with_callback会被阻塞 + 主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费 + 1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理 + 2、消息为self.sentinel标识,抛出StopIteration异常 + 主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新 + 异步线程检测stop_now属性被更新,抛出异常结束预测行为 + 迭代行为结束 + :param val: + :return: + """ + if self.stop_now: + raise ValueError + self.q.put(val) + + def gen(): + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + pass + except: + traceback.print_exc() + pass + + self.q.put(self.sentinel) + + self.thread = Thread(target=gen) + self.thread.start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True, None) + if obj is self.sentinel: + raise StopIteration + else: + return obj + + def __del__(self): + """ + 暂无实现 + :return: + """ + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ break 后会执行 """ + self.stop_now = True + + +class BaseAnswer(ABC): + """上层业务包装器.用于结果生成统一api调用""" + + @property + @abstractmethod + def _check_point(self) -> LoaderCheckPoint: + """Return _check_point of llm.""" + + @property + @abstractmethod + def _history_len(self) -> int: + """Return _history_len of llm.""" + + @abstractmethod + def set_history_len(self, history_len: int) -> None: + """Return _history_len of llm.""" + + def generatorAnswer(self, prompt: str, + history: List[List[str]] = [], + streaming: bool = False): + def generate_with_callback(callback=None, **kwargs): + kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback) + self._generate_answer(**kwargs) + + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs) + + """ + eos_token_id是指定token(例如,""), + 用于表示序列的结束。在生成文本任务中,生成器在生成序列时,将不断地生成token,直到生成此特殊的eos_token_id,表示序列生成已经完成。 + 在Hugging Face Transformer模型中,eos_token_id是由tokenizer自动添加到输入中的。 + 在模型生成输出时,如果模型生成了eos_token_id,则生成过程将停止并返回生成的序列。 + """ + eos_token_ids = [ + self._check_point.tokenizer.eos_token_id] if self._check_point.tokenizer.eos_token_id is not None else [] + + with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator: + for answerResult in generator: + if answerResult.listenerToken: + output = answerResult.listenerToken.input_ids + yield answerResult + + @abstractmethod + def _generate_answer(self, prompt: str, + history: List[List[str]] = [], + streaming: bool = False, + generate_with_callback: AnswerResultStream = None) -> None: + pass diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 49ce97c7..ae5b6dd8 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -1,189 +1,96 @@ -import json + +from abc import ABC + from langchain.llms.base import LLM -from typing import List, Dict, Optional -from transformers import AutoTokenizer, AutoModel, AutoConfig -import torch -from configs.model_config import * -from utils import torch_gc - -DEVICE_ = LLM_DEVICE -DEVICE_ID = "0" if torch.cuda.is_available() else None -DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_ +from typing import Optional, List +from models.loader import LoaderCheckPoint +from models.base import (BaseAnswer, + AnswerResult, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) -def auto_configure_device_map(num_gpus: int, use_lora: bool) -> Dict[str, int]: - # transformer.word_embeddings 占用1层 - # transformer.final_layernorm 和 lm_head 占用1层 - # transformer.layers 占用 28 层 - # 总共30层分配到num_gpus张卡上 - num_trans_layers = 28 - per_gpu_layers = 30 / num_gpus - - # bugfix: PEFT加载lora模型出现的层命名不同 - if LLM_LORA_PATH and use_lora: - layer_prefix = 'base_model.model.transformer' - else: - layer_prefix = 'transformer' - - # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError - # windows下 model.device 会被设置成 transformer.word_embeddings.device - # linux下 model.device 会被设置成 lm_head.device - # 在调用chat或者stream_chat时,input_ids会被放到model.device上 - # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError - # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 - device_map = {f'{layer_prefix}.word_embeddings': 0, - f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0, - f'base_model.model.lm_head': 0, } - - used = 2 - gpu_target = 0 - for i in range(num_trans_layers): - if used >= per_gpu_layers: - gpu_target += 1 - used = 0 - assert gpu_target < num_gpus - device_map[f'{layer_prefix}.layers.{i}'] = gpu_target - used += 1 - - return device_map +import transformers -class ChatGLM(LLM): +class ChatGLM(BaseAnswer, LLM, ABC): max_token: int = 10000 - temperature: float = 0.8 + temperature: float = 0.01 top_p = 0.9 + checkPoint: LoaderCheckPoint = None # history = [] - tokenizer: object = None - model: object = None history_len: int = 10 - def __init__(self): + def __init__(self, checkPoint: LoaderCheckPoint = None): super().__init__() + self.checkPoint = checkPoint @property def _llm_type(self) -> str: return "ChatGLM" - def _call(self, - prompt: str, - history: List[List[str]] = [], - streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]: + @property + def _check_point(self) -> LoaderCheckPoint: + return self.checkPoint + + @property + def _history_len(self) -> int: + return self.history_len + + def set_history_len(self, history_len: int = 10) -> None: + self.history_len = history_len + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + pass + + def _generate_answer(self, prompt: str, + history: List[List[str]] = [], + streaming: bool = False, + generate_with_callback: AnswerResultStream = None) -> None: + # Create the StoppingCriteriaList with the stopping strings + stopping_criteria_list = transformers.StoppingCriteriaList() + # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult + listenerQueue = AnswerResultQueueSentinelTokenListenerQueue() + stopping_criteria_list.append(listenerQueue) + if streaming: - for inum, (stream_resp, _) in enumerate(self.model.stream_chat( - self.tokenizer, + + for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat( + self.checkPoint.tokenizer, prompt, history=history[-self.history_len:-1] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, - top_p=self.top_p, + stopping_criteria=stopping_criteria_list )): - torch_gc() + self.checkPoint.clear_torch_cache() if inum == 0: history += [[prompt, stream_resp]] else: history[-1] = [prompt, stream_resp] - yield stream_resp, history - torch_gc() + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": stream_resp} + if listenerQueue.listenerQueue.__len__() > 0: + answer_result.listenerToken = listenerQueue.listenerQueue.pop() + generate_with_callback(answer_result) else: - response, _ = self.model.chat( - self.tokenizer, + response, _ = self.checkPoint.model.chat( + self.checkPoint.tokenizer, prompt, history=history[-self.history_len:] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, - top_p=self.top_p, + stopping_criteria=stopping_criteria_list ) - torch_gc() + self.checkPoint.clear_torch_cache() history += [[prompt, response]] - yield response, history - torch_gc() + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": response} + if listenerQueue.listenerQueue.__len__() > 0: + answer_result.listenerToken = listenerQueue.listenerQueue.pop() - # def chat(self, - # prompt: str) -> str: - # response, _ = self.model.chat( - # self.tokenizer, - # prompt, - # history=self.history[-self.history_len:] if self.history_len > 0 else [], - # max_length=self.max_token, - # temperature=self.temperature, - # ) - # torch_gc() - # self.history = self.history + [[None, response]] - # return response - - def load_model(self, - model_name_or_path: str = "THUDM/chatglm-6b", - llm_device=LLM_DEVICE, - use_ptuning_v2=False, - use_lora=False, - device_map: Optional[Dict[str, int]] = None, - **kwargs): - self.tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, - trust_remote_code=True - ) - - model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) - - if use_ptuning_v2: - try: - prefix_encoder_file = open('ptuning-v2/config.json', 'r') - prefix_encoder_config = json.loads(prefix_encoder_file.read()) - prefix_encoder_file.close() - model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] - model_config.prefix_projection = prefix_encoder_config['prefix_projection'] - except Exception as e: - logger.error(f"加载PrefixEncoder config.json失败: {e}") - self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True, - **kwargs) - if LLM_LORA_PATH and use_lora: - from peft import PeftModel - self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH) - - if torch.cuda.is_available() and llm_device.lower().startswith("cuda"): - # 根据当前设备GPU数量决定是否进行多卡部署 - num_gpus = torch.cuda.device_count() - if num_gpus < 2 and device_map is None: - self.model = self.model.half().cuda() - else: - from accelerate import dispatch_model - - # model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, - # config=model_config, **kwargs) - if LLM_LORA_PATH and use_lora: - from peft import PeftModel - model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH) - # 可传入device_map自定义每张卡的部署情况 - if device_map is None: - device_map = auto_configure_device_map(num_gpus, use_lora) - - self.model = dispatch_model(self.model.half(), device_map=device_map) - else: - self.model = self.model.float().to(llm_device) - - if use_ptuning_v2: - try: - prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin') - new_prefix_state_dict = {} - for k, v in prefix_state_dict.items(): - if k.startswith("transformer.prefix_encoder."): - new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v - self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) - self.model.transformer.prefix_encoder.float() - except Exception as e: - logger.error(f"加载PrefixEncoder模型参数失败:{e}") - - self.model = self.model.eval() + generate_with_callback(answer_result) -if __name__ == "__main__": - llm = ChatGLM() - llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL], - llm_device=LLM_DEVICE, ) - last_print_len = 0 - for resp, history in llm._call("你好", streaming=True): - logger.info(resp[last_print_len:], end="", flush=True) - last_print_len = len(resp) - for resp, history in llm._call("你好", streaming=False): - logger.info(resp) - pass diff --git a/models/extensions/callback.py b/models/extensions/callback.py new file mode 100644 index 00000000..61047796 --- /dev/null +++ b/models/extensions/callback.py @@ -0,0 +1,222 @@ +import gc +import traceback +from queue import Queue +from threading import Thread +import threading +from typing import Optional, List, Dict, Any +from collections import deque +import torch +import transformers + +from models.extensions.thread_with_exception import ThreadWithException +import models.shared as shared + + +class LimitedLengthDict(dict): + def __init__(self, maxlen=None, *args, **kwargs): + self.maxlen = maxlen + self._keys = deque() + super().__init__(*args, **kwargs) + + def __setitem__(self, key, value): + if key not in self: + if self.maxlen is not None and len(self) >= self.maxlen: + oldest_key = self._keys.popleft() + if oldest_key in self: + del self[oldest_key] + self._keys.append(key) + super().__setitem__(key, value) + + +class FixedLengthQueue: + + # 停止符号列表 + stop_sequence: Optional[str] = [] + # 缓冲区 + max_length: int = 0 + # 缓冲区容器 + queue: deque = None + # 输入区容器 + queue_in: LimitedLengthDict[int, str] = {} + # 输出区容器 + queue_out: Dict[int, str] = {} + + def __new__(cls, *args, **kwargs): + # 创建新的实例 + instance = super().__new__(cls) + # 在这里可以对实例进行额外的设置 + return instance + + def __init__(self, stop_sequence): + if stop_sequence is None: + self.stop_sequence = [] + self.max_length = 0 + elif isinstance(stop_sequence, str): + self.stop_sequence = [stop_sequence] + self.max_length = 1 + else: + self.stop_sequence = stop_sequence + self.max_length = len(''.join(stop_sequence)) + + self.queue = deque(maxlen=self.max_length) + self.queue.clear() + self.queue_in.clear() + self.queue_out.clear() + + def add(self, index, item): + self.queue_in[index] = item + + def _add_out(self, index, item): + self.queue_out[index] = item + + def put_replace_out(self, index): + return self.queue_out[index] + + def contains_replace_sequence(self): + """ + 替换字符 + :return: + """ + + for key, value in self.queue_in.items(): + + word_index = value.rfind(":") + if word_index != -1: + value = value.replace(":", ":") + + word_index = value.rfind("[") + if word_index != -1: + value = value.replace("[", "") + + word_index = value.rfind("]") + if word_index != -1: + value = value.replace("]", "") + + self._add_out(key, value) + + def contains_stop_sequence(self): + # 截取固定大小的数据判断 + self.queue.clear() + last_three_keys = list(self.queue_out.keys())[-self.max_length:] + joined_queue = ''.join([self.queue_out[key] for key in last_three_keys]) + for char in joined_queue: + self.queue.append(char) + + joined_queue = ''.join(self.queue) + # Initialize a variable to store the index of the last found stop string + last_stop_str_index = -1 + + # Iterate through the stop string list + for stop_word in self.stop_sequence: + # Find the last occurrence of the stop string in the output + stop_word_index = joined_queue.rfind(stop_word) + + # If the stop string is found, compare the index with the previously found index + if stop_word_index != -1 and stop_word_index > last_stop_str_index: + last_stop_str_index = stop_word_index + + # Handle the last found stop string index here + return last_stop_str_index + + def __repr__(self): + return str(self.queue) + + +# Copied from https://github.com/PygmalionAI/gradio-ui/ +class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): + + def __init__(self, sentinel_token_ids: list, starting_idx: int): + transformers.StoppingCriteria.__init__(self) + self.sentinel_token_ids = sentinel_token_ids + self.starting_idx = starting_idx + + def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool: + for sample in input_ids: + trimmed_sample = sample[self.starting_idx:] + + for i in range(len(self.sentinel_token_ids)): + # Can't unfold, output is still too tiny. Skip. + if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]: + continue + for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1): + if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)): + return True + return False + + +class Stream(transformers.StoppingCriteria): + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + if shared.stop_everything: + raise ValueError + if self.callback_func is not None: + self.callback_func(input_ids[0]) + return False + + +class Iteratorize: + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + thread: ThreadWithException = None + + def __new__(cls, *args, **kwargs): + # 创建新的实例 + instance = super().__new__(cls) + # 在这里可以对实例进行额外的设置 + return instance + + + def __init__(self, func, kwargs={}, callback=None): + self.mfunc = func + self.c_callback = callback + self.q = Queue() + self.sentinel = object() + self.kwargs = kwargs + + def _callback(val): + if shared.stop_everything: + raise ValueError + self.q.put(val) + + def gen(): + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + print("print(ValueError)") + except: + traceback.print_exc() + print("traceback.print_exc()") + self.q.put(self.sentinel) + + self.thread = ThreadWithException(target=gen) + self.thread.start() + + def __iter__(self): + shared.stop_everything = False + return self + + def __next__(self): + obj = self.q.get(True, None) + if obj is self.sentinel: + raise StopIteration + else: + return obj + + def __del__(self): + shared.stop_everything = False + self.q.empty() + shared.loaderCheckPoint.clear_torch_cache() + + def __enter__(self): + shared.stop_everything = False + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + shared.stop_everything = True + shared.loaderCheckPoint.clear_torch_cache() + self.thread.raise_exception() diff --git a/models/extensions/extensions.py b/models/extensions/extensions.py new file mode 100644 index 00000000..edd5c9e9 --- /dev/null +++ b/models/extensions/extensions.py @@ -0,0 +1,10 @@ +import gc +import traceback +import torch + +# This iterator returns the extensions in the order specified in the command-line +def iterator(): + state_extensions = {} + for name in sorted(state_extensions, key=lambda x: state_extensions[x][1]): + if state_extensions[name][0]: + yield getattr(extensions, name).script, name \ No newline at end of file diff --git a/models/extensions/llamacpp_model_alternative.py b/models/extensions/llamacpp_model_alternative.py new file mode 100644 index 00000000..6bdf9bc3 --- /dev/null +++ b/models/extensions/llamacpp_model_alternative.py @@ -0,0 +1,64 @@ +''' +Based on +https://github.com/abetlen/llama-cpp-python + +Documentation: +https://abetlen.github.io/llama-cpp-python/ +''' + +from llama_cpp import Llama, LlamaCache + +from modules import shared +from modules.callbacks import Iteratorize + + +class LlamaCppModel: + def __init__(self): + self.initialized = False + + @classmethod + def from_pretrained(self, path): + result = self() + + params = { + 'model_path': str(path), + 'n_ctx': 2048, + 'seed': 0, + 'n_threads': shared.args.threads or None + } + self.model = Llama(**params) + self.model.set_cache(LlamaCache) + + # This is ugly, but the model and the tokenizer are the same object in this library. + return result, result + + def encode(self, string): + if type(string) is str: + string = string.encode() + return self.model.tokenize(string) + + def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None): + if type(context) is str: + context = context.encode() + tokens = self.model.tokenize(context) + + output = b"" + count = 0 + for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty): + text = self.model.detokenize([token]) + output += text + if callback: + callback(text.decode()) + + count += 1 + if count >= token_count or (token == self.model.token_eos()): + break + + return output.decode() + + def generate_with_streaming(self, **kwargs): + with Iteratorize(self.generate, kwargs, callback=None) as generator: + reply = '' + for token in generator: + reply += token + yield reply diff --git a/models/extensions/thread_with_exception.py b/models/extensions/thread_with_exception.py new file mode 100644 index 00000000..d28b4800 --- /dev/null +++ b/models/extensions/thread_with_exception.py @@ -0,0 +1,30 @@ +# Python program raising +# exceptions in a python +# thread + +import threading +import ctypes +import time + + +class ThreadWithException(threading.Thread): + + def get_id(self): + return self.ident + + def raise_exception(self): + """raises the exception, performs cleanup if needed""" + try: + thread_id = self.get_id() + tid = ctypes.c_long(thread_id) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(SystemExit)) + if res == 0: + # pass + raise ValueError("invalid thread id") + elif res != 1: + # """if it returns a number greater than one, you're in trouble, + # and you should call it again with exc=NULL to revert the effect""" + ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) + raise SystemError("PyThreadState_SetAsyncExc failed") + except Exception as err: + print(err) diff --git a/models/llama_llm.py b/models/llama_llm.py new file mode 100644 index 00000000..4af0321c --- /dev/null +++ b/models/llama_llm.py @@ -0,0 +1,297 @@ +from abc import ABC + +from langchain.llms.base import LLM +import random +import torch +import transformers +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList +from typing import Optional, List, Dict, Any +from models.loader import LoaderCheckPoint +from models.extensions.callback import (Iteratorize, Stream, FixedLengthQueue) +import models.shared as shared +from models.base import (BaseAnswer, + AnswerResult, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) + + +def _streaming_response_template() -> Dict[str, Any]: + """ + :return: 响应结构 + """ + return { + "text": "" + } + + +def _update_response(response: Dict[str, Any], stream_response: str) -> None: + """Update response from the stream response.""" + response["text"] += stream_response + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class LLamaLLM(BaseAnswer, LLM, ABC): + checkPoint: LoaderCheckPoint = None + history = [] + history_len: int = 3 + max_new_tokens: int = 500 + num_beams: int = 1 + temperature: float = 0.5 + top_p: float = 0.4 + top_k: int = 10 + repetition_penalty: float = 1.2 + encoder_repetition_penalty: int = 1 + min_length: int = 0 + logits_processor: LogitsProcessorList = None + stopping_criteria: Optional[StoppingCriteriaList] = None + eos_token_id: Optional[int] = [2] + + state: object = {'max_new_tokens': 50, + 'seed': 1, + 'temperature': 0, 'top_p': 0.1, + 'top_k': 40, 'typical_p': 1, + 'repetition_penalty': 1.2, + 'encoder_repetition_penalty': 1, + 'no_repeat_ngram_size': 0, + 'min_length': 0, + 'penalty_alpha': 0, + 'num_beams': 1, + 'length_penalty': 1, + 'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False, + 'truncation_length': 2048, 'custom_stopping_strings': '', + 'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False, + 'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None', + 'pre_layer': 0, 'gpu_memory_0': 0} + + def __init__(self, checkPoint: LoaderCheckPoint = None): + super().__init__() + self.checkPoint = checkPoint + + @property + def _llm_type(self) -> str: + return "LLamaLLM" + + @property + def _check_point(self) -> LoaderCheckPoint: + return self.checkPoint + + def encode(self, prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): + input_ids = self.checkPoint.tokenizer.encode(str(prompt), return_tensors='pt', + add_special_tokens=add_special_tokens) + # This is a hack for making replies more creative. + if not add_bos_token and input_ids[0][0] == self.checkPoint.tokenizer.bos_token_id: + input_ids = input_ids[:, 1:] + + # Llama adds this extra token when the first character is '\n', and this + # compromises the stopping criteria, so we just remove it + if type(self.checkPoint.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: + input_ids = input_ids[:, 1:] + + # Handling truncation + if truncation_length is not None: + input_ids = input_ids[:, -truncation_length:] + + return input_ids.cuda() + + def decode(self, output_ids): + reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True) + return reply + + def generate_with_callback(self, callback=None, **kwargs): + self.checkPoint.clear_torch_cache() + kwargs['stopping_criteria'].append(Stream(callback_func=callback)) + with torch.no_grad(): + self.checkPoint.model.generate(**kwargs) + print("方法结束") + + def generate_with_streaming(self, **kwargs): + return Iteratorize(self.generate_with_callback, kwargs) + + # 将历史对话数组转换为文本格式 + def history_to_text(self, query): + formatted_history = '' + history = self.history[-self.history_len:] if self.history_len > 0 else [] + for i, (old_query, response) in enumerate(history): + formatted_history += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + formatted_history += "[Round {}]\n问:{}\n答:".format(len(history), query) + return formatted_history + + def prepare_inputs_for_generation(self, + input_ids: torch.LongTensor): + """ + 预生成注意力掩码和 输入序列中每个位置的索引的张量 + # TODO 没有思路 + :return: + """ + + mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device) + + attention_mask = self.get_masks(input_ids, input_ids.device) + + position_ids = self.get_position_ids( + input_ids, + device=input_ids.device, + mask_positions=mask_positions + ) + + return input_ids, position_ids, attention_mask + + def get_position_ids(self, input_ids: torch.LongTensor, mask_positions, device): + """ + 注意力偏移量 + :param input_ids: + :param mask_positions: + :param device: + :param use_gmasks: + :return: + """ + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids] + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + return position_ids + + def get_masks(self, input_ids, device): + """ + 获取注意力掩码 + :param input_ids: + :param device: + :return: + """ + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + return attention_mask + + def generate_softprompt_history_tensors(self, query): + """ + 历史对话软提示 + 这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history + 数组转换为所需的文本格式。然后,我们将格式化后的历史文本 + 再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。 + :return: + """ + + # 对话内容 + # 处理历史对话 + formatted_history = self.history_to_text(query) + return formatted_history + + @property + def _history_len(self) -> int: + return self.history_len + + def set_history_len(self, history_len: int = 10) -> None: + self.history_len = history_len + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + print(f"__call:{prompt}") + if self.logits_processor is None: + self.logits_processor = LogitsProcessorList() + self.logits_processor.append(InvalidScoreLogitsProcessor()) + + gen_kwargs = { + "max_new_tokens": self.max_new_tokens, + "num_beams": self.num_beams, + "top_p": self.top_p, + "top_k": self.top_k, + "repetition_penalty": self.repetition_penalty, + "encoder_repetition_penalty": self.encoder_repetition_penalty, + "min_length": self.min_length, + "temperature": self.temperature, + "eos_token_id": self.eos_token_id, + "logits_processor": self.logits_processor} + + # 向量拼接 + input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens) + # input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids) + + # 对话模型prompt + gen_kwargs.update({'inputs': input_ids}) + # 注意力掩码 + # gen_kwargs.update({'attention_mask': attention_mask}) + # gen_kwargs.update({'position_ids': position_ids}) + if self.stopping_criteria is None: + self.stopping_criteria = transformers.StoppingCriteriaList() + # 观测输出 + gen_kwargs.update({'stopping_criteria': self.stopping_criteria}) + shared.stop_everything = False + stopped = False + response_template = _streaming_response_template() + + # TODO 此流输出方法需要重写!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # stopping_criteria方法不可控制 迭代器的变量无法共享 + with self.generate_with_streaming(**gen_kwargs) as generator: + last_reply_len = 0 + reply_index = 0 + # Create a FixedLengthQueue with the desired stop sequence and a maximum length. + queue = FixedLengthQueue(stop) + for output in generator: + new_tokens = len(output) - len(input_ids[0]) + reply = self.decode(output[-new_tokens:]) + + new_reply = len(reply) - last_reply_len + output_reply = reply[-new_reply:] + queue.add(reply_index, output_reply) + queue.contains_replace_sequence() + if stop: + pos = queue.contains_stop_sequence() + if pos != -1: + shared.stop_everything = True + stopped = True + + #print(f"{reply_index}:reply {output_reply}") + english_reply = queue.put_replace_out(reply_index) + #print(f"{reply_index}:english_reply {english_reply}") + _update_response(response_template, english_reply) + last_reply_len = len(reply) + + reply_index += 1 + if new_tokens == self.max_new_tokens - 1 or stopped: + break + + response = response_template['text'] + print(f"response:{response}") + self.history = self.history + [[None, response]] + return response + + def _generate_answer(self, prompt: str, + history: List[List[str]] = [], + streaming: bool = False, + generate_with_callback: AnswerResultStream = None) -> None: + if history: + self.history = history + # Create the StoppingCriteriaList with the stopping strings + self.stopping_criteria = transformers.StoppingCriteriaList() + # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult + listenerQueue = AnswerResultQueueSentinelTokenListenerQueue() + self.stopping_criteria.append(listenerQueue) + # TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现 + softprompt = self.generate_softprompt_history_tensors(prompt) + response = self._call(prompt=softprompt, stop=['\n###']) + answer_result = AnswerResult() + answer_result.history = self.history + if listenerQueue.listenerQueue.__len__() > 0: + answer_result.listenerToken = listenerQueue.listenerQueue.pop() + answer_result.llm_output = {"answer": response} + generate_with_callback(answer_result) diff --git a/models/loader/__init__.py b/models/loader/__init__.py new file mode 100644 index 00000000..35c71e3c --- /dev/null +++ b/models/loader/__init__.py @@ -0,0 +1,2 @@ + +from .loader import * diff --git a/models/loader/args.py b/models/loader/args.py new file mode 100644 index 00000000..224bf6c6 --- /dev/null +++ b/models/loader/args.py @@ -0,0 +1,59 @@ +import argparse +import os + + + +# Additional argparse types +def path(string): + if not string: + return '' + s = os.path.expanduser(string) + if not os.path.exists(s): + raise argparse.ArgumentTypeError(f'No such file or directory: "{string}"') + return s + + +def file_path(string): + if not string: + return '' + s = os.path.expanduser(string) + if not os.path.isfile(s): + raise argparse.ArgumentTypeError(f'No such file: "{string}"') + return s + + +def dir_path(string): + if not string: + return '' + s = os.path.expanduser(string) + if not os.path.isdir(s): + raise argparse.ArgumentTypeError(f'No such directory: "{string}"') + return s + + +parser = argparse.ArgumentParser(prog='langchina-ChatGLM', + description='基于langchain和chatGML的LLM文档阅读器') + + + +parser.add_argument('--no-remote-model', action='store_true', default=False, help='remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model`') +parser.add_argument('--model', type=str, default='chatglm-6b', help='Name of the model to load by default.') +parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') +parser.add_argument("--model-dir", type=str, default='model/', help="Path to directory with all the models") +parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras") + +# Accelerate/transformers +parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') +parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') +parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.') +parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.') +parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') +parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') + + +args = parser.parse_args([]) +# Generares dict with a default value for each argument +DEFAULT_ARGS = vars(args) + + + diff --git a/models/loader/loader.py b/models/loader/loader.py new file mode 100644 index 00000000..c50f7d10 --- /dev/null +++ b/models/loader/loader.py @@ -0,0 +1,405 @@ +import gc +import json +import os +import re +import time +from pathlib import Path +from peft import PeftModel +from typing import Optional, List, Dict, Tuple, Union +import torch +import transformers + +from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, + AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer) +from transformers.dynamic_module_utils import get_class_from_dynamic_module +from transformers.modeling_utils import no_init_weights +from transformers.utils import ContextManagers +from accelerate import init_empty_weights +from accelerate.utils import get_balanced_memory, infer_auto_device_map + + +class LoaderCheckPoint: + """ + 加载自定义 model CheckPoint + """ + # remote in the model on loader checkpoint + no_remote_model: bool = False + # 模型名称 + model_name: str = None + tokenizer: object = None + # 模型全路径 + model_path: str = None + model: object = None + model_config: object = None + lora_names: set = [] + model_dir: str = None + lora_dir: str = None + ptuning_dir: str = None + use_ptuning_v2: bool = False + cpu: bool = False + gpu_memory: object = None + cpu_memory: object = None + auto_devices: object = True + # 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156 + load_in_8bit: bool = False + is_llamacpp: bool = False + bf16: bool = False + params: object = None + # 自定义设备网络 + device_map: Optional[Dict[str, int]] = None + # 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu + llm_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + + def __init__(self, params: dict = None): + """ + 模型初始化 + :param params: + """ + self.model_path = None + self.params = params or {} + self.no_remote_model = params.get('no_remote_model', False) + self.model_name = params.get('model', '') + self.lora = params.get('lora', '') + self.use_ptuning_v2 = params.get('use_ptuning_v2', False) + self.model = None + self.tokenizer = None + self.model_dir = params.get('model_dir', '') + self.lora_dir = params.get('lora_dir', '') + self.ptuning_dir = params.get('ptuning_dir', '') + self.cpu = params.get('cpu', False) + self.gpu_memory = params.get('gpu_memory', None) + self.cpu_memory = params.get('cpu_memory', None) + self.auto_devices = params.get('auto_devices', True) + self.load_in_8bit = params.get('load_in_8bit', False) + self.bf16 = params.get('bf16', False) + + def _load_model_config(self, model_name): + checkpoint = Path(f'{self.model_dir}/{model_name}') + + if self.model_path: + checkpoint = Path(f'{self.model_path}') + else: + if not self.no_remote_model: + checkpoint = model_name + + model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) + + return model_config + + def _load_model(self, model_name): + """ + 加载自定义位置的model + :param model_name: + :return: + """ + print(f"Loading {model_name}...") + t0 = time.time() + + checkpoint = Path(f'{self.model_dir}/{model_name}') + + self.is_llamacpp = len(list(checkpoint.glob('ggml*.bin'))) > 0 + + if self.model_path: + checkpoint = Path(f'{self.model_path}') + else: + if not self.no_remote_model: + checkpoint = model_name + + if 'chatglm' in model_name.lower(): + LoaderClass = AutoModel + else: + LoaderClass = AutoModelForCausalLM + + # Load the model in simple 16-bit mode by default + if not any([self.cpu, self.load_in_8bit, self.auto_devices, self.gpu_memory is not None, + self.cpu_memory is not None, self.is_llamacpp]): + + if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"): + # 根据当前设备GPU数量决定是否进行多卡部署 + num_gpus = torch.cuda.device_count() + if num_gpus < 2 and self.device_map is None: + model = ( + LoaderClass.from_pretrained(checkpoint, + low_cpu_mem_usage=True, + config=self.model_config, + torch_dtype=torch.bfloat16 if self.bf16 else torch.float16, + trust_remote_code=True) + .half() + .cuda() + ) + else: + from accelerate import dispatch_model + + model = LoaderClass.from_pretrained(checkpoint, + low_cpu_mem_usage=True, + config=self.model_config, + torch_dtype=torch.bfloat16 if self.bf16 else torch.float16, + trust_remote_code=True).half() + # 可传入device_map自定义每张卡的部署情况 + if self.device_map is None: + if 'chatglm' in model_name.lower(): + device_map = self.chatglm_auto_configure_device_map(num_gpus) + elif 'moss' in model_name.lower(): + device_map = self.moss_auto_configure_device_map(num_gpus,model_name) + else: + device_map = self.chatglm_auto_configure_device_map(num_gpus) + + model = dispatch_model(model, device_map=device_map) + else: + print( + "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n") + model = ( + AutoModel.from_pretrained( + checkpoint, + config=self.model_config, + trust_remote_code=True) + .float() + .to(self.llm_device) + ) + + elif self.is_llamacpp: + from models.extensions.llamacpp_model_alternative import LlamaCppModel + + model_file = list(checkpoint.glob('ggml*.bin'))[0] + print(f"llama.cpp weights detected: {model_file}\n") + + model, tokenizer = LlamaCppModel.from_pretrained(model_file) + return model, tokenizer + + # Custom + else: + params = {"low_cpu_mem_usage": True} + if not any((self.cpu, torch.cuda.is_available(), torch.has_mps)): + print( + "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n") + self.cpu = True + + if self.cpu: + params["torch_dtype"] = torch.float32 + else: + params["device_map"] = 'auto' + params["trust_remote_code"] = True + if self.load_in_8bit and any((self.auto_devices, self.gpu_memory)): + params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, + llm_int8_enable_fp32_cpu_offload=True) + elif self.load_in_8bit: + params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) + elif self.bf16: + params["torch_dtype"] = torch.bfloat16 + else: + params["torch_dtype"] = torch.float16 + + if self.gpu_memory: + memory_map = list(map(lambda x: x.strip(), self.gpu_memory)) + max_cpu_memory = self.cpu_memory.strip() if self.cpu_memory is not None else '99GiB' + max_memory = {} + for i in range(len(memory_map)): + max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else \ + memory_map[i] + max_memory['cpu'] = max_cpu_memory + params['max_memory'] = max_memory + elif self.auto_devices: + total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)) + suggestion = round((total_mem - 1000) / 1000) * 1000 + if total_mem - suggestion < 800: + suggestion -= 1000 + suggestion = int(round(suggestion / 1000)) + print( + f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") + + max_memory = {0: f'{suggestion}GiB', 'cpu': f'{self.cpu_memory or 99}GiB'} + params['max_memory'] = max_memory + + if self.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': + config = AutoConfig.from_pretrained(checkpoint) + with init_empty_weights(): + model = LoaderClass.from_config(config) + model.tie_weights() + if self.device_map is not None: + params['device_map'] = self.device_map + else: + params['device_map'] = infer_auto_device_map( + model, + dtype=torch.int8, + max_memory=params['max_memory'], + no_split_module_classes=model._no_split_modules + ) + + model = LoaderClass.from_pretrained(checkpoint, **params) + + # Loading the tokenizer + if type(model) is transformers.LlamaForCausalLM: + tokenizer = LlamaTokenizer.from_pretrained(checkpoint, clean_up_tokenization_spaces=True) + # Leaving this here until the LLaMA tokenizer gets figured out. + # For some people this fixes things, for others it causes an error. + try: + tokenizer.eos_token_id = 2 + tokenizer.bos_token_id = 1 + tokenizer.pad_token_id = 0 + except: + pass + else: + tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) + + print(f"Loaded the model in {(time.time() - t0):.2f} seconds.") + return model, tokenizer + + def chatglm_auto_configure_device_map(self, num_gpus: int) -> Dict[str, int]: + # transformer.word_embeddings 占用1层 + # transformer.final_layernorm 和 lm_head 占用1层 + # transformer.layers 占用 28 层 + # 总共30层分配到num_gpus张卡上 + num_trans_layers = 28 + per_gpu_layers = 30 / num_gpus + + # bugfix: PEFT加载lora模型出现的层命名不同 + if self.lora: + layer_prefix = 'base_model.model.transformer' + else: + layer_prefix = 'transformer' + + # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError + # windows下 model.device 会被设置成 transformer.word_embeddings.device + # linux下 model.device 会被设置成 lm_head.device + # 在调用chat或者stream_chat时,input_ids会被放到model.device上 + # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError + # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 + device_map = {f'{layer_prefix}.word_embeddings': 0, + f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0, + f'base_model.model.lm_head': 0, } + + used = 2 + gpu_target = 0 + for i in range(num_trans_layers): + if used >= per_gpu_layers: + gpu_target += 1 + used = 0 + assert gpu_target < num_gpus + device_map[f'{layer_prefix}.layers.{i}'] = gpu_target + used += 1 + + return device_map + + def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]: + checkpoint = Path(f'{self.model_dir}/{model_name}') + + if self.model_path: + checkpoint = Path(f'{self.model_path}') + else: + if not self.no_remote_model: + checkpoint = model_name + cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM", + pretrained_model_name_or_path=checkpoint) + + with ContextManagers([no_init_weights(_enable=True), init_empty_weights()]): + model = cls(self.model_config) + max_memory = get_balanced_memory(model, dtype=torch.int8 if self.load_in_8bit else None, + low_zero=False, no_split_module_classes=model._no_split_modules) + device_map = infer_auto_device_map( + model, dtype=torch.float16 if not self.load_in_8bit else torch.int8, max_memory=max_memory, + no_split_module_classes=model._no_split_modules) + device_map["transformer.wte"] = 0 + device_map["transformer.drop"] = 0 + device_map["transformer.ln_f"] = 0 + device_map["lm_head"] = 0 + return device_map + + def _add_lora_to_model(self, lora_names): + # 目前加载的lora + prior_set = set(self.lora_names) + # 需要加载的 + added_set = set(lora_names) - prior_set + # 删除的lora + removed_set = prior_set - set(lora_names) + self.lora_names = list(lora_names) + + # Nothing to do = skip. + if len(added_set) == 0 and len(removed_set) == 0: + return + + # Only adding, and already peft? Do it the easy way. + if len(removed_set) == 0 and len(prior_set) > 0: + print(f"Adding the LoRA(s) named {added_set} to the model...") + for lora in added_set: + self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora) + return + + # If removing anything, disable all and re-add. + if len(removed_set) > 0: + self.model.disable_adapter() + + if len(lora_names) > 0: + print("Applying the following LoRAs to {}: {}".format(self.model_name, ', '.join(lora_names))) + params = {} + if not self.cpu: + params['dtype'] = self.model.dtype + if hasattr(self.model, "hf_device_map"): + params['device_map'] = {"base_model.model." + k: v for k, v in self.model.hf_device_map.items()} + elif self.load_in_8bit: + params['device_map'] = {'': 0} + self.model.resize_token_embeddings(len(self.tokenizer)) + + self.model = PeftModel.from_pretrained(self.model, Path(f"{self.lora_dir}/{lora_names[0]}"), **params) + + for lora in lora_names[1:]: + self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora) + + if not self.load_in_8bit and not self.cpu: + + if not hasattr(self.model, "hf_device_map"): + if torch.has_mps: + device = torch.device('mps') + self.model = self.model.to(device) + else: + self.model = self.model.cuda() + + def clear_torch_cache(self): + gc.collect() + if not self.cpu: + device_id = "0" if torch.cuda.is_available() else None + CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device + with torch.cuda.device(CUDA_DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + def unload_model(self): + del self.model + del self.tokenizer + self.model = self.tokenizer = None + self.clear_torch_cache() + + def set_model_path(self, model_path): + self.model_path = model_path + + def reload_model(self): + self.unload_model() + self.model_config = self._load_model_config(self.model_name) + + if self.use_ptuning_v2: + try: + prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') + prefix_encoder_config = json.loads(prefix_encoder_file.read()) + prefix_encoder_file.close() + self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] + self.model_config.prefix_projection = prefix_encoder_config['prefix_projection'] + except Exception: + print("加载PrefixEncoder config.json失败") + + self.model, self.tokenizer = self._load_model(self.model_name) + + if self.lora: + self._add_lora_to_model([self.lora]) + + if self.use_ptuning_v2: + try: + prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v + self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + self.model.transformer.prefix_encoder.float() + except Exception: + print("加载PrefixEncoder模型参数失败") + + self.model = self.model.eval() diff --git a/models/moss_llm.py b/models/moss_llm.py index 343c79e6..c958baae 100644 --- a/models/moss_llm.py +++ b/models/moss_llm.py @@ -1,20 +1,14 @@ -import json +from abc import ABC + from langchain.llms.base import LLM -from typing import List, Dict, Optional -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig -from transformers.dynamic_module_utils import get_class_from_dynamic_module -from transformers.modeling_utils import no_init_weights -from transformers.utils import ContextManagers +from typing import Optional, List +from models.loader import LoaderCheckPoint +from models.base import (BaseAnswer, + AnswerResult, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) + import torch -from configs.model_config import * -from utils import torch_gc - -from accelerate import init_empty_weights -from accelerate.utils import get_balanced_memory, infer_auto_device_map - -DEVICE_ = LLM_DEVICE -DEVICE_ID = "0" if torch.cuda.is_available() else None -DEVICE = f"{DEVICE_}:{DEVICE_ID}" if DEVICE_ID else DEVICE_ META_INSTRUCTION = \ """You are an AI assistant whose name is MOSS. @@ -30,45 +24,40 @@ META_INSTRUCTION = \ """ -def auto_configure_device_map() -> Dict[str, int]: - cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM", - pretrained_model_name_or_path=llm_model_dict['moss']) - - with ContextManagers([no_init_weights(_enable=True), init_empty_weights()]): - model_config = AutoConfig.from_pretrained(llm_model_dict['moss'], trust_remote_code=True) - model = cls(model_config) - max_memory = get_balanced_memory(model, dtype=torch.int8 if LOAD_IN_8BIT else None, - low_zero=False, no_split_module_classes=model._no_split_modules) - device_map = infer_auto_device_map( - model, dtype=torch.float16 if not LOAD_IN_8BIT else torch.int8, max_memory=max_memory, - no_split_module_classes=model._no_split_modules) - device_map["transformer.wte"] = 0 - device_map["transformer.drop"] = 0 - device_map["transformer.ln_f"] = 0 - device_map["lm_head"] = 0 - return device_map - - -class MOSS(LLM): +class MOSSLLM(BaseAnswer, LLM, ABC): max_token: int = 2048 temperature: float = 0.7 top_p = 0.8 # history = [] - tokenizer: object = None - model: object = None + checkPoint: LoaderCheckPoint = None history_len: int = 10 - def __init__(self): + def __init__(self, checkPoint: LoaderCheckPoint = None): super().__init__() + self.checkPoint = checkPoint @property def _llm_type(self) -> str: return "MOSS" - def _call(self, - prompt: str, - history: List[List[str]] = [], - streaming: bool = STREAMING): # -> Tuple[str, List[List[str]]]: + @property + def _check_point(self) -> LoaderCheckPoint: + return self.checkPoint + + @property + def set_history_len(self) -> int: + return self.history_len + + def _set_history_len(self, history_len: int) -> None: + self.history_len = history_len + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + pass + + def _generate_answer(self, prompt: str, + history: List[List[str]] = [], + streaming: bool = False, + generate_with_callback: AnswerResultStream = None) -> None: if len(history) > 0: history = history[-self.history_len:-1] if self.history_len > 0 else [] prompt_w_history = str(history) @@ -77,9 +66,9 @@ class MOSS(LLM): prompt_w_history = META_INSTRUCTION prompt_w_history += '<|Human|>: ' + prompt + '' - inputs = self.tokenizer(prompt_w_history, return_tensors="pt") + inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt") with torch.no_grad(): - outputs = self.model.generate( + outputs = self.checkPoint.model.generate( inputs.input_ids.cuda(), attention_mask=inputs.attention_mask.cuda(), max_length=self.max_token, @@ -92,78 +81,8 @@ class MOSS(LLM): eos_token_id=106068, pad_token_id=self.tokenizer.pad_token_id) response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) - torch_gc() + self.checkPoint.clear_torch_cache() history += [[prompt, response]] yield response, history - torch_gc() - - def load_model(self, - model_name_or_path: str = "fnlp/moss-moon-003-sft", - llm_device=LLM_DEVICE, - use_ptuning_v2=False, - use_lora=False, - device_map: Optional[Dict[str, int]] = None, - **kwargs): - self.tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, - trust_remote_code=True - ) - - model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) - - if use_ptuning_v2: - try: - prefix_encoder_file = open('ptuning-v2/config.json', 'r') - prefix_encoder_config = json.loads(prefix_encoder_file.read()) - prefix_encoder_file.close() - model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] - model_config.prefix_projection = prefix_encoder_config['prefix_projection'] - except Exception as e: - print(e) - print("加载PrefixEncoder config.json失败") - - if torch.cuda.is_available() and llm_device.lower().startswith("cuda"): - # accelerate自动多卡部署 - self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=model_config, - load_in_8bit=LOAD_IN_8BIT, trust_remote_code=True, - device_map=auto_configure_device_map(), **kwargs) - - if LLM_LORA_PATH and use_lora: - from peft import PeftModel - self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH) - - else: - self.model = self.model.float().to(llm_device) - if LLM_LORA_PATH and use_lora: - from peft import PeftModel - self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH) - - if use_ptuning_v2: - try: - prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin') - new_prefix_state_dict = {} - for k, v in prefix_state_dict.items(): - if k.startswith("transformer.prefix_encoder."): - new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v - self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) - self.model.transformer.prefix_encoder.float() - except Exception as e: - print(e) - print("加载PrefixEncoder模型参数失败") - - self.model = self.model.eval() -if __name__ == "__main__": - llm = MOSS() - llm.load_model(model_name_or_path=llm_model_dict['moss'], - llm_device=LLM_DEVICE, ) - last_print_len = 0 - # for resp, history in llm._call("你好", streaming=True): - # print(resp[last_print_len:], end="", flush=True) - # last_print_len = len(resp) - for resp, history in llm._call("你好", streaming=False): - print(resp) - import time - time.sleep(10) - pass diff --git a/models/shared.py b/models/shared.py new file mode 100644 index 00000000..68325b0f --- /dev/null +++ b/models/shared.py @@ -0,0 +1,43 @@ +import sys + +from models.loader.args import parser +from models.loader import LoaderCheckPoint +from configs.model_config import (llm_model_dict, LLM_MODEL) +from models.base import BaseAnswer +"""迭代器是否停止状态""" +stop_everything = False + +loaderCheckPoint: LoaderCheckPoint = None + + +def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False) -> BaseAnswer: + """ + init llm_model_ins LLM + :param llm_model: model_name + :param no_remote_model: remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model + :param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder + :return: + """ + pre_model_name = loaderCheckPoint.model_name + llm_model_info = llm_model_dict[pre_model_name] + + if no_remote_model: + loaderCheckPoint.no_remote_model = no_remote_model + if use_ptuning_v2: + loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2 + + if llm_model: + llm_model_info = llm_model_dict[llm_model] + + if loaderCheckPoint.no_remote_model: + loaderCheckPoint.model_name = llm_model_info['name'] + else: + loaderCheckPoint.model_name = llm_model_info['remote-checkpoint'] + + loaderCheckPoint.model_path = llm_model_info['path'] + + loaderCheckPoint.reload_model() + + provides_class = getattr(sys.modules['models'], llm_model_info['provides']) + modelInsLLM = provides_class(checkPoint=loaderCheckPoint) + return modelInsLLM diff --git a/requirements.txt b/requirements.txt index cdd1e5b6..be2406c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,8 @@ fastapi uvicorn peft pypinyin -bitsandbytes click~=8.1.3 tabulate +bitsandbytes; platform_system != "Windows" +llama-cpp-python==0.1.34; platform_system != "Windows" +https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows" diff --git a/webui.py b/webui.py index c6d74bf6..5eb2a1a7 100644 --- a/webui.py +++ b/webui.py @@ -1,9 +1,17 @@ import gradio as gr import os import shutil + from chains.local_doc_qa import LocalDocQA from configs.model_config import * import nltk +from models.base import (BaseAnswer, + AnswerResult, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) +import models.shared as shared +from models.loader.args import parser +from models.loader import LoaderCheckPoint nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -69,7 +77,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR yield history + [[query, "请选择知识库后进行测试,当前未选择知识库。"]], "" else: - for resp, history in local_doc_qa.llm._call(query, history, streaming=streaming): + for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history, + streaming=streaming): + + resp = answer_result.llm_output["answer"] + history = answer_result.history history[-1][-1] = resp + ( "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") yield history, "" @@ -77,10 +89,12 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME) -def init_model(): +def init_model(llm_model: BaseAnswer = None): try: - local_doc_qa.init_cfg() - local_doc_qa.llm._call("你好") + local_doc_qa.init_cfg(llm_model=llm_model) + generator = local_doc_qa.llm.generatorAnswer("你好") + for answer_result in generator: + print(answer_result.llm_output) reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" logger.info(reply) return reply @@ -95,14 +109,13 @@ def init_model(): return reply -def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history): +def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history): try: - local_doc_qa.init_cfg(llm_model=llm_model, + llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) + llm_model_ins.history_len = llm_history_len + local_doc_qa.init_cfg(llm_model=llm_model_ins, embedding_model=embedding_model, - llm_history_len=llm_history_len, - use_ptuning_v2=use_ptuning_v2, - use_lora=use_lora, - top_k=top_k, ) + top_k=top_k) model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" logger.info(model_status) except Exception as e: @@ -219,7 +232,17 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI! 知识库暂不支持文件删除,该功能将在后续版本中推出。 """ -model_status = init_model() +# 初始化消息 +args = None +args = parser.parse_args() + +args_dict = vars(args) +shared.loaderCheckPoint = LoaderCheckPoint(args_dict) +llm_model_ins = shared.loaderLLM() +llm_model_ins.set_history_len(LLM_HISTORY_LEN) + +model_status = init_model(llm_model=llm_model_ins) + default_theme_args = dict( font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'], @@ -399,6 +422,10 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as label="LLM 模型", value=LLM_MODEL, interactive=True) + no_remote_model = gr.Checkbox(shared.LoaderCheckPoint.no_remote_model, + label="加载本地模型", + interactive=True) + llm_history_len = gr.Slider(0, 10, value=LLM_HISTORY_LEN, step=1, @@ -418,7 +445,7 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as label="向量匹配 top k", interactive=True) load_model_button = gr.Button("重新加载模型") load_model_button.click(reinit_model, show_progress=True, - inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, + inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, chatbot], outputs=chatbot) (demo