|
@@ -58,12 +58,12 @@ class GPTJAttention(nn.Module):
|
|
|
bias=False,
|
|
|
input_is_parallel=True,
|
|
|
perform_initialization=False)
|
|
|
+
|
|
|
tp_world_size = get_tensor_model_parallel_world_size()
|
|
|
assert self.total_num_heads % tp_world_size == 0
|
|
|
self.num_heads = self.total_num_heads // tp_world_size
|
|
|
|
|
|
scaling = self.head_size**-0.5
|
|
|
- assert config.rotary
|
|
|
assert config.rotary_dim % 2 == 0
|
|
|
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
|
|
|
scaling, config.rotary_dim)
|
|
@@ -106,6 +106,8 @@ class GPTJMLP(nn.Module):
|
|
|
hidden_states = self.act(hidden_states)
|
|
|
hidden_states, _ = self.fc_out(hidden_states)
|
|
|
return hidden_states
|
|
|
+
|
|
|
+
|
|
|
class GPTJBlock(nn.Module):
|
|
|
|
|
|
def __init__(self, config: GPTJConfig):
|
|
@@ -138,6 +140,8 @@ class GPTJBlock(nn.Module):
|
|
|
mlp_output = self.mlp(hidden_states)
|
|
|
hidden_states = attn_output + mlp_output + residual
|
|
|
return hidden_states
|
|
|
+
|
|
|
+
|
|
|
class GPTJModel(nn.Module):
|
|
|
|
|
|
def __init__(self, config: GPTJConfig):
|
|
@@ -242,4 +246,4 @@ class GPTJForCausalLM(nn.Module):
|
|
|
param = state_dict[name]
|
|
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
|
|
self._column_parallel_weights,
|
|
|
- self._row_parallel_weights, tp_rank)
|
|
|
+ self._row_parallel_weights, tp_rank)
|