Browse Source

Fix E1136 (#563)

Yuchao Dai 1 year ago
parent
commit
187c2a0635
2 changed files with 10 additions and 9 deletions
  1. 2 1
      flash_attn/models/gpt.py
  2. 8 8
      flash_attn/models/llama.py

+ 2 - 1
flash_attn/models/gpt.py

@@ -6,6 +6,7 @@ import re
 from collections import OrderedDict, namedtuple
 from collections.abc import Sequence
 from functools import partial
+from typing import Dict, List
 
 import torch
 import torch.nn as nn
@@ -810,7 +811,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
     return state_dict
 
 
-def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config):
+def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config):
     """Convert the list of sharded state_dict of a GPT model with tensor parallel to
     the state_dict of a standard GPT model.
 

+ 8 - 8
flash_attn/models/llama.py

@@ -6,7 +6,7 @@ import os
 import re
 from collections import OrderedDict
 from pathlib import Path
-from typing import Union
+from typing import Dict, List, Union
 
 import torch
 import torch.nn.functional as F
@@ -17,8 +17,8 @@ from einops import rearrange
 
 
 def remap_state_dict_meta_llama(
-    state_dict: dict[str, torch.Tensor], config: GPT2Config
-) -> dict[str, torch.Tensor]:
+    state_dict: Dict[str, torch.Tensor], config: GPT2Config
+) -> Dict[str, torch.Tensor]:
     """Convert the state_dict in Meta format to standard GPT format.
 
     This function modifies state_dict in place.
@@ -113,8 +113,8 @@ def remap_state_dict_meta_llama(
 
 
 def remap_state_dict_hf_llama(
-    state_dict: dict[str, torch.Tensor], config: GPT2Config
-) -> dict[str, torch.Tensor]:
+    state_dict: Dict[str, torch.Tensor], config: GPT2Config
+) -> Dict[str, torch.Tensor]:
     """Convert the state_dict in Hugging Face format to standard GPT format.
 
     This function modifies state_dict in place.
@@ -217,8 +217,8 @@ def remap_state_dict_hf_llama(
 
 
 def inv_remap_state_dict_hf_llama(
-    state_dict: dict[str, torch.Tensor], config: GPT2Config
-) -> dict[str, torch.Tensor]:
+    state_dict: Dict[str, torch.Tensor], config: GPT2Config
+) -> Dict[str, torch.Tensor]:
     """Convert the state_dict in standard GPT format to Hugging Face format.
 
     This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
@@ -382,7 +382,7 @@ def config_from_checkpoint(
 
 def state_dicts_from_checkpoint(
     checkpoint_path: Union[str, os.PathLike], model_name: str
-) -> list[dict]:
+) -> List[dict]:
     # Need to sort, otherwise we mess up the ordering and the weights are wrong
     return [
         torch.load(path, map_location="cpu")