diff --git a/server/llm_api.py b/server/llm_api.py index a9e5ab6c..b028747b 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,7 +1,7 @@ from fastapi import Body from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT -from server.utils import BaseResponse, fschat_controller_address, list_llm_models -import httpx +from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client + def list_running_models( @@ -13,8 +13,9 @@ def list_running_models( ''' try: controller_address = controller_address or fschat_controller_address() - r = httpx.post(controller_address + "/list_models") - return BaseResponse(data=r.json()["models"]) + with get_httpx_client() as client: + r = client.post(controller_address + "/list_models") + return BaseResponse(data=r.json()["models"]) except Exception as e: logger.error(f'{e.__class__.__name__}: {e}', exc_info=e if log_verbose else None) @@ -41,11 +42,12 @@ def stop_llm_model( ''' try: controller_address = controller_address or fschat_controller_address() - r = httpx.post( - controller_address + "/release_worker", - json={"model_name": model_name}, - ) - return r.json() + with get_httpx_client() as client: + r = client.post( + controller_address + "/release_worker", + json={"model_name": model_name}, + ) + return r.json() except Exception as e: logger.error(f'{e.__class__.__name__}: {e}', exc_info=e if log_verbose else None) @@ -64,12 +66,13 @@ def change_llm_model( ''' try: controller_address = controller_address or fschat_controller_address() - r = httpx.post( - controller_address + "/release_worker", - json={"model_name": model_name, "new_model_name": new_model_name}, - timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model - ) - return r.json() + with get_httpx_client() as client: + r = client.post( + controller_address + "/release_worker", + json={"model_name": model_name, "new_model_name": new_model_name}, + timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model + ) + return r.json() except Exception as e: logger.error(f'{e.__class__.__name__}: {e}', exc_info=e if log_verbose else None) diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 39ff293d..9079ea44 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -2,7 +2,7 @@ from server.model_workers.base import ApiModelWorker from fastchat import conversation as conv import sys import json -import httpx +from server.utils import get_httpx_client from pprint import pprint from typing import List, Dict @@ -63,22 +63,23 @@ class MiniMaxWorker(ApiModelWorker): } print("request data sent to minimax:") pprint(data) - response = httpx.stream("POST", - self.BASE_URL.format(pro=pro, group_id=group_id), - headers=headers, - json=data) - with response as r: - text = "" - for e in r.iter_text(): - if e.startswith("data: "): # 真是优秀的返回 - data = json.loads(e[6:]) - if not data.get("usage"): - if choices := data.get("choices"): - chunk = choices[0].get("delta", "").strip() - if chunk: - print(chunk) - text += chunk - yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" + with get_httpx_client() as client: + response = client.stream("POST", + self.BASE_URL.format(pro=pro, group_id=group_id), + headers=headers, + json=data) + with response as r: + text = "" + for e in r.iter_text(): + if e.startswith("data: "): # 真是优秀的返回 + data = json.loads(e[6:]) + if not data.get("usage"): + if choices := data.get("choices"): + chunk = choices[0].get("delta", "").strip() + if chunk: + print(chunk) + text += chunk + yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 387d4b7c..5eefd407 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -5,7 +5,7 @@ import sys import json import httpx from cachetools import cached, TTLCache -from server.utils import get_model_worker_config +from server.utils import get_model_worker_config, get_httpx_client from typing import List, Literal, Dict @@ -54,7 +54,8 @@ def get_baidu_access_token(api_key: str, secret_key: str) -> str: url = "https://aip.baidubce.com/oauth/2.0/token" params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} try: - return httpx.get(url, params=params).json().get("access_token") + with get_httpx_client() as client: + return client.get(url, params=params).json().get("access_token") except Exception as e: print(f"failed to get token from baidu: {e}") @@ -91,14 +92,15 @@ def request_qianfan_api( 'Accept': 'application/json', } - with httpx.stream("POST", url, headers=headers, json=payload) as response: - for line in response.iter_lines(): - if not line.strip(): - continue - if line.startswith("data: "): - line = line[6:] - resp = json.loads(line) - yield resp + with get_httpx_client() as client: + with client.stream("POST", url, headers=headers, json=payload) as response: + for line in response.iter_lines(): + if not line.strip(): + continue + if line.startswith("data: "): + line = line[6:] + resp = json.loads(line) + yield resp class QianFanWorker(ApiModelWorker): diff --git a/server/utils.py b/server/utils.py index 0efe1224..b6a3945b 100644 --- a/server/utils.py +++ b/server/utils.py @@ -7,11 +7,12 @@ import asyncio from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose, - FSCHAT_MODEL_WORKERS) + FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) import os from concurrent.futures import ThreadPoolExecutor, as_completed from langchain.chat_models import ChatOpenAI -from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable +import httpx +from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union thread_pool = ThreadPoolExecutor(os.cpu_count()) @@ -376,19 +377,63 @@ def get_prompt_template(name: str) -> Optional[str]: return prompt_config.PROMPT_TEMPLATES.get(name) -def set_httpx_timeout(timeout: float = None): +def set_httpx_config( + timeout: float = HTTPX_DEFAULT_TIMEOUT, + proxy: Union[str, Dict] = None, + ): ''' - 设置httpx默认timeout。 - httpx默认timeout是5秒,在请求LLM回答时不够用。 + 设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。 + 将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效) + 对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。 ''' import httpx - from configs.server_config import HTTPX_DEFAULT_TIMEOUT + import os - timeout = timeout or HTTPX_DEFAULT_TIMEOUT httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout + # 在进程范围内设置系统级代理 + proxies = {} + if isinstance(proxy, str): + for n in ["http", "https", "all"]: + proxies[n + "_proxy"] = proxy + elif isinstance(proxy, dict): + for n in ["http", "https", "all"]: + if p:= proxy.get(n): + proxies[n + "_proxy"] = p + elif p:= proxy.get(n + "_proxy"): + proxies[n + "_proxy"] = p + + for k, v in proxies.items(): + os.environ[k] = v + + # set host to bypass proxy + no_proxy = [x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()] + no_proxy += [ + # do not use proxy for locahost + "http://127.0.0.1", + "http://localhost", + ] + # do not use proxy for user deployed fastchat servers + for x in [ + fschat_controller_address(), + fschat_model_worker_address(), + fschat_openai_api_address(), + ]: + host = ":".join(x.split(":")[:2]) + if host not in no_proxy: + no_proxy.append(host) + os.environ["NO_PROXY"] = ",".join(no_proxy) + + # TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。 + # patch requests to use custom proxies instead of system settings + # def _get_proxies(): + # return {} + + # import urllib.request + # urllib.request.getproxies = _get_proxies + # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch def detect_device() -> Literal["cuda", "mps", "cpu"]: @@ -436,3 +481,51 @@ def run_in_thread_pool( for obj in as_completed(tasks): yield obj.result() + +def get_httpx_client( + use_async: bool = False, + proxies: Union[str, Dict] = None, + timeout: float = HTTPX_DEFAULT_TIMEOUT, + **kwargs, +) -> Union[httpx.Client, httpx.AsyncClient]: + ''' + helper to get httpx client with default proxies that bypass local addesses. + ''' + default_proxies = { + # do not use proxy for locahost + "all://127.0.0.1": None, + "all://localhost": None, + } + # do not use proxy for user deployed fastchat servers + for x in [ + fschat_controller_address(), + fschat_model_worker_address(), + fschat_openai_api_address(), + ]: + host = ":".join(x.split(":")[:2]) + default_proxies.update({host: None}) + + # get proxies from system envionrent + default_proxies.update({ + "http://": os.environ.get("http_proxy"), + "https://": os.environ.get("https_proxy"), + "all://": os.environ.get("all_proxy"), + }) + for host in os.environ.get("no_proxy", "").split(","): + if host := host.strip(): + default_proxies.update({host: None}) + + # merge default proxies with user provided proxies + if isinstance(proxies, str): + proxies = {"all://": proxies} + + if isinstance(proxies, dict): + default_proxies.update(proxies) + + # construct Client + kwargs.update(timeout=timeout, proxies=default_proxies) + if use_async: + return httpx.AsyncClient(**kwargs) + else: + return httpx.Client(**kwargs) + diff --git a/startup.py b/startup.py index 838e588d..4e4f6503 100644 --- a/startup.py +++ b/startup.py @@ -31,7 +31,7 @@ from configs import ( HTTPX_DEFAULT_TIMEOUT, ) from server.utils import (fschat_controller_address, fschat_model_worker_address, - fschat_openai_api_address, set_httpx_timeout, + fschat_openai_api_address, set_httpx_config, get_httpx_client, get_model_worker_config, get_all_model_worker_configs, MakeFastAPIOffline, FastAPI, llm_device, embedding_device) import argparse @@ -203,7 +203,6 @@ def create_openai_api_app( def _set_app_event(app: FastAPI, started_event: mp.Event = None): @app.on_event("startup") async def on_startup(): - set_httpx_timeout() if started_event is not None: started_event.set() @@ -214,6 +213,8 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None): from fastapi import Body import time import sys + from server.utils import set_httpx_config + set_httpx_config() app = create_controller_app( dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"), @@ -251,12 +252,13 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None): logger.error(msg) return {"code": 500, "msg": msg} - r = httpx.post(worker_address + "/release", - json={"new_model_name": new_model_name, "keep_origin": keep_origin}) - if r.status_code != 200: - msg = f"failed to release model: {model_name}" - logger.error(msg) - return {"code": 500, "msg": msg} + with get_httpx_client() as client: + r = client.post(worker_address + "/release", + json={"new_model_name": new_model_name, "keep_origin": keep_origin}) + if r.status_code != 200: + msg = f"failed to release model: {model_name}" + logger.error(msg) + return {"code": 500, "msg": msg} if new_model_name: timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register @@ -299,6 +301,8 @@ def run_model_worker( import uvicorn from fastapi import Body import sys + from server.utils import set_httpx_config + set_httpx_config() kwargs = get_model_worker_config(model_name) host = kwargs.pop("host") @@ -337,6 +341,8 @@ def run_model_worker( def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None): import uvicorn import sys + from server.utils import set_httpx_config + set_httpx_config() controller_addr = fschat_controller_address() app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. @@ -353,6 +359,8 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None): def run_api_server(started_event: mp.Event = None): from server.api import create_app import uvicorn + from server.utils import set_httpx_config + set_httpx_config() app = create_app() _set_app_event(app, started_event) @@ -364,6 +372,9 @@ def run_api_server(started_event: mp.Event = None): def run_webui(started_event: mp.Event = None): + from server.utils import set_httpx_config + set_httpx_config() + host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] diff --git a/webui_pages/utils.py b/webui_pages/utils.py index cca30fa2..a2113685 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -26,7 +26,7 @@ import contextlib import json import os from io import BytesIO -from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address +from server.utils import run_async, iter_over_async, set_httpx_config, api_address, get_httpx_client from configs.model_config import NLTK_DATA_PATH import nltk @@ -35,7 +35,7 @@ from pprint import pprint KB_ROOT_PATH = Path(KB_ROOT_PATH) -set_httpx_timeout() +set_httpx_config() class ApiRequest: @@ -53,6 +53,8 @@ class ApiRequest: self.base_url = base_url self.timeout = timeout self.no_remote_api = no_remote_api + self._client = get_httpx_client() + self._aclient = get_httpx_client(use_async=True) if no_remote_api: logger.warn("将来可能取消对no_remote_api的支持,更新版本时请注意。") @@ -79,9 +81,9 @@ class ApiRequest: while retry > 0: try: if stream: - return httpx.stream("GET", url, params=params, **kwargs) + return self._client.stream("GET", url, params=params, **kwargs) else: - return httpx.get(url, params=params, **kwargs) + return self._client.get(url, params=params, **kwargs) except Exception as e: msg = f"error when get {url}: {e}" logger.error(f'{e.__class__.__name__}: {msg}', @@ -98,18 +100,18 @@ class ApiRequest: ) -> Union[httpx.Response, None]: url = self._parse_url(url) kwargs.setdefault("timeout", self.timeout) - async with httpx.AsyncClient() as client: - while retry > 0: - try: - if stream: - return await client.stream("GET", url, params=params, **kwargs) - else: - return await client.get(url, params=params, **kwargs) - except Exception as e: - msg = f"error when aget {url}: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - retry -= 1 + + while retry > 0: + try: + if stream: + return await self._aclient.stream("GET", url, params=params, **kwargs) + else: + return await self._aclient.get(url, params=params, **kwargs) + except Exception as e: + msg = f"error when aget {url}: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + retry -= 1 def post( self, @@ -124,11 +126,10 @@ class ApiRequest: kwargs.setdefault("timeout", self.timeout) while retry > 0: try: - # return requests.post(url, data=data, json=json, stream=stream, **kwargs) if stream: - return httpx.stream("POST", url, data=data, json=json, **kwargs) + return self._client.stream("POST", url, data=data, json=json, **kwargs) else: - return httpx.post(url, data=data, json=json, **kwargs) + return self._client.post(url, data=data, json=json, **kwargs) except Exception as e: msg = f"error when post {url}: {e}" logger.error(f'{e.__class__.__name__}: {msg}', @@ -146,18 +147,18 @@ class ApiRequest: ) -> Union[httpx.Response, None]: url = self._parse_url(url) kwargs.setdefault("timeout", self.timeout) - async with httpx.AsyncClient() as client: - while retry > 0: - try: - if stream: - return await client.stream("POST", url, data=data, json=json, **kwargs) - else: - return await client.post(url, data=data, json=json, **kwargs) - except Exception as e: - msg = f"error when apost {url}: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - retry -= 1 + + while retry > 0: + try: + if stream: + return await self._client.stream("POST", url, data=data, json=json, **kwargs) + else: + return await self._client.post(url, data=data, json=json, **kwargs) + except Exception as e: + msg = f"error when apost {url}: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + retry -= 1 def delete( self, @@ -173,9 +174,9 @@ class ApiRequest: while retry > 0: try: if stream: - return httpx.stream("DELETE", url, data=data, json=json, **kwargs) + return self._client.stream("DELETE", url, data=data, json=json, **kwargs) else: - return httpx.delete(url, data=data, json=json, **kwargs) + return self._client.delete(url, data=data, json=json, **kwargs) except Exception as e: msg = f"error when delete {url}: {e}" logger.error(f'{e.__class__.__name__}: {msg}', @@ -193,18 +194,18 @@ class ApiRequest: ) -> Union[httpx.Response, None]: url = self._parse_url(url) kwargs.setdefault("timeout", self.timeout) - async with httpx.AsyncClient() as client: - while retry > 0: - try: - if stream: - return await client.stream("DELETE", url, data=data, json=json, **kwargs) - else: - return await client.delete(url, data=data, json=json, **kwargs) - except Exception as e: - msg = f"error when adelete {url}: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - retry -= 1 + + while retry > 0: + try: + if stream: + return await self._aclient.stream("DELETE", url, data=data, json=json, **kwargs) + else: + return await self._aclient.delete(url, data=data, json=json, **kwargs) + except Exception as e: + msg = f"error when adelete {url}: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + retry -= 1 def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): '''