mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-02-04 21:43:13 +08:00
Left out
This commit is contained in:
parent
91062a834f
commit
07eb712a73
@ -127,7 +127,7 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
logger.debug(f"input_ids: {input_ids.shape}")
|
logger.debug(f"input_ids: {input_ids.shape}")
|
||||||
|
|
||||||
@ -198,7 +198,7 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
else:
|
else:
|
||||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||||
|
|
||||||
self.prepare_logits_wrapper(input_ids, device)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user