使用多进程提高导入知识库的速度 (#3276)

This commit is contained in:
dignfei 2024-03-13 10:36:14 +08:00 committed by GitHub
parent 6310095a00
commit 7b2b24c0bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 47 additions and 16 deletions

View File

@ -193,7 +193,7 @@ $ python startup.py -a
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
### 项目交流群
<img src="img/qr_code_90.jpg" alt="二维码" width="300" />
<img src="img/qr_code_96.jpg" alt="二维码" width="300" />
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

BIN
img/qr_code_90.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 273 KiB

BIN
img/qr_code_91.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 227 KiB

BIN
img/qr_code_92.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 213 KiB

BIN
img/qr_code_93.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 226 KiB

BIN
img/qr_code_94.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 244 KiB

BIN
img/qr_code_95.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 252 KiB

BIN
img/qr_code_96.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 222 KiB

BIN
img/qrcode_90_2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB

View File

@ -38,6 +38,9 @@ def search_docs(
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
for d in data:
if "vector" in d.metadata:
del d.metadata["vector"]
return data

View File

@ -16,7 +16,7 @@ import langchain_community.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
from server.utils import run_in_thread_pool
from server.utils import run_in_thread_pool, run_in_process_pool
import json
from typing import List, Union, Dict, Tuple, Generator
import chardet
@ -353,6 +353,16 @@ class KnowledgeFile:
return os.path.getsize(self.filepath)
def files2docs_in_thread_file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return False, (file.kb_name, file.filename, msg)
def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
chunk_size: int = CHUNK_SIZE,
@ -365,14 +375,6 @@ def files2docs_in_thread(
生成器返回值为 status, (kb_name, file_name, docs | error)
'''
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return False, (file.kb_name, file.filename, msg)
kwargs_list = []
for i, file in enumerate(files):
@ -395,7 +397,7 @@ def files2docs_in_thread(
except Exception as e:
yield False, (kb_name, filename, str(e))
for result in run_in_thread_pool(func=file2docs, params=kwargs_list):
for result in run_in_process_pool(func=files2docs_in_thread_file2docs, params=kwargs_list):
yield result

View File

@ -2,8 +2,9 @@ from fastapi import FastAPI
from pathlib import Path
import asyncio
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.embeddings.base import Embeddings
import sys
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.llms import OpenAI
import httpx
@ -572,11 +573,36 @@ def run_in_thread_pool(
tasks = []
with ThreadPoolExecutor() as pool:
for kwargs in params:
thread = pool.submit(func, **kwargs)
tasks.append(thread)
tasks.append(pool.submit(func, **kwargs))
for obj in as_completed(tasks):
yield obj.result()
try:
yield obj.result()
except Exception as e:
logger.error(f"error in sub thread: {e}", exc_info=True)
def run_in_process_pool(
func: Callable,
params: List[Dict] = [],
) -> Generator:
'''
在线程池中批量运行任务并将运行结果以生成器的形式返回
请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数
'''
tasks = []
max_workers = None
if sys.platform.startswith("win"):
max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows
with ProcessPoolExecutor(max_workers=max_workers) as pool:
for kwargs in params:
tasks.append(pool.submit(func, **kwargs))
for obj in as_completed(tasks):
try:
yield obj.result()
except Exception as e:
logger.error(f"error in sub process: {e}", exc_info=True)
def get_httpx_client(