diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 3fcbf6f..1057e82 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -67,6 +67,7 @@ def local_chat( tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if mode == 'long_context': + assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" torch.set_default_dtype(torch.float16) else: torch.set_default_dtype(config.torch_dtype) @@ -143,8 +144,9 @@ def local_chat( input_tensor = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) - assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ - "please change max_seq_len in ~/.ktransformers/config.yaml" + if mode == 'long_context': + assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ + "please change max_seq_len in ~/.ktransformers/config.yaml" torch.set_default_dtype( torch.bfloat16 ) # TODO: Remove this, replace dtype using config diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index d84b063..44f0037 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022 Date : 2024-07-25 11:25:24 Version : 0.1.0 LastEditors : Azure -LastEditTime : 2024-08-27 03:50:23 +LastEditTime : 2024-08-29 09:41:10 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' @@ -202,7 +202,7 @@ class KExpertsCPU(KExpertsBase): def forward(self, input_tensor, expert_ids, weights): # generate, capture and run cuda graph # print(expert_ids) - if input_tensor.size(0)==1: + if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing(): # TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible #print("capturing experts") KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) @@ -636,7 +636,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"): + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 146fb85..7cdb204 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang Date : 2024-07-25 11:25:24 Version : 0.1.0 LastEditors : Azure -LastEditTime : 2024-08-14 14:57:04 +LastEditTime : 2024-08-29 09:11:16 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' @@ -277,7 +277,7 @@ class KLinearCPUInfer(KLinearBase): def forward(self, x: torch.Tensor) -> torch.Tensor: origin_shape = x.shape # [batch_size, q_len, hidden_size] - if origin_shape[1] == 1: + if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing(): out_device = x.device self.input_tensor_cpu.copy_(x, non_blocking=True) qlen = origin_shape[1] diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index d6cdc47..f6e85c0 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -670,11 +670,12 @@ class KDeepseekV2Model(BaseInjectedModule): if self.transfer_map is not None and i in self.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.transfer_map[i] - if cur_device not in self.stream_device_map: + if cur_device not in self.stream_device_map and cur_device.lower() != "cpu": self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) - torch.cuda.set_device(cur_device) - self.stream_device_map[cur_device].wait_stream(prev_stream) - torch.cuda.set_stream(self.stream_device_map[cur_device]) + if cur_device.lower() != "cpu": + torch.cuda.set_device(cur_device) + self.stream_device_map[cur_device].wait_stream(prev_stream) + torch.cuda.set_stream(self.stream_device_map[cur_device]) hidden_states = hidden_states.to( self.transfer_map[i], non_blocking=True )