diff --git a/install.sh b/install.sh index ffb7aca..c5773ec 100644 --- a/install.sh +++ b/install.sh @@ -2,6 +2,8 @@ set -e # clear build dirs +rm -rf build +rm -rf *.egg-info rm -rf ktransformers/ktransformers_ext/build rm -rf ktransformers/ktransformers_ext/cuda/build rm -rf ktransformers/ktransformers_ext/cuda/dist diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 7ecc637..edca541 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -104,7 +104,10 @@ class KTransformersInterface(TransformersInterface): torch.cuda.synchronize() logits = logits[0, -1, :] return self.logits_to_token(logits) - + + if self.args.use_cuda_graph: + warm_uped = True + if self.use_static_cache: mask = torch.ones((1, self.seq_length)).to(torch_device) logits = self.model( @@ -118,7 +121,6 @@ class KTransformersInterface(TransformersInterface): else: logits = self.model(self.current_ids, return_dict=False)[0] logits = logits[0, -1, :] - warm_uped = True return self.logits_to_token(logits) diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 5086a3b..deb6cfa 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -18,7 +18,7 @@ import sys, os from ..base import ThreadContext, BackendInterfaceBase from ktransformers.server.config.log import logger from ..args import ConfigArgs, default_args - +from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton # This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py class TextStreamer: @@ -330,8 +330,14 @@ class TransformersInterface(BackendInterfaceBase): @torch.no_grad def generate(self): self.profiler.set_counter("decode", 0) - for _ in range(1, self.args.max_new_tokens): + for i in range(1, self.args.max_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + if i > 1 and flashinfer_enabled: + MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, + num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, + head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, + sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) next_token = self.decode_one_tokens() self.profiler.inc("decode") if next_token == self.tokenizer.eos_token_id: