From cea07d19984f8a318a400567b9a6393e7f101ea0 Mon Sep 17 00:00:00 2001 From: Yuhao Tsui Date: Mon, 24 Feb 2025 10:09:42 +0800 Subject: [PATCH] Feat: Clear cache during weight loading to prevent OOM on GPUs with <=8GB VRAM This change explicitly clears CUDA cache during weight loading to mitigate memory fragmentation issues, particularly beneficial for low-VRAM GPUs. --- ktransformers/util/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 5c608b1..88559fb 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -70,6 +70,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) print(f"loading {translated_key} to {device}") + torch.cuda.empty_cache() # device = "cpu" if "embd" in translated_key else "cuda" weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype) set_param(module, name, weights)