mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-01-19 12:43:16 +08:00
Merge pull request #413 from kvcache-ai/fix_precision_MLA
Fix precision mla
This commit is contained in:
commit
1c1769a579
5
.gitignore
vendored
5
.gitignore
vendored
@ -23,3 +23,8 @@ tmp1.txt
|
||||
test_65_300_1536.txt
|
||||
test.txt
|
||||
book
|
||||
ktransformers/tests/mmlu_result_silicon.json
|
||||
ktransformers/tests/chat_txt.txt
|
||||
mmlu_result_q4km.json
|
||||
mmlu_result_q4km.log
|
||||
ktransformers/tests/mmlu_result_silicon.log
|
||||
|
||||
6
Makefile
6
Makefile
@ -18,4 +18,8 @@ dev_install:
|
||||
|
||||
echo "Installing ktransformers"
|
||||
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation
|
||||
echo "Installation completed successfully"
|
||||
echo "Installation completed successfully"
|
||||
install_numa:
|
||||
USE_NUMA=1 make dev_install
|
||||
install_no_numa:
|
||||
env -u USE_NUMA make dev_install
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
set -e
|
||||
|
||||
# clear build dirs
|
||||
rm -rf build
|
||||
rm -rf *.egg-info
|
||||
rm -rf ktransformers/ktransformers_ext/build
|
||||
rm -rf ktransformers/ktransformers_ext/cuda/build
|
||||
rm -rf ktransformers/ktransformers_ext/cuda/dist
|
||||
|
||||
@ -30,6 +30,7 @@ from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
|
||||
custom_models = {
|
||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
@ -170,9 +171,16 @@ def local_chat(
|
||||
torch.set_default_dtype(
|
||||
torch.bfloat16
|
||||
) # TODO: Remove this, replace dtype using config
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode, force_think
|
||||
)
|
||||
|
||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
|
||||
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||
)
|
||||
else:
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -138,8 +138,6 @@ class StaticCache(transformers.StaticCache):
|
||||
page_idx = cache_position // self.page_size
|
||||
page_offset = cache_position % self.page_size
|
||||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||||
#print("page_idx", page_idx)
|
||||
#print("page_offset", page_offset)
|
||||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||||
return k_out, self.page_table_list[layer_idx]
|
||||
|
||||
@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
orig_module.dim, orig_module.max_position_embeddings, orig_module.base
|
||||
@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
orig_module.dim,
|
||||
@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
orig_module.dim,
|
||||
@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
||||
# **kwargs,
|
||||
# ):
|
||||
# BaseInjectedModule.__init__(
|
||||
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
# self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
# )
|
||||
# self.generate_device = generate_device
|
||||
# self.prefill_device = prefill_device
|
||||
@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
orig_module.dim,
|
||||
|
||||
@ -19,9 +19,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
|
||||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.cache_utils import Cache
|
||||
from flash_attn import flash_attn_with_kvcache, flash_attn_func
|
||||
from flash_attn import flash_attn_func
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
import os
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
if flashinfer_enabled:
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref
|
||||
|
||||
logger = logging.getLogger("attention")
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@ -41,15 +45,15 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
use_triton: bool = False,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
self.use_triton = use_triton
|
||||
self.mla_wrapper = None
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
@ -141,6 +145,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
#print(compressed_kv.shape)
|
||||
|
||||
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
|
||||
|
||||
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
|
||||
compressed_kv = compressed_kv.squeeze(1)
|
||||
"""
|
||||
@ -168,8 +173,9 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
|
||||
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
|
||||
|
||||
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
|
||||
@ -179,14 +185,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_linux(
|
||||
def forward_linux_triton(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
@ -267,7 +273,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
# use triton attention kernel adapted from vLLM and SGLang for MQA
|
||||
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
||||
page_table,
|
||||
position_ids.squeeze(0).to(torch.int32), attn_logits,
|
||||
position_ids.squeeze(0).to(torch.int32)+1, attn_logits,
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
self.softmax_scale,
|
||||
past_key_value.page_size)
|
||||
@ -325,6 +331,154 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_linux_flashinfer(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||
|
||||
# decode
|
||||
if q_len == 1:
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, past_key_value.page_size, self.kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, past_key_value.page_size, self.qk_rope_head_dim)
|
||||
# k_pe [max_pages, page_size, self.qk_rope_head_dim]
|
||||
# compressed_kv [max_pages, page_size, self.kv_lora_rank]
|
||||
|
||||
# q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
|
||||
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(1, 2)
|
||||
assert q_nope.is_contiguous()
|
||||
|
||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||
q_nope.squeeze_(1)
|
||||
q_pe.squeeze_(1)
|
||||
|
||||
# flash attn doesn't support head_dim bigger than 256, use flashinfer
|
||||
if self.mla_wrapper is None:
|
||||
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
|
||||
if self.mla_wrapper.need_plan:
|
||||
self.mla_wrapper.need_plan = False
|
||||
self.mla_wrapper.plan(None,None,None,
|
||||
position_ids.squeeze(1)+1,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
past_key_value.page_size,
|
||||
self.softmax_scale,
|
||||
q_nope.dtype,
|
||||
compressed_kv.dtype)
|
||||
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
"""
|
||||
k = (
|
||||
torch.cat([compressed_kv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(self.num_heads, dim=1)
|
||||
)
|
||||
v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1)
|
||||
lens = position_ids.item() + 1
|
||||
#print("lens", lens)
|
||||
attn_ref, lse_ref = attention_ref(
|
||||
1,
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
k[:lens],
|
||||
v[:lens],
|
||||
False,
|
||||
self.softmax_scale
|
||||
)
|
||||
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||
"""
|
||||
|
||||
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
|
||||
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
else:
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
k_pe.squeeze(0)
|
||||
compressed_kv.squeeze(0)
|
||||
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
k_pe.unsqueeze(0)
|
||||
compressed_kv.unsqueeze(0)
|
||||
|
||||
k_pe = k_pe[:, :q_len]
|
||||
compressed_kv = compressed_kv[:, :q_len]
|
||||
kv = (
|
||||
self.kv_b_proj(compressed_kv)
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
)
|
||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
||||
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states_padded,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if self.q_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
bsz, q_len, self.num_heads * self.v_head_dim
|
||||
).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_windows(
|
||||
self,
|
||||
@ -403,7 +557,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if not self.use_triton: # os.name == 'nt'
|
||||
if os.name == 'nt':
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -415,16 +569,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.forward_linux(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
if flashinfer_enabled:
|
||||
return self.forward_linux_flashinfer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.forward_linux_triton(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class KLlamaAttention(BaseInjectedModule):
|
||||
@ -435,9 +601,10 @@ class KLlamaAttention(BaseInjectedModule):
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
|
||||
@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
nn.Module.__init__(self)
|
||||
nn.Module.__setattr__(self, "orig_module", orig_module)
|
||||
object.__setattr__(self, "key", key)
|
||||
object.__setattr__(self, "gguf_loader", gguf_loader)
|
||||
object.__setattr__(self, "config", config)
|
||||
object.__setattr__(self, "device", device)
|
||||
object.__setattr__(self, "prefill_device", prefill_device)
|
||||
object.__setattr__(self, "generate_device", generate_device)
|
||||
object.__setattr__(self, "device", generate_device)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
|
||||
|
||||
@ -119,6 +119,7 @@ class KExpertsCPU(KExpertsBase):
|
||||
output_cpu:Tensor = None
|
||||
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
|
||||
#stream_map:dict = {} # Manage cuda stream on different gpu
|
||||
#gguf_loader:GGUFLoader = None
|
||||
CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||
def __init__(
|
||||
self,
|
||||
@ -132,6 +133,9 @@ class KExpertsCPU(KExpertsBase):
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
#if KExpertsCPU.gguf_loader is None:
|
||||
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
|
||||
self.gguf_loader = gguf_loader
|
||||
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.out_device = out_device
|
||||
@ -532,7 +536,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
|
||||
generate_device: str = "cpu",
|
||||
generate_op: str | None = "KExpertsCPU",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
|
||||
240
ktransformers/operators/flashinfer_wrapper.py
Normal file
240
ktransformers/operators/flashinfer_wrapper.py
Normal file
@ -0,0 +1,240 @@
|
||||
'''
|
||||
Description : flashinfer MLA wrapper
|
||||
Author : Boxin Zhang
|
||||
Version : 0.2.2
|
||||
'''
|
||||
import torch
|
||||
|
||||
flashinfer_enabled = False
|
||||
|
||||
try:
|
||||
import flashinfer
|
||||
flashinfer_enabled = True
|
||||
print("found flashinfer")
|
||||
|
||||
except ImportError:
|
||||
print("flashinfer not found, use triton for linux")
|
||||
|
||||
import math
|
||||
|
||||
def attention_ref(
|
||||
batch_size,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
causal: bool,
|
||||
sm_scale: float,
|
||||
) -> torch.Tensor:
|
||||
qo_len = q.shape[0] // batch_size
|
||||
kv_len = k.shape[0] // batch_size
|
||||
num_qo_heads = q.shape[1]
|
||||
head_dim_qk = q.shape[2]
|
||||
head_dim_vo = v.shape[2]
|
||||
logits = (
|
||||
torch.einsum(
|
||||
"bmhd,bnhd->bhmn",
|
||||
q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(),
|
||||
k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(),
|
||||
)
|
||||
* sm_scale
|
||||
)
|
||||
|
||||
#print("attn weights", logits)
|
||||
|
||||
if causal:
|
||||
mask = (
|
||||
torch.arange(kv_len - qo_len, kv_len).unsqueeze(1)
|
||||
>= torch.arange(0, kv_len).unsqueeze(0)
|
||||
).to(q.device)
|
||||
else:
|
||||
mask = torch.ones(qo_len, kv_len).to(q.device)
|
||||
|
||||
logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf"))
|
||||
lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2)
|
||||
p = torch.softmax(logits, dim=-1)
|
||||
o_ref = (
|
||||
torch.einsum(
|
||||
"bhmn,bnhd->bmhd",
|
||||
p,
|
||||
v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(),
|
||||
)
|
||||
.contiguous()
|
||||
.view(batch_size * qo_len, num_qo_heads, head_dim_vo)
|
||||
.to(q)
|
||||
)
|
||||
|
||||
return o_ref, lse_ref * math.log2(math.e)
|
||||
|
||||
class MLAWrapper():
|
||||
def __init__(self,
|
||||
max_batch_size,
|
||||
max_pages,
|
||||
use_cuda_graph = True,
|
||||
device = "cuda",
|
||||
):
|
||||
self.float_workspace_buffer = torch.empty(128*1024*1024, dtype=torch.int8, device=device)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_pages = max_pages
|
||||
if use_cuda_graph:
|
||||
if self.max_batch_size == 1:
|
||||
self.qo_indptr_buf = torch.arange(0, max_batch_size+1, dtype=torch.int32, device=device)
|
||||
self.kv_indptr_buf = torch.tensor([0, max_pages], dtype=torch.int32, device=device)
|
||||
self.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
|
||||
else:
|
||||
self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
|
||||
self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
|
||||
self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)
|
||||
self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)
|
||||
else:
|
||||
self.qo_indptr_buf = None
|
||||
self.kv_indptr_buf = None
|
||||
self.kv_indices_buf = None
|
||||
self.kv_len_arr_buf = None
|
||||
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||
self.float_workspace_buffer,
|
||||
use_cuda_graph=False,
|
||||
qo_indptr=self.qo_indptr_buf,
|
||||
kv_indptr=self.kv_indptr_buf,
|
||||
kv_indices=self.kv_indices_buf,
|
||||
kv_len_arr=self.kv_len_arr_buf,
|
||||
)
|
||||
self.need_plan = True
|
||||
|
||||
def plan(self,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
page_size,
|
||||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
):
|
||||
if qo_indptr is None:
|
||||
assert self.max_batch_size == 1
|
||||
qo_indptr = self.qo_indptr_buf
|
||||
if kv_indptr is None:
|
||||
assert self.max_batch_size == 1
|
||||
kv_indptr = self.kv_indptr_buf
|
||||
if kv_indices is None:
|
||||
assert self.max_batch_size == 1
|
||||
kv_indices = self.kv_indices_buf
|
||||
|
||||
self.wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
page_size,
|
||||
False, # causal is False for decoding
|
||||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
)
|
||||
|
||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse)
|
||||
|
||||
class MLAWrapperSingleton():
|
||||
wrappers:dict = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, device, *args, **kwargs)->MLAWrapper:
|
||||
if device not in cls.wrappers:
|
||||
cls.make_instance(device, *args, **kwargs)
|
||||
return cls.wrappers[device]
|
||||
|
||||
@classmethod
|
||||
def make_instance(cls, device, *args, **kwargs):
|
||||
cls.wrappers[device] = MLAWrapper(*args, **kwargs, device=device)
|
||||
|
||||
@classmethod
|
||||
def plan_all(cls, qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
page_size,
|
||||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,):
|
||||
for device, wrapper in cls.wrappers.items():
|
||||
kv_len_arr_cur_device = kv_len_arr.to(device)
|
||||
wrapper.plan(qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr_cur_device,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
page_size,
|
||||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
max_batch_size = 1
|
||||
max_pages = 1
|
||||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
wrapper = MLAWrapperSingleton.get_instance(
|
||||
"cuda",
|
||||
max_batch_size,
|
||||
max_pages,
|
||||
)
|
||||
|
||||
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda")
|
||||
|
||||
wrapper.plan(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
kv_len_arr,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
page_size,
|
||||
192 ** (-0.5),
|
||||
torch.bfloat16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
|
||||
print(k[:10].shape)
|
||||
print(v[:10].shape)
|
||||
|
||||
attn_ref, lse_ref = attention_ref(
|
||||
max_batch_size,
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
k[:10],
|
||||
v[:10],
|
||||
False,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
||||
@ -93,11 +93,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
@ -383,7 +383,7 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||
prefill_op: str| None = "KLinearTorch",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
# build all the linear operators
|
||||
if prefill_op is not None:
|
||||
|
||||
@ -15,7 +15,7 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device
|
||||
|
||||
|
||||
warm_uped = False
|
||||
class KTransformersThreadContext(TransformersThreadContext):
|
||||
pass
|
||||
|
||||
@ -77,7 +77,10 @@ class KTransformersInterface(TransformersInterface):
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
if self.args.use_cuda_graph:
|
||||
global warm_uped
|
||||
torch.cuda.set_device(torch_device)
|
||||
if self.args.use_cuda_graph and warm_uped == True:
|
||||
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(
|
||||
@ -99,7 +102,10 @@ class KTransformersInterface(TransformersInterface):
|
||||
torch.cuda.synchronize()
|
||||
logits = logits[0, -1, :]
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
|
||||
if self.args.use_cuda_graph:
|
||||
warm_uped = True
|
||||
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||
logits = self.model(
|
||||
@ -125,6 +131,7 @@ class KTransformersInterface(TransformersInterface):
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
device = "cuda:0" if device == "cuda" else device
|
||||
|
||||
if is_new:
|
||||
self.cache.reset()
|
||||
@ -156,6 +163,7 @@ class KTransformersInterface(TransformersInterface):
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
torch.cuda.set_device(device)
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
||||
@ -18,7 +18,7 @@ import sys, os
|
||||
from ..base import ThreadContext, BackendInterfaceBase
|
||||
from ktransformers.server.config.log import logger
|
||||
from ..args import ConfigArgs, default_args
|
||||
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
|
||||
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||
class TextStreamer:
|
||||
@ -291,8 +291,14 @@ class TransformersInterface(BackendInterfaceBase):
|
||||
@torch.no_grad
|
||||
def generate(self):
|
||||
self.profiler.set_counter("decode", 0)
|
||||
for _ in range(1, self.args.max_new_tokens):
|
||||
for i in range(1, self.args.max_new_tokens):
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||
if i > 1 and flashinfer_enabled:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id:
|
||||
|
||||
195
ktransformers/tests/mmlu_pro_test.py
Normal file
195
ktransformers/tests/mmlu_pro_test.py
Normal file
@ -0,0 +1,195 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
import requests
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
|
||||
import os
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
os.environ['https_proxy'] = ''
|
||||
os.environ['http_proxy'] = ''
|
||||
hint = 'There is a single choice question. Answer the question by replying A, B, C, D, E, F, G, H, I, J. No other answers are accepted. Just the letter.'
|
||||
|
||||
|
||||
class DataEvaluator:
|
||||
def __init__(self):
|
||||
# self.template_prompt = template_prompt
|
||||
self.data = []
|
||||
|
||||
def load_data(self, file_path):
|
||||
"""
|
||||
Load data from a Parquet file into a list.
|
||||
Each record in the Parquet file should represent an individual record.
|
||||
"""
|
||||
# 读取 Parquet 文件
|
||||
# dataset = load_dataset('parquet', data_files=file_path)
|
||||
ds = load_dataset("TIGER-Lab/MMLU-Pro")
|
||||
df = pd.DataFrame(ds['test'])
|
||||
# print(ds)
|
||||
# # ds_1 = ds['train']
|
||||
# ds_2 = ds['validation']
|
||||
# ds_3 = ds['test']
|
||||
# # 将数据集转换为 Pandas DataFrame
|
||||
# df_test = pd.DataFrame(ds['test'])
|
||||
# df_val = pd.DataFrame(ds['validation'])
|
||||
|
||||
# for _, row in df.iterrows():
|
||||
# self.data.append(row.to_dict())
|
||||
# df = pd.read_parquet(file_path)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
|
||||
def get_prompt(self, record):
|
||||
"""
|
||||
Combine fields from a record with the template prompt to create a full prompt.
|
||||
:param record: Dictionary containing fields to populate the template.
|
||||
:return: A formatted prompt string.
|
||||
"""
|
||||
# 查看ABCD。。。的选项
|
||||
options_str = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(record['options'])])
|
||||
prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '"
|
||||
return prompt
|
||||
|
||||
def post_processing(self, text):
|
||||
"""
|
||||
Perform post-processing on the prediction string.
|
||||
:param text: The raw prediction string.
|
||||
:return: Processed prediction string.
|
||||
"""
|
||||
text = text.lstrip('\n').split('\n')[0]
|
||||
return text[:1]
|
||||
|
||||
def score(self, pred, answers):
|
||||
"""
|
||||
Calculate scores between the prediction and the answer.
|
||||
Uses ROUGE scores as the evaluation metric.
|
||||
:param pred: The predicted string.
|
||||
:param answer: The reference answer string.
|
||||
:return: A dictionary containing ROUGE scores.
|
||||
"""
|
||||
for answer in answers:
|
||||
if pred == answer:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
# Function to generate text using API
|
||||
def generate_text(api_url, question, model_name, stream=False):
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
# 添加 API Key
|
||||
'Authorization' : 'Bearer '
|
||||
}
|
||||
data = {
|
||||
"messages": [{"content": question, "role": "user"}],
|
||||
"model": model_name,
|
||||
"stream": stream,
|
||||
# "temperature": 0.0
|
||||
}
|
||||
|
||||
print("POST data:", data)
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()
|
||||
else:
|
||||
print(f"API Request failed with status code {response.status_code}")
|
||||
return None
|
||||
|
||||
# Main function to handle multiple evaluations
|
||||
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
|
||||
start_total_time = time.time()
|
||||
|
||||
total_score = 0
|
||||
|
||||
results = []
|
||||
# 设置随机数种子
|
||||
random.seed(42)
|
||||
random.shuffle(data_evaluator.data)
|
||||
for i in range(min(concurrent_requests, len(data_evaluator.data))):
|
||||
# Randomly select a data item from data for each request
|
||||
data_item = data_evaluator.data[i]
|
||||
question = data_evaluator.get_prompt(data_item)
|
||||
# print(question)
|
||||
|
||||
# Start the timer for this evaluation
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Generate prediction using the API
|
||||
prediction = generate_text(api_url, question, model_name)
|
||||
|
||||
if prediction is None:
|
||||
raise Exception(f"Failed to get prediction for {question}")
|
||||
|
||||
answer = data_item['answer']
|
||||
# Compute score
|
||||
score = data_evaluator.score(data_evaluator.post_processing(prediction), answer)
|
||||
|
||||
# Calculate the time taken
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Collect the result data
|
||||
result_data = {
|
||||
"question_id": data_item['question_id'],
|
||||
"answer": answer,
|
||||
"prediction": data_evaluator.post_processing(prediction),
|
||||
"score": score,
|
||||
"time": elapsed_time
|
||||
}
|
||||
|
||||
# Write results to result.json with each field on a new line
|
||||
with open(result_file, 'a', encoding='utf-8') as f:
|
||||
json.dump(result_data, f, ensure_ascii=False, indent=4)
|
||||
f.write("\n") # Ensure each JSON object is on a new line
|
||||
|
||||
results.append(result_data)
|
||||
|
||||
# Aggregate scores
|
||||
total_score += score
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing request {i}: {e}")
|
||||
|
||||
# Calculate total time and throughput
|
||||
total_time = time.time() - start_total_time
|
||||
throughput = concurrent_requests / total_time
|
||||
|
||||
# Log the total time, throughput, and average ROUGE scores
|
||||
with open(log_file, 'a', encoding='utf-8') as log_f:
|
||||
log_f.write(f"Total Time: {total_time:.2f} seconds\n")
|
||||
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
|
||||
log_f.write(f"Average Scores: {total_score / concurrent_requests}\n")
|
||||
log_f.write('-' * 40 + '\n')
|
||||
|
||||
print(f"Results saved to {result_file}")
|
||||
print(f"Log saved to {log_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="API Generate Tester")
|
||||
parser.add_argument("--concurrent", type=int, default=1000, help="Number of concurrent evaluations")
|
||||
parser.add_argument("--file", type=str, default="TIGER-Lab/MMLU-Pro", help="Path to the mmlu.jsonl file")
|
||||
parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file")
|
||||
parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file")
|
||||
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the data from the provided file
|
||||
# template_prompt = hint + "\nQuestion: {question}\nA. {options}\nB. {option_b}\nC. {option_c}\nD. {option_d}\nAnswer: '"
|
||||
# template_prompt_pro = hint + "\nQuestion: {question}\nA. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nE. {options[4]}\nF. {options[5]}\nG. \
|
||||
# {options[6]}\nH. {options[7]}\nI. {options[8]}\nJ. {options[9]}\nAnswer: '"
|
||||
|
||||
|
||||
# Load the data from the provided file
|
||||
data_evaluator = DataEvaluator()
|
||||
data_evaluator.load_data(args.file)
|
||||
|
||||
# Run the main function with the specified number of concurrent evaluations
|
||||
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
|
||||
195
ktransformers/tests/mmlu_test.py
Normal file
195
ktransformers/tests/mmlu_test.py
Normal file
@ -0,0 +1,195 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
import requests
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
|
||||
import os
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
os.environ['https_proxy'] = ''
|
||||
os.environ['http_proxy'] = ''
|
||||
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
|
||||
|
||||
|
||||
class DataEvaluator:
|
||||
def __init__(self):
|
||||
# self.template_prompt = template_prompt
|
||||
self.data = []
|
||||
|
||||
def load_data(self, file_path):
|
||||
"""
|
||||
Load data from a Parquet file into a list.
|
||||
Each record in the Parquet file should represent an individual record.
|
||||
"""
|
||||
# 读取 Parquet 文件
|
||||
# dataset = load_dataset('parquet', data_files=file_path)
|
||||
ds = load_dataset(file_path,"all")
|
||||
df = pd.DataFrame(ds['test'])
|
||||
# print(ds)
|
||||
# # ds_1 = ds['train']
|
||||
# ds_2 = ds['validation']
|
||||
# ds_3 = ds['test']
|
||||
# # 将数据集转换为 Pandas DataFrame
|
||||
# df_test = pd.DataFrame(ds['test'])
|
||||
# df_val = pd.DataFrame(ds['validation'])
|
||||
|
||||
# for _, row in df.iterrows():
|
||||
# self.data.append(row.to_dict())
|
||||
# df = pd.read_parquet(file_path)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
|
||||
def get_prompt(self, record):
|
||||
"""
|
||||
Combine fields from a record with the template prompt to create a full prompt.
|
||||
:param record: Dictionary containing fields to populate the template.
|
||||
:return: A formatted prompt string.
|
||||
"""
|
||||
# 查看ABCD。。。的选项
|
||||
options_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(record['choices'])])
|
||||
prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '"
|
||||
return prompt
|
||||
|
||||
def post_processing(self, text):
|
||||
"""
|
||||
Perform post-processing on the prediction string.
|
||||
:param text: The raw prediction string.
|
||||
:return: Processed prediction string.
|
||||
"""
|
||||
text = text.lstrip('\n').split('\n')[0]
|
||||
return text[:1]
|
||||
|
||||
def score(self, pred, answers):
|
||||
"""
|
||||
Calculate scores between the prediction and the answer.
|
||||
Uses ROUGE scores as the evaluation metric.
|
||||
:param pred: The predicted string.
|
||||
:param answer: The reference answer string.
|
||||
:return: A dictionary containing ROUGE scores.
|
||||
"""
|
||||
for answer in answers:
|
||||
if pred == answer:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
# Function to generate text using API
|
||||
def generate_text(api_url, question, model_name, stream=False):
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
# 添加 API Key
|
||||
'Authorization' : 'Bearer '
|
||||
}
|
||||
data = {
|
||||
"messages": [{"content": question, "role": "user"}],
|
||||
"model": model_name,
|
||||
"stream": stream,
|
||||
# "temperature": 0.0
|
||||
}
|
||||
|
||||
print("POST data:", data)
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()
|
||||
else:
|
||||
print(f"API Request failed with status code {response.status_code}")
|
||||
return None
|
||||
|
||||
# Main function to handle multiple evaluations
|
||||
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
|
||||
start_total_time = time.time()
|
||||
|
||||
total_score = 0
|
||||
|
||||
results = []
|
||||
# 设置随机数种子
|
||||
random.seed(42)
|
||||
random.shuffle(data_evaluator.data)
|
||||
for i in range(min(concurrent_requests, len(data_evaluator.data))):
|
||||
# Randomly select a data item from data for each request
|
||||
data_item = data_evaluator.data[i]
|
||||
question = data_evaluator.get_prompt(data_item)
|
||||
# print(question)
|
||||
|
||||
# Start the timer for this evaluation
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Generate prediction using the API
|
||||
prediction = generate_text(api_url, question, model_name)
|
||||
|
||||
if prediction is None:
|
||||
raise Exception(f"Failed to get prediction for {question}")
|
||||
|
||||
answer = chr(data_item['answer'] + 65)
|
||||
# Compute score
|
||||
score = data_evaluator.score(data_evaluator.post_processing(prediction), answer)
|
||||
|
||||
# Calculate the time taken
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Collect the result data
|
||||
result_data = {
|
||||
"question_id": i,
|
||||
"answer": answer,
|
||||
"prediction": data_evaluator.post_processing(prediction),
|
||||
"score": score,
|
||||
"time": elapsed_time
|
||||
}
|
||||
|
||||
# Write results to result.json with each field on a new line
|
||||
with open(result_file, 'a', encoding='utf-8') as f:
|
||||
json.dump(result_data, f, ensure_ascii=False, indent=4)
|
||||
f.write("\n") # Ensure each JSON object is on a new line
|
||||
|
||||
results.append(result_data)
|
||||
|
||||
# Aggregate scores
|
||||
total_score += score
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing request {i}: {e}")
|
||||
|
||||
# Calculate total time and throughput
|
||||
total_time = time.time() - start_total_time
|
||||
throughput = concurrent_requests / total_time
|
||||
|
||||
# Log the total time, throughput, and average ROUGE scores
|
||||
with open(log_file, 'a', encoding='utf-8') as log_f:
|
||||
log_f.write(f"Total Time: {total_time:.2f} seconds\n")
|
||||
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
|
||||
log_f.write(f"Average Scores: {total_score / concurrent_requests}\n")
|
||||
log_f.write('-' * 40 + '\n')
|
||||
|
||||
print(f"Results saved to {result_file}")
|
||||
print(f"Log saved to {log_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="API Generate Tester")
|
||||
parser.add_argument("--concurrent", type=int, default=1000, help="Number of concurrent evaluations")
|
||||
parser.add_argument("--file", type=str, default="cais/mmlu", help="Path to the mmlu.jsonl file")
|
||||
parser.add_argument("--result", type=str, default="./mmlu_result_silicon.json", help="Path to save the result JSON file")
|
||||
parser.add_argument("--log", type=str, default="./mmlu_result_silicon.log", help="Path to save the log file")
|
||||
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10003/v1/chat/completions", help="API URL")
|
||||
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the data from the provided file
|
||||
# template_prompt = hint + "\nQuestion: {question}\nA. {options}\nB. {option_b}\nC. {option_c}\nD. {option_d}\nAnswer: '"
|
||||
# template_prompt_pro = hint + "\nQuestion: {question}\nA. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nE. {options[4]}\nF. {options[5]}\nG. \
|
||||
# {options[6]}\nH. {options[7]}\nI. {options[8]}\nJ. {options[9]}\nAnswer: '"
|
||||
|
||||
|
||||
# Load the data from the provided file
|
||||
data_evaluator = DataEvaluator()
|
||||
data_evaluator.load_data(args.file)
|
||||
|
||||
# Run the main function with the specified number of concurrent evaluations
|
||||
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
|
||||
@ -109,6 +109,7 @@ GGML_TYPES = {
|
||||
"Q5_K": 13,
|
||||
"Q6_K": 14,
|
||||
"IQ4_XS": 23,
|
||||
"BF16": 30,
|
||||
}
|
||||
|
||||
GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
|
||||
@ -116,6 +117,7 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
|
||||
GGML_BLOCK_SIZES = {
|
||||
"F32": 4,
|
||||
"F16": 2,
|
||||
"BF16": 2,
|
||||
"Q4_0": 2 + 16,
|
||||
"Q5_0": 2 + 4 + 16,
|
||||
"Q8_0": 2 + 32,
|
||||
@ -130,6 +132,7 @@ GGML_BLOCK_SIZES = {
|
||||
GGML_ELEMENTS_PER_BLOCK = {
|
||||
"F32": 1,
|
||||
"F16": 1,
|
||||
"BF16": 1,
|
||||
"Q4_0": 32,
|
||||
"Q5_0": 32,
|
||||
"Q8_0": 32,
|
||||
@ -333,6 +336,8 @@ class GGUFLoader:
|
||||
else:
|
||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values)
|
||||
if ggml_name == "BF16":
|
||||
values = values.view(torch.bfloat16)
|
||||
values = values.view(shape[::-1])
|
||||
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
|
||||
n_head = self.gguf_file_meta['llama.attention.head_count']
|
||||
@ -764,6 +769,7 @@ def dequantize_f16_gpu(data, device):
|
||||
GGML_DEQUANTIZE = {
|
||||
"F32": dequantize_f32,
|
||||
"F16": dequantize_f16,
|
||||
"BF16": dequantize_f16,
|
||||
"Q4_0": dequantize_q4_0,
|
||||
"Q5_0": dequantize_q5_0,
|
||||
"Q8_0": dequantize_q8_0,
|
||||
@ -778,6 +784,7 @@ GGML_DEQUANTIZE = {
|
||||
GGML_DEQUANTIZE_GPU = {
|
||||
"F32": dequantize_f32_gpu,
|
||||
"F16": dequantize_f16_gpu,
|
||||
"BF16": dequantize_f16_gpu,
|
||||
"Q4_0": dequantize_q4_0_gpu,
|
||||
"Q5_0": dequantize_q5_0_gpu,
|
||||
"Q8_0": dequantize_q8_0_gpu,
|
||||
|
||||
@ -17,6 +17,7 @@ from ktransformers.operators import base_operator
|
||||
from ktransformers.models.custom_cache import StaticCache
|
||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.util.textstream import TextStreamer
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
|
||||
warm_uped = False
|
||||
|
||||
@ -87,7 +88,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
||||
module.load()
|
||||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||
mode = 'normal', force_think: bool = False):
|
||||
mode = 'normal', force_think: bool = False, use_flashinfer_mla = False,
|
||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
@ -137,7 +139,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||
)
|
||||
else:
|
||||
past_key_values = None
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long)
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
@ -182,7 +184,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||
generated_ids[:, seq_length] = next_token
|
||||
tokens.append(int(next_token))
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long)
|
||||
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
seq_length += 1
|
||||
|
||||
@ -195,7 +197,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||
warm_uped = True
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
|
||||
if i > 1 and use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user