Merge pull request #2946 from chatchat-space/dev

Dev
This commit is contained in:
zR 2024-02-06 13:52:55 +08:00 committed by GitHub
commit ab650253d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 77 additions and 33 deletions

View File

@ -192,7 +192,7 @@ please refer to the [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/
### WeChat Group ### WeChat Group
<img src="img/qr_code_88.jpg" alt="二维码" width="300" height="300" /> <img src="img/qr_code_89.jpg" alt="二维码" width="300" height="300" />
### WeChat Official Account ### WeChat Official Account

View File

@ -185,7 +185,7 @@ $ python startup.py -a
### WeChat グループ ### WeChat グループ
<img src="img/qr_code_88.jpg" alt="二维码" width="300" height="300" /> <img src="img/qr_code_89.jpg" alt="二维码" width="300" height="300" />
### WeChat 公式アカウント ### WeChat 公式アカウント

View File

@ -150,12 +150,16 @@ MODEL_PATH = {
"m3e-small": "moka-ai/m3e-small", "m3e-small": "moka-ai/m3e-small",
"m3e-base": "moka-ai/m3e-base", "m3e-base": "moka-ai/m3e-base",
"m3e-large": "moka-ai/m3e-large", "m3e-large": "moka-ai/m3e-large",
"bge-small-zh": "BAAI/bge-small-zh", "bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh", "bge-base-zh": "BAAI/bge-base-zh",
"bge-large-zh": "BAAI/bge-large-zh", "bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5", "bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5", "bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
"bge-m3": "BAAI/bge-m3",
"piccolo-base-zh": "sensenova/piccolo-base-zh", "piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh", "piccolo-large-zh": "sensenova/piccolo-large-zh",
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large", "nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
@ -181,6 +185,14 @@ MODEL_PATH = {
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
# Qwen1.5 模型 VLLM可能出现问题
"Qwen1.5-0.5B-Chat": "Qwen/Qwen1.5-0.5B-Chat",
"Qwen1.5-1.8B-Chat": "Qwen/Qwen1.5-1.8B-Chat",
"Qwen1.5-4B-Chat": "Qwen/Qwen1.5-4B-Chat",
"Qwen1.5-7B-Chat": "Qwen/Qwen1.5-7B-Chat",
"Qwen1.5-14B-Chat": "Qwen/Qwen1.5-14B-Chat",
"Qwen1.5-72B-Chat": "Qwen/Qwen1.5-72B-Chat",
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat", "baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat", "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat", "baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat",

View File

@ -90,13 +90,12 @@ FSCHAT_MODEL_WORKERS = {
# 'disable_log_requests': False # 'disable_log_requests': False
}, },
"Qwen-1_8B-Chat": {
"device": "cpu",
},
"chatglm3-6b": { "chatglm3-6b": {
"device": "cuda", "device": "cuda",
}, },
"Qwen1.5-0.5B-Chat": {
"device": "cuda",
},
# 以下配置可以不用修改在model_config中设置启动的模型 # 以下配置可以不用修改在model_config中设置启动的模型
"zhipu-api": { "zhipu-api": {
"port": 21001, "port": 21001,

Binary file not shown.

Before

Width:  |  Height:  |  Size: 195 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 318 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 234 KiB

View File

@ -2,7 +2,7 @@ torch==2.1.2
torchvision==0.16.2 torchvision==0.16.2
torchaudio==2.1.2 torchaudio==2.1.2
xformers==0.0.23.post1 xformers==0.0.23.post1
transformers==4.37.1 transformers==4.37.2
sentence_transformers==2.2.2 sentence_transformers==2.2.2
langchain==0.0.354 langchain==0.0.354
langchain-experimental==0.0.47 langchain-experimental==0.0.47

View File

@ -2,7 +2,7 @@ torch~=2.1.2
torchvision~=0.16.2 torchvision~=0.16.2
torchaudio~=2.1.2 torchaudio~=2.1.2
xformers>=0.0.23.post1 xformers>=0.0.23.post1
transformers==4.37.1 transformers==4.37.2
sentence_transformers==2.2.2 sentence_transformers==2.2.2
langchain==0.0.354 langchain==0.0.354
langchain-experimental==0.0.47 langchain-experimental==0.0.47

View File

@ -155,6 +155,20 @@ class ESKBService(KBService):
k=top_k) k=top_k)
return docs return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
results = []
for doc_id in ids:
try:
response = self.es_client_python.get(index=self.index_name, id=doc_id)
source = response["_source"]
# Assuming your document has "text" and "metadata" fields
text = source.get("context", "")
metadata = source.get("metadata", {})
results.append(Document(page_content=text, metadata=metadata))
except Exception as e:
logger.error(f"Error retrieving document from Elasticsearch! {e}")
return results
def del_doc_by_ids(self, ids: List[str]) -> bool: def del_doc_by_ids(self, ids: List[str]) -> bool:
for doc_id in ids: for doc_id in ids:
try: try:
@ -200,17 +214,21 @@ class ESKBService(KBService):
# 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...] # 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...]
print("写入数据成功.") print("写入数据成功.")
print("*"*100) print("*"*100)
if self.es_client_python.indices.exists(index=self.index_name): if self.es_client_python.indices.exists(index=self.index_name):
file_path = docs[0].metadata.get("source") file_path = docs[0].metadata.get("source")
query = { query = {
"query": { "query": {
"term": { "term": {
"metadata.source.keyword": file_path "metadata.source.keyword": file_path
},
"term": {
"_index": self.index_name
} }
} }
} }
search_results = self.es_client_python.search(body=query) # 注意设置size默认返回10个。
search_results = self.es_client_python.search(body=query, size=50)
if len(search_results["hits"]["hits"]) == 0: if len(search_results["hits"]["hits"]) == 0:
raise ValueError("召回元素个数为0") raise ValueError("召回元素个数为0")
info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]] info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]]

View File

@ -50,7 +50,7 @@ class MilvusKBService(KBService):
def _load_milvus(self): def _load_milvus(self):
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model), self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"), connection_args=kbs_config.get("milvus"),
index_params=kbs_config.get("milvus_kwargs")["index_params"], index_params=kbs_config.get("milvus_kwargs")["index_params"],
search_params=kbs_config.get("milvus_kwargs")["search_params"] search_params=kbs_config.get("milvus_kwargs")["search_params"]
@ -89,6 +89,14 @@ class MilvusKBService(KBService):
if self.milvus.col: if self.milvus.col:
self.milvus.col.delete(expr=f'pk in {id_list}') self.milvus.col.delete(expr=f'pk in {id_list}')
# Issue 2846, for windows
# if self.milvus.col:
# file_path = kb_file.filepath.replace("\\", "\\\\")
# file_name = os.path.basename(file_path)
# id_list = [item.get("pk") for item in
# self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])]
# self.milvus.col.delete(expr=f'pk in {id_list}')
def do_clear_vs(self): def do_clear_vs(self):
if self.milvus.col: if self.milvus.col:
self.do_drop_kb() self.do_drop_kb()

View File

@ -13,7 +13,6 @@ async def request(appid, api_key, api_secret, Spark_url, domain, question, tempe
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url) wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
wsUrl = wsParam.create_url() wsUrl = wsParam.create_url()
data = SparkApi.gen_params(appid, domain, question, temperature, max_token) data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
print(data)
async with websockets.connect(wsUrl) as ws: async with websockets.connect(wsUrl) as ws:
await ws.send(json.dumps(data, ensure_ascii=False)) await ws.send(json.dumps(data, ensure_ascii=False))
finish = False finish = False

View File

@ -1,13 +1,21 @@
from contextlib import contextmanager
import httpx
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
from httpx_sse import EventSource
from server.model_workers.base import * from server.model_workers.base import *
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
from typing import List, Dict, Iterator, Literal from typing import List, Dict, Iterator, Literal, Any
from configs import logger, log_verbose
import requests
import jwt import jwt
import time import time
import json
@contextmanager
def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any):
with client.stream(method, url, **kwargs) as response:
yield EventSource(response)
def generate_token(apikey: str, exp_seconds: int): def generate_token(apikey: str, exp_seconds: int):
@ -37,7 +45,7 @@ class ChatGLMWorker(ApiModelWorker):
model_names: List[str] = ["zhipu-api"], model_names: List[str] = ["zhipu-api"],
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
version: Literal["chatglm_turbo"] = "chatglm_turbo", version: Literal["glm-4"] = "glm-4",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
@ -59,25 +67,25 @@ class ChatGLMWorker(ApiModelWorker):
"temperature": params.temperature, "temperature": params.temperature,
"stream": False "stream": False
} }
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
response = requests.post(url, headers=headers, json=data) with httpx.Client(headers=headers) as client:
# for chunk in response.iter_lines(): response = client.post(url, json=data)
# if chunk: response.raise_for_status()
# chunk_str = chunk.decode('utf-8') chunk = response.json()
# json_start_pos = chunk_str.find('{"id"') print(chunk)
# if json_start_pos != -1: yield {"error_code": 0, "text": chunk["choices"][0]["message"]["content"]}
# json_str = chunk_str[json_start_pos:]
# json_data = json.loads(json_str) # with connect_sse(client, "POST", url, json=data) as event_source:
# for choice in json_data.get('choices', []): # for sse in event_source.iter_sse():
# delta = choice.get('delta', {}) # chunk = json.loads(sse.data)
# content = delta.get('content', '') # if len(chunk["choices"]) != 0:
# yield {"error_code": 0, "text": content} # text += chunk["choices"][0]["delta"]["content"]
ans = response.json() # yield {"error_code": 0, "text": text}
content = ans["choices"][0]["message"]["content"]
yield {"error_code": 0, "text": content}
def get_embeddings(self, params): def get_embeddings(self, params):
# 临时解决方案不支持embedding
print("embedding") print("embedding")
print(params) print(params)