move document_loaders & text_splitter under server

This commit is contained in:
liunux4odoo 2024-02-08 11:58:28 +08:00
parent 5d422ca9a1
commit 73eb5e2e32
15 changed files with 9 additions and 131 deletions

View File

@ -1,6 +1,6 @@
from typing import List from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader from langchain.document_loaders.unstructured import UnstructuredFileLoader
from document_loaders.ocr import get_ocr from server.document_loaders.ocr import get_ocr
class RapidOCRLoader(UnstructuredFileLoader): class RapidOCRLoader(UnstructuredFileLoader):

View File

@ -4,7 +4,7 @@ import cv2
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from configs import PDF_OCR_THRESHOLD from configs import PDF_OCR_THRESHOLD
from document_loaders.ocr import get_ocr from server.document_loaders.ocr import get_ocr
import tqdm import tqdm

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import operator
import os import os
from pathlib import Path from pathlib import Path
from langchain.docstore.document import Document from langchain.docstore.document import Document

View File

@ -11,7 +11,7 @@ from configs import (
TEXT_SPLITTER_NAME, TEXT_SPLITTER_NAME,
) )
import importlib import importlib
from text_splitter import zh_title_enhance as func_zh_title_enhance from server.text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain.document_loaders import langchain.document_loaders
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter
@ -153,15 +153,15 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
try: try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader", if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader",
"RapidOCRDocLoader", "RapidOCRPPTLoader"]: "RapidOCRDocLoader", "RapidOCRPPTLoader"]:
document_loaders_module = importlib.import_module('document_loaders') document_loaders_module = importlib.import_module("server.document_loaders")
else: else:
document_loaders_module = importlib.import_module('langchain.document_loaders') document_loaders_module = importlib.import_module("langchain.document_loaders")
DocumentLoader = getattr(document_loaders_module, loader_name) DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e: except Exception as e:
msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}" msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
document_loaders_module = importlib.import_module('langchain.document_loaders') document_loaders_module = importlib.import_module("langchain.document_loaders")
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if loader_name == "UnstructuredFileLoader": if loader_name == "UnstructuredFileLoader":
@ -204,10 +204,10 @@ def make_text_splitter(
else: else:
try: ## 优先使用用户自定义的text_splitter try: ## 优先使用用户自定义的text_splitter
text_splitter_module = importlib.import_module('text_splitter') text_splitter_module = importlib.import_module("server.text_splitter")
TextSplitter = getattr(text_splitter_module, splitter_name) TextSplitter = getattr(text_splitter_module, splitter_name)
except: ## 否则使用langchain的text_splitter except: ## 否则使用langchain的text_splitter
text_splitter_module = importlib.import_module('langchain.text_splitter') text_splitter_module = importlib.import_module("langchain.text_splitter")
TextSplitter = getattr(text_splitter_module, splitter_name) TextSplitter = getattr(text_splitter_module, splitter_name)
if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载 if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载

View File

@ -1,123 +0,0 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import json, httpx
from typing import List, Dict
from configs import logger, log_verbose
class GeminiWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["gemini-api"],
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs)
def create_gemini_messages(self, messages) -> json:
has_history = any(msg['role'] == 'assistant' for msg in messages)
gemini_msg = []
for msg in messages:
role = msg['role']
content = msg['content']
if role == 'system':
continue
if has_history:
if role == 'assistant':
role = "model"
transformed_msg = {"role": role, "parts": [{"text": content}]}
else:
if role == 'user':
transformed_msg = {"parts": [{"text": content}]}
gemini_msg.append(transformed_msg)
msg = dict(contents=gemini_msg)
return msg
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
data = self.create_gemini_messages(messages=params.messages)
generationConfig = dict(
temperature=params.temperature,
topK=1,
topP=1,
maxOutputTokens=4096,
stopSequences=[]
)
data['generationConfig'] = generationConfig
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key
headers = {
'Content-Type': 'application/json',
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
logger.info(f'{self.__class__.__name__}:data: {data}')
text = ""
json_string = ""
timeout = httpx.Timeout(60.0)
client = get_httpx_client(timeout=timeout)
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()
if not line or "[DONE]" in line:
continue
json_string += line
try:
resp = json.loads(json_string)
if 'candidates' in resp:
for candidate in resp['candidates']:
content = candidate.get('content', {})
parts = content.get('parts', [])
for part in parts:
if 'text' in part:
text += part['text']
yield {
"error_code": 0,
"text": text
}
print(text)
except json.JSONDecodeError as e:
print("Failed to decode JSON:", e)
print("Invalid JSON string:", json_string)
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="You are a helpful, respectful and honest assistant.",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = GeminiWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21012",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21012)