Răsfoiți Sursa

feat: adapt modeling_llama.py

AlpinDale 1 an în urmă
părinte
comite
42682fdaf9
1 a modificat fișierele cu 189 adăugiri și 2 ștergeri
  1. 189 2
      aphrodite/modeling/models/llama.py

+ 189 - 2
aphrodite/modeling/models/llama.py

@@ -1,6 +1,8 @@
 # coding=utf-8
 # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#  Copyright 2023 The PygmalionAI team. 
+#  Copyright 2023 The vLLM team.
+#  Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
 # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
 # and OPT implementations in this library. It has been modified from its
@@ -92,4 +94,189 @@ class LlamaAttention(nn.Module):
             bias=False,
             input_is_parallel=True,
             perform_initialization=False,
-        )
+        )
+        self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.scaling, rotary_dim=self.head_dim)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+        cache_event: Optional[torch.cuda.Event],
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(
+            positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class LlamaDecoderLayer(nn.Module):
+
+    def __init__(self, config: LlamaConfig):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.self_attn = LlamaAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+        )
+        self.mlp = LlamaMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+        )
+        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+        cache_event: Optional[torch.cuda.Event],
+    ) -> torch.Tensor:
+        # Self Attention
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+            cache_event=cache_event,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+        return hidden_states
+
+class LlamaModel(nn.Module):
+
+    def __init__(self, config: LlamaConfig):
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, perform_initialization=False)
+        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+        cache_events: Optional[List[torch.cuda.Event]],
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        for i in range(len(self.layers)):
+            if cache_events is None:
+                cache_event = None
+            else:
+                cache_event = cache_events[i]
+            layer = self.layers[i]
+            hidden_states = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+                cache_event,
+            )
+        hidden_states = self.norm(hidden_states)
+        return hidden_states
+
+
+class LlamaForCausalLM(nn.Module):
+    
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.model = LlamaModel(config)
+        self.lm_head = ColumnParallelLinear(config.hidden_size,
+                                            config.vocab_size,
+                                            bias=False,
+                                            gather_output=False,
+                                            perform_initialization=False)
+        self.sampler = Sampler(config.vocab_size)
+
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+        cache_events: Optional[List[torch.cuda.Event]],
+    ) -> Dict[int, SequenceOutputs]:
+        hidden_states = self.model(
+            input_ids, positions, kv_caches, input_metadata, cache_events)
+        next_tokens = self.sampler(
+            self.lm_head.weight, hidden_states, input_metadata)
+        return next_tokens
+    
+    _column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
+                                "qkv_proj.weight", "gate_proj.weight",
+                                "up_proj.weight"]
+    _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
+
+    def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False):
+        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
+        state_dict = self.state_dict()
+
+        for name, loaded_weight in hf_model_weights_iterator(
+            model_name_or_path, cache_dir, use_np_cache):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            
+            is_attention_weight = False
+            for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
+                if att_weight_name not in name:
+                    continue
+                param = state_dict[name.replace(att_weight_name, "qkv_proj")]
+                shard_size = param.shape[0] // 3
+                loaded_weight = loaded_weight[
+                    shard_size * tensor_model_parallel_rank
+                    :shard_size * (tensor_model_parallel_rank + 1)]
+                param_slice = param.data[shard_size * stride_id
+                                        :shard_size * (stride_id + 1)]
+                assert param_slice.shape == loaded_weight.shape
+                param_slice.copy_(loaded_weight)
+                is_attention_weight = True
+                break
+            if is_attention_weight:
+                continue
+
+            is_gate_up_weight = False
+            for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
+                if weight_name not in name:
+                    continue
+                param = state_dict[name.replace(weight_name, "gate_up_proj")]
+                shard_size = param.shape[0] // 2
+                loaded_weight = loaded_weight[
+                    shard_size * tensor_model_parallel_rank
+                    :shard_size * (tensor_model_parallel_rank + 1)]
+                param_slice = param.data[shard_size * stride_id
+                                        :shard_size * (stride_id + 1)]
+                assert param_slice.shape == loaded_weight.shape
+                param_slice.copy_(loaded_weight)
+                is_gate_up_weight = True
+                break
+            if is_gate_up_weight:
+                continue
+
+
+            param = state_dict[name]
+            load_tensor_parallel_weights(param, loaded_weights, name,
+                                        self._column_parallel_weights,
+                                        self._row_parallel_weights,
+                                        tensor_model_parallel_rank)