diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 939f411f..cb285a4a 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -11,7 +11,7 @@ import sys from typing import List, Dict, Iterator, Literal, Any import jwt import time - +import json @contextmanager def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any): @@ -68,23 +68,26 @@ class ChatGLMWorker(ApiModelWorker): "messages": params.messages, "max_tokens": params.max_tokens, "temperature": params.temperature, - "stream": False + "stream": True } url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" - with httpx.Client(headers=headers) as client: - response = client.post(url, json=data) - response.raise_for_status() - chunk = response.json() - print(chunk) - yield {"error_code": 0, "text": chunk["choices"][0]["message"]["content"]} - - # with connect_sse(client, "POST", url, json=data) as event_source: - # for sse in event_source.iter_sse(): - # chunk = json.loads(sse.data) - # if len(chunk["choices"]) != 0: - # text += chunk["choices"][0]["delta"]["content"] - # yield {"error_code": 0, "text": text} + text = "" + response = requests.post(url, headers=headers, json=data, stream=True) + for chunk in response.iter_lines(): + if chunk: + if chunk.startswith(b'data:'): + json_str = chunk.decode('utf-8')[6:] + try: + data = json.loads(json_str) + if 'finish_reason' in data and data.get('finish_reason') =="stop": + break + else: + msg = data['choices'][0]['delta']['content'] + text += msg + yield {"error_code": 0, "text": text} + except json.JSONDecodeError as e: + pass def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: