mirror of
https://github.com/RYDE-WORK/ktransformers.git
synced 2026-01-24 15:33:29 +08:00
Feat: Clear cache during weight loading to prevent OOM on GPUs with <=8GB VRAM
This change explicitly clears CUDA cache during weight loading to mitigate memory fragmentation issues, particularly beneficial for low-VRAM GPUs.
This commit is contained in:
parent
eb039b723d
commit
cea07d1998
@ -70,6 +70,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
torch.cuda.empty_cache()
|
||||
# device = "cpu" if "embd" in translated_key else "cuda"
|
||||
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
||||
set_param(module, name, weights)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user