From 56d32a9908c4bc6d9d990779373015430eca8a88 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Mon, 22 Jan 2024 13:42:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=8E=89=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E6=B2=A1=E7=94=A8=E7=9A=84=E6=B3=A8=E9=87=8A=EF=BC=8C=E5=B7=B2?= =?UTF-8?q?=E7=BB=8F=E4=B8=8D=E9=9C=80=E8=A6=81todo=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/server_config.py.example | 6 ---- document_loaders/mypdfloader.py | 3 -- server/embeddings_api.py | 1 - server/knowledge_base/kb_cache/faiss_cache.py | 2 -- server/knowledge_base/kb_doc_api.py | 2 -- server/knowledge_base/kb_service/base.py | 1 - .../kb_service/milvus_kb_service.py | 1 - .../kb_service/pg_kb_service.py | 2 -- server/knowledge_base/kb_summary/base.py | 1 - .../kb_summary/summary_chunk.py | 6 ---- server/knowledge_base/utils.py | 1 - server/model_workers/azure.py | 2 -- server/model_workers/baichuan.py | 2 -- server/model_workers/base.py | 2 -- server/model_workers/fangzhou.py | 33 +++++++++---------- server/model_workers/gemini.py | 1 - server/model_workers/minimax.py | 5 +-- server/model_workers/qianfan.py | 3 -- server/model_workers/qwen.py | 2 -- server/model_workers/tiangong.py | 2 -- server/model_workers/xinghuo.py | 3 -- server/utils.py | 7 ++-- startup.py | 2 +- webui_pages/dialogue/dialogue.py | 1 - 24 files changed, 21 insertions(+), 70 deletions(-) diff --git a/configs/server_config.py.example b/configs/server_config.py.example index a812376b..09b1546b 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -132,12 +132,6 @@ FSCHAT_MODEL_WORKERS = { }, } -# fastchat multi model worker server -FSCHAT_MULTI_MODEL_WORKERS = { - # TODO: -} - -# fastchat controller server FSCHAT_CONTROLLER = { "host": DEFAULT_BIND_HOST, "port": 20001, diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py index 5c480cff..faaf63dd 100644 --- a/document_loaders/mypdfloader.py +++ b/document_loaders/mypdfloader.py @@ -16,11 +16,8 @@ class RapidOCRPDFLoader(UnstructuredFileLoader): b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0") for i, page in enumerate(doc): - # 更新描述 b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i)) - # 立即显示进度条更新结果 b_unit.refresh() - # TODO: 依据文本与图片顺序调整处理方式 text = page.get_text("") resp += text + "\n" diff --git a/server/embeddings_api.py b/server/embeddings_api.py index 440bb774..e907de07 100644 --- a/server/embeddings_api.py +++ b/server/embeddings_api.py @@ -16,7 +16,6 @@ def embed_texts( ) -> BaseResponse: ''' 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) - TODO: 也许需要加入缓存机制,减少 token 消耗 ''' try: if embed_model in list_embed_models(): # 使用本地Embeddings模型 diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index ed48b5dd..60c550ee 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -55,8 +55,6 @@ class _FaissPool(CachePool): embed_model: str = EMBEDDING_MODEL, embed_device: str = embedding_device(), ) -> FAISS: - # TODO: 整个Embeddings加载逻辑有些混乱,待清理 - # create an empty vector store embeddings = EmbeddingsFunAdapter(embed_model) doc = Document(page_content="init", metadata={}) vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 09a264f9..e58ea41f 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -95,7 +95,6 @@ def _save_files_in_thread(files: List[UploadFile], and not override and os.path.getsize(file_path) == len(file_content) ): - # TODO: filesize 不同后的处理 file_status = f"文件 {filename} 已存在。" logger.warn(file_status) return dict(code=404, msg=file_status, data=data) @@ -116,7 +115,6 @@ def _save_files_in_thread(files: List[UploadFile], yield result -# TODO: 等langchain.document_loaders支持内存文件的时候再开通 # def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), # knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), # override: bool = Form(False, description="覆盖已有文件"), diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 44c0d64e..bd5a54eb 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -191,7 +191,6 @@ class KBService(ABC): ''' 传入参数为: {doc_id: Document, ...} 如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档 - TODO:是否要支持新增 docs ? ''' self.del_doc_by_ids(list(docs.keys())) docs = [] diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 32382929..43b616e2 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -70,7 +70,6 @@ class MilvusKBService(KBService): return score_threshold_process(score_threshold, top_k, docs) def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: - # TODO: workaround for bug #10492 in langchain for doc in docs: for k, v in doc.metadata.items(): doc.metadata[k] = str(v) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index ec0e147b..46efe7d8 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -32,8 +32,6 @@ class PGKBService(KBService): results = [Document(page_content=row[0], metadata=row[1]) for row in session.execute(stmt, {'ids': ids}).fetchall()] return results - - # TODO: def del_doc_by_ids(self, ids: List[str]) -> bool: return super().del_doc_by_ids(ids) diff --git a/server/knowledge_base/kb_summary/base.py b/server/knowledge_base/kb_summary/base.py index 00dcea6f..6d095fee 100644 --- a/server/knowledge_base/kb_summary/base.py +++ b/server/knowledge_base/kb_summary/base.py @@ -13,7 +13,6 @@ from server.db.repository.knowledge_metadata_repository import add_summary_to_db from langchain.docstore.document import Document -# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加 class KBSummaryService(ABC): kb_name: str embed_model: str diff --git a/server/knowledge_base/kb_summary/summary_chunk.py b/server/knowledge_base/kb_summary/summary_chunk.py index 0b88f233..7c2aaf47 100644 --- a/server/knowledge_base/kb_summary/summary_chunk.py +++ b/server/knowledge_base/kb_summary/summary_chunk.py @@ -112,12 +112,6 @@ class SummaryAdapter: docs: List[DocumentWithVSId] = []) -> List[Document]: logger.info("start summary") - # TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题 - # merge_docs = self._drop_overlap(docs) - # # 将merge_docs中的句子合并成一个文档 - # text = self._join_docs(merge_docs) - # 根据段落于句子的分隔符,将文档分成chunk,每个chunk长度小于token_max长度 - """ 这个过程分成两个部分: 1. 对每个文档进行处理,得到每个文档的摘要 diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 6064477a..3a8c701b 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -174,7 +174,6 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): if encode_detect is None: encode_detect = {"encoding": "utf-8"} loader_kwargs["encoding"] = encode_detect["encoding"] - ## TODO:支持更多的自定义CSV读取逻辑 elif loader_name == "JSONLoader": loader_kwargs.setdefault("jq_schema", ".") diff --git a/server/model_workers/azure.py b/server/model_workers/azure.py index 70959325..f0835ae1 100644 --- a/server/model_workers/azure.py +++ b/server/model_workers/azure.py @@ -67,12 +67,10 @@ class AzureWorker(ApiModelWorker): self.logger.error(f"请求 Azure API 时发生错误:{resp}") def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], system_message="You are a helpful, respectful and honest assistant.", diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py index 5e9cbbb0..75cfad4e 100644 --- a/server/model_workers/baichuan.py +++ b/server/model_workers/baichuan.py @@ -88,12 +88,10 @@ class BaiChuanWorker(ApiModelWorker): yield data def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], system_message="", diff --git a/server/model_workers/base.py b/server/model_workers/base.py index 88affb43..234ab47a 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -125,8 +125,6 @@ class ApiModelWorker(BaseModelWorker): def count_token(self, params): - # TODO:需要完善 - # print("count token") prompt = params["prompt"] return {"count": len(str(prompt)), "error_code": 0} diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py index ddbad4ab..fdb50a1c 100644 --- a/server/model_workers/fangzhou.py +++ b/server/model_workers/fangzhou.py @@ -12,16 +12,16 @@ class FangZhouWorker(ApiModelWorker): """ def __init__( - self, - *, - model_names: List[str] = ["fangzhou-api"], - controller_addr: str = None, - worker_addr: str = None, - version: Literal["chatglm-6b-model"] = "chatglm-6b-model", - **kwargs, + self, + *, + model_names: List[str] = ["fangzhou-api"], + controller_addr: str = None, + worker_addr: str = None, + version: Literal["chatglm-6b-model"] = "chatglm-6b-model", + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小 + kwargs.setdefault("context_len", 16384) super().__init__(**kwargs) self.version = version @@ -53,15 +53,15 @@ class FangZhouWorker(ApiModelWorker): if error := resp.error: if error.code_n > 0: data = { - "error_code": error.code_n, - "text": error.message, - "error": { - "message": error.message, - "type": "invalid_request_error", - "param": None, - "code": None, - } + "error_code": error.code_n, + "text": error.message, + "error": { + "message": error.message, + "type": "invalid_request_error", + "param": None, + "code": None, } + } self.logger.error(f"请求方舟 API 时发生错误:{data}") yield data elif chunk := resp.choice.message.content: @@ -77,7 +77,6 @@ class FangZhouWorker(ApiModelWorker): break def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) diff --git a/server/model_workers/gemini.py b/server/model_workers/gemini.py index db41029b..e9175b6e 100644 --- a/server/model_workers/gemini.py +++ b/server/model_workers/gemini.py @@ -95,7 +95,6 @@ class GeminiWorker(ApiModelWorker): print("Invalid JSON string:", json_string) def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index ba610d52..79d24514 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -37,7 +37,6 @@ class MiniMaxWorker(ApiModelWorker): def do_chat(self, params: ApiChatParams) -> Dict: # 按照官网推荐,直接调用abab 5.5模型 - # TODO: 支持指定回复要求,支持指定用户名称、AI名称 params.load_config(self.model_names[0]) url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}' @@ -55,7 +54,7 @@ class MiniMaxWorker(ApiModelWorker): "temperature": params.temperature, "top_p": params.top_p, "tokens_to_generate": params.max_tokens or 1024, - # TODO: 以下参数为minimax特有,传入空值会出错。 + # 以下参数为minimax特有,传入空值会出错。 # "prompt": params.system_message or self.conv.system_message, # "bot_setting": [], # "role_meta": params.role_meta, @@ -143,12 +142,10 @@ class MiniMaxWorker(ApiModelWorker): return {"code": 200, "data": result} def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。", diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 7dd3a355..da362ec6 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -187,14 +187,11 @@ class QianFanWorker(ApiModelWorker): i += batch_size return {"code": 200, "data": result} - # TODO: qianfan支持续写模型 def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], system_message="你是一个聪明的助手,请根据用户的提示来完成任务", diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py index 58d1bcd1..2741b74d 100644 --- a/server/model_workers/qwen.py +++ b/server/model_workers/qwen.py @@ -100,12 +100,10 @@ class QwenWorker(ApiModelWorker): return {"code": 200, "data": result} def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", diff --git a/server/model_workers/tiangong.py b/server/model_workers/tiangong.py index e127ea55..88010a15 100644 --- a/server/model_workers/tiangong.py +++ b/server/model_workers/tiangong.py @@ -70,12 +70,10 @@ class TianGongWorker(ApiModelWorker): yield data def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], system_message="", diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 1e772a33..de38308b 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -42,7 +42,6 @@ class XingHuoWorker(ApiModelWorker): self.version = version def do_chat(self, params: ApiChatParams) -> Dict: - # TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接 params.load_config(self.model_names[0]) version_mapping = { @@ -73,12 +72,10 @@ class XingHuoWorker(ApiModelWorker): yield {"error_code": 0, "text": text} def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # TODO: 确认模板是否需要修改 return conv.Conversation( name=self.model_names[0], system_message="你是一个聪明的助手,请根据用户的提示来完成任务", diff --git a/server/utils.py b/server/utils.py index 26ef967e..7fed5f8c 100644 --- a/server/utils.py +++ b/server/utils.py @@ -36,7 +36,6 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): await fn except Exception as e: logging.exception(e) - # TODO: handle exception msg = f"Caught exception: {e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) @@ -404,7 +403,7 @@ def fschat_controller_address() -> str: def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str: - if model := get_model_worker_config(model_name): # TODO: depends fastchat + if model := get_model_worker_config(model_name): host = model["host"] if host == "0.0.0.0": host = "127.0.0.1" @@ -449,7 +448,7 @@ def get_prompt_template(type: str, name: str) -> Optional[str]: from configs import prompt_config import importlib - importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载 + importlib.reload(prompt_config) return prompt_config.PROMPT_TEMPLATES[type].get(name) @@ -550,7 +549,7 @@ def run_in_thread_pool( thread = pool.submit(func, **kwargs) tasks.append(thread) - for obj in as_completed(tasks): # TODO: Ctrl+c无法停止 + for obj in as_completed(tasks): yield obj.result() diff --git a/startup.py b/startup.py index 0681dcda..359fb709 100644 --- a/startup.py +++ b/startup.py @@ -418,7 +418,7 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None): set_httpx_config() controller_addr = fschat_controller_address() - app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. + app = create_openai_api_app(controller_addr, log_level=log_level) _set_app_event(app, started_event) host = FSCHAT_OPENAI_API["host"] diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index b5691ffd..b9d2f7fd 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -126,7 +126,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): chat_box.use_chat_name(conversation_name) conversation_id = st.session_state["conversation_ids"][conversation_name] - # TODO: 对话模型与会话绑定 def on_mode_change(): mode = st.session_state.dialogue_mode text = f"已切换到 {mode} 模式。"