|
@@ -6,7 +6,7 @@ import os
|
|
|
import re
|
|
|
from collections import OrderedDict
|
|
|
from pathlib import Path
|
|
|
-from typing import Union
|
|
|
+from typing import Dict, List, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
@@ -17,8 +17,8 @@ from einops import rearrange
|
|
|
|
|
|
|
|
|
def remap_state_dict_meta_llama(
|
|
|
- state_dict: dict[str, torch.Tensor], config: GPT2Config
|
|
|
-) -> dict[str, torch.Tensor]:
|
|
|
+ state_dict: Dict[str, torch.Tensor], config: GPT2Config
|
|
|
+) -> Dict[str, torch.Tensor]:
|
|
|
"""Convert the state_dict in Meta format to standard GPT format.
|
|
|
|
|
|
This function modifies state_dict in place.
|
|
@@ -113,8 +113,8 @@ def remap_state_dict_meta_llama(
|
|
|
|
|
|
|
|
|
def remap_state_dict_hf_llama(
|
|
|
- state_dict: dict[str, torch.Tensor], config: GPT2Config
|
|
|
-) -> dict[str, torch.Tensor]:
|
|
|
+ state_dict: Dict[str, torch.Tensor], config: GPT2Config
|
|
|
+) -> Dict[str, torch.Tensor]:
|
|
|
"""Convert the state_dict in Hugging Face format to standard GPT format.
|
|
|
|
|
|
This function modifies state_dict in place.
|
|
@@ -217,8 +217,8 @@ def remap_state_dict_hf_llama(
|
|
|
|
|
|
|
|
|
def inv_remap_state_dict_hf_llama(
|
|
|
- state_dict: dict[str, torch.Tensor], config: GPT2Config
|
|
|
-) -> dict[str, torch.Tensor]:
|
|
|
+ state_dict: Dict[str, torch.Tensor], config: GPT2Config
|
|
|
+) -> Dict[str, torch.Tensor]:
|
|
|
"""Convert the state_dict in standard GPT format to Hugging Face format.
|
|
|
|
|
|
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
|
|
@@ -382,7 +382,7 @@ def config_from_checkpoint(
|
|
|
|
|
|
def state_dicts_from_checkpoint(
|
|
|
checkpoint_path: Union[str, os.PathLike], model_name: str
|
|
|
-) -> list[dict]:
|
|
|
+) -> List[dict]:
|
|
|
# Need to sort, otherwise we mess up the ordering and the weights are wrong
|
|
|
return [
|
|
|
torch.load(path, map_location="cpu")
|