mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 07:53:29 +08:00
commit
0a37fe93b8
@ -262,6 +262,10 @@ def make_text_splitter(
|
|||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||||
text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
|
|
||||||
|
# If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
|
||||||
|
# text_splitter._tokenizer.max_length = 37016792
|
||||||
|
# text_splitter._tokenizer.prefer_gpu()
|
||||||
return text_splitter
|
return text_splitter
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
103
startup.py
103
startup.py
@ -6,9 +6,8 @@ import sys
|
|||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
from langchain_core._api import deprecated
|
||||||
|
|
||||||
|
|
||||||
# 设置numexpr最大线程数,默认为CPU核心数
|
|
||||||
try:
|
try:
|
||||||
import numexpr
|
import numexpr
|
||||||
|
|
||||||
@ -33,15 +32,18 @@ from configs import (
|
|||||||
HTTPX_DEFAULT_TIMEOUT,
|
HTTPX_DEFAULT_TIMEOUT,
|
||||||
)
|
)
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||||
fschat_openai_api_address, set_httpx_config, get_httpx_client,
|
fschat_openai_api_address, get_httpx_client, get_model_worker_config,
|
||||||
get_model_worker_config, get_all_model_worker_configs,
|
|
||||||
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
||||||
from server.knowledge_base.migrate import create_tables
|
from server.knowledge_base.migrate import create_tables
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Tuple, List, Dict
|
from typing import List, Dict
|
||||||
from configs import VERSION
|
from configs import VERSION
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.3.0",
|
||||||
|
message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动,0.2.x中相关功能将废弃",
|
||||||
|
removal="0.3.0")
|
||||||
def create_controller_app(
|
def create_controller_app(
|
||||||
dispatch_method: str,
|
dispatch_method: str,
|
||||||
log_level: str = "INFO",
|
log_level: str = "INFO",
|
||||||
@ -88,7 +90,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
|
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
|
if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
|
||||||
from fastchat.serve.base_model_worker import app
|
from fastchat.serve.base_model_worker import app
|
||||||
worker = ""
|
worker = ""
|
||||||
# 在线模型API
|
# 在线模型API
|
||||||
@ -107,12 +109,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
import fastchat.serve.vllm_worker
|
import fastchat.serve.vllm_worker
|
||||||
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
|
||||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
args.tokenizer = args.model_path
|
||||||
args.tokenizer_mode = 'auto'
|
args.tokenizer_mode = 'auto'
|
||||||
args.trust_remote_code= True
|
args.trust_remote_code = True
|
||||||
args.download_dir= None
|
args.download_dir = None
|
||||||
args.load_format = 'auto'
|
args.load_format = 'auto'
|
||||||
args.dtype = 'auto'
|
args.dtype = 'auto'
|
||||||
args.seed = 0
|
args.seed = 0
|
||||||
@ -122,13 +124,13 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
args.block_size = 16
|
args.block_size = 16
|
||||||
args.swap_space = 4 # GiB
|
args.swap_space = 4 # GiB
|
||||||
args.gpu_memory_utilization = 0.90
|
args.gpu_memory_utilization = 0.90
|
||||||
args.max_num_batched_tokens = None # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
|
args.max_num_batched_tokens = None # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
|
||||||
args.max_num_seqs = 256
|
args.max_num_seqs = 256
|
||||||
args.disable_log_stats = False
|
args.disable_log_stats = False
|
||||||
args.conv_template = None
|
args.conv_template = None
|
||||||
args.limit_worker_concurrency = 5
|
args.limit_worker_concurrency = 5
|
||||||
args.no_register = False
|
args.no_register = False
|
||||||
args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
||||||
args.engine_use_ray = False
|
args.engine_use_ray = False
|
||||||
args.disable_log_requests = False
|
args.disable_log_requests = False
|
||||||
|
|
||||||
@ -154,16 +156,16 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
worker = VLLMWorker(
|
worker = VLLMWorker(
|
||||||
controller_addr = args.controller_address,
|
controller_addr=args.controller_address,
|
||||||
worker_addr = args.worker_address,
|
worker_addr=args.worker_address,
|
||||||
worker_id = worker_id,
|
worker_id=worker_id,
|
||||||
model_path = args.model_path,
|
model_path=args.model_path,
|
||||||
model_names = args.model_names,
|
model_names=args.model_names,
|
||||||
limit_worker_concurrency = args.limit_worker_concurrency,
|
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||||
no_register = args.no_register,
|
no_register=args.no_register,
|
||||||
llm_engine = engine,
|
llm_engine=engine,
|
||||||
conv_template = args.conv_template,
|
conv_template=args.conv_template,
|
||||||
)
|
)
|
||||||
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
||||||
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||||
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
||||||
@ -171,7 +173,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
else:
|
else:
|
||||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||||
|
|
||||||
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
||||||
args.max_gpu_memory = "22GiB"
|
args.max_gpu_memory = "22GiB"
|
||||||
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
||||||
|
|
||||||
@ -325,7 +327,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
|||||||
|
|
||||||
with get_httpx_client() as client:
|
with get_httpx_client() as client:
|
||||||
r = client.post(worker_address + "/release",
|
r = client.post(worker_address + "/release",
|
||||||
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
|
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
msg = f"failed to release model: {model_name}"
|
msg = f"failed to release model: {model_name}"
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
@ -393,8 +395,8 @@ def run_model_worker(
|
|||||||
# add interface to release and load model
|
# add interface to release and load model
|
||||||
@app.post("/release")
|
@app.post("/release")
|
||||||
def release_model(
|
def release_model(
|
||||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
if keep_origin:
|
if keep_origin:
|
||||||
if new_model_name:
|
if new_model_name:
|
||||||
@ -450,13 +452,13 @@ def run_webui(started_event: mp.Event = None, run_mode: str = None):
|
|||||||
port = WEBUI_SERVER["port"]
|
port = WEBUI_SERVER["port"]
|
||||||
|
|
||||||
cmd = ["streamlit", "run", "webui.py",
|
cmd = ["streamlit", "run", "webui.py",
|
||||||
"--server.address", host,
|
"--server.address", host,
|
||||||
"--server.port", str(port),
|
"--server.port", str(port),
|
||||||
"--theme.base", "light",
|
"--theme.base", "light",
|
||||||
"--theme.primaryColor", "#165dff",
|
"--theme.primaryColor", "#165dff",
|
||||||
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
||||||
"--theme.textColor", "#000000",
|
"--theme.textColor", "#000000",
|
||||||
]
|
]
|
||||||
if run_mode == "lite":
|
if run_mode == "lite":
|
||||||
cmd += [
|
cmd += [
|
||||||
"--",
|
"--",
|
||||||
@ -605,8 +607,10 @@ async def start_main_server():
|
|||||||
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
|
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
|
||||||
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
|
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(signal_received, frame):
|
def f(signal_received, frame):
|
||||||
raise KeyboardInterrupt(f"{signalname} received")
|
raise KeyboardInterrupt(f"{signalname} received")
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
# This will be inherited by the child process if it is forked (not spawned)
|
# This will be inherited by the child process if it is forked (not spawned)
|
||||||
@ -701,8 +705,8 @@ async def start_main_server():
|
|||||||
for model_name in args.model_name:
|
for model_name in args.model_name:
|
||||||
config = get_model_worker_config(model_name)
|
config = get_model_worker_config(model_name)
|
||||||
if (config.get("online_api")
|
if (config.get("online_api")
|
||||||
and config.get("worker_class")
|
and config.get("worker_class")
|
||||||
and model_name in FSCHAT_MODEL_WORKERS):
|
and model_name in FSCHAT_MODEL_WORKERS):
|
||||||
e = manager.Event()
|
e = manager.Event()
|
||||||
model_worker_started.append(e)
|
model_worker_started.append(e)
|
||||||
process = Process(
|
process = Process(
|
||||||
@ -742,12 +746,12 @@ async def start_main_server():
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# 保证任务收到SIGINT后,能够正常退出
|
# 保证任务收到SIGINT后,能够正常退出
|
||||||
if p:= processes.get("controller"):
|
if p := processes.get("controller"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
controller_started.wait() # 等待controller启动完成
|
controller_started.wait() # 等待controller启动完成
|
||||||
|
|
||||||
if p:= processes.get("openai_api"):
|
if p := processes.get("openai_api"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
@ -763,24 +767,24 @@ async def start_main_server():
|
|||||||
for e in model_worker_started:
|
for e in model_worker_started:
|
||||||
e.wait()
|
e.wait()
|
||||||
|
|
||||||
if p:= processes.get("api"):
|
if p := processes.get("api"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
api_started.wait() # 等待api.py启动完成
|
api_started.wait() # 等待api.py启动完成
|
||||||
|
|
||||||
if p:= processes.get("webui"):
|
if p := processes.get("webui"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
webui_started.wait() # 等待webui.py启动完成
|
webui_started.wait() # 等待webui.py启动完成
|
||||||
|
|
||||||
dump_server_info(after_start=True, args=args)
|
dump_server_info(after_start=True, args=args)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cmd = queue.get() # 收到切换模型的消息
|
cmd = queue.get() # 收到切换模型的消息
|
||||||
e = manager.Event()
|
e = manager.Event()
|
||||||
if isinstance(cmd, list):
|
if isinstance(cmd, list):
|
||||||
model_name, cmd, new_model_name = cmd
|
model_name, cmd, new_model_name = cmd
|
||||||
if cmd == "start": # 运行新模型
|
if cmd == "start": # 运行新模型
|
||||||
logger.info(f"准备启动新模型进程:{new_model_name}")
|
logger.info(f"准备启动新模型进程:{new_model_name}")
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_model_worker,
|
target=run_model_worker,
|
||||||
@ -831,7 +835,6 @@ async def start_main_server():
|
|||||||
else:
|
else:
|
||||||
logger.error(f"未找到模型进程:{model_name}")
|
logger.error(f"未找到模型进程:{model_name}")
|
||||||
|
|
||||||
|
|
||||||
# for process in processes.get("model_worker", {}).values():
|
# for process in processes.get("model_worker", {}).values():
|
||||||
# process.join()
|
# process.join()
|
||||||
# for process in processes.get("online_api", {}).values():
|
# for process in processes.get("online_api", {}).values():
|
||||||
@ -866,10 +869,9 @@ async def start_main_server():
|
|||||||
for p in processes.values():
|
for p in processes.values():
|
||||||
logger.info("Process status: %s", p)
|
logger.info("Process status: %s", p)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 确保数据库表被创建
|
|
||||||
create_tables()
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
create_tables()
|
||||||
if sys.version_info < (3, 10):
|
if sys.version_info < (3, 10):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
else:
|
else:
|
||||||
@ -879,16 +881,15 @@ if __name__ == "__main__":
|
|||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
# 同步调用协程代码
|
|
||||||
loop.run_until_complete(start_main_server())
|
|
||||||
|
|
||||||
|
loop.run_until_complete(start_main_server())
|
||||||
|
|
||||||
# 服务启动后接口调用示例:
|
# 服务启动后接口调用示例:
|
||||||
# import openai
|
# import openai
|
||||||
# openai.api_key = "EMPTY" # Not support yet
|
# openai.api_key = "EMPTY" # Not support yet
|
||||||
# openai.api_base = "http://localhost:8888/v1"
|
# openai.api_base = "http://localhost:8888/v1"
|
||||||
|
|
||||||
# model = "chatglm2-6b"
|
# model = "chatglm3-6b"
|
||||||
|
|
||||||
# # create a chat completion
|
# # create a chat completion
|
||||||
# completion = openai.ChatCompletion.create(
|
# completion = openai.ChatCompletion.create(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user