修改Embeddings和FAISS缓存加载方式,知识库相关API接口支持多线程并发 (#1434)

* 修改Embeddings和FAISS缓存加载方式,支持多线程,支持内存FAISS

* 知识库相关API接口支持多线程并发

* 根据新的API接口调整ApiRequest和测试用例

* 删除webui.py失效的启动说明
This commit is contained in:
liunux4odoo 2023-09-11 20:41:41 +08:00 committed by GitHub
parent d0e654d847
commit 22ff073309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 497 additions and 530 deletions

View File

@ -4,12 +4,11 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import LLM_MODEL, NLTK_DATA_PATH from configs import VERSION
from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT from configs.model_config import NLTK_DATA_PATH
from configs import VERSION, logger, log_verbose from configs.server_config import OPEN_CROSS_DOMAIN
import argparse import argparse
import uvicorn import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat, from server.chat import (chat, knowledge_base_chat, openai_chat,
@ -18,8 +17,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store, update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore) search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address from server.llm_api import list_llm_models, change_llm_model, stop_llm_model
import httpx from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -126,79 +125,20 @@ def create_app():
)(recreate_vector_store) )(recreate_vector_store)
# LLM模型相关接口 # LLM模型相关接口
@app.post("/llm_model/list_models", app.post("/llm_model/list_models",
tags=["LLM Model Management"], tags=["LLM Model Management"],
summary="列出当前已加载的模型") summary="列出当前已加载的模型",
def list_models( )(list_llm_models)
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/stop", app.post("/llm_model/stop",
tags=["LLM Model Management"], tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)", summary="停止指定的LLM模型Model Worker)",
) )(stop_llm_model)
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/change", app.post("/llm_model/change",
tags=["LLM Model Management"], tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)", summary="切换指定的LLM模型Model Worker)",
) )(change_llm_model)
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
return app return app

View File

@ -12,10 +12,10 @@ def list_kbs():
return ListResponse(data=list_kbs_from_db()) return ListResponse(data=list_kbs_from_db())
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"), vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse: ) -> BaseResponse:
# Create selected knowledge base # Create selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -38,8 +38,8 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb( def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"]) knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse: ) -> BaseResponse:
# Delete selected knowledge base # Delete selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):

View File

@ -0,0 +1,137 @@
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
import threading
from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE,
embedding_model_dict, logger, log_verbose)
from server.utils import embedding_device
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
class ThreadSafeObject:
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
self._obj = obj
self._key = key
self._pool = pool
self._lock = threading.RLock()
self._loaded = threading.Event()
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}>"
@contextmanager
def acquire(self, owner: str = "", msg: str = ""):
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self._key)
if log_verbose:
logger.info(f"{owner} 开始操作:{self._key}{msg}")
yield self._obj
finally:
if log_verbose:
logger.info(f"{owner} 结束操作:{self._key}{msg}")
self._lock.release()
def start_loading(self):
self._loaded.clear()
def finish_loading(self):
self._loaded.set()
def wait_for_loading(self):
self._loaded.wait()
@property
def obj(self):
return self._obj
@obj.setter
def obj(self, val: Any):
self._obj = val
class CachePool:
def __init__(self, cache_num: int = -1):
self._cache_num = cache_num
self._cache = OrderedDict()
self.atomic = threading.RLock()
def keys(self) -> List[str]:
return list(self._cache.keys())
def _check_count(self):
if isinstance(self._cache_num, int) and self._cache_num > 0:
while len(self._cache) > self._cache_num:
self._cache.popitem(last=False)
def get(self, key: str) -> ThreadSafeObject:
if cache := self._cache.get(key):
cache.wait_for_loading()
return cache
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
self._cache[key] = obj
self._check_count()
return obj
def pop(self, key: str = None) -> ThreadSafeObject:
if key is None:
return self._cache.popitem(last=False)
else:
return self._cache.pop(key, None)
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
cache = self.get(key)
if cache is None:
raise RuntimeError(f"请求的资源 {key} 不存在")
elif isinstance(cache, ThreadSafeObject):
self._cache.move_to_end(key)
return cache.acquire(owner=owner, msg=msg)
else:
return cache
def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings:
from server.db.repository.knowledge_base_repository import get_kb_detail
kb_detail = get_kb_detail(kb_name=kb_name)
print(kb_detail)
embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL)
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str, device: str) -> Embeddings:
self.atomic.acquire()
model = model or EMBEDDING_MODEL
device = device or embedding_device()
key = (model, device)
if not self.get(key):
item = ThreadSafeObject(key, pool=self)
self.set(key, item)
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()
else:
self.atomic.release()
return self.get(key).obj
embeddings_pool = EmbeddingsPool(cache_num=1)

View File

@ -0,0 +1,157 @@
from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores import FAISS
import os
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
def save(self, path: str, create_path: bool = True):
with self.acquire():
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self._key} 保存到磁盘")
return ret
def clear(self):
ret = []
with self.acquire():
ids = list(self._obj.docstore._dict.keys())
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self._key} 清空")
return ret
class _FaissPool(CachePool):
def new_vector_store(
self,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS:
embeddings = embeddings_pool.load_embeddings(embed_model, embed_device)
# create an empty vector store
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
def save_vector_store(self, kb_name: str, path: str=None):
if cache := self.get(kb_name):
return cache.save(path)
def unload_vector_store(self, kb_name: str):
if cache := self.get(kb_name):
self.pop(kb_name)
logger.info(f"成功释放向量库:{kb_name}")
class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' from disk.")
vs_path = get_vs_path(kb_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
class MemoFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' to memory.")
# create an empty vector store
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
memo_faiss_pool = MemoFaissPool()
if __name__ == "__main__":
import time, random
from pprint import pprint
kb_names = ["vs1", "vs2", "vs3"]
# for name in kb_names:
# memo_faiss_pool.load_vector_store(name)
def worker(vs_name: str, name: str):
vs_name = "samples"
time.sleep(random.randint(1, 5))
embeddings = embeddings_pool.load_embeddings()
r = random.randint(1, 3)
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
if r == 1: # add docs
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
pprint(ids)
elif r == 2: # search docs
docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0)
pprint(docs)
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear()
threads = []
for n in range(1, 30):
t = threading.Thread(target=worker,
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
daemon=True)
t.start()
threads.append(t)
for t in threads:
t.join()

View File

@ -117,13 +117,13 @@ def _save_files_in_thread(files: List[UploadFile],
# yield json.dumps(result, ensure_ascii=False) # yield json.dumps(result, ensure_ascii=False)
async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"), override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"), not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse: ) -> BaseResponse:
''' '''
API接口上传文件/或向量化 API接口上传文件/或向量化
''' '''
@ -148,7 +148,7 @@ async def upload_docs(files: List[UploadFile] = File(..., description="上传文
# 对保存的文件进行向量化 # 对保存的文件进行向量化
if to_vector_store: if to_vector_store:
result = await update_docs( result = update_docs(
knowledge_base_name=knowledge_base_name, knowledge_base_name=knowledge_base_name,
file_names=file_names, file_names=file_names,
override_custom_docs=True, override_custom_docs=True,
@ -162,11 +162,11 @@ async def upload_docs(files: List[UploadFile] = File(..., description="上传文
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files}) return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]), def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
delete_content: bool = Body(False), delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse: ) -> BaseResponse:
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -196,12 +196,12 @@ async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"])
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files}) return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
async def update_docs( def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]), file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse: ) -> BaseResponse:
''' '''
更新知识库文档 更新知识库文档
@ -302,11 +302,11 @@ def download_doc(
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store( def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]), knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True), allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE), vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(EMBEDDING_MODEL),
): ):
''' '''
recreate vector store from the content. recreate vector store from the content.

View File

@ -146,7 +146,6 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold, embeddings) docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs return docs
# TODO: milvus/pg需要实现该方法
def get_doc_by_id(self, id: str) -> Optional[Document]: def get_doc_by_id(self, id: str) -> Optional[Document]:
return None return None

View File

@ -3,62 +3,16 @@ import shutil
from configs.model_config import ( from configs.model_config import (
KB_ROOT_PATH, KB_ROOT_PATH,
CACHED_VS_NUM,
EMBEDDING_MODEL,
SCORE_THRESHOLD, SCORE_THRESHOLD,
logger, log_verbose, logger, log_verbose,
) )
from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from functools import lru_cache from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile from server.knowledge_base.utils import KnowledgeFile
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from typing import List, Dict, Optional from typing import List, Dict, Optional
from langchain.docstore.document import Document from langchain.docstore.document import Document
from server.utils import torch_gc, embedding_device from server.utils import torch_gc
_VECTOR_STORE_TICKS = {}
@lru_cache(CACHED_VS_NUM)
def load_faiss_vector_store(
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
) -> FAISS:
logger.info(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
if not os.path.exists(vs_path):
os.makedirs(vs_path)
if "index.faiss" in os.listdir(vs_path):
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
else:
# create an empty vector store
doc = Document(page_content="init", metadata={})
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = [k for k, v in search_index.docstore._dict.items()]
search_index.delete(ids)
search_index.save_local(vs_path)
if tick == 0: # vector store is loaded first time
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
return search_index
def refresh_vs_cache(kb_name: str):
"""
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
class FaissKBService(KBService): class FaissKBService(KBService):
@ -74,24 +28,15 @@ class FaissKBService(KBService):
def get_kb_path(self): def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name) return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> FAISS: def load_vector_store(self) -> ThreadSafeFaiss:
return load_faiss_vector_store( return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
knowledge_base_name=self.kb_name,
embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
)
def save_vector_store(self, vector_store: FAISS = None): def save_vector_store(self):
vector_store = vector_store or self.load_vector_store() self.load_vector_store().save(self.vs_path)
vector_store.save_local(self.vs_path)
return vector_store
def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name)
def get_doc_by_id(self, id: str) -> Optional[Document]: def get_doc_by_id(self, id: str) -> Optional[Document]:
vector_store = self.load_vector_store() with self.load_vector_store().acquire() as vs:
return vector_store.docstore._dict.get(id) return vs.docstore._dict.get(id)
def do_init(self): def do_init(self):
self.kb_path = self.get_kb_path() self.kb_path = self.get_kb_path()
@ -112,43 +57,38 @@ class FaissKBService(KBService):
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None, embeddings: Embeddings = None,
) -> List[Document]: ) -> List[Document]:
search_index = self.load_vector_store() with self.load_vector_store().acquire() as vs:
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs return docs
def do_add_doc(self, def do_add_doc(self,
docs: List[Document], docs: List[Document],
**kwargs, **kwargs,
) -> List[Dict]: ) -> List[Dict]:
vector_store = self.load_vector_store() with self.load_vector_store().acquire() as vs:
ids = vector_store.add_documents(docs) ids = vs.add_documents(docs)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc() torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return doc_infos return doc_infos
def do_delete_doc(self, def do_delete_doc(self,
kb_file: KnowledgeFile, kb_file: KnowledgeFile,
**kwargs): **kwargs):
vector_store = self.load_vector_store() with self.load_vector_store().acquire() as vs:
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath] if len(ids) > 0:
if len(ids) == 0: vs.delete(ids)
return None if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
vector_store.delete(ids) return ids
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return vector_store
def do_clear_vs(self): def do_clear_vs(self):
with kb_faiss_pool.atomic:
kb_faiss_pool.pop(self.kb_name)
shutil.rmtree(self.vs_path) shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path) os.makedirs(self.vs_path)
self.refresh_vs_cache()
def exist_doc(self, file_name: str): def exist_doc(self, file_name: str):
if super().exist_doc(file_name): if super().exist_doc(file_name):

View File

@ -1,7 +1,6 @@
from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE, logger, log_verbose from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE, logger, log_verbose
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder, run_in_thread_pool, list_files_from_folder,files2docs_in_thread,
files2docs_in_thread,
KnowledgeFile,) KnowledgeFile,)
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
from server.db.repository.knowledge_file_repository import add_file_to_db from server.db.repository.knowledge_file_repository import add_file_to_db
@ -72,7 +71,6 @@ def folder2db(
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store() kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "fill_info_only": elif mode == "fill_info_only":
files = list_files_from_folder(kb_name) files = list_files_from_folder(kb_name)
kb_files = file_to_kbfile(kb_name, files) kb_files = file_to_kbfile(kb_name, files)
@ -89,7 +87,6 @@ def folder2db(
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store() kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "increament": elif mode == "increament":
db_files = kb.list_files() db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name) folder_files = list_files_from_folder(kb_name)
@ -107,7 +104,6 @@ def folder2db(
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store() kb.save_vector_store()
kb.refresh_vs_cache()
else: else:
print(f"unspported migrate mode: {mode}") print(f"unspported migrate mode: {mode}")
@ -139,7 +135,6 @@ def prune_db_files(kb_name: str):
kb.delete_doc(kb_file, not_refresh_vs_cache=True) kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store() kb.save_vector_store()
kb.refresh_vs_cache()
return kb_files return kb_files
def prune_folder_files(kb_name: str): def prune_folder_files(kb_name: str):

View File

@ -4,13 +4,13 @@ from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings
from configs.model_config import ( from configs.model_config import (
embedding_model_dict, embedding_model_dict,
EMBEDDING_MODEL,
KB_ROOT_PATH, KB_ROOT_PATH,
CHUNK_SIZE, CHUNK_SIZE,
OVERLAP_SIZE, OVERLAP_SIZE,
ZH_TITLE_ENHANCE, ZH_TITLE_ENHANCE,
logger, log_verbose, logger, log_verbose,
) )
from functools import lru_cache
import importlib import importlib
from text_splitter import zh_title_enhance from text_splitter import zh_title_enhance
import langchain.document_loaders import langchain.document_loaders
@ -19,25 +19,11 @@ from langchain.text_splitter import TextSplitter
from pathlib import Path from pathlib import Path
import json import json
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool from server.utils import run_in_thread_pool, embedding_device
import io import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
def validate_kb_name(knowledge_base_id: str) -> bool: def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字 # 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id: if "../" in knowledge_base_id:
@ -72,19 +58,12 @@ def list_files_from_folder(kb_name: str):
if os.path.isfile(os.path.join(doc_path, file))] if os.path.isfile(os.path.join(doc_path, file))]
@lru_cache(1) def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()):
def load_embeddings(model: str, device: str): '''
if model == "text-embedding-ada-002": # openai text-embedding-ada-002 从缓存中加载embeddings可以避免多线程时竞争加载
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) '''
elif 'bge-' in model: from server.knowledge_base.kb_cache.base import embeddings_pool
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], return embeddings_pool.load_embeddings(model=model, device=device)
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
return embeddings
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],

View File

@ -1,279 +1,70 @@
from multiprocessing import Process, Queue from fastapi import Body
import multiprocessing as mp from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
import sys from server.utils import BaseResponse, fschat_controller_address
import os import httpx
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger, log_verbose
from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
host_ip = "0.0.0.0" def list_llm_models(
controller_port = 20001 controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
model_worker_port = 20002 ) -> BaseResponse:
openai_api_port = 8888 '''
base_url = "http://127.0.0.1:{}" 从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
def create_controller_app( def stop_llm_model(
dispatch_method="shortest_queue", model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
): ):
import fastchat.constants '''
fastchat.constants.LOGDIR = LOG_PATH 向fastchat controller请求切换LLM模型
from fastchat.serve.controller import app, Controller '''
try:
controller = Controller(dispatch_method) controller_address = controller_address or fschat_controller_address()
sys.modules["fastchat.serve.controller"].controller = controller r = httpx.post(
controller_address + "/release_worker",
MakeFastAPIOffline(app) json={"model_name": model_name, "new_model_name": new_model_name},
app.title = "FastChat Controller" timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
return app
def create_model_worker_app(
worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
device=llm_device(),
gpus=None,
max_gpu_memory="20GiB",
load_8bit=False,
cpu_offloading=None,
gptq_ckpt=None,
gptq_wbits=16,
gptq_groupsize=-1,
gptq_act_order=False,
awq_ckpt=None,
awq_wbits=16,
awq_groupsize=-1,
model_names=[LLM_MODEL],
num_gpus=1, # not in fastchat
conv_template=None,
limit_worker_concurrency=5,
stream_interval=2,
no_register=False,
):
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
import argparse
import threading
import fastchat.serve.model_worker
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
) )
self.heart_beat_thread.start() return r.json()
ModelWorker.init_heart_beat = _new_init_heart_beat except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
parser = argparse.ArgumentParser() exc_info=e if log_verbose else None)
args = parser.parse_args() return BaseResponse(
args.model_path = model_path code=500,
args.model_names = model_names msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
args.device = device
args.load_8bit = load_8bit
args.gptq_ckpt = gptq_ckpt
args.gptq_wbits = gptq_wbits
args.gptq_groupsize = gptq_groupsize
args.gptq_act_order = gptq_act_order
args.awq_ckpt = awq_ckpt
args.awq_wbits = awq_wbits
args.awq_groupsize = awq_groupsize
args.gpus = gpus
args.num_gpus = num_gpus
args.max_gpu_memory = max_gpu_memory
args.cpu_offloading = cpu_offloading
args.worker_address = worker_address
args.controller_address = controller_address
args.conv_template = conv_template
args.limit_worker_concurrency = limit_worker_concurrency
args.stream_interval = stream_interval
args.no_register = no_register
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if gpus and num_gpus is None:
num_gpus = len(gpus.split(','))
args.num_gpus = num_gpus
gptq_config = GptqConfig(
ckpt=gptq_ckpt or model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
# torch.multiprocessing.set_start_method('spawn')
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})"
return app
def create_openai_api_app(
controller_address=base_url.format(controller_port),
api_keys=[],
):
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app
def run_controller(q):
import uvicorn
app = create_controller_app()
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
q.put(1)
uvicorn.run(app, host=host_ip, port=controller_port)
def run_model_worker(q, *args, **kwargs):
import uvicorn
app = create_model_worker_app(*args, **kwargs)
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != 1:
q.put(no)
else:
break
q.put(2)
uvicorn.run(app, host=host_ip, port=model_worker_port)
def run_openai_api(q):
import uvicorn
app = create_openai_api_app()
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != 2:
q.put(no)
else:
break
q.put(3)
uvicorn.run(app, host=host_ip, port=openai_api_port)
if __name__ == "__main__":
mp.set_start_method("spawn")
queue = Queue()
logger.info(llm_model_dict[LLM_MODEL])
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
if not model_path:
logger.error("local_model_path 不能为空")
else:
controller_process = Process(
target=run_controller,
name=f"controller({os.getpid()})",
args=(queue,),
daemon=True,
)
controller_process.start()
model_worker_process = Process(
target=run_model_worker,
name=f"model_worker({os.getpid()})",
args=(queue,),
# kwargs={"load_8bit": True},
daemon=True,
)
model_worker_process.start()
openai_api_process = Process(
target=run_openai_api,
name=f"openai_api({os.getpid()})",
args=(queue,),
daemon=True,
)
openai_api_process.start()
try:
model_worker_process.join()
controller_process.join()
openai_api_process.join()
except KeyboardInterrupt:
model_worker_process.terminate()
controller_process.terminate()
openai_api_process.terminate()
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1"
# model = "chatglm2-6b"
# # create a chat completion
# completion = openai.ChatCompletion.create(
# model=model,
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
# )
# # print the completion
# print(completion.choices[0].message.content)

View File

@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional, Callable, Generator, Dict, Any from typing import Literal, Optional, Callable, Generator, Dict, Any
thread_pool = ThreadPoolExecutor() thread_pool = ThreadPoolExecutor(os.cpu_count())
class BaseResponse(BaseModel): class BaseResponse(BaseModel):

View File

@ -14,7 +14,7 @@ from pprint import pprint
api_base_url = api_address() api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url) api: ApiRequest = ApiRequest(api_base_url, no_remote_api=True)
kb = "kb_for_api_test" kb = "kb_for_api_test"
@ -84,7 +84,7 @@ def test_upload_docs():
print(f"\n尝试重新上传知识文件, 覆盖自定义docs") print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]} docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)} data = {"knowledge_base_name": kb, "override": True, "docs": docs}
data = api.upload_kb_docs(files, **data) data = api.upload_kb_docs(files, **data)
pprint(data) pprint(data)
assert data["code"] == 200 assert data["code"] == 200

View File

@ -5,8 +5,9 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path)) sys.path.append(str(root_path))
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict from configs.model_config import LLM_MODEL, llm_model_dict
from server.utils import api_address
from pprint import pprint from pprint import pprint
import random import random

View File

@ -47,7 +47,6 @@ data = {
} }
def test_chat_fastchat(api="/chat/fastchat"): def test_chat_fastchat(api="/chat/fastchat"):
url = f"{api_base_url}{api}" url = f"{api_base_url}{api}"
data2 = { data2 = {

View File

@ -1,9 +1,3 @@
# 运行方式:
# 1. 安装必要的包pip install streamlit-option-menu streamlit-chatbox>=1.1.6
# 2. 运行本机fastchat服务python server\llm_api.py 或者 运行对应的sh文件
# 3. 运行API服务器python server/api.py。如果使用api = ApiRequest(no_remote_api=True),该步可以跳过。
# 4. 运行WEB UIstreamlit run webui.py --server.port 7860
import streamlit as st import streamlit as st
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_option_menu import option_menu from streamlit_option_menu import option_menu

View File

@ -20,6 +20,7 @@ from server.chat.openai_chat import OpenAiChatMsgIn
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import contextlib import contextlib
import json import json
import os
from io import BytesIO from io import BytesIO
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
@ -475,7 +476,7 @@ class ApiRequest:
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_api import create_kb from server.knowledge_base.kb_api import create_kb
response = run_async(create_kb(**data)) response = create_kb(**data)
return response.dict() return response.dict()
else: else:
response = self.post( response = self.post(
@ -497,7 +498,7 @@ class ApiRequest:
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_api import delete_kb from server.knowledge_base.kb_api import delete_kb
response = run_async(delete_kb(knowledge_base_name)) response = delete_kb(knowledge_base_name)
return response.dict() return response.dict()
else: else:
response = self.post( response = self.post(
@ -584,7 +585,7 @@ class ApiRequest:
filename = filename or file.name filename = filename or file.name
else: # a local path else: # a local path
file = Path(file).absolute().open("rb") file = Path(file).absolute().open("rb")
filename = filename or file.name filename = filename or os.path.split(file.name)[-1]
return filename, file return filename, file
files = [convert_file(file) for file in files] files = [convert_file(file) for file in files]
@ -602,13 +603,13 @@ class ApiRequest:
from tempfile import SpooledTemporaryFile from tempfile import SpooledTemporaryFile
upload_files = [] upload_files = []
for file, filename in files: for filename, file in files:
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read()) temp_file.write(file.read())
temp_file.seek(0) temp_file.seek(0)
upload_files.append(UploadFile(file=temp_file, filename=filename)) upload_files.append(UploadFile(file=temp_file, filename=filename))
response = run_async(upload_docs(upload_files, **data)) response = upload_docs(upload_files, **data)
return response.dict() return response.dict()
else: else:
if isinstance(data["docs"], dict): if isinstance(data["docs"], dict):
@ -643,7 +644,7 @@ class ApiRequest:
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_doc_api import delete_docs from server.knowledge_base.kb_doc_api import delete_docs
response = run_async(delete_docs(**data)) response = delete_docs(**data)
return response.dict() return response.dict()
else: else:
response = self.post( response = self.post(
@ -676,7 +677,7 @@ class ApiRequest:
} }
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_doc_api import update_docs from server.knowledge_base.kb_doc_api import update_docs
response = run_async(update_docs(**data)) response = update_docs(**data)
return response.dict() return response.dict()
else: else:
if isinstance(data["docs"], dict): if isinstance(data["docs"], dict):
@ -710,7 +711,7 @@ class ApiRequest:
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_doc_api import recreate_vector_store from server.knowledge_base.kb_doc_api import recreate_vector_store
response = run_async(recreate_vector_store(**data)) response = recreate_vector_store(**data)
return self._fastapi_stream2generator(response, as_json=True) return self._fastapi_stream2generator(response, as_json=True)
else: else:
response = self.post( response = self.post(
@ -721,14 +722,30 @@ class ApiRequest:
) )
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
def list_running_models(self, controller_address: str = None): # LLM模型相关操作
def list_running_models(
self,
controller_address: str = None,
no_remote_api: bool = None,
):
''' '''
获取Fastchat中正运行的模型列表 获取Fastchat中正运行的模型列表
''' '''
r = self.post( if no_remote_api is None:
"/llm_model/list_models", no_remote_api = self.no_remote_api
)
return r.json().get("data", []) data = {
"controller_address": controller_address,
}
if no_remote_api:
from server.llm_api import list_llm_models
return list_llm_models(**data).data
else:
r = self.post(
"/llm_model/list_models",
json=data,
)
return r.json().get("data", [])
def list_config_models(self): def list_config_models(self):
''' '''
@ -740,30 +757,43 @@ class ApiRequest:
self, self,
model_name: str, model_name: str,
controller_address: str = None, controller_address: str = None,
no_remote_api: bool = None,
): ):
''' '''
停止某个LLM模型 停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉 注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"model_name": model_name, "model_name": model_name,
"controller_address": controller_address, "controller_address": controller_address,
} }
r = self.post(
"/llm_model/stop", if no_remote_api:
json=data, from server.llm_api import stop_llm_model
) return stop_llm_model(**data).dict()
return r.json() else:
r = self.post(
"/llm_model/stop",
json=data,
)
return r.json()
def change_llm_model( def change_llm_model(
self, self,
model_name: str, model_name: str,
new_model_name: str, new_model_name: str,
controller_address: str = None, controller_address: str = None,
no_remote_api: bool = None,
): ):
''' '''
向fastchat controller请求切换LLM模型 向fastchat controller请求切换LLM模型
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if not model_name or not new_model_name: if not model_name or not new_model_name:
return return
@ -792,12 +822,17 @@ class ApiRequest:
"new_model_name": new_model_name, "new_model_name": new_model_name,
"controller_address": controller_address, "controller_address": controller_address,
} }
r = self.post(
"/llm_model/change", if no_remote_api:
json=data, from server.llm_api import change_llm_model
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model return change_llm_model(**data).dict()
) else:
return r.json() r = self.post(
"/llm_model/change",
json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: