llama_embedding.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from typing import Iterable, List, Optional, Tuple
  2. import torch
  3. from torch import nn
  4. from aphrodite.attention import AttentionMetadata
  5. from aphrodite.common.sequence import PoolerOutput
  6. from aphrodite.common.utils import progress_bar
  7. from aphrodite.modeling.layers.pooler import Pooler, PoolingType
  8. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  9. from aphrodite.modeling.models.llama import LlamaModel
  10. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  11. class LlamaEmbeddingModel(nn.Module):
  12. """A model that uses Llama with additional embedding functionalities.
  13. This class encapsulates the LlamaModel and provides an interface for
  14. embedding operations and customized pooling functions.
  15. Attributes:
  16. model: An instance of LlamaModel used for forward operations.
  17. _pooler: An instance of Pooler used for pooling operations.
  18. """
  19. def __init__(
  20. self,
  21. **kwargs,
  22. ) -> None:
  23. super().__init__()
  24. self.model = LlamaModel(**kwargs)
  25. self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
  26. def forward(
  27. self,
  28. input_ids: Optional[torch.Tensor],
  29. positions: torch.Tensor,
  30. kv_caches: List[torch.Tensor],
  31. attn_metadata: AttentionMetadata,
  32. inputs_embeds: Optional[torch.Tensor] = None,
  33. ) -> torch.Tensor:
  34. return self.model.forward(input_ids, positions, kv_caches,
  35. attn_metadata, inputs_embeds)
  36. def pooler(
  37. self,
  38. hidden_states: torch.Tensor,
  39. pooling_metadata: PoolingMetadata,
  40. ) -> Optional[PoolerOutput]:
  41. return self._pooler(hidden_states, pooling_metadata)
  42. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  43. stacked_params_mapping = [
  44. # (param_name, shard_name, shard_id)
  45. ("qkv_proj", "q_proj", "q"),
  46. ("qkv_proj", "k_proj", "k"),
  47. ("qkv_proj", "v_proj", "v"),
  48. ("gate_up_proj", "gate_proj", 0),
  49. ("gate_up_proj", "up_proj", 1),
  50. ]
  51. params_dict = dict(self.model.named_parameters())
  52. weights_list = list(weights)
  53. for name, loaded_weight in progress_bar(weights_list,
  54. desc="Loading modules..."):
  55. if "rotary_emb.inv_freq" in name:
  56. continue
  57. if ("rotary_emb.cos_cached" in name
  58. or "rotary_emb.sin_cached" in name):
  59. # Models trained using ColossalAI may include these tensors in
  60. # the checkpoint. Skip them.
  61. continue
  62. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  63. if weight_name not in name:
  64. continue
  65. name = name.replace(weight_name, param_name)
  66. # Skip loading extra bias for GPTQ models.
  67. if name.endswith(".bias") and name not in params_dict:
  68. continue
  69. param = params_dict[name]
  70. weight_loader = param.weight_loader
  71. weight_loader(param, loaded_weight, shard_id)
  72. break
  73. else:
  74. # Skip loading extra bias for GPTQ models.
  75. if name.endswith(".bias") and name not in params_dict:
  76. continue
  77. param = params_dict[name]
  78. weight_loader = getattr(param, "weight_loader",
  79. default_weight_loader)
  80. weight_loader(param, loaded_weight)