diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 45aa436..95d8086 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -174,7 +174,19 @@ class StaticCache(transformers.StaticCache): self.key_cache[layer_idx].zero_() if self.value_cache[layer_idx] is not None: self.value_cache[layer_idx].zero_() + self.past_tokens[layer_idx] = 0 + + def remove_suffix(self, start_pos): + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + if self.is_MLA: + k_cache = self.key_cache[layer_idx] + k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_() + else: + self.key_cache[layer_idx][..., start_pos:, :].zero_() + self.value_cache[layer_idx][..., start_pos:, :].zero_() + self.past_tokens[layer_idx] = start_pos def get_max_cache_shape(self) -> Tuple[int, int, int, int]: """Returns the maximum shape of the cache.""" - return self.max_cache_len \ No newline at end of file + return self.max_cache_len diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 44fe7d2..a9df65b 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -90,7 +90,8 @@ class ArgumentParser: # user config parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) - parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think) + parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think) + parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph) # web config parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 4ceb65d..efc23b9 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -121,33 +121,53 @@ class KTransformersInterface(TransformersInterface): @torch.no_grad def prefill(self, input_ids: torch.Tensor, is_new: bool): input_ids_length = input_ids.shape[-1] - self.profiler.set_counter("prefill", input_ids_length) logger.debug(f"input_ids: {input_ids.shape}") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") if is_new: - self.cache.reset() self.ever_generated_ids.clear() - former_seq_length = 0 - self.seq_length = input_ids_length - self.generated_ids = torch.zeros( - self.args.batch_size, - self.seq_length + self.args.max_new_tokens + 1, - dtype=torch.int, - device=self.args.device, - ) - else: - logger.debug(f"generate_ids: {self.generated_ids.shape}") - former_seq_length = self.seq_length - self.seq_length += input_ids_length - expected_length = self.seq_length + self.args.max_new_tokens + 1 - delta_length = expected_length - self.generated_ids.shape[-1] - if delta_length > 0: - new_generate_ids = torch.zeros( - self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + same_prefix = 0 + flat_input_ids = input_ids.flatten() + + if getattr(self, 'generated_ids', None) is None: + self.generated_ids = torch.zeros( + self.args.batch_size, + input_ids.shape[-1] + self.args.max_new_tokens + 1, + dtype=torch.int, + device=self.args.device, ) - self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + self.seq_length = 1 + + flat_prev_ids = self.generated_ids.flatten() + for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1): + if flat_input_ids[i] == flat_prev_ids[i]: + same_prefix += 1 + else: + break + + logger.debug(f"same prefix len: {same_prefix}") + self.cache.remove_suffix(same_prefix) + self.seq_length = same_prefix + self.generated_ids = self.generated_ids[..., :same_prefix] + input_ids = input_ids[..., same_prefix:] + input_ids_length = input_ids.shape[-1] + + self.ever_generated_ids.clear() + self.profiler.set_counter("prefill", input_ids_length) + logger.debug(f"input_ids: {input_ids.shape}") + + logger.debug(f"generate_ids: {self.generated_ids.shape}") + former_seq_length = self.seq_length + self.seq_length += input_ids_length + expected_length = self.seq_length + self.args.max_new_tokens + 1 + delta_length = expected_length - self.generated_ids.shape[-1] + if delta_length > 0: + new_generate_ids = torch.zeros( + self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + ) + self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") cache_position = torch.arange(former_seq_length, self.seq_length, device=device) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) @@ -168,6 +188,7 @@ class KTransformersInterface(TransformersInterface): else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] + self.prepare_logits_wrapper(input_ids, device) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @@ -179,4 +200,4 @@ class KTransformersInterface(TransformersInterface): async def inference(self, local_messages, thread_id: str): async with self._infer_lock: async for v in super().inference(local_messages, thread_id): - yield v \ No newline at end of file + yield v diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index f18581a..33331d0 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -198,14 +198,28 @@ class TransformersInterface(BackendInterfaceBase): self.seq_length += 1 return self.streamer.put(new_tokens) - def logits_to_token(self, logits: torch.Tensor): - logits = logits / self.args.temperature if self.args.temperature!=0 else logits + def prepare_logits_wrapper(self, inputs, device): + generation_config, model_kwargs = self.model._prepare_generation_config( + None, max_length=self.args.max_new_tokens, + do_sample=True, + top_k=self.args.top_k, + top_p=self.args.top_p, + temperature=self.args.temperature, + repetition_penalty=self.args.repetition_penalty # change this to modify generate config + ) + self.inputs = inputs + self.generation_config = generation_config + try: # transformers==4.43 + self.logits_warper = ( + self.model._get_logits_warper(generation_config,device=device) + ) + except: + self.logits_warper = ( + self.model._get_logits_warper(generation_config) + ) - for token_idx in self.ever_generated_ids: - if logits[token_idx] < 0: - logits[token_idx] *= self.args.repetition_penalty - else: - logits[token_idx] /= self.args.repetition_penalty + def logits_to_token(self, logits: torch.Tensor): + logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1)) probs = torch.nn.functional.softmax(logits, dim=-1) @@ -239,31 +253,51 @@ class TransformersInterface(BackendInterfaceBase): @torch.no_grad def prefill(self, input_ids: torch.Tensor, is_new: bool): input_ids_length = input_ids.shape[-1] - self.profiler.set_counter("prefill", input_ids_length) logger.debug(f"input_ids: {input_ids.shape}") if is_new: - self.cache.reset() self.ever_generated_ids.clear() - former_seq_length = 0 - self.seq_length = input_ids_length - self.generated_ids = torch.zeros( - self.args.batch_size, - self.seq_length + self.args.max_new_tokens + 1, - dtype=torch.int, - device=self.args.device, - ) - else: - logger.debug(f"generate_ids: {self.generated_ids.shape}") - former_seq_length = self.seq_length - self.seq_length += input_ids_length - expected_length = self.seq_length + self.args.max_new_tokens + 1 - delta_length = expected_length - self.generated_ids.shape[-1] - if delta_length > 0: - new_generate_ids = torch.zeros( - self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + same_prefix = 0 + flat_input_ids = input_ids.flatten() + + if getattr(self, 'generated_ids', None) is None: + self.generated_ids = torch.zeros( + self.args.batch_size, + input_ids.shape[-1] + self.args.max_new_tokens + 1, + dtype=torch.int, + device=self.args.device, ) - self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + self.seq_length = 1 + + flat_prev_ids = self.generated_ids.flatten() + for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1): + if flat_input_ids[i] == flat_prev_ids[i]: + same_prefix += 1 + else: + break + + logger.debug(f"same prefix len: {same_prefix}") + self.cache.remove_suffix(same_prefix) + self.seq_length = same_prefix + self.generated_ids = self.generated_ids[..., :same_prefix] + input_ids = input_ids[..., same_prefix:] + input_ids_length = input_ids.shape[-1] + + self.ever_generated_ids.clear() + self.profiler.set_counter("prefill", input_ids_length) + logger.debug(f"input_ids: {input_ids.shape}") + + logger.debug(f"generate_ids: {self.generated_ids.shape}") + former_seq_length = self.seq_length + self.seq_length += input_ids_length + expected_length = self.seq_length + self.args.max_new_tokens + 1 + delta_length = expected_length - self.generated_ids.shape[-1] + if delta_length > 0: + new_generate_ids = torch.zeros( + self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device + ) + self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) + logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) @@ -285,6 +319,7 @@ class TransformersInterface(BackendInterfaceBase): else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] + self.prepare_logits_wrapper(input_ids, device) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @@ -315,6 +350,7 @@ class TransformersInterface(BackendInterfaceBase): return True async def inference(self, local_messages, thread_id: str): + self.streamer.reset() self.profiler.create_and_start_timer("tokenize") if isinstance(local_messages, List): input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) @@ -325,7 +361,7 @@ class TransformersInterface(BackendInterfaceBase): else: raise ValueError("local_messages should be List or str") if Config().user_force_think: - token_thinks = torch.tensor([self.tokenizer.encode("\\n",add_special_tokens=False)],device=input_ids.device) + token_thinks = torch.tensor([self.tokenizer.encode("\n",add_special_tokens=False)],device=input_ids.device) input_ids = torch.cat( [input_ids, token_thinks], dim=1 ) @@ -333,11 +369,14 @@ class TransformersInterface(BackendInterfaceBase): self.profiler.pause_timer("tokenize") self.profiler.create_and_start_timer("prefill") - if Config().user_force_think: - t = "\n" - print(t,end="",flush=True) - yield t + + for t in self.prefill(input_ids, self.check_is_new(thread_id)): + # output think token after prefill done + if Config().user_force_think: + think = '\n' + print(think, end="",flush=True) + yield think if t is not None: print(t, end="",flush=True) yield t diff --git a/ktransformers/server/main.py b/ktransformers/server/main.py index 5e01a48..fc1f51a 100644 --- a/ktransformers/server/main.py +++ b/ktransformers/server/main.py @@ -105,6 +105,10 @@ def custom_openapi(app): def main(): cfg = Config() + + # Temporarily disable cuda graph by default because of a bug in the prefix cache. + cfg.use_cuda_graph = False + arg_parser = ArgumentParser(cfg) # 初始化消息