使用多进程提高导入知识库的速度 (#3276)
@ -193,7 +193,7 @@ $ python startup.py -a
|
|||||||
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
||||||
|
|
||||||
### 项目交流群
|
### 项目交流群
|
||||||
<img src="img/qr_code_90.jpg" alt="二维码" width="300" />
|
<img src="img/qr_code_96.jpg" alt="二维码" width="300" />
|
||||||
|
|
||||||
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||||
|
|
||||||
|
|||||||
BIN
img/qr_code_90.jpg
Normal file
|
After Width: | Height: | Size: 273 KiB |
BIN
img/qr_code_91.jpg
Normal file
|
After Width: | Height: | Size: 227 KiB |
BIN
img/qr_code_92.jpg
Normal file
|
After Width: | Height: | Size: 213 KiB |
BIN
img/qr_code_93.jpg
Normal file
|
After Width: | Height: | Size: 226 KiB |
BIN
img/qr_code_94.jpg
Normal file
|
After Width: | Height: | Size: 244 KiB |
BIN
img/qr_code_95.jpg
Normal file
|
After Width: | Height: | Size: 252 KiB |
BIN
img/qr_code_96.jpg
Normal file
|
After Width: | Height: | Size: 222 KiB |
BIN
img/qrcode_90_2.jpg
Normal file
|
After Width: | Height: | Size: 232 KiB |
@ -38,6 +38,9 @@ def search_docs(
|
|||||||
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||||
elif file_name or metadata:
|
elif file_name or metadata:
|
||||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||||
|
for d in data:
|
||||||
|
if "vector" in d.metadata:
|
||||||
|
del d.metadata["vector"]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import langchain_community.document_loaders
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import TextSplitter
|
from langchain.text_splitter import TextSplitter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from server.utils import run_in_thread_pool
|
from server.utils import run_in_thread_pool, run_in_process_pool
|
||||||
import json
|
import json
|
||||||
from typing import List, Union, Dict, Tuple, Generator
|
from typing import List, Union, Dict, Tuple, Generator
|
||||||
import chardet
|
import chardet
|
||||||
@ -353,6 +353,16 @@ class KnowledgeFile:
|
|||||||
return os.path.getsize(self.filepath)
|
return os.path.getsize(self.filepath)
|
||||||
|
|
||||||
|
|
||||||
|
def files2docs_in_thread_file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||||
|
try:
|
||||||
|
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
|
||||||
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
|
exc_info=e if log_verbose else None)
|
||||||
|
return False, (file.kb_name, file.filename, msg)
|
||||||
|
|
||||||
|
|
||||||
def files2docs_in_thread(
|
def files2docs_in_thread(
|
||||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
@ -365,14 +375,6 @@ def files2docs_in_thread(
|
|||||||
生成器返回值为 status, (kb_name, file_name, docs | error)
|
生成器返回值为 status, (kb_name, file_name, docs | error)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
|
||||||
try:
|
|
||||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
|
||||||
except Exception as e:
|
|
||||||
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
|
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
return False, (file.kb_name, file.filename, msg)
|
|
||||||
|
|
||||||
kwargs_list = []
|
kwargs_list = []
|
||||||
for i, file in enumerate(files):
|
for i, file in enumerate(files):
|
||||||
@ -395,7 +397,7 @@ def files2docs_in_thread(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield False, (kb_name, filename, str(e))
|
yield False, (kb_name, filename, str(e))
|
||||||
|
|
||||||
for result in run_in_thread_pool(func=file2docs, params=kwargs_list):
|
for result in run_in_process_pool(func=files2docs_in_thread_file2docs, params=kwargs_list):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,9 @@ from fastapi import FastAPI
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
import sys
|
||||||
from langchain.embeddings.base import Embeddings
|
import multiprocessing as mp
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||||
from langchain_openai.chat_models import ChatOpenAI
|
from langchain_openai.chat_models import ChatOpenAI
|
||||||
from langchain_openai.llms import OpenAI
|
from langchain_openai.llms import OpenAI
|
||||||
import httpx
|
import httpx
|
||||||
@ -572,11 +573,36 @@ def run_in_thread_pool(
|
|||||||
tasks = []
|
tasks = []
|
||||||
with ThreadPoolExecutor() as pool:
|
with ThreadPoolExecutor() as pool:
|
||||||
for kwargs in params:
|
for kwargs in params:
|
||||||
thread = pool.submit(func, **kwargs)
|
tasks.append(pool.submit(func, **kwargs))
|
||||||
tasks.append(thread)
|
|
||||||
|
|
||||||
for obj in as_completed(tasks):
|
for obj in as_completed(tasks):
|
||||||
yield obj.result()
|
try:
|
||||||
|
yield obj.result()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"error in sub thread: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def run_in_process_pool(
|
||||||
|
func: Callable,
|
||||||
|
params: List[Dict] = [],
|
||||||
|
) -> Generator:
|
||||||
|
'''
|
||||||
|
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||||
|
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||||
|
'''
|
||||||
|
tasks = []
|
||||||
|
max_workers = None
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
|
max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows
|
||||||
|
with ProcessPoolExecutor(max_workers=max_workers) as pool:
|
||||||
|
for kwargs in params:
|
||||||
|
tasks.append(pool.submit(func, **kwargs))
|
||||||
|
|
||||||
|
for obj in as_completed(tasks):
|
||||||
|
try:
|
||||||
|
yield obj.result()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"error in sub process: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def get_httpx_client(
|
def get_httpx_client(
|
||||||
|
|||||||