mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-02-06 22:55:50 +08:00
fix: fix server for triton kernel
This commit is contained in:
parent
bb1cadfff3
commit
ee24eb8dc3
@ -16,6 +16,8 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
|
|||||||
from ktransformers.util.utils import get_device
|
from ktransformers.util.utils import get_device
|
||||||
|
|
||||||
|
|
||||||
|
warm_uped = False
|
||||||
|
|
||||||
class KTransformersThreadContext(TransformersThreadContext):
|
class KTransformersThreadContext(TransformersThreadContext):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -74,10 +76,13 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
self._infer_lock = asyncio.Lock()
|
self._infer_lock = asyncio.Lock()
|
||||||
|
|
||||||
def decode_one_tokens(self):
|
def decode_one_tokens(self):
|
||||||
|
global warm_uped
|
||||||
|
|
||||||
device_map = self.model.gguf_loader.tensor_device_map
|
device_map = self.model.gguf_loader.tensor_device_map
|
||||||
torch_device = get_device("blk.0.self_attn", device_map)
|
torch_device = get_device("blk.0.self_attn", device_map)
|
||||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||||
if self.args.use_cuda_graph:
|
torch.cuda.set_device(torch_device)
|
||||||
|
if warm_uped and self.args.use_cuda_graph:
|
||||||
if not hasattr(self, "cuda_graph_runner"):
|
if not hasattr(self, "cuda_graph_runner"):
|
||||||
self.cuda_graph_runner = CUDAGraphRunner()
|
self.cuda_graph_runner = CUDAGraphRunner()
|
||||||
self.cuda_graph_runner.capture(
|
self.cuda_graph_runner.capture(
|
||||||
@ -113,6 +118,7 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
else:
|
else:
|
||||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||||
logits = logits[0, -1, :]
|
logits = logits[0, -1, :]
|
||||||
|
warm_uped = True
|
||||||
|
|
||||||
return self.logits_to_token(logits)
|
return self.logits_to_token(logits)
|
||||||
|
|
||||||
@ -176,6 +182,7 @@ class KTransformersInterface(TransformersInterface):
|
|||||||
if not (type(self) is TransformersInterface):
|
if not (type(self) is TransformersInterface):
|
||||||
input_ids = input_ids.to("cpu")
|
input_ids = input_ids.to("cpu")
|
||||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||||
|
torch.cuda.set_device(device)
|
||||||
if self.use_static_cache:
|
if self.use_static_cache:
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
|||||||
@ -106,9 +106,6 @@ def custom_openapi(app):
|
|||||||
def main():
|
def main():
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
# Temporarily disable cuda graph by default because of a bug in the prefix cache.
|
|
||||||
cfg.use_cuda_graph = False
|
|
||||||
|
|
||||||
arg_parser = ArgumentParser(cfg)
|
arg_parser = ArgumentParser(cfg)
|
||||||
|
|
||||||
# 初始化消息
|
# 初始化消息
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user