diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 8ae38db6..536c5cdb 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -44,6 +44,9 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = CustomAsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/chat.py b/server/chat/chat.py index acf3ec0c..47ec871c 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -45,7 +45,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 chat_type="llm_chat", query=query) callbacks.append(conversation_callback) - + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None model = get_ChatOpenAI( model_name=model_name, diff --git a/server/chat/completion.py b/server/chat/completion.py index ee5e2d12..beda0261 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -28,6 +28,9 @@ async def completion(query: str = Body(..., description="用户输入", examples echo: bool = echo, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_OpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index ea3475a0..a4db9174 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -114,6 +114,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= async def knowledge_base_chat_iterator() -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 0ea99a6e..a99a045a 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -49,6 +49,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 8325b4d9..98b26c6d 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -148,6 +148,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature,