Quellcode durchsuchen

fix: mistral nemo

AlpinDale vor 7 Monaten
Ursprung
Commit
639e48e47d
1 geänderte Dateien mit 4 neuen und 1 gelöschten Zeilen
  1. 4 1
      aphrodite/modeling/models/llama.py

+ 4 - 1
aphrodite/modeling/models/llama.py

@@ -83,6 +83,7 @@ class LlamaAttention(nn.Module):
 
     def __init__(
         self,
+        config: LlamaConfig,
         hidden_size: int,
         num_heads: int,
         num_kv_heads: int,
@@ -109,7 +110,8 @@ class LlamaAttention(nn.Module):
             # the KV heads across multiple tensor parallel GPUs.
             assert tp_size % self.total_num_kv_heads == 0
         self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
-        self.head_dim = hidden_size // self.total_num_heads
+        self.head_dim = getattr(config, "head_dim",
+                                hidden_size // self.total_num_heads)
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
@@ -192,6 +194,7 @@ class LlamaDecoderLayer(nn.Module):
         attention_bias = getattr(config, "attention_bias", False) or getattr(
             config, "bias", False)
         self.self_attn = LlamaAttention(
+            config=config,
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
             num_kv_heads=getattr(config, "num_key_value_heads",