|
@@ -252,7 +252,7 @@ class GPTJForCausalLM(nn.Module):
|
|
|
continue
|
|
|
# pylint: disable=unsubscriptable-object
|
|
|
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
|
|
- shard_size = param.shape[1]
|
|
|
+ shard_size = param.shape[0] // 3
|
|
|
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
|
|
(tp_rank + 1)]
|
|
|
param_slice = param.data[shard_size * stride_id:shard_size *
|