2024-03-06 13:42:01 +08:00

257 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
OVERLAP_SIZE,
logger, log_verbose, )
from server.knowledge_base.utils import (list_files_from_folder)
from sse_starlette import EventSourceResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List, Optional
from server.knowledge_base.kb_summary.base import KBSummaryService
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
def recreate_summary_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
重建单个知识库文件摘要
:param max_tokens:
:param endpoint_host:
:param endpoint_host_key:
:param endpoint_host_proxy:
:param model_name:
:param temperature:
:param file_description:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
files = list_files_from_folder(knowledge_base_name)
i = 0
for i, file_name in enumerate(files):
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
return EventSourceResponse(output())
def summary_file_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["test.pdf"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
单个知识库根据文件名称摘要
:param endpoint_host:
:param endpoint_host_key:
:param endpoint_host_proxy:
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param file_name:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f" {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"{file_name} 总结完成",
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
return EventSourceResponse(output())
def summary_doc_ids_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
doc_ids: List = Body([], examples=[["uuid"]]),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
) -> BaseResponse:
"""
单个知识库根据doc_ids摘要
:param knowledge_base_name:
:param endpoint_host:
:param endpoint_host_key:
:param endpoint_host_proxy:
:param doc_ids:
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param vs_type:
:param embed_model:
:return:
"""
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists():
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
else:
llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
doc_infos = kb.get_doc_by_ids(ids=doc_ids)
# doc_infos转换成DocumentWithVSId包装的对象
doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)]
docs = summary.summarize(file_description=file_description,
docs=doc_info_with_ids)
# 将docs转换成dict
resp_summarize = [{**doc.dict()} for doc in docs]
return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize})