1
0

llama_embedding.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from typing import Iterable, List, Optional, Tuple
  2. import torch
  3. from torch import nn
  4. from aphrodite.attention import AttentionMetadata
  5. from aphrodite.common.sequence import PoolerOutput
  6. from aphrodite.modeling.layers.pooler import Pooler, PoolingType
  7. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  8. from aphrodite.modeling.models.llama import LlamaModel
  9. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  10. class LlamaEmbeddingModel(nn.Module):
  11. """A model that uses Llama with additional embedding functionalities.
  12. This class encapsulates the LlamaModel and provides an interface for
  13. embedding operations and customized pooling functions.
  14. Attributes:
  15. model: An instance of LlamaModel used for forward operations.
  16. _pooler: An instance of Pooler used for pooling operations.
  17. """
  18. def __init__(
  19. self,
  20. **kwargs,
  21. ) -> None:
  22. super().__init__()
  23. self.model = LlamaModel(**kwargs)
  24. self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
  25. def forward(
  26. self,
  27. input_ids: Optional[torch.Tensor],
  28. positions: torch.Tensor,
  29. kv_caches: List[torch.Tensor],
  30. attn_metadata: AttentionMetadata,
  31. inputs_embeds: Optional[torch.Tensor] = None,
  32. ) -> torch.Tensor:
  33. return self.model.forward(input_ids, positions, kv_caches,
  34. attn_metadata, inputs_embeds)
  35. def pooler(
  36. self,
  37. hidden_states: torch.Tensor,
  38. pooling_metadata: PoolingMetadata,
  39. ) -> Optional[PoolerOutput]:
  40. return self._pooler(hidden_states, pooling_metadata)
  41. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  42. stacked_params_mapping = [
  43. # (param_name, shard_name, shard_id)
  44. ("qkv_proj", "q_proj", "q"),
  45. ("qkv_proj", "k_proj", "k"),
  46. ("qkv_proj", "v_proj", "v"),
  47. ("gate_up_proj", "gate_proj", 0),
  48. ("gate_up_proj", "up_proj", 1),
  49. ]
  50. params_dict = dict(self.model.named_parameters())
  51. for name, loaded_weight in weights:
  52. if "rotary_emb.inv_freq" in name:
  53. continue
  54. if ("rotary_emb.cos_cached" in name
  55. or "rotary_emb.sin_cached" in name):
  56. # Models trained using ColossalAI may include these tensors in
  57. # the checkpoint. Skip them.
  58. continue
  59. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  60. if weight_name not in name:
  61. continue
  62. name = name.replace(weight_name, param_name)
  63. # Skip loading extra bias for GPTQ models.
  64. if name.endswith(".bias") and name not in params_dict:
  65. continue
  66. param = params_dict[name]
  67. weight_loader = param.weight_loader
  68. weight_loader(param, loaded_weight, shard_id)
  69. break
  70. else:
  71. # Skip loading extra bias for GPTQ models.
  72. if name.endswith(".bias") and name not in params_dict:
  73. continue
  74. param = params_dict[name]
  75. weight_loader = getattr(param, "weight_loader",
  76. default_weight_loader)
  77. weight_loader(param, loaded_weight)