From d9323ea3f6613cd07bd42cd929f4b8641ebe289e Mon Sep 17 00:00:00 2001 From: yhfgyyf <574821834@qq.com> Date: Wed, 8 May 2024 15:05:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dglm-4=20=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E6=B5=81=E5=BC=8F=E8=BE=93=E5=87=BA=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/model_workers/zhipu.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 939f411f..cb285a4a 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -11,7 +11,7 @@ import sys from typing import List, Dict, Iterator, Literal, Any import jwt import time - +import json @contextmanager def connect_sse(client: httpx.Client, method: str, url: str, **kwargs: Any): @@ -68,23 +68,26 @@ class ChatGLMWorker(ApiModelWorker): "messages": params.messages, "max_tokens": params.max_tokens, "temperature": params.temperature, - "stream": False + "stream": True } url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" - with httpx.Client(headers=headers) as client: - response = client.post(url, json=data) - response.raise_for_status() - chunk = response.json() - print(chunk) - yield {"error_code": 0, "text": chunk["choices"][0]["message"]["content"]} - - # with connect_sse(client, "POST", url, json=data) as event_source: - # for sse in event_source.iter_sse(): - # chunk = json.loads(sse.data) - # if len(chunk["choices"]) != 0: - # text += chunk["choices"][0]["delta"]["content"] - # yield {"error_code": 0, "text": text} + text = "" + response = requests.post(url, headers=headers, json=data, stream=True) + for chunk in response.iter_lines(): + if chunk: + if chunk.startswith(b'data:'): + json_str = chunk.decode('utf-8')[6:] + try: + data = json.loads(json_str) + if 'finish_reason' in data and data.get('finish_reason') =="stop": + break + else: + msg = data['choices'][0]['delta']['content'] + text += msg + yield {"error_code": 0, "text": text} + except json.JSONDecodeError as e: + pass def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: