diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 113f194..434399f 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -172,7 +172,19 @@ class StaticCache(transformers.StaticCache): self.key_cache[layer_idx].zero_() if self.value_cache[layer_idx] is not None: self.value_cache[layer_idx].zero_() + self.past_tokens[layer_idx] = 0 + + def remove_suffix(self, start_pos): + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + if self.is_MLA: + k_cache = self.key_cache[layer_idx] + k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_() + else: + self.key_cache[layer_idx][..., start_pos:, :].zero_() + self.value_cache[layer_idx][..., start_pos:, :].zero_() + self.past_tokens[layer_idx] = start_pos def get_max_cache_shape(self) -> Tuple[int, int, int, int]: """Returns the maximum shape of the cache.""" - return self.max_cache_len \ No newline at end of file + return self.max_cache_len diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index e3f388a..7915654 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -129,8 +129,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # compressed_kv [pages, page_size, 1, self.kv_lora_rank] q_absorb, out_absorb = self.get_absorbed() - if hasattr(self.orig_module, 'kv_b_proj'): - del self.orig_module.kv_b_proj + # if hasattr(self.orig_module, 'kv_b_proj'): + # del self.orig_module.kv_b_proj # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] @@ -222,6 +222,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) + + kv_seq_len = q_len + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) @@ -293,26 +303,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models k_pe.squeeze(0) compressed_kv.squeeze(0) - past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) - k_pe.unsqueeze(0) - compressed_kv.unsqueeze(0) - - k_pe = k_pe[:, :q_len] - compressed_kv = compressed_kv[:, :q_len] + compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv, k_pe = torch.split( + compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim) + k_pe = k_pe[:, :kv_seq_len] + compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank) + compressed_kv = compressed_kv[:, :kv_seq_len] kv = ( self.kv_b_proj(compressed_kv) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) + key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) - value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim) + value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) attn_output = flash_attn_func( @@ -362,6 +374,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) + + kv_seq_len = q_len + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) @@ -441,26 +463,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models k_pe.squeeze(0) compressed_kv.squeeze(0) - past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) - k_pe.unsqueeze(0) - compressed_kv.unsqueeze(0) - - k_pe = k_pe[:, :q_len] - compressed_kv = compressed_kv[:, :q_len] + compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv, k_pe = torch.split( + compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim) + k_pe = k_pe[:, :kv_seq_len] + compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank) + compressed_kv = compressed_kv[:, :kv_seq_len] kv = ( self.kv_b_proj(compressed_kv) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) + key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) - value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim) + value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) attn_output = flash_attn_func( diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index 4da3bc9..f84538a 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -5,18 +5,15 @@ from fastapi import APIRouter from fastapi.requests import Request from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import chat_stream_response -from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject +from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage from ktransformers.server.backend.base import BackendInterfaceBase +from ktransformers.server.config.config import Config router = APIRouter() -models = [ - {"id": "0", "name": "ktranformers-model"}, -] - @router.get('/models', tags=['openai']) async def list_models(): - return models + return [{"id": Config().model_name, "name": Config().model_name}] @router.post('/chat/completions', tags=['openai']) @@ -36,7 +33,8 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): yield chunk return chat_stream_response(request,inner()) else: - comp = ChatCompletionObject(id=id,object='chat.completion.chunk',created=int(time())) + comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) + comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) async for token in interface.inference(input_message,id): comp.append_token(token) return comp diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 44fe7d2..a9df65b 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -90,7 +90,8 @@ class ArgumentParser: # user config parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) - parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think) + parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think) + parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph) # web config parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 34f18b9..edca541 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -15,7 +15,9 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.util.utils import get_device + warm_uped = False + class KTransformersThreadContext(TransformersThreadContext): pass @@ -74,13 +76,13 @@ class KTransformersInterface(TransformersInterface): self._infer_lock = asyncio.Lock() def decode_one_tokens(self): + global warm_uped + device_map = self.model.gguf_loader.tensor_device_map torch_device = get_device("blk.0.self_attn", device_map) torch_device = "cuda:0" if torch_device == "cuda" else torch_device - global warm_uped torch.cuda.set_device(torch_device) - if self.args.use_cuda_graph and warm_uped == True: - + if warm_uped and self.args.use_cuda_graph: if not hasattr(self, "cuda_graph_runner"): self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner.capture( @@ -127,34 +129,54 @@ class KTransformersInterface(TransformersInterface): @torch.no_grad def prefill(self, input_ids: torch.Tensor, is_new: bool): input_ids_length = input_ids.shape[-1] - self.profiler.set_counter("prefill", input_ids_length) logger.debug(f"input_ids: {input_ids.shape}") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = "cuda:0" if device == "cuda" else device if is_new: - self.cache.reset() self.ever_generated_ids.clear() - former_seq_length = 0 - self.seq_length = input_ids_length - self.generated_ids = torch.zeros( - self.args.batch_size, - self.seq_length + self.args.max_new_tokens + 1, - dtype=torch.int, - device=self.args.device, - ) - else: - logger.debug(f"generate_ids: {self.generated_ids.shape}") - former_seq_length = self.seq_length - self.seq_length += input_ids_length - expected_length = self.seq_length + self.args.max_new_tokens + 1 - delta_length = expected_length - self.generated_ids.shape[-1] - if delta_length > 0: - new_generate_ids = torch.zeros( - self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + same_prefix = 0 + flat_input_ids = input_ids.flatten() + + if getattr(self, 'generated_ids', None) is None: + self.generated_ids = torch.zeros( + self.args.batch_size, + input_ids.shape[-1] + self.args.max_new_tokens + 1, + dtype=torch.int, + device=self.args.device, ) - self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + self.seq_length = 1 + + flat_prev_ids = self.generated_ids.flatten() + for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1): + if flat_input_ids[i] == flat_prev_ids[i]: + same_prefix += 1 + else: + break + + logger.debug(f"same prefix len: {same_prefix}") + self.cache.remove_suffix(same_prefix) + self.seq_length = same_prefix + self.generated_ids = self.generated_ids[..., :same_prefix] + input_ids = input_ids[..., same_prefix:] + input_ids_length = input_ids.shape[-1] + + self.ever_generated_ids.clear() + self.profiler.set_counter("prefill", input_ids_length) + logger.debug(f"input_ids: {input_ids.shape}") + + logger.debug(f"generate_ids: {self.generated_ids.shape}") + former_seq_length = self.seq_length + self.seq_length += input_ids_length + expected_length = self.seq_length + self.args.max_new_tokens + 1 + delta_length = expected_length - self.generated_ids.shape[-1] + if delta_length > 0: + new_generate_ids = torch.zeros( + self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + ) + self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") cache_position = torch.arange(former_seq_length, self.seq_length, device=device) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) @@ -176,6 +198,7 @@ class KTransformersInterface(TransformersInterface): else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] + self.prepare_logits_wrapper(input_ids, device) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @@ -187,4 +210,4 @@ class KTransformersInterface(TransformersInterface): async def inference(self, local_messages, thread_id: str): async with self._infer_lock: async for v in super().inference(local_messages, thread_id): - yield v \ No newline at end of file + yield v diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 4021670..deb6cfa 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -170,7 +170,7 @@ class TransformersInterface(BackendInterfaceBase): for m in messages[1:]: if m["role"] == "user" and new_messages[-1]["role"] == "user": logger.warning("merge two adjacent user messages") - new_messages[-1]["content"] += m["content"] + new_messages[-1]["content"] += '\n' + m["content"] else: new_messages.append(m) # if (self.last_request_id is not None) and self.last_request_id == thread_id: @@ -179,7 +179,11 @@ class TransformersInterface(BackendInterfaceBase): # input_ids = self.tokenizer.apply_chat_template( # new_messages, return_tensors="pt", add_generation_prompt=True # ).to(self.args.device) - input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device) + input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True) + # drop token in chat template + if input_str.endswith('\n'): + input_str = input_str[:-len('\n')] + input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device) if (self.last_request_id is not None) and self.last_request_id == thread_id: x = self.generated_ids[:,:self.seq_length] y = input_ids[:,:self.seq_length] @@ -198,14 +202,28 @@ class TransformersInterface(BackendInterfaceBase): self.seq_length += 1 return self.streamer.put(new_tokens) - def logits_to_token(self, logits: torch.Tensor): - logits = logits / self.args.temperature if self.args.temperature!=0 else logits + def prepare_logits_wrapper(self, inputs, device): + generation_config, model_kwargs = self.model._prepare_generation_config( + None, max_length=self.args.max_new_tokens, + do_sample=True, + top_k=self.args.top_k, + top_p=self.args.top_p, + temperature=self.args.temperature, + repetition_penalty=self.args.repetition_penalty # change this to modify generate config + ) + self.inputs = inputs + self.generation_config = generation_config + try: # transformers==4.43 + self.logits_warper = ( + self.model._get_logits_warper(generation_config,device=device) + ) + except: + self.logits_warper = ( + self.model._get_logits_warper(generation_config) + ) - for token_idx in self.ever_generated_ids: - if logits[token_idx] < 0: - logits[token_idx] *= self.args.repetition_penalty - else: - logits[token_idx] /= self.args.repetition_penalty + def logits_to_token(self, logits: torch.Tensor): + logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1)) probs = torch.nn.functional.softmax(logits, dim=-1) @@ -239,31 +257,51 @@ class TransformersInterface(BackendInterfaceBase): @torch.no_grad def prefill(self, input_ids: torch.Tensor, is_new: bool): input_ids_length = input_ids.shape[-1] - self.profiler.set_counter("prefill", input_ids_length) logger.debug(f"input_ids: {input_ids.shape}") if is_new: - self.cache.reset() self.ever_generated_ids.clear() - former_seq_length = 0 - self.seq_length = input_ids_length - self.generated_ids = torch.zeros( - self.args.batch_size, - self.seq_length + self.args.max_new_tokens + 1, - dtype=torch.int, - device=self.args.device, - ) - else: - logger.debug(f"generate_ids: {self.generated_ids.shape}") - former_seq_length = self.seq_length - self.seq_length += input_ids_length - expected_length = self.seq_length + self.args.max_new_tokens + 1 - delta_length = expected_length - self.generated_ids.shape[-1] - if delta_length > 0: - new_generate_ids = torch.zeros( - self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + same_prefix = 0 + flat_input_ids = input_ids.flatten() + + if getattr(self, 'generated_ids', None) is None: + self.generated_ids = torch.zeros( + self.args.batch_size, + input_ids.shape[-1] + self.args.max_new_tokens + 1, + dtype=torch.int, + device=self.args.device, ) - self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + self.seq_length = 1 + + flat_prev_ids = self.generated_ids.flatten() + for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1): + if flat_input_ids[i] == flat_prev_ids[i]: + same_prefix += 1 + else: + break + + logger.debug(f"same prefix len: {same_prefix}") + self.cache.remove_suffix(same_prefix) + self.seq_length = same_prefix + self.generated_ids = self.generated_ids[..., :same_prefix] + input_ids = input_ids[..., same_prefix:] + input_ids_length = input_ids.shape[-1] + + self.ever_generated_ids.clear() + self.profiler.set_counter("prefill", input_ids_length) + logger.debug(f"input_ids: {input_ids.shape}") + + logger.debug(f"generate_ids: {self.generated_ids.shape}") + former_seq_length = self.seq_length + self.seq_length += input_ids_length + expected_length = self.seq_length + self.args.max_new_tokens + 1 + delta_length = expected_length - self.generated_ids.shape[-1] + if delta_length > 0: + new_generate_ids = torch.zeros( + self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + ) + self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) @@ -285,6 +323,7 @@ class TransformersInterface(BackendInterfaceBase): else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] + self.prepare_logits_wrapper(input_ids, device) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @@ -321,6 +360,7 @@ class TransformersInterface(BackendInterfaceBase): return True async def inference(self, local_messages, thread_id: str): + self.streamer.reset() self.profiler.create_and_start_timer("tokenize") if isinstance(local_messages, List): input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) @@ -330,8 +370,9 @@ class TransformersInterface(BackendInterfaceBase): #input_ids = torch.tensor([[6366]], device=input_ids.device) else: raise ValueError("local_messages should be List or str") + if Config().user_force_think: - token_thinks = torch.tensor([self.tokenizer.encode("\\n",add_special_tokens=False)],device=input_ids.device) + token_thinks = torch.tensor([self.tokenizer.encode("\n",add_special_tokens=False)],device=input_ids.device) input_ids = torch.cat( [input_ids, token_thinks], dim=1 ) @@ -339,11 +380,14 @@ class TransformersInterface(BackendInterfaceBase): self.profiler.pause_timer("tokenize") self.profiler.create_and_start_timer("prefill") - if Config().user_force_think: - t = "\n" - print(t,end="",flush=True) - yield t + + for t in self.prefill(input_ids, self.check_is_new(thread_id)): + # output think token after prefill done + if Config().user_force_think: + think = '\n' + print(think, end="",flush=True) + yield think if t is not None: print(t, end="",flush=True) yield t diff --git a/ktransformers/server/main.py b/ktransformers/server/main.py index 5e01a48..f536f9c 100644 --- a/ktransformers/server/main.py +++ b/ktransformers/server/main.py @@ -105,6 +105,7 @@ def custom_openapi(app): def main(): cfg = Config() + arg_parser = ArgumentParser(cfg) # 初始化消息