Browse Source

fix: yarn (#112)

AlpinDale 1 year ago
parent
commit
5175605f8d

+ 13 - 4
aphrodite/common/config.py

@@ -359,15 +359,24 @@ def _get_and_verify_max_len(
         if max_len_key is not None:
             derived_max_model_len = min(derived_max_model_len, max_len_key)
     if derived_max_model_len == float("inf"):
-        raise ValueError(
-            "The model's config.json must contain one of the following keys "
-            "to determine the original maximum length of the model: "
-            f"{possible_keys}")
+        if max_model_len is not None:
+            # If max_model_len is specified, we use it.
+            return max_model_len
+        default_max_len = 2048
+        logger.warning(
+            "The model's config.json does not contain any of the following "
+            "keys to determine the original maximum length of the model: "
+            f"{possible_keys}. Assuming the model's maximum length is "
+            f"{default_max_len}.")
+        derived_max_model_len = default_max_len
 
     rope_scaling = getattr(hf_config, "rope_scaling", None)
     if rope_scaling is not None:
         assert "factor" in rope_scaling
         scaling_factor = rope_scaling["factor"]
+        if rope_scaling["type"] == "yarn":
+            derived_max_model_len = rope_scaling[
+                "original_max_position_embeddings"]
         derived_max_model_len *= scaling_factor
 
     if max_model_len is None:

+ 3 - 3
aphrodite/modeling/layers/attention.py

@@ -339,9 +339,9 @@ class PagedAttentionWithRoPE(PagedAttention):
                     head_size, rotary_dim, max_position, base, is_neox_style,
                     scaling_factor)
             elif scaling_type == "yarn":
-                new_max_position = rope_scaling[
+                original_max_position = rope_scaling[
                     "original_max_position_embeddings"]
-                assert max_position == new_max_position * scaling_factor
+                assert max_position == original_max_position * scaling_factor
                 extra_kwargs = {
                     k: v
                     for k, v in rope_scaling.items()
@@ -349,7 +349,7 @@ class PagedAttentionWithRoPE(PagedAttention):
                              "beta_fast", "beta_slow")
                 }
                 self.rotary_emb = YaRNScalingRotaryEmbedding(
-                    head_size, rotary_dim, new_max_position, base,
+                    head_size, rotary_dim, original_max_position, base,
                     is_neox_style, scaling_factor, **extra_kwargs)
             else:
                 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

+ 6 - 4
aphrodite/modeling/layers/rotary_embedding.py

@@ -171,6 +171,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
         return cache
 
 
+# Inverse dim formula to find dim based on number of rotations
 def _yarn_find_correction_dim(num_rotations: int,
                               dim: int,
                               base: float = 10000,
@@ -180,6 +181,7 @@ def _yarn_find_correction_dim(num_rotations: int,
                                                               math.log(base))
 
 
+# Find dim range bounds based on rotations
 def _yarn_find_correction_range(low_rot: int,
                                 high_rot: int,
                                 dim: int,
@@ -197,7 +199,7 @@ def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
                            dtype: torch.dtype,
                            device: torch.device) -> torch.Tensor:
     if low == high:
-        high += 0.001
+        high += 0.001  # Prevent singularity
 
     linear_func = (torch.arange(dim, dtype=dtype, device=device) -
                    low) / (high - low)
@@ -233,9 +235,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
         self.attn_factor = attn_factor
         self.beta_fast = beta_fast
         self.beta_slow = beta_slow
+        # Get n-d magnitude scaling corrected for interpolation
         self.mscale = float(
-            _yarn_get_mscale(self.scaling_factor) * attn_factor
-        )  # get n-d magnitude scaling corrected for interpolation
+            _yarn_get_mscale(self.scaling_factor) * attn_factor)
         super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                          is_neox_style)
 
@@ -249,7 +251,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
         low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
                                                 self.rotary_dim, self.base,
                                                 self.max_position_embeddings)
-
+        # Get n-d rotational scaling corrected for extrapolation
         inv_freq_mask = (1 - _yarn_linear_ramp_mask(
             low, high, self.rotary_dim // 2, dtype=torch.float,
             device="cuda")) * self.extrapolation_factor

+ 6 - 6
kernels/activation_kernels.cu

@@ -16,8 +16,8 @@ __global__ void silu_and_mul_kernel(
   scalar_t* __restrict__ out,               // [..., d]
   const scalar_t* __restrict__ input,       // [..., 2, d]
   const int d) {
-  const int token_idx = blockIdx.x;
-  for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
+  const int64_t token_idx = blockIdx.x;
+  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
     const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
     const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
     out[token_idx * d + idx] = silu(x) * y;
@@ -30,7 +30,7 @@ void silu_and_mul(
   torch::Tensor& out,      // [..., d]
   torch::Tensor& input)    // [..., 2 * d]
 {
-  int num_tokens = input.numel() / input.size(-1);
+  int64_t num_tokens = input.numel() / input.size(-1);
   int d = input.size(-1) / 2;
 
   dim3 grid(num_tokens);
@@ -55,8 +55,8 @@ __global__ void activation_kernel(
   scalar_t* __restrict__ out,               // [..., d]
   const scalar_t* __restrict__ input,       // [..., d]
   const int d) {
-  const int token_idx = blockIdx.x;
-  for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
+  const int64_t token_idx = blockIdx.x;
+  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
     const scalar_t x = __ldg(&input[token_idx * d + idx]);
     out[token_idx * d + idx] = ACT_FN(x);
   }
@@ -67,7 +67,7 @@ __global__ void activation_kernel(
 // Launch element-wise activation kernel.
 #define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                                  \
   int d = input.size(-1);                                                                 \
-  int num_tokens = input.numel() / d;                                                     \
+  int64_t num_tokens = input.numel() / d;                                                     \
   dim3 grid(num_tokens);                                                                  \
   dim3 block(std::min(d, 1024));                                                          \
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \

+ 1 - 1
kernels/pos_encoding_kernels.cu

@@ -84,7 +84,7 @@ void rotary_embedding(
   int head_size,
   torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
   bool is_neox) {
-  int num_tokens = query.numel() / query.size(-1);
+  int64_t num_tokens = query.numel() / query.size(-1);
   int rot_dim = cos_sin_cache.size(1);
   int num_heads = query.size(-1) / head_size;
   int num_kv_heads = key.size(-1) / head_size;