mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-02-03 13:03:35 +08:00
🚑️:roll back transformer.py and find that it's multiple chat hsitory have minor accurate error
This commit is contained in:
parent
a148da2cfe
commit
c2b4dc805c
@ -172,13 +172,23 @@ class TransformersInterface(BackendInterfaceBase):
|
|||||||
new_messages[-1]["content"] += m["content"]
|
new_messages[-1]["content"] += m["content"]
|
||||||
else:
|
else:
|
||||||
new_messages.append(m)
|
new_messages.append(m)
|
||||||
|
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||||
|
# input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
|
||||||
|
# else:
|
||||||
|
# input_ids = self.tokenizer.apply_chat_template(
|
||||||
|
# new_messages, return_tensors="pt", add_generation_prompt=True
|
||||||
|
# ).to(self.args.device)
|
||||||
|
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
|
||||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||||
input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device)
|
x = self.generated_ids[:,:self.seq_length]
|
||||||
else:
|
y = input_ids[:,:self.seq_length]
|
||||||
input_ids = self.tokenizer.apply_chat_template(
|
# We can only hope that the input_ids are the same
|
||||||
new_messages, return_tensors="pt", add_generation_prompt=True
|
unequal_mask = torch.ne(x,y)
|
||||||
).to(self.args.device)
|
unequal_positions = torch.nonzero(unequal_mask)
|
||||||
|
num_unequal_elements = unequal_mask.sum().item()
|
||||||
|
logger.warning(f'num_unequal_elements: {num_unequal_elements}')
|
||||||
|
|
||||||
|
input_ids = input_ids[:,self.seq_length:]
|
||||||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user