Просмотр исходного кода

fix: specify device when loading lora and embedding tensors

AlpinDale 6 месяцев назад
Родитель
Сommit
2a349ca3e1
1 измененных файлов с 3 добавлено и 2 удалено
  1. 3 2
      aphrodite/lora/models.py

+ 3 - 2
aphrodite/lora/models.py

@@ -248,7 +248,7 @@ class LoRAModel(AdapterModel):
                     f" target modules in {expected_lora_modules}"
                     f" but received {unexpected_modules}."
                     f" Please verify that the loaded LoRA module is correct")
-            tensors = torch.load(lora_bin_file_path)
+            tensors = torch.load(lora_bin_file_path, map_location=device)
         else:
             raise ValueError(f"{lora_dir} doesn't contain tensors")
 
@@ -257,7 +257,8 @@ class LoRAModel(AdapterModel):
             embeddings = safetensors.torch.load_file(
                 new_embeddings_tensor_path)
         elif os.path.isfile(new_embeddings_bin_file_path):
-            embeddings = torch.load(new_embeddings_bin_file_path)
+            embeddings = torch.load(new_embeddings_bin_file_path,
+                                    map_location=device)
 
         rank = config["r"]
         lora_alpha = config["lora_alpha"]