From 7c94df4bcf55b302f4db075529a6d5d7ecd8ce52 Mon Sep 17 00:00:00 2001 From: liam Date: Mon, 28 Oct 2024 21:09:40 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=91=EF=B8=8F:=20back=20transformer.py?= =?UTF-8?q?=20bugs=20version,=20and=20fix=20typo=20error=20in=20local=5Fch?= =?UTF-8?q?at.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ktransformers/local_chat.py | 2 +- .../server/backend/interfaces/transformers.py | 24 ++++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 80ada29..3e0d2ef 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -91,7 +91,7 @@ def local_chat(): generated = asyncio.run(async_inference(messages)) his_content += [ {"role": "user", "content": content}, - {"role": "assitant", "content": generated}, + {"role": "assistant", "content": generated}, ] diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 7f569c4..cddc198 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -164,7 +164,6 @@ class TransformersInterface(BackendInterfaceBase): if m["role"] == "system": logger.warning(f'change {m["role"]} to user') m["role"] = "user" - new_messages = [messages[0]] for m in messages[1:]: if m["role"] == "user" and new_messages[-1]["role"] == "user": @@ -173,12 +172,25 @@ class TransformersInterface(BackendInterfaceBase): else: new_messages.append(m) + # if (self.last_request_id is not None) and self.last_request_id == thread_id: + # logger.debug(f"last message: {new_messages[-1]}") + # input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",add_generation_prompt=False).to(self.args.device) + # else: + # input_ids = self.tokenizer.apply_chat_template( + # new_messages, return_tensors="pt", add_generation_prompt=True + # ).to(self.args.device) + + input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) if (self.last_request_id is not None) and self.last_request_id == thread_id: - input_ids = self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt").to(self.args.device) - else: - input_ids = self.tokenizer.apply_chat_template( - new_messages, return_tensors="pt", add_generation_prompt=True - ).to(self.args.device) + x = self.generated_ids[:,:self.seq_length] + y = input_ids[:,:self.seq_length] + # We can only hope that the input_ids are the same + unequal_mask = torch.ne(x,y) + unequal_positions = torch.nonzero(unequal_mask) + num_unequal_elements = unequal_mask.sum().item() + logger.warning(f'num_unequal_elements: {num_unequal_elements}') + + input_ids = input_ids[:,self.seq_length:] logger.debug(f"get input ids of shape {input_ids.shape}") return input_ids