|
@@ -268,8 +268,7 @@ class LlamaForCausalLM(nn.Module):
|
|
|
def load_weights(self,
|
|
|
model_name_or_path: str,
|
|
|
cache_dir: Optional[str] = None,
|
|
|
- load_format: str = "auto",
|
|
|
- revision: Optional[str] = None):
|
|
|
+ load_format: str = "auto"):
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
|
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
|
@@ -286,7 +285,7 @@ class LlamaForCausalLM(nn.Module):
|
|
|
state_dict = self.state_dict()
|
|
|
|
|
|
for name, loaded_weight in hf_model_weights_iterator(
|
|
|
- model_name_or_path, cache_dir, load_format, revision):
|
|
|
+ model_name_or_path, cache_dir, load_format):
|
|
|
if "rotary_emb.inv_freq" in name:
|
|
|
continue
|
|
|
|
|
@@ -336,4 +335,4 @@ class LlamaForCausalLM(nn.Module):
|
|
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
|
|
self._column_parallel_weights,
|
|
|
self._row_parallel_weights,
|
|
|
- tensor_model_parallel_rank)
|
|
|
+ tensor_model_parallel_rank)
|