mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-28 01:33:17 +08:00
Merge branch 'dev' of https://github.com/imClumsyPanda/langchain-ChatGLM into dev
This commit is contained in:
commit
5eccb58759
@ -5,9 +5,10 @@ COPY . /chatGLM/
|
||||
|
||||
WORKDIR /chatGLM
|
||||
|
||||
RUN apt-get update -y && apt-get install python3 python3-pip curl -y && apt-get clean
|
||||
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && echo "Asia/Shanghai" > /etc/timezone
|
||||
RUN apt-get update -y && apt-get install python3 python3-pip curl libgl1 libglib2.0-0 -y && apt-get clean
|
||||
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py
|
||||
|
||||
RUN pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn && rm -rf `pip3 cache dir`
|
||||
RUN pip3 install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/ && rm -rf `pip3 cache dir`
|
||||
|
||||
CMD ["python3","-u", "webui.py"]
|
||||
CMD ["python3","-u", "webui.py"]
|
||||
|
||||
@ -129,6 +129,8 @@ VUE 前端界面如下图所示:
|
||||

|
||||
2. `知识库问答` 界面
|
||||

|
||||
3. `Bing搜索` 界面
|
||||

|
||||
|
||||
WebUI 界面如下图所示:
|
||||
1. `对话` Tab 界面
|
||||
@ -219,6 +221,6 @@ Web UI 可以实现如下功能:
|
||||
- [x] VUE 前端
|
||||
|
||||
## 项目交流群
|
||||

|
||||

|
||||
|
||||
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||
|
||||
@ -20,46 +20,7 @@ from models.loader import LoaderCheckPoint
|
||||
import models.shared as shared
|
||||
from agent import bing_search
|
||||
from langchain.docstore.document import Document
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
class SemanticSearch:
|
||||
def __init__(self):
|
||||
self.use= SentenceTransformer('GanymedeNil_text2vec-large-chinese')
|
||||
self.fitted = False
|
||||
|
||||
def fit(self, data, batch=100, n_neighbors=10):
|
||||
self.data = data
|
||||
self.embeddings = self.get_text_embedding(data, batch=batch)
|
||||
n_neighbors = min(n_neighbors, len(self.embeddings))
|
||||
self.nn = NearestNeighbors(n_neighbors=n_neighbors)
|
||||
self.nn.fit(self.embeddings)
|
||||
self.fitted = True
|
||||
|
||||
def __call__(self, text, return_data=True):
|
||||
inp_emb = self.use.encode([text])
|
||||
neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
|
||||
|
||||
if return_data:
|
||||
return [self.data[i] for i in neighbors]
|
||||
else:
|
||||
return neighbors
|
||||
|
||||
def get_text_embedding(self, texts, batch=100):
|
||||
embeddings = []
|
||||
for i in range(0, len(texts), batch):
|
||||
text_batch = texts[i : (i + batch)]
|
||||
emb_batch = self.use.encode(text_batch)
|
||||
embeddings.append(emb_batch)
|
||||
embeddings = np.vstack(embeddings)
|
||||
return embeddings
|
||||
|
||||
def get_docs_with_score(docs_with_score):
|
||||
docs = []
|
||||
for doc, score in docs_with_score:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def load_file(filepath, sentence_size=SENTENCE_SIZE):
|
||||
if filepath.lower().endswith(".md"):
|
||||
@ -301,41 +262,9 @@ class LocalDocQA:
|
||||
vector_store.chunk_conent = self.chunk_conent
|
||||
vector_store.score_threshold = self.score_threshold
|
||||
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
|
||||
|
||||
###########################################精排 之前faiss检索作为粗排 需要设置model_config参数VECTOR_SEARCH_TOP_K =300
|
||||
###########################################原理:粗排:faiss+semantic search 检索得到大量相关文档,需要设置ECTOR_SEARCH_TOP为300,然后合并文档,重新切分,
|
||||
#############################################利用knn+ semantic search 进行二次检索,输入到prompt
|
||||
####提取文档
|
||||
related_docs = get_docs_with_score(related_docs_with_score)
|
||||
text_batch0=[]
|
||||
for i in range(len(related_docs)):
|
||||
cut_txt = " ".join([w for w in list(related_docs[i].page_content)])
|
||||
cut_txt =cut_txt.replace(" ", "")
|
||||
text_batch0.append(cut_txt)
|
||||
######文档去重
|
||||
text_batch_new=[]
|
||||
for i in range(len(text_batch0)):
|
||||
if text_batch0[i] in text_batch_new:
|
||||
continue
|
||||
else:
|
||||
while text_batch_new and text_batch_new[-1] > text_batch0[i] and text_batch_new[-1] in text_batch0[i + 1:]:
|
||||
text_batch_new.pop() # 弹出栈顶元素
|
||||
text_batch_new.append(text_batch0[i])
|
||||
text_batch_new0 = "\n".join([doc for doc in text_batch_new])
|
||||
###精排 采用knn和semantic search
|
||||
recommender = SemanticSearch()
|
||||
chunks = text_to_chunks(text_batch_new0, start_page=1)
|
||||
recommender.fit(chunks)
|
||||
topn_chunks = recommender(query)
|
||||
torch_gc()
|
||||
#去掉文字中的空格
|
||||
topn_chunks0=[]
|
||||
for i in range(len(topn_chunks)):
|
||||
cut_txt =topn_chunks[i].replace(" ", "")
|
||||
topn_chunks0.append(cut_txt)
|
||||
############生成prompt
|
||||
prompt = generate_prompt(topn_chunks0, query)
|
||||
########################
|
||||
prompt = generate_prompt(related_docs_with_score, query)
|
||||
|
||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
||||
streaming=streaming):
|
||||
resp = answer_result.llm_output["answer"]
|
||||
|
||||
@ -62,11 +62,33 @@ llm_model_dict = {
|
||||
"pretrained_model_name": "fnlp/moss-moon-003-sft",
|
||||
"local_model_path": None,
|
||||
"provides": "MOSSLLM"
|
||||
},
|
||||
"vicuna-13b-hf": {
|
||||
"name": "vicuna-13b-hf",
|
||||
"pretrained_model_name": "vicuna-13b-hf",
|
||||
"local_model_path": None,
|
||||
"provides": "LLamaLLM"
|
||||
},
|
||||
"fastChat": {
|
||||
"name": "fastChat",
|
||||
"pretrained_model_name": "fastChat",
|
||||
"local_model_path": None,
|
||||
"provides": "FastChatLLM"
|
||||
}
|
||||
}
|
||||
|
||||
# LLM model name
|
||||
# LLM 名称
|
||||
LLM_MODEL = "chatglm-6b"
|
||||
# 如果你需要加载本地的model,指定这个参数 ` --no-remote-model`,或者下方参数修改为 `True`
|
||||
NO_REMOTE_MODEL = False
|
||||
# 量化加载8bit 模型
|
||||
LOAD_IN_8BIT = False
|
||||
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
||||
BF16 = False
|
||||
# 本地模型存放的位置
|
||||
MODEL_DIR = "model/"
|
||||
# 本地lora存放的位置
|
||||
LORA_DIR = "loras/"
|
||||
|
||||
# LLM lora path,默认为空,如果有请直接指定文件夹路径
|
||||
LLM_LORA_PATH = ""
|
||||
|
||||
@ -1 +0,0 @@
|
||||
from .fastchat_api import *
|
||||
@ -1,261 +0,0 @@
|
||||
"""
|
||||
Conversation prompt template.
|
||||
|
||||
Now we support
|
||||
- Vicuna
|
||||
- Koala
|
||||
- OpenAssistant/oasst-sft-1-pythia-12b
|
||||
- StabilityAI/stablelm-tuned-alpha-7b
|
||||
- databricks/dolly-v2-12b
|
||||
- THUDM/chatglm-6b
|
||||
- Alpaca/LLaMa
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
DOLLY = auto()
|
||||
OASST_PYTHIA = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
|
||||
# Used for gradio server
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += self.sep + " " + role + ": " + message
|
||||
else:
|
||||
ret += self.sep + " " + role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ":\n" + message + seps[i % 2]
|
||||
if i % 2 == 1:
|
||||
ret += "\n\n"
|
||||
else:
|
||||
ret += role + ":\n"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + message + self.sep
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
conv_id=self.conv_id,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id,
|
||||
}
|
||||
|
||||
|
||||
conv_one_shot = Conversation(
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(
|
||||
(
|
||||
"Human",
|
||||
"What are the key differences between renewable and non-renewable energy sources?",
|
||||
),
|
||||
(
|
||||
"Assistant",
|
||||
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
||||
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
||||
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
||||
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
||||
"renewable and non-renewable energy sources:\n"
|
||||
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
||||
"energy sources are finite and will eventually run out.\n"
|
||||
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
||||
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
||||
"and other negative effects.\n"
|
||||
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
||||
"have lower operational costs than non-renewable sources.\n"
|
||||
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
||||
"locations than non-renewable sources.\n"
|
||||
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
||||
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
||||
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
||||
"non-renewable sources are not, and their depletion can lead to economic and social instability.",
|
||||
),
|
||||
),
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
|
||||
conv_vicuna_v1_1 = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
|
||||
conv_koala_v1 = Conversation(
|
||||
system="BEGINNING OF CONVERSATION:",
|
||||
roles=("USER", "GPT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
conv_dolly = Conversation(
|
||||
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
|
||||
roles=("### Instruction", "### Response"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.DOLLY,
|
||||
sep="\n\n",
|
||||
sep2="### End",
|
||||
)
|
||||
|
||||
conv_oasst = Conversation(
|
||||
system="",
|
||||
roles=("<|prompter|>", "<|assistant|>"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.OASST_PYTHIA,
|
||||
sep="<|endoftext|>",
|
||||
)
|
||||
|
||||
conv_stablelm = Conversation(
|
||||
system="""<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
""",
|
||||
roles=("<|USER|>", "<|ASSISTANT|>"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.OASST_PYTHIA,
|
||||
sep="",
|
||||
)
|
||||
|
||||
conv_templates = {
|
||||
"conv_one_shot": conv_one_shot,
|
||||
"vicuna_v1.1": conv_vicuna_v1_1,
|
||||
"koala_v1": conv_koala_v1,
|
||||
"dolly": conv_dolly,
|
||||
"oasst": conv_oasst,
|
||||
}
|
||||
|
||||
|
||||
def get_default_conv_template(model_name):
|
||||
model_name = model_name.lower()
|
||||
if "vicuna" in model_name or "output" in model_name:
|
||||
return conv_vicuna_v1_1
|
||||
elif "koala" in model_name:
|
||||
return conv_koala_v1
|
||||
elif "dolly-v2" in model_name:
|
||||
return conv_dolly
|
||||
elif "oasst" in model_name and "pythia" in model_name:
|
||||
return conv_oasst
|
||||
elif "stablelm" in model_name:
|
||||
return conv_stablelm
|
||||
return conv_one_shot
|
||||
|
||||
|
||||
def compute_skip_echo_len(model_name, conv, prompt):
|
||||
model_name = model_name.lower()
|
||||
if "chatglm" in model_name:
|
||||
skip_echo_len = len(conv.messages[-2][1]) + 1
|
||||
elif "dolly-v2" in model_name:
|
||||
special_toks = ["### Instruction:", "### Response:", "### End"]
|
||||
skip_echo_len = len(prompt)
|
||||
for tok in special_toks:
|
||||
skip_echo_len -= prompt.count(tok) * len(tok)
|
||||
elif "oasst" in model_name and "pythia" in model_name:
|
||||
special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
|
||||
skip_echo_len = len(prompt)
|
||||
for tok in special_toks:
|
||||
skip_echo_len -= prompt.count(tok) * len(tok)
|
||||
elif "stablelm" in model_name:
|
||||
special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
|
||||
skip_echo_len = len(prompt)
|
||||
for tok in special_toks:
|
||||
skip_echo_len -= prompt.count(tok) * len(tok)
|
||||
else:
|
||||
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
|
||||
return skip_echo_len
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(default_conversation.get_prompt())
|
||||
@ -1,459 +0,0 @@
|
||||
"""Wrapper around FastChat APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
FAST_CHAT_API = "http://localhost:21002/worker_generate_stream"
|
||||
|
||||
|
||||
def _streaming_response_template() -> Dict[str, Any]:
|
||||
"""
|
||||
:return: 响应结构
|
||||
"""
|
||||
return {
|
||||
"text": "",
|
||||
"error_code": 0,
|
||||
}
|
||||
|
||||
|
||||
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
||||
"""Update response from the stream response."""
|
||||
response["text"] += stream_response["text"]
|
||||
response["error_code"] += stream_response["error_code"]
|
||||
|
||||
|
||||
class BaseFastChat(BaseLLM):
|
||||
"""Wrapper around FastChat large language models."""
|
||||
|
||||
model_name: str = "text-davinci-003"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
max_new_tokens: int = 200
|
||||
stop: int = 20
|
||||
batch_size: int = 20
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Penalizes repeated tokens."""
|
||||
n: int = 1
|
||||
"""Whether to stream the results or not."""
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
||||
"""Set of special tokens that are allowed。"""
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
||||
"""Set of special tokens that are not allowed。"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.ignore
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name not in all_required_field_names:
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transfered to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling FastChat API."""
|
||||
normal_params = {
|
||||
"model": self.model_name,
|
||||
"prompt": '',
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
return {**normal_params}
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Call out to FastChat's endpoint with k unique prompts.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The full LLM output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = fastchat.generate(["Tell me a joke."])
|
||||
"""
|
||||
# TODO: write a unit test for this
|
||||
params = self._invocation_params
|
||||
sub_prompts = self.get_sub_prompts(params, prompts)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
headers = {"User-Agent": "fastchat Client"}
|
||||
for _prompts in sub_prompts:
|
||||
|
||||
params["prompt"] = _prompts[0]
|
||||
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
|
||||
response_template = _streaming_response_template()
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
self.callback_manager.on_llm_new_token(
|
||||
output,
|
||||
verbose=self.verbose,
|
||||
logprobs=data["error_code"],
|
||||
)
|
||||
_update_response(response_template, data)
|
||||
choices.append(response_template)
|
||||
else:
|
||||
response_template = _streaming_response_template()
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
_update_response(response_template, data)
|
||||
|
||||
choices.append(response_template)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Call out to FastChat's endpoint async with k unique prompts."""
|
||||
params = self._invocation_params
|
||||
sub_prompts = self.get_sub_prompts(params, prompts)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
|
||||
headers = {"User-Agent": "fastchat Client"}
|
||||
for _prompts in sub_prompts:
|
||||
|
||||
params["prompt"] = _prompts[0]
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
|
||||
response_template = _streaming_response_template()
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
self.callback_manager.on_llm_new_token(
|
||||
output,
|
||||
verbose=self.verbose,
|
||||
logprobs=data["error_code"],
|
||||
)
|
||||
_update_response(response_template, data)
|
||||
choices.append(response_template)
|
||||
else:
|
||||
response_template = _streaming_response_template()
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
_update_response(response_template, data)
|
||||
|
||||
choices.append(response_template)
|
||||
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
|
||||
def get_sub_prompts(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
) -> List[List[str]]:
|
||||
"""Get the sub prompts for llm call."""
|
||||
if params["max_new_tokens"] == -1:
|
||||
if len(prompts) != 1:
|
||||
raise ValueError(
|
||||
"max_new_tokens set to -1 not supported for multiple inputs."
|
||||
)
|
||||
params["max_new_tokens"] = self.max_new_tokens_for_prompt(prompts[0])
|
||||
# append pload
|
||||
sub_prompts = [
|
||||
prompts[i: i + self.batch_size]
|
||||
for i in range(0, len(prompts), self.batch_size)
|
||||
]
|
||||
|
||||
return sub_prompts
|
||||
|
||||
def create_llm_result(
|
||||
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
||||
) -> LLMResult:
|
||||
"""Create the LLMResult from the choices and prompts."""
|
||||
generations = []
|
||||
for i, _ in enumerate(prompts):
|
||||
sub_choices = choices[i * self.n: (i + 1) * self.n]
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=choice["text"],
|
||||
generation_info=dict(
|
||||
finish_reason='over',
|
||||
logprobs=choice["text"],
|
||||
),
|
||||
)
|
||||
for choice in sub_choices
|
||||
]
|
||||
)
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
||||
"""Call FastChat with streaming flag and return the resulting generator.
|
||||
|
||||
BETA: this is a beta feature while we figure out the right abstraction.
|
||||
Once that happens, this interface could change.
|
||||
|
||||
Args:
|
||||
prompt: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
A generator representing the stream of tokens from OpenAI.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
generator = fastChat.stream("Tell me a joke.")
|
||||
for token in generator:
|
||||
yield token
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["prompt"] = prompt
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
|
||||
headers = {"User-Agent": "fastchat Client"}
|
||||
response = requests.post(
|
||||
FAST_CHAT_API,
|
||||
headers=headers,
|
||||
json=params,
|
||||
stream=True,
|
||||
)
|
||||
for stream_resp in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if stream_resp:
|
||||
data = json.loads(stream_resp.decode("utf-8"))
|
||||
skip_echo_len = len(_prompts[0])
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
data["text"] = output
|
||||
yield data
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fastChat"
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate num tokens with tiktoken package."""
|
||||
# tiktoken NOT supported for Python < 3.8
|
||||
if sys.version_info[1] < 8:
|
||||
return super().get_num_tokens(text)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.model_name)
|
||||
|
||||
tokenized_text = enc.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def modelname_to_contextsize(self, modelname: str) -> int:
|
||||
"""Calculate the maximum number of tokens possible to generate for a model.
|
||||
|
||||
Args:
|
||||
modelname: The modelname we want to know the context size for.
|
||||
|
||||
Returns:
|
||||
The maximum context size
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
max_new_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
||||
"""
|
||||
model_token_mapping = {
|
||||
"vicuna-13b": 2049,
|
||||
"koala": 2049,
|
||||
"dolly-v2": 2049,
|
||||
"oasst": 2049,
|
||||
"stablelm": 2049,
|
||||
}
|
||||
|
||||
context_size = model_token_mapping.get(modelname, None)
|
||||
|
||||
if context_size is None:
|
||||
raise ValueError(
|
||||
f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
|
||||
"Known models are: " + ", ".join(model_token_mapping.keys())
|
||||
)
|
||||
|
||||
return context_size
|
||||
|
||||
def max_new_tokens_for_prompt(self, prompt: str) -> int:
|
||||
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
The maximum number of tokens to generate for a prompt.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
max_new_tokens = openai.max_token_for_prompt("Tell me a joke.")
|
||||
"""
|
||||
num_tokens = self.get_num_tokens(prompt)
|
||||
|
||||
# get max context size for model by name
|
||||
max_size = self.modelname_to_contextsize(self.model_name)
|
||||
return max_size - num_tokens
|
||||
|
||||
|
||||
class FastChat(BaseFastChat):
|
||||
"""Wrapper around OpenAI large language models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``OPENAI_API_KEY`` set with your API key.
|
||||
|
||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OpenAI
|
||||
openai = FastChat(model_name="vicuna")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**{"model": self.model_name}, **super()._invocation_params}
|
||||
BIN
img/qr_code_22.jpg
Normal file
BIN
img/qr_code_22.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 274 KiB |
BIN
img/vue_0521_2.png
Normal file
BIN
img/vue_0521_2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.7 MiB |
@ -1,3 +1,4 @@
|
||||
from .chatglm_llm import ChatGLM
|
||||
# from .llama_llm import LLamaLLM
|
||||
from .llama_llm import LLamaLLM
|
||||
from .moss_llm import MOSSLLM
|
||||
from .fastchat_llm import FastChatLLM
|
||||
|
||||
@ -175,15 +175,6 @@ class BaseAnswer(ABC):
|
||||
def generate_with_streaming(**kwargs):
|
||||
return Iteratorize(generate_with_callback, kwargs)
|
||||
|
||||
"""
|
||||
eos_token_id是指定token(例如,"</s>"),
|
||||
用于表示序列的结束。在生成文本任务中,生成器在生成序列时,将不断地生成token,直到生成此特殊的eos_token_id,表示序列生成已经完成。
|
||||
在Hugging Face Transformer模型中,eos_token_id是由tokenizer自动添加到输入中的。
|
||||
在模型生成输出时,如果模型生成了eos_token_id,则生成过程将停止并返回生成的序列。
|
||||
"""
|
||||
eos_token_ids = [
|
||||
self._check_point.tokenizer.eos_token_id] if self._check_point.tokenizer.eos_token_id is not None else []
|
||||
|
||||
with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator:
|
||||
for answerResult in generator:
|
||||
if answerResult.listenerToken:
|
||||
|
||||
@ -1,223 +0,0 @@
|
||||
# import gc
|
||||
import traceback
|
||||
from queue import Queue
|
||||
# from threading import Thread
|
||||
# import threading
|
||||
from typing import Optional, List, Dict, Any, TypeVar, Deque
|
||||
from collections import deque
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from models.extensions.thread_with_exception import ThreadWithException
|
||||
import models.shared as shared
|
||||
|
||||
|
||||
K = TypeVar('K')
|
||||
V = TypeVar('V')
|
||||
|
||||
class LimitedLengthDict(Dict[K, V]):
|
||||
def __init__(self, maxlen=None, *args, **kwargs):
|
||||
self.maxlen = maxlen
|
||||
self._keys: Deque[K] = deque()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __setitem__(self, key: K, value: V):
|
||||
if key not in self:
|
||||
if self.maxlen is not None and len(self) >= self.maxlen:
|
||||
oldest_key = self._keys.popleft()
|
||||
if oldest_key in self:
|
||||
del self[oldest_key]
|
||||
self._keys.append(key)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class FixedLengthQueue:
|
||||
# 停止符号列表
|
||||
stop_sequence: Optional[str] = []
|
||||
# 缓冲区
|
||||
max_length: int = 0
|
||||
# 缓冲区容器
|
||||
queue: deque = None
|
||||
# 输入区容器
|
||||
queue_in: LimitedLengthDict[int, str] = {}
|
||||
# 输出区容器
|
||||
queue_out: Dict[int, str] = {}
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# 创建新的实例
|
||||
instance = super().__new__(cls)
|
||||
# 在这里可以对实例进行额外的设置
|
||||
return instance
|
||||
|
||||
def __init__(self, stop_sequence):
|
||||
if stop_sequence is None:
|
||||
self.stop_sequence = []
|
||||
self.max_length = 0
|
||||
elif isinstance(stop_sequence, str):
|
||||
self.stop_sequence = [stop_sequence]
|
||||
self.max_length = 1
|
||||
else:
|
||||
self.stop_sequence = stop_sequence
|
||||
self.max_length = len(''.join(stop_sequence))
|
||||
|
||||
self.queue = deque(maxlen=self.max_length)
|
||||
self.queue.clear()
|
||||
self.queue_in.clear()
|
||||
self.queue_out.clear()
|
||||
|
||||
def add(self, index, item):
|
||||
self.queue_in[index] = item
|
||||
|
||||
def _add_out(self, index, item):
|
||||
self.queue_out[index] = item
|
||||
|
||||
def put_replace_out(self, index):
|
||||
return self.queue_out[index]
|
||||
|
||||
def contains_replace_sequence(self):
|
||||
"""
|
||||
替换字符
|
||||
:return:
|
||||
"""
|
||||
|
||||
for key, value in self.queue_in.items():
|
||||
|
||||
word_index = value.rfind(":")
|
||||
if word_index != -1:
|
||||
value = value.replace(":", ":")
|
||||
|
||||
word_index = value.rfind("[")
|
||||
if word_index != -1:
|
||||
value = value.replace("[", "")
|
||||
|
||||
word_index = value.rfind("]")
|
||||
if word_index != -1:
|
||||
value = value.replace("]", "")
|
||||
|
||||
self._add_out(key, value)
|
||||
|
||||
def contains_stop_sequence(self):
|
||||
# 截取固定大小的数据判断
|
||||
self.queue.clear()
|
||||
last_three_keys = list(self.queue_out.keys())[-self.max_length:]
|
||||
joined_queue = ''.join([self.queue_out[key] for key in last_three_keys])
|
||||
for char in joined_queue:
|
||||
self.queue.append(char)
|
||||
|
||||
joined_queue = ''.join(self.queue)
|
||||
# Initialize a variable to store the index of the last found stop string
|
||||
last_stop_str_index = -1
|
||||
|
||||
# Iterate through the stop string list
|
||||
for stop_word in self.stop_sequence:
|
||||
# Find the last occurrence of the stop string in the output
|
||||
stop_word_index = joined_queue.rfind(stop_word)
|
||||
|
||||
# If the stop string is found, compare the index with the previously found index
|
||||
if stop_word_index != -1 and stop_word_index > last_stop_str_index:
|
||||
last_stop_str_index = stop_word_index
|
||||
|
||||
# Handle the last found stop string index here
|
||||
return last_stop_str_index
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.queue)
|
||||
|
||||
|
||||
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
|
||||
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
self.sentinel_token_ids = sentinel_token_ids
|
||||
self.starting_idx = starting_idx
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||
for sample in input_ids:
|
||||
trimmed_sample = sample[self.starting_idx:]
|
||||
|
||||
for i in range(len(self.sentinel_token_ids)):
|
||||
# Can't unfold, output is still too tiny. Skip.
|
||||
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
|
||||
continue
|
||||
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
|
||||
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
if shared.stop_everything:
|
||||
raise ValueError
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(input_ids[0])
|
||||
return False
|
||||
|
||||
|
||||
class Iteratorize:
|
||||
"""
|
||||
Transforms a function that takes a callback
|
||||
into a lazy iterator (generator).
|
||||
"""
|
||||
|
||||
thread: ThreadWithException = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# 创建新的实例
|
||||
instance = super().__new__(cls)
|
||||
# 在这里可以对实例进行额外的设置
|
||||
return instance
|
||||
|
||||
def __init__(self, func, kwargs={}, callback=None):
|
||||
self.mfunc = func
|
||||
self.c_callback = callback
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _callback(val):
|
||||
if shared.stop_everything:
|
||||
raise ValueError
|
||||
self.q.put(val)
|
||||
|
||||
def gen():
|
||||
try:
|
||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||
except ValueError:
|
||||
print("print(ValueError)")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print("traceback.print_exc()")
|
||||
self.q.put(self.sentinel)
|
||||
|
||||
self.thread = ThreadWithException(target=gen)
|
||||
self.thread.start()
|
||||
|
||||
def __iter__(self):
|
||||
shared.stop_everything = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True, None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
return obj
|
||||
|
||||
def __del__(self):
|
||||
shared.stop_everything = False
|
||||
self.q.empty()
|
||||
shared.loaderCheckPoint.clear_torch_cache()
|
||||
|
||||
def __enter__(self):
|
||||
shared.stop_everything = False
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
shared.stop_everything = True
|
||||
shared.loaderCheckPoint.clear_torch_cache()
|
||||
self.thread.raise_exception()
|
||||
@ -1,10 +0,0 @@
|
||||
import gc
|
||||
import traceback
|
||||
import torch
|
||||
|
||||
# This iterator returns the extensions in the order specified in the command-line
|
||||
def iterator():
|
||||
state_extensions = {}
|
||||
for name in sorted(state_extensions, key=lambda x: state_extensions[x][1]):
|
||||
if state_extensions[name][0]:
|
||||
yield getattr(extensions, name).script, name
|
||||
@ -1,64 +0,0 @@
|
||||
'''
|
||||
Based on
|
||||
https://github.com/abetlen/llama-cpp-python
|
||||
|
||||
Documentation:
|
||||
https://abetlen.github.io/llama-cpp-python/
|
||||
'''
|
||||
|
||||
from llama_cpp import Llama, LlamaCache
|
||||
|
||||
from modules import shared
|
||||
from modules.callbacks import Iteratorize
|
||||
|
||||
|
||||
class LlamaCppModel:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, path):
|
||||
result = self()
|
||||
|
||||
params = {
|
||||
'model_path': str(path),
|
||||
'n_ctx': 2048,
|
||||
'seed': 0,
|
||||
'n_threads': shared.args.threads or None
|
||||
}
|
||||
self.model = Llama(**params)
|
||||
self.model.set_cache(LlamaCache)
|
||||
|
||||
# This is ugly, but the model and the tokenizer are the same object in this library.
|
||||
return result, result
|
||||
|
||||
def encode(self, string):
|
||||
if type(string) is str:
|
||||
string = string.encode()
|
||||
return self.model.tokenize(string)
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
|
||||
if type(context) is str:
|
||||
context = context.encode()
|
||||
tokens = self.model.tokenize(context)
|
||||
|
||||
output = b""
|
||||
count = 0
|
||||
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
|
||||
text = self.model.detokenize([token])
|
||||
output += text
|
||||
if callback:
|
||||
callback(text.decode())
|
||||
|
||||
count += 1
|
||||
if count >= token_count or (token == self.model.token_eos()):
|
||||
break
|
||||
|
||||
return output.decode()
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||
reply = ''
|
||||
for token in generator:
|
||||
reply += token
|
||||
yield reply
|
||||
@ -1,30 +0,0 @@
|
||||
# Python program raising
|
||||
# exceptions in a python
|
||||
# thread
|
||||
|
||||
import threading
|
||||
import ctypes
|
||||
import time
|
||||
|
||||
|
||||
class ThreadWithException(threading.Thread):
|
||||
|
||||
def get_id(self):
|
||||
return self.ident
|
||||
|
||||
def raise_exception(self):
|
||||
"""raises the exception, performs cleanup if needed"""
|
||||
try:
|
||||
thread_id = self.get_id()
|
||||
tid = ctypes.c_long(thread_id)
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(SystemExit))
|
||||
if res == 0:
|
||||
# pass
|
||||
raise ValueError("invalid thread id")
|
||||
elif res != 1:
|
||||
# """if it returns a number greater than one, you're in trouble,
|
||||
# and you should call it again with exc=NULL to revert the effect"""
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
|
||||
raise SystemError("PyThreadState_SetAsyncExc failed")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
54
models/fastchat_llm.py
Normal file
54
models/fastchat_llm.py
Normal file
@ -0,0 +1,54 @@
|
||||
from abc import ABC
|
||||
import requests
|
||||
from typing import Optional, List
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
|
||||
|
||||
class FastChatLLM(BaseAnswer, LLM, ABC):
|
||||
max_token: int = 10000
|
||||
temperature: float = 0.01
|
||||
top_p = 0.9
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
# history = []
|
||||
history_len: int = 10
|
||||
|
||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "FastChat"
|
||||
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
return self.checkPoint
|
||||
|
||||
@property
|
||||
def _history_len(self) -> int:
|
||||
return self.history_len
|
||||
|
||||
def set_history_len(self, history_len: int = 10) -> None:
|
||||
self.history_len = history_len
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def _generate_answer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
|
||||
response = "fastchat 响应结果"
|
||||
history += [[prompt, response]]
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": response}
|
||||
|
||||
generate_with_callback(answer_result)
|
||||
@ -8,28 +8,12 @@ from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||
from typing import Optional, List, Dict, Any
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.extensions.callback import (Iteratorize, Stream, FixedLengthQueue)
|
||||
import models.shared as shared
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
|
||||
|
||||
def _streaming_response_template() -> Dict[str, Any]:
|
||||
"""
|
||||
:return: 响应结构
|
||||
"""
|
||||
return {
|
||||
"text": ""
|
||||
}
|
||||
|
||||
|
||||
def _update_response(response: Dict[str, Any], stream_response: str) -> None:
|
||||
"""Update response from the stream response."""
|
||||
response["text"] += stream_response
|
||||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
@ -105,16 +89,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
return reply
|
||||
|
||||
def generate_with_callback(self, callback=None, **kwargs):
|
||||
self.checkPoint.clear_torch_cache()
|
||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||
with torch.no_grad():
|
||||
self.checkPoint.model.generate(**kwargs)
|
||||
print("方法结束")
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
return Iteratorize(self.generate_with_callback, kwargs)
|
||||
|
||||
# 将历史对话数组转换为文本格式
|
||||
def history_to_text(self, query):
|
||||
formatted_history = ''
|
||||
@ -144,45 +118,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||
|
||||
return input_ids, position_ids, attention_mask
|
||||
|
||||
def get_position_ids(self, input_ids: torch.LongTensor, mask_positions, device):
|
||||
"""
|
||||
注意力偏移量
|
||||
:param input_ids:
|
||||
:param mask_positions:
|
||||
:param device:
|
||||
:param use_gmasks:
|
||||
:return:
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids]
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[i, context_length:] = mask_positions[i]
|
||||
block_position_ids = [torch.cat((
|
||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||
)) for context_length in context_lengths]
|
||||
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||
return position_ids
|
||||
|
||||
def get_masks(self, input_ids, device):
|
||||
"""
|
||||
获取注意力掩码
|
||||
:param input_ids:
|
||||
:param device:
|
||||
:return:
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids]
|
||||
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
||||
attention_mask.tril_()
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
attention_mask[i, :, :context_length] = 1
|
||||
attention_mask.unsqueeze_(1)
|
||||
attention_mask = (attention_mask < 0.5).bool()
|
||||
return attention_mask
|
||||
|
||||
def generate_softprompt_history_tensors(self, query):
|
||||
"""
|
||||
历史对话软提示
|
||||
@ -222,11 +157,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||
"eos_token_id": self.eos_token_id,
|
||||
"logits_processor": self.logits_processor}
|
||||
|
||||
# 向量拼接
|
||||
# 向量转换
|
||||
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
|
||||
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
|
||||
|
||||
# 对话模型prompt
|
||||
|
||||
gen_kwargs.update({'inputs': input_ids})
|
||||
# 注意力掩码
|
||||
# gen_kwargs.update({'attention_mask': attention_mask})
|
||||
@ -235,45 +170,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||
# 观测输出
|
||||
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
||||
shared.stop_everything = False
|
||||
stopped = False
|
||||
response_template = _streaming_response_template()
|
||||
|
||||
# TODO 此流输出方法需要重写!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
# stopping_criteria方法不可控制 迭代器的变量无法共享
|
||||
with self.generate_with_streaming(**gen_kwargs) as generator:
|
||||
last_reply_len = 0
|
||||
reply_index = 0
|
||||
# Create a FixedLengthQueue with the desired stop sequence and a maximum length.
|
||||
queue = FixedLengthQueue(stop)
|
||||
for output in generator:
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = self.decode(output[-new_tokens:])
|
||||
|
||||
new_reply = len(reply) - last_reply_len
|
||||
output_reply = reply[-new_reply:]
|
||||
queue.add(reply_index, output_reply)
|
||||
queue.contains_replace_sequence()
|
||||
if stop:
|
||||
pos = queue.contains_stop_sequence()
|
||||
if pos != -1:
|
||||
shared.stop_everything = True
|
||||
stopped = True
|
||||
|
||||
#print(f"{reply_index}:reply {output_reply}")
|
||||
english_reply = queue.put_replace_out(reply_index)
|
||||
#print(f"{reply_index}:english_reply {english_reply}")
|
||||
_update_response(response_template, english_reply)
|
||||
last_reply_len = len(reply)
|
||||
|
||||
reply_index += 1
|
||||
if new_tokens == self.max_new_tokens - 1 or stopped:
|
||||
break
|
||||
|
||||
response = response_template['text']
|
||||
print(f"response:{response}")
|
||||
self.history = self.history + [[None, response]]
|
||||
return response
|
||||
output_ids = self.checkPoint.model.generate(**gen_kwargs)
|
||||
new_tokens = len(output_ids[0]) - len(input_ids[0])
|
||||
reply = self.decode(output_ids[0][-new_tokens:])
|
||||
print(f"response:{reply}")
|
||||
self.history = self.history + [[None, reply]]
|
||||
return reply
|
||||
|
||||
def _generate_answer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from configs.model_config import *
|
||||
|
||||
|
||||
# Additional argparse types
|
||||
@ -32,28 +32,25 @@ def dir_path(string):
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(prog='langchina-ChatGLM',
|
||||
description='基于langchain和chatGML的LLM文档阅读器')
|
||||
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain | '
|
||||
'基于本地知识库的 ChatGLM 问答')
|
||||
|
||||
|
||||
|
||||
parser.add_argument('--no-remote-model', action='store_true', default=False, help='remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model`')
|
||||
parser.add_argument('--model', type=str, default='chatglm-6b', help='Name of the model to load by default.')
|
||||
parser.add_argument('--no-remote-model', action='store_true', default=NO_REMOTE_MODEL, help='remote in the model on '
|
||||
'loader checkpoint, '
|
||||
'if your load local '
|
||||
'model to add the ` '
|
||||
'--no-remote-model`')
|
||||
parser.add_argument('--model', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||
parser.add_argument("--model-dir", type=str, default='model/', help="Path to directory with all the models")
|
||||
parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
|
||||
parser.add_argument("--model-dir", type=str, default=MODEL_DIR, help="Path to directory with all the models")
|
||||
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
||||
|
||||
# Accelerate/transformers
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
|
||||
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
|
||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||
|
||||
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
||||
help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--bf16', action='store_true', default=BF16,
|
||||
help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||
|
||||
args = parser.parse_args([])
|
||||
# Generares dict with a default value for each argument
|
||||
DEFAULT_ARGS = vars(args)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -113,7 +113,6 @@ class LoaderCheckPoint:
|
||||
if num_gpus < 2 and self.device_map is None:
|
||||
model = (
|
||||
LoaderClass.from_pretrained(checkpoint,
|
||||
low_cpu_mem_usage=True,
|
||||
config=self.model_config,
|
||||
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||
trust_remote_code=True)
|
||||
@ -124,7 +123,6 @@ class LoaderCheckPoint:
|
||||
from accelerate import dispatch_model
|
||||
|
||||
model = LoaderClass.from_pretrained(checkpoint,
|
||||
low_cpu_mem_usage=True,
|
||||
config=self.model_config,
|
||||
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||
trust_remote_code=True).half()
|
||||
|
||||
@ -4,8 +4,6 @@ from models.loader.args import parser
|
||||
from models.loader import LoaderCheckPoint
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||
from models.base import BaseAnswer
|
||||
"""迭代器是否停止状态"""
|
||||
stop_everything = False
|
||||
|
||||
loaderCheckPoint: LoaderCheckPoint = None
|
||||
|
||||
@ -36,7 +34,10 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
|
||||
|
||||
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
|
||||
|
||||
loaderCheckPoint.reload_model()
|
||||
if 'fastChat' in loaderCheckPoint.model_name:
|
||||
loaderCheckPoint.unload_model()
|
||||
else:
|
||||
loaderCheckPoint.reload_model()
|
||||
|
||||
provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
|
||||
modelInsLLM = provides_class(checkPoint=loaderCheckPoint)
|
||||
|
||||
@ -21,8 +21,8 @@ click~=8.1.3
|
||||
tabulate
|
||||
azure-core
|
||||
bitsandbytes; platform_system != "Windows"
|
||||
llama-cpp-python==0.1.34; platform_system != "Windows"
|
||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||
#llama-cpp-python==0.1.34; platform_system != "Windows"
|
||||
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||
|
||||
torch~=2.0.0
|
||||
pydantic~=1.10.7
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user