Merge pull request #3966 from yhfgyyf/glm-4-stream

修复glm-4 无法流式输出的bug
This commit is contained in:
zR 2024-06-08 16:34:58 +08:00 committed by GitHub
commit 3ecff4bce7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,7 +11,7 @@ import sys
from typing import List, Dict, Iterator, Literal, Any
import jwt
import time
import json
@contextmanager
def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any):
@ -68,23 +68,26 @@ class ChatGLMWorker(ApiModelWorker):
"messages": params.messages,
"max_tokens": params.max_tokens,
"temperature": params.temperature,
"stream": False
"stream": True
}
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
with httpx.Client(headers=headers) as client:
response = client.post(url, json=data)
response.raise_for_status()
chunk = response.json()
print(chunk)
yield {"error_code": 0, "text": chunk["choices"][0]["message"]["content"]}
# with connect_sse(client, "POST", url, json=data) as event_source:
# for sse in event_source.iter_sse():
# chunk = json.loads(sse.data)
# if len(chunk["choices"]) != 0:
# text += chunk["choices"][0]["delta"]["content"]
# yield {"error_code": 0, "text": text}
text = ""
response = requests.post(url, headers=headers, json=data, stream=True)
for chunk in response.iter_lines():
if chunk:
if chunk.startswith(b'data:'):
json_str = chunk.decode('utf-8')[6:]
try:
data = json.loads(json_str)
if 'finish_reason' in data and data.get('finish_reason') =="stop":
break
else:
msg = data['choices'][0]['delta']['content']
text += msg
yield {"error_code": 0, "text": text}
except json.JSONDecodeError as e:
pass
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: