diff --git a/requirements.txt b/requirements.txt index 871699c2..9a895b7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ torchaudio>=2.1.2 # Langchain 0.1.x requirements langchain>=0.1.0 langchain_openai>=0.0.2 -langchain-community>=1.0.0 +langchain-community>=0.0.11 langchainhub>=0.1.14 pydantic==1.10.13 diff --git a/server/api.py b/server/api.py index 8ecf9503..d83be2f5 100644 --- a/server/api.py +++ b/server/api.py @@ -17,9 +17,7 @@ from server.chat.chat import chat from server.chat.completion import completion from server.chat.feedback import chat_feedback from server.embeddings_api import embed_texts_endpoint -from server.llm_api import (list_running_models, list_config_models, - change_llm_model, stop_llm_model, - get_model_config) + from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs, get_prompt_template) from typing import List, Literal @@ -73,32 +71,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): # 摘要相关接口 mount_filename_summary_routes(app) - # LLM模型相关接口 - app.post("/llm_model/list_running_models", - tags=["LLM Model Management"], - summary="列出当前已加载的模型", - )(list_running_models) - - app.post("/llm_model/list_config_models", - tags=["LLM Model Management"], - summary="列出configs已配置的模型", - )(list_config_models) - - app.post("/llm_model/get_model_config", - tags=["LLM Model Management"], - summary="获取模型配置(合并后)", - )(get_model_config) - - app.post("/llm_model/stop", - tags=["LLM Model Management"], - summary="停止指定的LLM模型(Model Worker)", - )(stop_llm_model) - - app.post("/llm_model/change", - tags=["LLM Model Management"], - summary="切换指定的LLM模型(Model Worker)", - )(change_llm_model) - # 服务器相关接口 app.post("/server/configs", tags=["Server State"], diff --git a/server/chat/chat.py b/server/chat/chat.py index 670cd770..44d42ebd 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -21,7 +21,7 @@ from server.db.repository import add_message_to_db from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus -def create_models_from_config(configs, callbacks, stream): +def create_models_from_config(configs, openai_config, callbacks, stream): if configs is None: configs = {} models = {} @@ -30,6 +30,9 @@ def create_models_from_config(configs, callbacks, stream): for model_name, params in model_configs.items(): callbacks = callbacks if params.get('callbacks', False) else None model_instance = get_ChatOpenAI( + endpoint_host=openai_config.get('endpoint_host', None), + endpoint_host_key=openai_config.get('endpoint_host_key', None), + endpoint_host_proxy=openai_config.get('endpoint_host_proxy', None), model_name=model_name, temperature=params.get('temperature', 0.5), max_tokens=params.get('max_tokens', 1000), @@ -113,6 +116,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 ), stream: bool = Body(True, description="流式输出"), model_config: Dict = Body({}, description="LLM 模型配置"), + openai_config: Dict = Body({}, description="openaiEndpoint配置"), tool_config: Dict = Body({}, description="工具配置"), ): async def chat_iterator() -> AsyncIterable[str]: @@ -124,7 +128,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callback = AgentExecutorAsyncIteratorCallbackHandler() callbacks = [callback] - models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, stream=stream) + models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, + openai_config=openai_config, stream=stream) tools = [tool for tool in all_tools if tool.name in tool_config] tools = [t.copy(update={"callbacks": callbacks}) for t in tools] full_chain = create_models_chains(prompts=prompts, diff --git a/server/chat/completion.py b/server/chat/completion.py index bddf07e3..31eade96 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -14,6 +14,9 @@ from server.utils import get_prompt_template async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), stream: bool = Body(False, description="流式输出"), echo: bool = Body(False, description="除了输出之外,还回显输入"), + endpoint_host: str = Body(False, description="接入点地址"), + endpoint_host_key: str = Body(False, description="接入点key"), + endpoint_host_proxy: str = Body(False, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -24,6 +27,9 @@ async def completion(query: str = Body(..., description="用户输入", examples #TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 async def completion_iterator(query: str, + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str, model_name: str = None, prompt_name: str = prompt_name, echo: bool = echo, @@ -34,6 +40,9 @@ async def completion(query: str = Body(..., description="用户输入", examples max_tokens = None model = get_OpenAI( + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, @@ -63,7 +72,10 @@ async def completion(query: str = Body(..., description="用户输入", examples await task - return EventSourceResponse(completion_iterator(query=query, + return StreamingResponse(completion_iterator(query=query, + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, prompt_name=prompt_name), ) diff --git a/server/llm_api.py b/server/llm_api.py deleted file mode 100644 index 21410fc7..00000000 --- a/server/llm_api.py +++ /dev/null @@ -1,124 +0,0 @@ -from fastapi import Body -from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, LLM_MODEL_CONFIG -from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, - get_httpx_client, get_model_worker_config) -from typing import List - - -def list_running_models( - controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]), - placeholder: str = Body(None, description="该参数未使用,占位用"), -) -> BaseResponse: - ''' - 从fastchat controller获取已加载模型列表及其配置项 - ''' - try: - controller_address = controller_address or fschat_controller_address() - with get_httpx_client() as client: - r = client.post(controller_address + "/list_models") - models = r.json()["models"] - data = {m: get_model_config(m).data for m in models} - - ## 只有LLM模型才返回 - result = {} - for model, config in data.items(): - if model in LLM_MODEL_CONFIG['llm_model']: - result[model] = config - - return BaseResponse(data=result) - except Exception as e: - logger.error(f'{e.__class__.__name__}: {e}', - exc_info=e if log_verbose else None) - return BaseResponse( - code=500, - data={}, - msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") - - -def list_config_models( - types: List[str] = Body(["local", "online"], description="模型配置项类别,如local, online, worker"), - placeholder: str = Body(None, description="占位用,无实际效果") -) -> BaseResponse: - ''' - 从本地获取configs中配置的模型列表 - ''' - data = {} - result = {} - - for type, models in list_config_llm_models().items(): - if type in types: - data[type] = {m: get_model_config(m).data for m in models} - - for model, config in data.items(): - if model in LLM_MODEL_CONFIG['llm_model']: - result[type][model] = config - - return BaseResponse(data=result) - - -def get_model_config( - model_name: str = Body(description="配置中LLM模型的名称"), - placeholder: str = Body(None, description="占位用,无实际效果") -) -> BaseResponse: - ''' - 获取LLM模型配置项(合并后的) - ''' - config = {} - # 删除ONLINE_MODEL配置中的敏感信息 - for k, v in get_model_worker_config(model_name=model_name).items(): - if not (k == "worker_class" - or "key" in k.lower() - or "secret" in k.lower() - or k.lower().endswith("id")): - config[k] = v - - return BaseResponse(data=config) - - -def stop_llm_model( - model_name: str = Body(..., description="要停止的LLM模型名称", examples=[]), - controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) -) -> BaseResponse: - ''' - 向fastchat controller请求停止某个LLM模型。 - 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 - ''' - try: - controller_address = controller_address or fschat_controller_address() - with get_httpx_client() as client: - r = client.post( - controller_address + "/release_worker", - json={"model_name": model_name}, - ) - return r.json() - except Exception as e: - logger.error(f'{e.__class__.__name__}: {e}', - exc_info=e if log_verbose else None) - return BaseResponse( - code=500, - msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}") - - -def change_llm_model( - model_name: str = Body(..., description="当前运行模型"), - new_model_name: str = Body(..., description="要切换的新模型"), - controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) -): - ''' - 向fastchat controller请求切换LLM模型。 - ''' - try: - controller_address = controller_address or fschat_controller_address() - with get_httpx_client() as client: - r = client.post( - controller_address + "/release_worker", - json={"model_name": model_name, "new_model_name": new_model_name}, - timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model - ) - return r.json() - except Exception as e: - logger.error(f'{e.__class__.__name__}: {e}', - exc_info=e if log_verbose else None) - return BaseResponse( - code=500, - msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") diff --git a/server/utils.py b/server/utils.py index bba6e5eb..a89471f7 100644 --- a/server/utils.py +++ b/server/utils.py @@ -45,6 +45,9 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): def get_ChatOpenAI( + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str, model_name: str, temperature: float, max_tokens: int = None, @@ -61,18 +64,21 @@ def get_ChatOpenAI( streaming=streaming, verbose=verbose, callbacks=callbacks, - openai_api_key=config.get("api_key", "EMPTY"), - openai_api_base=config.get("api_base_url", fschat_openai_api_address()), + openai_api_key=endpoint_host_key if endpoint_host_key else "None", + openai_api_base=endpoint_host if endpoint_host else "None", model_name=model_name, temperature=temperature, max_tokens=max_tokens, - openai_proxy=config.get("openai_proxy"), + openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None, **kwargs ) return model def get_OpenAI( + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str, model_name: str, temperature: float, max_tokens: int = None, @@ -82,19 +88,18 @@ def get_OpenAI( verbose: bool = True, **kwargs: Any, ) -> OpenAI: - config = get_model_worker_config(model_name) - if model_name == "openai-api": - model_name = config.get("model_name") + + # TODO: 从API获取模型信息 model = OpenAI( streaming=streaming, verbose=verbose, callbacks=callbacks, - openai_api_key=config.get("api_key", "EMPTY"), - openai_api_base=config.get("api_base_url", fschat_openai_api_address()), + openai_api_key=endpoint_host_key if endpoint_host_key else "None", + openai_api_base=endpoint_host if endpoint_host else "None", model_name=model_name, temperature=temperature, max_tokens=max_tokens, - openai_proxy=config.get("openai_proxy"), + openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None, echo=echo, **kwargs ) @@ -365,69 +370,14 @@ def get_model_worker_config(model_name: str = None) -> dict: ''' from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH from configs.server_config import FSCHAT_MODEL_WORKERS - from server import model_workers config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy()) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy()) - if model_name in ONLINE_LLM_MODEL: - config["online_api"] = True - if provider := config.get("provider"): - try: - config["worker_class"] = getattr(model_workers, provider) - except Exception as e: - msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - # 本地模型 - if model_name in MODEL_PATH["llm_model"]: - path = get_model_path(model_name) - config["model_path"] = path - if path and os.path.isdir(path): - config["model_path_exists"] = True - config["device"] = llm_device(config.get("device")) + return config -def get_all_model_worker_configs() -> dict: - result = {} - model_names = set(FSCHAT_MODEL_WORKERS.keys()) - for name in model_names: - if name != "default": - result[name] = get_model_worker_config(name) - return result - - -def fschat_controller_address() -> str: - from configs.server_config import FSCHAT_CONTROLLER - - host = FSCHAT_CONTROLLER["host"] - if host == "0.0.0.0": - host = "127.0.0.1" - port = FSCHAT_CONTROLLER["port"] - return f"http://{host}:{port}" - - -def fschat_model_worker_address(model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model']))) -> str: - if model := get_model_worker_config(model_name): # TODO: depends fastchat - host = model["host"] - if host == "0.0.0.0": - host = "127.0.0.1" - port = model["port"] - return f"http://{host}:{port}" - return "" - - -def fschat_openai_api_address() -> str: - from configs.server_config import FSCHAT_OPENAI_API - - host = FSCHAT_OPENAI_API["host"] - if host == "0.0.0.0": - host = "127.0.0.1" - port = FSCHAT_OPENAI_API["port"] - return f"http://{host}:{port}/v1" - - def api_address() -> str: from configs.server_config import API_SERVER @@ -461,6 +411,7 @@ def get_prompt_template(type: str, name: str) -> Optional[str]: def set_httpx_config( timeout: float = HTTPX_DEFAULT_TIMEOUT, proxy: Union[str, Dict] = None, + unused_proxies: List[str] = [], ): ''' 设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。 @@ -498,11 +449,7 @@ def set_httpx_config( "http://localhost", ] # do not use proxy for user deployed fastchat servers - for x in [ - fschat_controller_address(), - fschat_model_worker_address(), - fschat_openai_api_address(), - ]: + for x in unused_proxies: host = ":".join(x.split(":")[:2]) if host not in no_proxy: no_proxy.append(host) @@ -568,6 +515,7 @@ def get_httpx_client( use_async: bool = False, proxies: Union[str, Dict] = None, timeout: float = HTTPX_DEFAULT_TIMEOUT, + unused_proxies: List[str] = [], **kwargs, ) -> Union[httpx.Client, httpx.AsyncClient]: ''' @@ -579,11 +527,7 @@ def get_httpx_client( "all://localhost": None, } # do not use proxy for user deployed fastchat servers - for x in [ - fschat_controller_address(), - fschat_model_worker_address(), - fschat_openai_api_address(), - ]: + for x in unused_proxies: host = ":".join(x.split(":")[:2]) default_proxies.update({host: None}) @@ -629,8 +573,6 @@ def get_server_configs() -> Dict: 获取configs中的原始配置项,供前端使用 ''' _custom = { - "controller_address": fschat_controller_address(), - "openai_api_address": fschat_openai_api_address(), "api_address": api_address(), } @@ -638,14 +580,8 @@ def get_server_configs() -> Dict: def list_online_embed_models() -> List[str]: - from server import model_workers - ret = [] - for k, v in list_config_llm_models()["online"].items(): - if provider := v.get("provider"): - worker_class = getattr(model_workers, provider, None) - if worker_class is not None and worker_class.can_embedding(): - ret.append(k) + # TODO: 从在线API获取支持的模型列表 return ret diff --git a/startup.py b/startup.py index 380160df..a69a5473 100644 --- a/startup.py +++ b/startup.py @@ -25,17 +25,11 @@ from configs import ( LLM_MODEL_CONFIG, EMBEDDING_MODEL, TEXT_SPLITTER_NAME, - FSCHAT_CONTROLLER, - FSCHAT_OPENAI_API, - FSCHAT_MODEL_WORKERS, API_SERVER, WEBUI_SERVER, HTTPX_DEFAULT_TIMEOUT, ) -from server.utils import (fschat_controller_address, fschat_model_worker_address, - fschat_openai_api_address, get_httpx_client, - get_model_worker_config, - MakeFastAPIOffline, FastAPI, llm_device, embedding_device) +from server.utils import (FastAPI, embedding_device) from server.knowledge_base.migrate import create_tables import argparse from typing import List, Dict @@ -49,234 +43,6 @@ for model_category in LLM_MODEL_CONFIG.values(): all_model_names_list = list(all_model_names) -def create_controller_app( - dispatch_method: str, - log_level: str = "INFO", -) -> FastAPI: - import fastchat.constants - fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.controller import app, Controller, logger - logger.setLevel(log_level) - - controller = Controller(dispatch_method) - sys.modules["fastchat.serve.controller"].controller = controller - - MakeFastAPIOffline(app) - app.title = "FastChat Controller" - app._controller = controller - return app - - -def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: - """ - kwargs包含的字段如下: - host: - port: - model_names:[`model_name`] - controller_address: - worker_address: - - 对于Langchain支持的模型: - langchain_model:True - 不会使用fschat - 对于online_api: - online_api:True - worker_class: `provider` - 对于离线模型: - model_path: `model_name_or_path`,huggingface的repo-id或本地路径 - device:`LLM_DEVICE` - """ - import fastchat.constants - fastchat.constants.LOGDIR = LOG_PATH - import argparse - - parser = argparse.ArgumentParser() - args = parser.parse_args([]) - - for k, v in kwargs.items(): - setattr(args, k, v) - if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作 - from fastchat.serve.base_model_worker import app - worker = "" - # 在线模型API - elif worker_class := kwargs.get("worker_class"): - from fastchat.serve.base_model_worker import app - - worker = worker_class(model_names=args.model_names, - controller_addr=args.controller_address, - worker_addr=args.worker_address) - # sys.modules["fastchat.serve.base_model_worker"].worker = worker - sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level) - - # 本地模型 - else: - from configs.model_config import VLLM_MODEL_DICT - if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm": - import fastchat.serve.vllm_worker - from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id - from vllm import AsyncLLMEngine - from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs - - args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 - args.tokenizer_mode = 'auto' - args.trust_remote_code = True - args.download_dir = None - args.load_format = 'auto' - args.dtype = 'auto' - args.seed = 0 - args.worker_use_ray = False - args.pipeline_parallel_size = 1 - args.tensor_parallel_size = 1 - args.block_size = 16 - args.swap_space = 4 # GiB - args.gpu_memory_utilization = 0.90 - args.max_num_batched_tokens = None # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够 - args.max_num_seqs = 256 - args.disable_log_stats = False - args.conv_template = None - args.limit_worker_concurrency = 5 - args.no_register = False - args.num_gpus = 4 # vllm worker的切分是tensor并行,这里填写显卡的数量 - args.engine_use_ray = False - args.disable_log_requests = False - - # 0.2.2 vllm后要加的参数, 但是这里不需要 - args.max_model_len = None - args.revision = None - args.quantization = None - args.max_log_len = None - args.tokenizer_revision = None - - # 0.2.2 vllm需要新加的参数 - args.max_paddings = 256 - - if args.model_path: - args.model = args.model_path - if args.num_gpus > 1: - args.tensor_parallel_size = args.num_gpus - - for k, v in kwargs.items(): - setattr(args, k, v) - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args(engine_args) - - worker = VLLMWorker( - controller_addr=args.controller_address, - worker_addr=args.worker_address, - worker_id=worker_id, - model_path=args.model_path, - model_names=args.model_names, - limit_worker_concurrency=args.limit_worker_concurrency, - no_register=args.no_register, - llm_engine=engine, - conv_template=args.conv_template, - ) - sys.modules["fastchat.serve.vllm_worker"].engine = engine - sys.modules["fastchat.serve.vllm_worker"].worker = worker - sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level) - - else: - from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id - - args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3" - args.max_gpu_memory = "22GiB" - args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量 - - args.load_8bit = False - args.cpu_offloading = None - args.gptq_ckpt = None - args.gptq_wbits = 16 - args.gptq_groupsize = -1 - args.gptq_act_order = False - args.awq_ckpt = None - args.awq_wbits = 16 - args.awq_groupsize = -1 - args.model_names = [""] - args.conv_template = None - args.limit_worker_concurrency = 5 - args.stream_interval = 2 - args.no_register = False - args.embed_in_truncate = False - for k, v in kwargs.items(): - setattr(args, k, v) - if args.gpus: - if args.num_gpus is None: - args.num_gpus = len(args.gpus.split(',')) - if len(args.gpus.split(",")) < args.num_gpus: - raise ValueError( - f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" - ) - os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus - gptq_config = GptqConfig( - ckpt=args.gptq_ckpt or args.model_path, - wbits=args.gptq_wbits, - groupsize=args.gptq_groupsize, - act_order=args.gptq_act_order, - ) - awq_config = AWQConfig( - ckpt=args.awq_ckpt or args.model_path, - wbits=args.awq_wbits, - groupsize=args.awq_groupsize, - ) - worker = ModelWorker( - controller_addr=args.controller_address, - worker_addr=args.worker_address, - worker_id=worker_id, - model_path=args.model_path, - model_names=args.model_names, - limit_worker_concurrency=args.limit_worker_concurrency, - no_register=args.no_register, - device=args.device, - num_gpus=args.num_gpus, - max_gpu_memory=args.max_gpu_memory, - load_8bit=args.load_8bit, - cpu_offloading=args.cpu_offloading, - gptq_config=gptq_config, - awq_config=awq_config, - stream_interval=args.stream_interval, - conv_template=args.conv_template, - embed_in_truncate=args.embed_in_truncate, - ) - sys.modules["fastchat.serve.model_worker"].args = args - sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config - # sys.modules["fastchat.serve.model_worker"].worker = worker - sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level) - - MakeFastAPIOffline(app) - app.title = f"FastChat LLM Server ({args.model_names[0]})" - app._worker = worker - return app - - -def create_openai_api_app( - controller_address: str, - api_keys: List = [], - log_level: str = "INFO", -) -> FastAPI: - import fastchat.constants - fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings - from fastchat.utils import build_logger - logger = build_logger("openai_api", "openai_api.log") - logger.setLevel(log_level) - - app.add_middleware( - CORSMiddleware, - allow_credentials=True, - allow_origins=["*"], - allow_methods=["*"], - allow_headers=["*"], - ) - - sys.modules["fastchat.serve.openai_api_server"].logger = logger - app_settings.controller_address = controller_address - app_settings.api_keys = api_keys - - MakeFastAPIOffline(app) - app.title = "FastChat OpeanAI API Server" - return app - def _set_app_event(app: FastAPI, started_event: mp.Event = None): @app.on_event("startup") @@ -285,154 +51,6 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None): started_event.set() -def run_controller(log_level: str = "INFO", started_event: mp.Event = None): - import uvicorn - import httpx - from fastapi import Body - import time - import sys - from server.utils import set_httpx_config - set_httpx_config() - - app = create_controller_app( - dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"), - log_level=log_level, - ) - _set_app_event(app, started_event) - - # add interface to release and load model worker - @app.post("/release_worker") - def release_worker( - model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]), - # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]), - new_model_name: str = Body(None, description="释放后加载该模型"), - keep_origin: bool = Body(False, description="不释放原模型,加载新模型") - ) -> Dict: - available_models = app._controller.list_models() - if new_model_name in available_models: - msg = f"要切换的LLM模型 {new_model_name} 已经存在" - logger.info(msg) - return {"code": 500, "msg": msg} - - if new_model_name: - logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}") - else: - logger.info(f"即将停止LLM模型: {model_name}") - - if model_name not in available_models: - msg = f"the model {model_name} is not available" - logger.error(msg) - return {"code": 500, "msg": msg} - - worker_address = app._controller.get_worker_address(model_name) - if not worker_address: - msg = f"can not find model_worker address for {model_name}" - logger.error(msg) - return {"code": 500, "msg": msg} - - with get_httpx_client() as client: - r = client.post(worker_address + "/release", - json={"new_model_name": new_model_name, "keep_origin": keep_origin}) - if r.status_code != 200: - msg = f"failed to release model: {model_name}" - logger.error(msg) - return {"code": 500, "msg": msg} - - if new_model_name: - timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register - while timer > 0: - models = app._controller.list_models() - if new_model_name in models: - break - time.sleep(1) - timer -= 1 - if timer > 0: - msg = f"sucess change model from {model_name} to {new_model_name}" - logger.info(msg) - return {"code": 200, "msg": msg} - else: - msg = f"failed change model from {model_name} to {new_model_name}" - logger.error(msg) - return {"code": 500, "msg": msg} - else: - msg = f"sucess to release model: {model_name}" - logger.info(msg) - return {"code": 200, "msg": msg} - - host = FSCHAT_CONTROLLER["host"] - port = FSCHAT_CONTROLLER["port"] - - if log_level == "ERROR": - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) - - -def run_model_worker( - model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model'])), - controller_address: str = "", - log_level: str = "INFO", - q: mp.Queue = None, - started_event: mp.Event = None, -): - import uvicorn - from fastapi import Body - import sys - from server.utils import set_httpx_config - set_httpx_config() - - kwargs = get_model_worker_config(model_name) - host = kwargs.pop("host") - port = kwargs.pop("port") - kwargs["model_names"] = [model_name] - kwargs["controller_address"] = controller_address or fschat_controller_address() - kwargs["worker_address"] = fschat_model_worker_address(model_name) - model_path = kwargs.get("model_path", "") - kwargs["model_path"] = model_path - app = create_model_worker_app(log_level=log_level, **kwargs) - _set_app_event(app, started_event) - if log_level == "ERROR": - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - # add interface to release and load model - @app.post("/release") - def release_model( - new_model_name: str = Body(None, description="释放后加载该模型"), - keep_origin: bool = Body(False, description="不释放原模型,加载新模型") - ) -> Dict: - if keep_origin: - if new_model_name: - q.put([model_name, "start", new_model_name]) - else: - if new_model_name: - q.put([model_name, "replace", new_model_name]) - else: - q.put([model_name, "stop", None]) - return {"code": 200, "msg": "done"} - - uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) - - -def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None): - import uvicorn - import sys - from server.utils import set_httpx_config - set_httpx_config() - - controller_addr = fschat_controller_address() - app = create_openai_api_app(controller_addr, log_level=log_level) - _set_app_event(app, started_event) - - host = FSCHAT_OPENAI_API["host"] - port = FSCHAT_OPENAI_API["port"] - if log_level == "ERROR": - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - uvicorn.run(app, host=host, port=port) - - def run_api_server(started_event: mp.Event = None, run_mode: str = None): from server.api import create_app import uvicorn @@ -473,6 +91,18 @@ def run_webui(started_event: mp.Event = None, run_mode: str = None): p.wait() +def run_loom(started_event: mp.Event = None): + from configs import LOOM_CONFIG + + cmd = ["python", "-m", "loom_core.openai_plugins.deploy.local", + "-f", LOOM_CONFIG + ] + + p = subprocess.Popen(cmd) + started_event.set() + p.wait() + + def parse_args() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument( @@ -488,57 +118,14 @@ def parse_args() -> argparse.ArgumentParser: help="run fastchat's controller/openai_api/model_worker servers, run api.py", dest="all_api", ) - parser.add_argument( - "--llm-api", - action="store_true", - help="run fastchat's controller/openai_api/model_worker servers", - dest="llm_api", - ) - parser.add_argument( - "-o", - "--openai-api", - action="store_true", - help="run fastchat's controller/openai_api servers", - dest="openai_api", - ) - parser.add_argument( - "-m", - "--model-worker", - action="store_true", - help="run fastchat's model_worker server with specified model name. " - "specify --model-name if not using default llm models", - dest="model_worker", - ) - parser.add_argument( - "-n", - "--model-name", - type=str, - nargs="+", - default=all_model_names_list, - help="specify model name for model worker. " - "add addition names with space seperated to start multiple model workers.", - dest="model_name", - ) - parser.add_argument( - "-c", - "--controller", - type=str, - help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER", - dest="controller_address", - ) + parser.add_argument( "--api", action="store_true", help="run api.py server", dest="api", ) - parser.add_argument( - "-p", - "--api-worker", - action="store_true", - help="run online model api such as zhipuai", - dest="api_worker", - ) + parser.add_argument( "-w", "--webui", @@ -546,13 +133,6 @@ def parse_args() -> argparse.ArgumentParser: help="run webui.py server", dest="webui", ) - parser.add_argument( - "-q", - "--quiet", - action="store_true", - help="减少fastchat服务log信息", - dest="quiet", - ) parser.add_argument( "-i", "--lite", @@ -578,24 +158,13 @@ def dump_server_info(after_start=False, args=None): print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - models = list(LLM_MODEL_CONFIG['llm_model'].keys()) - if args and args.model_name: - models = args.model_name - print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}") - print(f"当前启动的LLM模型:{models} @ {llm_device()}") - for model in models: - pprint(get_model_worker_config(model)) print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") if after_start: print("\n") print(f"服务端运行信息:") - if args.openai_api: - print(f" OpenAI API Server: {fschat_openai_api_address()}") - if args.api: - print(f" Chatchat API Server: {api_address()}") if args.webui: print(f" Chatchat WEBUI Server: {webui_address()}") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) @@ -625,7 +194,6 @@ async def start_main_server(): manager = mp.Manager() run_mode = None - queue = manager.Queue() args, parser = parse_args() if args.all_webui: @@ -662,69 +230,16 @@ async def start_main_server(): processes = {"online_api": {}, "model_worker": {}} def process_count(): - return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2 - - if args.quiet or not log_verbose: - log_level = "ERROR" - else: - log_level = "INFO" - - controller_started = manager.Event() - if args.openai_api: - process = Process( - target=run_controller, - name=f"controller", - kwargs=dict(log_level=log_level, started_event=controller_started), - daemon=True, - ) - processes["controller"] = process - - process = Process( - target=run_openai_api, - name=f"openai_api", - daemon=True, - ) - processes["openai_api"] = process - - model_worker_started = [] - if args.model_worker: - for model_name in args.model_name: - config = get_model_worker_config(model_name) - if not config.get("online_api"): - e = manager.Event() - model_worker_started.append(e) - process = Process( - target=run_model_worker, - name=f"model_worker - {model_name}", - kwargs=dict(model_name=model_name, - controller_address=args.controller_address, - log_level=log_level, - q=queue, - started_event=e), - daemon=True, - ) - processes["model_worker"][model_name] = process - - if args.api_worker: - for model_name in args.model_name: - config = get_model_worker_config(model_name) - if (config.get("online_api") - and config.get("worker_class") - and model_name in FSCHAT_MODEL_WORKERS): - e = manager.Event() - model_worker_started.append(e) - process = Process( - target=run_model_worker, - name=f"api_worker - {model_name}", - kwargs=dict(model_name=model_name, - controller_address=args.controller_address, - log_level=log_level, - q=queue, - started_event=e), - daemon=True, - ) - processes["online_api"][model_name] = process + return len(processes) + loom_started = manager.Event() + process = Process( + target=run_loom, + name=f"run_loom Server", + kwargs=dict(started_event=loom_started), + daemon=True, + ) + processes["run_loom"] = process api_started = manager.Event() if args.api: process = Process( @@ -750,25 +265,10 @@ async def start_main_server(): else: try: # 保证任务收到SIGINT后,能够正常退出 - if p := processes.get("controller"): + if p := processes.get("run_loom"): p.start() p.name = f"{p.name} ({p.pid})" - controller_started.wait() # 等待controller启动完成 - - if p := processes.get("openai_api"): - p.start() - p.name = f"{p.name} ({p.pid})" - - for n, p in processes.get("model_worker", {}).items(): - p.start() - p.name = f"{p.name} ({p.pid})" - - for n, p in processes.get("online_api", []).items(): - p.start() - p.name = f"{p.name} ({p.pid})" - - for e in model_worker_started: - e.wait() + loom_started.wait() # 等待Loom启动完成 if p := processes.get("api"): p.start() @@ -782,74 +282,10 @@ async def start_main_server(): dump_server_info(after_start=True, args=args) - while True: - cmd = queue.get() # 收到切换模型的消息 - e = manager.Event() - if isinstance(cmd, list): - model_name, cmd, new_model_name = cmd - if cmd == "start": # 运行新模型 - logger.info(f"准备启动新模型进程:{new_model_name}") - process = Process( - target=run_model_worker, - name=f"model_worker - {new_model_name}", - kwargs=dict(model_name=new_model_name, - controller_address=args.controller_address, - log_level=log_level, - q=queue, - started_event=e), - daemon=True, - ) - process.start() - process.name = f"{process.name} ({process.pid})" - processes["model_worker"][new_model_name] = process - e.wait() - logger.info(f"成功启动新模型进程:{new_model_name}") - elif cmd == "stop": - if process := processes["model_worker"].get(model_name): - time.sleep(1) - process.terminate() - process.join() - logger.info(f"停止模型进程:{model_name}") - else: - logger.error(f"未找到模型进程:{model_name}") - elif cmd == "replace": - if process := processes["model_worker"].pop(model_name, None): - logger.info(f"停止模型进程:{model_name}") - start_time = datetime.now() - time.sleep(1) - process.terminate() - process.join() - process = Process( - target=run_model_worker, - name=f"model_worker - {new_model_name}", - kwargs=dict(model_name=new_model_name, - controller_address=args.controller_address, - log_level=log_level, - q=queue, - started_event=e), - daemon=True, - ) - process.start() - process.name = f"{process.name} ({process.pid})" - processes["model_worker"][new_model_name] = process - e.wait() - timing = datetime.now() - start_time - logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。") - else: - logger.error(f"未找到模型进程:{model_name}") + # 等待所有进程退出 + if p := processes.get("webui"): - # for process in processes.get("model_worker", {}).values(): - # process.join() - # for process in processes.get("online_api", {}).values(): - # process.join() - - # for name, process in processes.items(): - # if name not in ["model_worker", "online_api"]: - # if isinstance(p, dict): - # for work_process in p.values(): - # work_process.join() - # else: - # process.join() + p.join() except Exception as e: logger.error(e) logger.warning("Caught KeyboardInterrupt! Setting stop event...") diff --git a/tests/api/test_server_state_api.py b/tests/api/test_server_state_api.py index 2edfb496..851b1aee 100644 --- a/tests/api/test_server_state_api.py +++ b/tests/api/test_server_state_api.py @@ -13,21 +13,6 @@ from typing import List api = ApiRequest() -def test_get_default_llm(): - llm = api.get_default_llm_model() - - print(llm) - assert isinstance(llm, tuple) - assert isinstance(llm[0], str) and isinstance(llm[1], bool) - - -def test_server_configs(): - configs = api.get_server_configs() - pprint(configs, depth=2) - - assert isinstance(configs, dict) - assert len(configs) > 0 - @pytest.mark.parametrize("type", ["llm_chat"]) def test_get_prompt_template(type): diff --git a/tests/test_online_api.py b/tests/test_online_api.py index 372fad4c..e514d472 100644 --- a/tests/test_online_api.py +++ b/tests/test_online_api.py @@ -8,13 +8,13 @@ from server.model_workers.base import * from server.utils import get_model_worker_config, list_config_llm_models from pprint import pprint import pytest - - -workers = [] -for x in list_config_llm_models()["online"]: - if x in ONLINE_LLM_MODEL and x not in workers: - workers.append(x) -print(f"all workers to test: {workers}") +# +# +# workers = [] +# for x in list_config_llm_models()["online"]: +# if x in ONLINE_LLM_MODEL and x not in workers: +# workers.append(x) +# print(f"all workers to test: {workers}") # workers = ["fangzhou-api"] @@ -56,4 +56,4 @@ def test_embeddings(worker): assert isinstance(embeddings, list) and len(embeddings) > 0 assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 assert isinstance(embeddings[0][0], float) - print("向量长度:", len(embeddings[0])) \ No newline at end of file + print("向量长度:", len(embeddings[0])) diff --git a/webui.py b/webui.py index b0c30973..8c0387b3 100644 --- a/webui.py +++ b/webui.py @@ -1,4 +1,6 @@ import streamlit as st + +from webui_pages.openai_plugins import openai_plugins_page from webui_pages.utils import * from streamlit_option_menu import option_menu from webui_pages.dialogue.dialogue import dialogue_page, chat_box @@ -22,9 +24,26 @@ if __name__ == "__main__": 'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat', 'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues", 'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!""" - } + }, + layout="wide" + ) + # use the following code to set the app to wide mode and the html markdown to increase the sidebar width + st.markdown( + """ +