mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-01-19 21:03:18 +08:00
Merge pull request #382 from ceerRep/server-prefix-cache
fix server and add prefix cache for server
This commit is contained in:
commit
cf4da5fd47
@ -172,7 +172,19 @@ class StaticCache(transformers.StaticCache):
|
||||
self.key_cache[layer_idx].zero_()
|
||||
if self.value_cache[layer_idx] is not None:
|
||||
self.value_cache[layer_idx].zero_()
|
||||
self.past_tokens[layer_idx] = 0
|
||||
|
||||
def remove_suffix(self, start_pos):
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
# In-place ops prevent breaking the static address
|
||||
if self.is_MLA:
|
||||
k_cache = self.key_cache[layer_idx]
|
||||
k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_()
|
||||
else:
|
||||
self.key_cache[layer_idx][..., start_pos:, :].zero_()
|
||||
self.value_cache[layer_idx][..., start_pos:, :].zero_()
|
||||
self.past_tokens[layer_idx] = start_pos
|
||||
|
||||
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
|
||||
"""Returns the maximum shape of the cache."""
|
||||
return self.max_cache_len
|
||||
return self.max_cache_len
|
||||
|
||||
@ -129,8 +129,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
if hasattr(self.orig_module, 'kv_b_proj'):
|
||||
del self.orig_module.kv_b_proj
|
||||
# if hasattr(self.orig_module, 'kv_b_proj'):
|
||||
# del self.orig_module.kv_b_proj
|
||||
|
||||
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
|
||||
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
|
||||
@ -222,6 +222,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
|
||||
|
||||
kv_seq_len = q_len
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
|
||||
@ -293,26 +303,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
k_pe.squeeze(0)
|
||||
compressed_kv.squeeze(0)
|
||||
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
k_pe.unsqueeze(0)
|
||||
compressed_kv.unsqueeze(0)
|
||||
|
||||
k_pe = k_pe[:, :q_len]
|
||||
compressed_kv = compressed_kv[:, :q_len]
|
||||
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)
|
||||
k_pe = k_pe[:, :kv_seq_len]
|
||||
compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)
|
||||
compressed_kv = compressed_kv[:, :kv_seq_len]
|
||||
kv = (
|
||||
self.kv_b_proj(compressed_kv)
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
.view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
)
|
||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
|
||||
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
|
||||
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
|
||||
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
@ -362,6 +374,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
|
||||
|
||||
kv_seq_len = q_len
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
|
||||
@ -441,26 +463,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
k_pe.squeeze(0)
|
||||
compressed_kv.squeeze(0)
|
||||
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
k_pe.unsqueeze(0)
|
||||
compressed_kv.unsqueeze(0)
|
||||
|
||||
k_pe = k_pe[:, :q_len]
|
||||
compressed_kv = compressed_kv[:, :q_len]
|
||||
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)
|
||||
k_pe = k_pe[:, :kv_seq_len]
|
||||
compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)
|
||||
compressed_kv = compressed_kv[:, :kv_seq_len]
|
||||
kv = (
|
||||
self.kv_b_proj(compressed_kv)
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
.view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
)
|
||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
|
||||
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
|
||||
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
|
||||
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
|
||||
@ -5,18 +5,15 @@ from fastapi import APIRouter
|
||||
from fastapi.requests import Request
|
||||
from ktransformers.server.utils.create_interface import get_interface
|
||||
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage
|
||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
models = [
|
||||
{"id": "0", "name": "ktranformers-model"},
|
||||
]
|
||||
|
||||
@router.get('/models', tags=['openai'])
|
||||
async def list_models():
|
||||
return models
|
||||
return [{"id": Config().model_name, "name": Config().model_name}]
|
||||
|
||||
|
||||
@router.post('/chat/completions', tags=['openai'])
|
||||
@ -36,7 +33,8 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
|
||||
yield chunk
|
||||
return chat_stream_response(request,inner())
|
||||
else:
|
||||
comp = ChatCompletionObject(id=id,object='chat.completion.chunk',created=int(time()))
|
||||
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
|
||||
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
|
||||
async for token in interface.inference(input_message,id):
|
||||
comp.append_token(token)
|
||||
return comp
|
||||
|
||||
@ -90,7 +90,8 @@ class ArgumentParser:
|
||||
# user config
|
||||
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
|
||||
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
|
||||
parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think)
|
||||
parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think)
|
||||
parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph)
|
||||
|
||||
# web config
|
||||
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
|
||||
|
||||
@ -15,7 +15,9 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device
|
||||
|
||||
|
||||
warm_uped = False
|
||||
|
||||
class KTransformersThreadContext(TransformersThreadContext):
|
||||
pass
|
||||
|
||||
@ -74,13 +76,13 @@ class KTransformersInterface(TransformersInterface):
|
||||
self._infer_lock = asyncio.Lock()
|
||||
|
||||
def decode_one_tokens(self):
|
||||
global warm_uped
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
global warm_uped
|
||||
torch.cuda.set_device(torch_device)
|
||||
if self.args.use_cuda_graph and warm_uped == True:
|
||||
|
||||
if warm_uped and self.args.use_cuda_graph:
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(
|
||||
@ -127,34 +129,54 @@ class KTransformersInterface(TransformersInterface):
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
device = "cuda:0" if device == "cuda" else device
|
||||
|
||||
if is_new:
|
||||
self.cache.reset()
|
||||
self.ever_generated_ids.clear()
|
||||
former_seq_length = 0
|
||||
self.seq_length = input_ids_length
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
self.seq_length + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
same_prefix = 0
|
||||
flat_input_ids = input_ids.flatten()
|
||||
|
||||
if getattr(self, 'generated_ids', None) is None:
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
self.seq_length = 1
|
||||
|
||||
flat_prev_ids = self.generated_ids.flatten()
|
||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
same_prefix += 1
|
||||
else:
|
||||
break
|
||||
|
||||
logger.debug(f"same prefix len: {same_prefix}")
|
||||
self.cache.remove_suffix(same_prefix)
|
||||
self.seq_length = same_prefix
|
||||
self.generated_ids = self.generated_ids[..., :same_prefix]
|
||||
input_ids = input_ids[..., same_prefix:]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
|
||||
self.ever_generated_ids.clear()
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
@ -176,6 +198,7 @@ class KTransformersInterface(TransformersInterface):
|
||||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
self.prepare_logits_wrapper(input_ids, device)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
@ -187,4 +210,4 @@ class KTransformersInterface(TransformersInterface):
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
async with self._infer_lock:
|
||||
async for v in super().inference(local_messages, thread_id):
|
||||
yield v
|
||||
yield v
|
||||
|
||||
@ -170,7 +170,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
for m in messages[1:]:
|
||||
if m["role"] == "user" and new_messages[-1]["role"] == "user":
|
||||
logger.warning("merge two adjacent user messages")
|
||||
new_messages[-1]["content"] += m["content"]
|
||||
new_messages[-1]["content"] += '\n' + m["content"]
|
||||
else:
|
||||
new_messages.append(m)
|
||||
# if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
@ -179,7 +179,11 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
# input_ids = self.tokenizer.apply_chat_template(
|
||||
# new_messages, return_tensors="pt", add_generation_prompt=True
|
||||
# ).to(self.args.device)
|
||||
input_ids = self.tokenizer.apply_chat_template(new_messages,return_tensors='pt',add_generation_prompt=True).to(self.args.device)
|
||||
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||
# drop <think> token in chat template
|
||||
if input_str.endswith('<think>\n'):
|
||||
input_str = input_str[:-len('<think>\n')]
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
x = self.generated_ids[:,:self.seq_length]
|
||||
y = input_ids[:,:self.seq_length]
|
||||
@ -198,14 +202,28 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
self.seq_length += 1
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = logits / self.args.temperature if self.args.temperature!=0 else logits
|
||||
def prepare_logits_wrapper(self, inputs, device):
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
top_k=self.args.top_k,
|
||||
top_p=self.args.top_p,
|
||||
temperature=self.args.temperature,
|
||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
self.inputs = inputs
|
||||
self.generation_config = generation_config
|
||||
try: # transformers==4.43
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config,device=device)
|
||||
)
|
||||
except:
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config)
|
||||
)
|
||||
|
||||
for token_idx in self.ever_generated_ids:
|
||||
if logits[token_idx] < 0:
|
||||
logits[token_idx] *= self.args.repetition_penalty
|
||||
else:
|
||||
logits[token_idx] /= self.args.repetition_penalty
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
@ -239,31 +257,51 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
if is_new:
|
||||
self.cache.reset()
|
||||
self.ever_generated_ids.clear()
|
||||
former_seq_length = 0
|
||||
self.seq_length = input_ids_length
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
self.seq_length + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
same_prefix = 0
|
||||
flat_input_ids = input_ids.flatten()
|
||||
|
||||
if getattr(self, 'generated_ids', None) is None:
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
self.seq_length = 1
|
||||
|
||||
flat_prev_ids = self.generated_ids.flatten()
|
||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
same_prefix += 1
|
||||
else:
|
||||
break
|
||||
|
||||
logger.debug(f"same prefix len: {same_prefix}")
|
||||
self.cache.remove_suffix(same_prefix)
|
||||
self.seq_length = same_prefix
|
||||
self.generated_ids = self.generated_ids[..., :same_prefix]
|
||||
input_ids = input_ids[..., same_prefix:]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
|
||||
self.ever_generated_ids.clear()
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
@ -285,6 +323,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
self.prepare_logits_wrapper(input_ids, device)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
@ -321,6 +360,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
return True
|
||||
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
@ -330,8 +370,9 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
||||
if Config().user_force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device)
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
@ -339,11 +380,14 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
if Config().user_force_think:
|
||||
t = "<think>\n"
|
||||
print(t,end="",flush=True)
|
||||
yield t
|
||||
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
# output think token after prefill done
|
||||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
|
||||
@ -105,6 +105,7 @@ def custom_openapi(app):
|
||||
|
||||
def main():
|
||||
cfg = Config()
|
||||
|
||||
arg_parser = ArgumentParser(cfg)
|
||||
|
||||
# 初始化消息
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user