Browse Source

fix: assertion error in gpt-j

AlpinDale 1 year ago
parent
commit
d5abc36994
1 changed files with 1 additions and 1 deletions
  1. 1 1
      aphrodite/modeling/models/gpt_j.py

+ 1 - 1
aphrodite/modeling/models/gpt_j.py

@@ -252,7 +252,7 @@ class GPTJForCausalLM(nn.Module):
                     continue
                 # pylint: disable=unsubscriptable-object
                 param = state_dict[name.replace(att_weight_name, "qkv_proj")]
-                shard_size = param.shape[1]
+                shard_size = param.shape[0] // 3
                 loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
                                               (tp_rank + 1)]
                 param_slice = param.data[shard_size * stride_id:shard_size *