Merge pull request #3966 from yhfgyyf/glm-4-stream

修复glm-4 无法流式输出的bug
This commit is contained in:
zR 2024-06-08 16:34:58 +08:00 committed by GitHub
commit 3ecff4bce7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,7 +11,7 @@ import sys
from typing import List, Dict, Iterator, Literal, Any from typing import List, Dict, Iterator, Literal, Any
import jwt import jwt
import time import time
import json
@contextmanager @contextmanager
def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any): def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any):
@ -68,23 +68,26 @@ class ChatGLMWorker(ApiModelWorker):
"messages": params.messages, "messages": params.messages,
"max_tokens": params.max_tokens, "max_tokens": params.max_tokens,
"temperature": params.temperature, "temperature": params.temperature,
"stream": False "stream": True
} }
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
with httpx.Client(headers=headers) as client: text = ""
response = client.post(url, json=data) response = requests.post(url, headers=headers, json=data, stream=True)
response.raise_for_status() for chunk in response.iter_lines():
chunk = response.json() if chunk:
print(chunk) if chunk.startswith(b'data:'):
yield {"error_code": 0, "text": chunk["choices"][0]["message"]["content"]} json_str = chunk.decode('utf-8')[6:]
try:
# with connect_sse(client, "POST", url, json=data) as event_source: data = json.loads(json_str)
# for sse in event_source.iter_sse(): if 'finish_reason' in data and data.get('finish_reason') =="stop":
# chunk = json.loads(sse.data) break
# if len(chunk["choices"]) != 0: else:
# text += chunk["choices"][0]["delta"]["content"] msg = data['choices'][0]['delta']['content']
# yield {"error_code": 0, "text": text} text += msg
yield {"error_code": 0, "text": text}
except json.JSONDecodeError as e:
pass
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: