Jelajahi Sumber

fix: gpt-j loading

AlpinDale 1 tahun lalu
induk
melakukan
6dfca19dda
2 mengubah file dengan 7 tambahan dan 2 penghapusan
  1. 1 0
      .gitignore
  2. 6 2
      aphrodite/modeling/models/gpt_j.py

+ 1 - 0
.gitignore

@@ -7,4 +7,5 @@ repos
 *.so
 .conda
 build
+tests
 

+ 6 - 2
aphrodite/modeling/models/gpt_j.py

@@ -58,12 +58,12 @@ class GPTJAttention(nn.Module):
                                           bias=False,
                                           input_is_parallel=True,
                                           perform_initialization=False)
+
         tp_world_size = get_tensor_model_parallel_world_size()
         assert self.total_num_heads % tp_world_size == 0
         self.num_heads = self.total_num_heads // tp_world_size
 
         scaling = self.head_size**-0.5
-        assert config.rotary
         assert config.rotary_dim % 2 == 0
         self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
                                            scaling, config.rotary_dim)
@@ -106,6 +106,8 @@ class GPTJMLP(nn.Module):
         hidden_states = self.act(hidden_states)
         hidden_states, _ = self.fc_out(hidden_states)
         return hidden_states
+
+
 class GPTJBlock(nn.Module):
 
     def __init__(self, config: GPTJConfig):
@@ -138,6 +140,8 @@ class GPTJBlock(nn.Module):
         mlp_output = self.mlp(hidden_states)
         hidden_states = attn_output + mlp_output + residual
         return hidden_states
+
+
 class GPTJModel(nn.Module):
 
     def __init__(self, config: GPTJConfig):
@@ -242,4 +246,4 @@ class GPTJForCausalLM(nn.Module):
             param = state_dict[name]
             load_tensor_parallel_weights(param, loaded_weight, name,
                                          self._column_parallel_weights,
-                                         self._row_parallel_weights, tp_rank)
+                                         self._row_parallel_weights, tp_rank)