mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-02-04 21:43:13 +08:00
Fix according to upstream changes
This commit is contained in:
parent
26f7b4af11
commit
b121ca4df8
@ -201,10 +201,9 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
else:
|
else:
|
||||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||||
|
|
||||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
|
||||||
if flashinfer_enabled:
|
if flashinfer_enabled:
|
||||||
MLAWrapperSingleton.reset_buffer()
|
MLAWrapperSingleton.reset_buffer()
|
||||||
self.prepare_logits_wrapper(input_ids, device)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user