|
@@ -276,7 +276,7 @@ class LlamaForCausalLM(nn.Module):
|
|
|
|
|
|
|
|
|
param = state_dict[name]
|
|
|
- load_tensor_parallel_weights(param, loaded_weights, name,
|
|
|
+ load_tensor_parallel_weights(param, loaded_weight, name,
|
|
|
self._column_parallel_weights,
|
|
|
self._row_parallel_weights,
|
|
|
tensor_model_parallel_rank)
|