From c2b4dc805c71f05bd61e47be75a8aa44b76bd38f Mon Sep 17 00:00:00 2001 From: liam Date: Fri, 1 Nov 2024 11:01:30 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=91=EF=B8=8F:roll=20back=20transformer?= =?UTF-8?q?.py=20and=20find=20that=20it's=20multiple=20chat=20hsitory=20ha?= =?UTF-8?q?ve=20minor=20accurate=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../server/backend/interfaces/transformers.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index f3f0373..f205ac5 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -172,13 +172,23 @@ class TransformersInterface(BackendInterfaceBase): new_messages[-1]["content"] += m["content"] else: new_messages.append(m) - + # if (self.last_request_id is not None) and self.last_request_id == thread_id: + # input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device) + # else: + # input_ids = self.tokenizer.apply_chat_template( + # new_messages, return_tensors="pt", add_generation_prompt=True + # ).to(self.args.device) + input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) if (self.last_request_id is not None) and self.last_request_id == thread_id: - input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device) - else: - input_ids = self.tokenizer.apply_chat_template( - new_messages, return_tensors="pt", add_generation_prompt=True - ).to(self.args.device) + x = self.generated_ids[:,:self.seq_length] + y = input_ids[:,:self.seq_length] + # We can only hope that the input_ids are the same + unequal_mask = torch.ne(x,y) + unequal_positions = torch.nonzero(unequal_mask) + num_unequal_elements = unequal_mask.sum().item() + logger.warning(f'num_unequal_elements: {num_unequal_elements}') + + input_ids = input_ids[:,self.seq_length:] logger.debug(f"get input ids of shape {input_ids.shape}") return input_ids