소스 검색

fix: assertion error in gpt-j

AlpinDale 1 년 전
부모
커밋
d5abc36994
1개의 변경된 파일1개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      aphrodite/modeling/models/gpt_j.py

+ 1 - 1
aphrodite/modeling/models/gpt_j.py

@@ -252,7 +252,7 @@ class GPTJForCausalLM(nn.Module):
                     continue
                 # pylint: disable=unsubscriptable-object
                 param = state_dict[name.replace(att_weight_name, "qkv_proj")]
-                shard_size = param.shape[1]
+                shard_size = param.shape[0] // 3
                 loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
                                               (tp_rank + 1)]
                 param_slice = param.data[shard_size * stride_id:shard_size *