|
@@ -93,6 +93,7 @@ class LlamaAttention(nn.Module):
|
|
|
num_heads: int,
|
|
|
num_kv_heads: int,
|
|
|
rope_theta: float = 10000,
|
|
|
+ rope_scaling: Optional[Dict[str, Any]] = None,
|
|
|
max_position_embeddings: int = 8192,
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
|
) -> None:
|
|
@@ -110,6 +111,7 @@ class LlamaAttention(nn.Module):
|
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
|
self.scaling = self.head_dim**-0.5
|
|
|
self.rope_theta = rope_theta
|
|
|
+ self.rope_scaling = rope_scaling
|
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
|
|
self.qkv_proj = ParallelLinear.column(
|
|
@@ -134,6 +136,7 @@ class LlamaAttention(nn.Module):
|
|
|
self.head_dim,
|
|
|
self.scaling,
|
|
|
base=self.rope_theta,
|
|
|
+ rope_scaling = self.rope_scaling,
|
|
|
max_position=self.max_position_embeddings,
|
|
|
rotary_dim=self.head_dim,
|
|
|
num_kv_heads=self.num_kv_heads)
|
|
@@ -166,6 +169,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
|
self.hidden_size = config.hidden_size
|
|
|
# Requires transformers > 4.32.0
|
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
|
+ rope_scaling = getattr(config, "rope_scaling", None)
|
|
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
|
|
8192)
|
|
|
self.self_attn = LlamaAttention(
|
|
@@ -173,6 +177,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
|
num_heads=config.num_attention_heads,
|
|
|
num_kv_heads=config.num_key_value_heads,
|
|
|
rope_theta=rope_theta,
|
|
|
+ rope_scaling=rope_scaling,
|
|
|
max_position_embeddings=max_position_embeddings,
|
|
|
quant_config=quant_config,
|
|
|
)
|