diff --git a/server/chat/utils.py b/server/chat/utils.py index 0dd17d39..60ea2d1c 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -4,7 +4,7 @@ from langchain.prompts.chat import ChatMessagePromptTemplate from configs import logger, log_verbose from server.utils import get_model_worker_config, fschat_openai_api_address from langchain.chat_models import ChatOpenAI -from typing import Awaitable, List, Tuple, Dict, Union, Callable +from typing import Awaitable, List, Tuple, Dict, Union, Callable, Any def get_ChatOpenAI( @@ -12,6 +12,7 @@ def get_ChatOpenAI( temperature: float, streaming: bool = True, callbacks: List[Callable] = [], + **kwargs: Any, ) -> ChatOpenAI: config = get_model_worker_config(model_name) model = ChatOpenAI( @@ -22,7 +23,8 @@ def get_ChatOpenAI( openai_api_base=config.get("api_base_url", fschat_openai_api_address()), model_name=model_name, temperature=temperature, - openai_proxy=config.get("openai_proxy") + openai_proxy=config.get("openai_proxy"), + **kwargs ) return model