泛型标记错误 (#3855)

* 仓库设置,一些版本问题

* pip源仓库设置,一些版本问题,启动说明

* 配置说明

* 发布的依赖信息

* 泛型标记错误
This commit is contained in:
glide-the 2024-04-23 19:33:31 +08:00 committed by GitHub
parent eaea9a61cc
commit a3c4b6da49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 23 deletions

View File

@ -3,11 +3,10 @@ import multiprocessing
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import multiprocessing as mp import multiprocessing as mp
import os import os
import subprocess import logging
import sys import sys
from multiprocessing import Process from multiprocessing import Process
logger = logging.getLogger()
from chatchat.model_loaders.init_server import init_server
# 设置numexpr最大线程数默认为CPU核心数 # 设置numexpr最大线程数默认为CPU核心数
try: try:
@ -18,20 +17,9 @@ try:
except: except:
pass pass
from chatchat.configs import ( from fastapi import FastAPI
LOG_PATH,
log_verbose,
logger,
DEFAULT_EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
API_SERVER,
WEBUI_SERVER, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT
)
from chatchat.server.utils import FastAPI
from chatchat.server.knowledge_base.migrate import create_tables
import argparse import argparse
from typing import List, Dict from typing import List, Dict
from chatchat.configs import VERSION
def _set_app_event(app: FastAPI, started_event: mp.Event = None): def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@ -48,9 +36,20 @@ def run_init_server(
model_platforms_shard: Dict, model_platforms_shard: Dict,
started_event: mp.Event = None, started_event: mp.Event = None,
run_mode: str = None, run_mode: str = None,
model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG, model_providers_cfg_path: str = None,
provider_host: str = MODEL_PROVIDERS_CFG_HOST, provider_host: str = None,
provider_port: int = MODEL_PROVIDERS_CFG_PORT): provider_port: int = None):
from chatchat.model_loaders.init_server import init_server
from chatchat.configs import (MODEL_PROVIDERS_CFG_PATH_CONFIG,
MODEL_PROVIDERS_CFG_HOST,
MODEL_PROVIDERS_CFG_PORT)
if model_providers_cfg_path is None:
model_providers_cfg_path = MODEL_PROVIDERS_CFG_PATH_CONFIG
if provider_host is None:
provider_host = MODEL_PROVIDERS_CFG_HOST
if provider_port is None:
provider_port = MODEL_PROVIDERS_CFG_PORT
init_server(model_platforms_shard=model_platforms_shard, init_server(model_platforms_shard=model_platforms_shard,
started_event=started_event, started_event=started_event,
model_providers_cfg_path=model_providers_cfg_path, model_providers_cfg_path=model_providers_cfg_path,
@ -64,7 +63,7 @@ def run_api_server(model_platforms_shard: Dict,
from chatchat.server.api_server.server_app import create_app from chatchat.server.api_server.server_app import create_app
import uvicorn import uvicorn
from chatchat.server.utils import set_httpx_config from chatchat.server.utils import set_httpx_config
from chatchat.configs import MODEL_PLATFORMS from chatchat.configs import MODEL_PLATFORMS, API_SERVER
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms']) MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}") logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}")
set_httpx_config() set_httpx_config()
@ -81,7 +80,7 @@ def run_webui(model_platforms_shard: Dict,
started_event: mp.Event = None, run_mode: str = None): started_event: mp.Event = None, run_mode: str = None):
import sys import sys
from chatchat.server.utils import set_httpx_config from chatchat.server.utils import set_httpx_config
from chatchat.configs import MODEL_PLATFORMS from chatchat.configs import MODEL_PLATFORMS, WEBUI_SERVER
if model_platforms_shard.get('provider_platforms'): if model_platforms_shard.get('provider_platforms'):
MODEL_PLATFORMS.extend(model_platforms_shard.get('provider_platforms')) MODEL_PLATFORMS.extend(model_platforms_shard.get('provider_platforms'))
logger.info(f"Webui MODEL_PLATFORMS: {MODEL_PLATFORMS}") logger.info(f"Webui MODEL_PLATFORMS: {MODEL_PLATFORMS}")
@ -184,7 +183,7 @@ def dump_server_info(after_start=False, args=None):
import platform import platform
import langchain import langchain
from chatchat.server.utils import api_address, webui_address from chatchat.server.utils import api_address, webui_address
from chatchat.configs import VERSION, TEXT_SPLITTER_NAME, DEFAULT_EMBEDDING_MODEL
print("\n") print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.") print(f"操作系统:{platform.platform()}.")
@ -216,6 +215,7 @@ def dump_server_info(after_start=False, args=None):
async def start_main_server(): async def start_main_server():
import time import time
import signal import signal
from chatchat.configs import LOG_PATH
def handler(signalname): def handler(signalname):
""" """
@ -346,9 +346,13 @@ async def start_main_server():
logger.info("Process status: %s", p) logger.info("Process status: %s", p)
if __name__ == "__main__": def main():
# 添加这行代码 # 添加这行代码
cwd = os.getcwd()
sys.path.append(cwd)
multiprocessing.freeze_support() multiprocessing.freeze_support()
print("cwd:"+cwd)
from chatchat.server.knowledge_base.migrate import create_tables
create_tables() create_tables()
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -360,3 +364,7 @@ if __name__ == "__main__":
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.run_until_complete(start_main_server()) loop.run_until_complete(start_main_server())
if __name__ == "__main__":
main()

View File

@ -5,6 +5,9 @@ description = ""
authors = ["chatchat"] authors = ["chatchat"]
readme = "README.md" readme = "README.md"
[tool.poetry.scripts]
chatchat = 'chatchat.startup:main'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<3.12,!=3.9.7" python = ">=3.8.1,<3.12,!=3.9.7"
model-providers = "^0.3.0" model-providers = "^0.3.0"

View File

@ -526,7 +526,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _extract_response_tool_calls( def _extract_response_tool_calls(
self, self,
response_tool_calls: Union[ response_tool_calls: Union[
List[ChatCompletionMessageToolCall, ChoiceDeltaToolCall] List[ChatCompletionMessageToolCall],
List[ChoiceDeltaToolCall]
], ],
) -> List[AssistantPromptMessage.ToolCall]: ) -> List[AssistantPromptMessage.ToolCall]:
""" """