diff --git a/document_loaders/FilteredCSVloader.py b/server/document_loaders/FilteredCSVloader.py similarity index 100% rename from document_loaders/FilteredCSVloader.py rename to server/document_loaders/FilteredCSVloader.py diff --git a/document_loaders/__init__.py b/server/document_loaders/__init__.py similarity index 100% rename from document_loaders/__init__.py rename to server/document_loaders/__init__.py diff --git a/document_loaders/mydocloader.py b/server/document_loaders/mydocloader.py similarity index 100% rename from document_loaders/mydocloader.py rename to server/document_loaders/mydocloader.py diff --git a/document_loaders/myimgloader.py b/server/document_loaders/myimgloader.py similarity index 94% rename from document_loaders/myimgloader.py rename to server/document_loaders/myimgloader.py index e09c6172..ffedf8ff 100644 --- a/document_loaders/myimgloader.py +++ b/server/document_loaders/myimgloader.py @@ -1,6 +1,6 @@ from typing import List from langchain.document_loaders.unstructured import UnstructuredFileLoader -from document_loaders.ocr import get_ocr +from server.document_loaders.ocr import get_ocr class RapidOCRLoader(UnstructuredFileLoader): diff --git a/document_loaders/mypdfloader.py b/server/document_loaders/mypdfloader.py similarity index 98% rename from document_loaders/mypdfloader.py rename to server/document_loaders/mypdfloader.py index 71dfd137..b15364be 100644 --- a/document_loaders/mypdfloader.py +++ b/server/document_loaders/mypdfloader.py @@ -4,7 +4,7 @@ import cv2 from PIL import Image import numpy as np from configs import PDF_OCR_THRESHOLD -from document_loaders.ocr import get_ocr +from server.document_loaders.ocr import get_ocr import tqdm diff --git a/document_loaders/mypptloader.py b/server/document_loaders/mypptloader.py similarity index 100% rename from document_loaders/mypptloader.py rename to server/document_loaders/mypptloader.py diff --git a/document_loaders/ocr.py b/server/document_loaders/ocr.py similarity index 100% rename from document_loaders/ocr.py rename to server/document_loaders/ocr.py diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 02962527..e7aa33d2 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +import operator import os from pathlib import Path from langchain.docstore.document import Document diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index d37b52c7..9e52162a 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -11,7 +11,7 @@ from configs import ( TEXT_SPLITTER_NAME, ) import importlib -from text_splitter import zh_title_enhance as func_zh_title_enhance +from server.text_splitter import zh_title_enhance as func_zh_title_enhance import langchain.document_loaders from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter @@ -153,15 +153,15 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): try: if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader", "RapidOCRDocLoader", "RapidOCRPPTLoader"]: - document_loaders_module = importlib.import_module('document_loaders') + document_loaders_module = importlib.import_module("server.document_loaders") else: - document_loaders_module = importlib.import_module('langchain.document_loaders') + document_loaders_module = importlib.import_module("langchain.document_loaders") DocumentLoader = getattr(document_loaders_module, loader_name) except Exception as e: msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) - document_loaders_module = importlib.import_module('langchain.document_loaders') + document_loaders_module = importlib.import_module("langchain.document_loaders") DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") if loader_name == "UnstructuredFileLoader": @@ -204,10 +204,10 @@ def make_text_splitter( else: try: ## 优先使用用户自定义的text_splitter - text_splitter_module = importlib.import_module('text_splitter') + text_splitter_module = importlib.import_module("server.text_splitter") TextSplitter = getattr(text_splitter_module, splitter_name) except: ## 否则使用langchain的text_splitter - text_splitter_module = importlib.import_module('langchain.text_splitter') + text_splitter_module = importlib.import_module("langchain.text_splitter") TextSplitter = getattr(text_splitter_module, splitter_name) if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载 diff --git a/server/model_workers/gemini.py b/server/model_workers/gemini.py deleted file mode 100644 index e9175b6e..00000000 --- a/server/model_workers/gemini.py +++ /dev/null @@ -1,123 +0,0 @@ -import sys -from fastchat.conversation import Conversation -from server.model_workers.base import * -from server.utils import get_httpx_client -from fastchat import conversation as conv -import json, httpx -from typing import List, Dict -from configs import logger, log_verbose - - -class GeminiWorker(ApiModelWorker): - def __init__( - self, - *, - controller_addr: str = None, - worker_addr: str = None, - model_names: List[str] = ["gemini-api"], - **kwargs, - ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 4096) - super().__init__(**kwargs) - - def create_gemini_messages(self, messages) -> json: - has_history = any(msg['role'] == 'assistant' for msg in messages) - gemini_msg = [] - - for msg in messages: - role = msg['role'] - content = msg['content'] - if role == 'system': - continue - if has_history: - if role == 'assistant': - role = "model" - transformed_msg = {"role": role, "parts": [{"text": content}]} - else: - if role == 'user': - transformed_msg = {"parts": [{"text": content}]} - - gemini_msg.append(transformed_msg) - - msg = dict(contents=gemini_msg) - return msg - - def do_chat(self, params: ApiChatParams) -> Dict: - params.load_config(self.model_names[0]) - data = self.create_gemini_messages(messages=params.messages) - generationConfig = dict( - temperature=params.temperature, - topK=1, - topP=1, - maxOutputTokens=4096, - stopSequences=[] - ) - - data['generationConfig'] = generationConfig - url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key - headers = { - 'Content-Type': 'application/json', - } - if log_verbose: - logger.info(f'{self.__class__.__name__}:url: {url}') - logger.info(f'{self.__class__.__name__}:headers: {headers}') - logger.info(f'{self.__class__.__name__}:data: {data}') - - text = "" - json_string = "" - timeout = httpx.Timeout(60.0) - client = get_httpx_client(timeout=timeout) - with client.stream("POST", url, headers=headers, json=data) as response: - for line in response.iter_lines(): - line = line.strip() - if not line or "[DONE]" in line: - continue - - json_string += line - - try: - resp = json.loads(json_string) - if 'candidates' in resp: - for candidate in resp['candidates']: - content = candidate.get('content', {}) - parts = content.get('parts', []) - for part in parts: - if 'text' in part: - text += part['text'] - yield { - "error_code": 0, - "text": text - } - print(text) - except json.JSONDecodeError as e: - print("Failed to decode JSON:", e) - print("Invalid JSON string:", json_string) - - def get_embeddings(self, params): - print("embedding") - print(params) - - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - return conv.Conversation( - name=self.model_names[0], - system_message="You are a helpful, respectful and honest assistant.", - messages=[], - roles=["user", "assistant"], - sep="\n### ", - stop_str="###", - ) - - -if __name__ == "__main__": - import uvicorn - from server.utils import MakeFastAPIOffline - from fastchat.serve.base_model_worker import app - - worker = GeminiWorker( - controller_addr="http://127.0.0.1:20001", - worker_addr="http://127.0.0.1:21012", - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - MakeFastAPIOffline(app) - uvicorn.run(app, port=21012) diff --git a/text_splitter/__init__.py b/server/text_splitter/__init__.py similarity index 100% rename from text_splitter/__init__.py rename to server/text_splitter/__init__.py diff --git a/text_splitter/ali_text_splitter.py b/server/text_splitter/ali_text_splitter.py similarity index 100% rename from text_splitter/ali_text_splitter.py rename to server/text_splitter/ali_text_splitter.py diff --git a/text_splitter/chinese_recursive_text_splitter.py b/server/text_splitter/chinese_recursive_text_splitter.py similarity index 100% rename from text_splitter/chinese_recursive_text_splitter.py rename to server/text_splitter/chinese_recursive_text_splitter.py diff --git a/text_splitter/chinese_text_splitter.py b/server/text_splitter/chinese_text_splitter.py similarity index 100% rename from text_splitter/chinese_text_splitter.py rename to server/text_splitter/chinese_text_splitter.py diff --git a/text_splitter/zh_title_enhance.py b/server/text_splitter/zh_title_enhance.py similarity index 100% rename from text_splitter/zh_title_enhance.py rename to server/text_splitter/zh_title_enhance.py