olmo.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
  4. # Copyright 2024 The vLLM team.
  5. # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only OLMo model compatible with HuggingFace weights."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import OlmoConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.sequence import SamplerOutput
  30. from aphrodite.distributed import get_tensor_model_parallel_world_size
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  33. QKVParallelLinear,
  34. RowParallelLinear)
  35. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  36. from aphrodite.modeling.layers.rotary_embedding import get_rope
  37. from aphrodite.modeling.layers.sampler import Sampler
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  39. ParallelLMHead, VocabParallelEmbedding)
  40. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  41. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  42. from aphrodite.quantization.base_config import QuantizationConfig
  43. class OlmoAttention(nn.Module):
  44. """
  45. This is the attention block where the output is computed as
  46. ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
  47. (plus another skip connection).
  48. """
  49. def __init__(
  50. self,
  51. config: OlmoConfig,
  52. quant_config: Optional[QuantizationConfig] = None,
  53. ):
  54. super().__init__()
  55. self.config = config
  56. self.hidden_size = config.hidden_size
  57. tensor_model_parallel_world_size = (
  58. get_tensor_model_parallel_world_size())
  59. self.total_num_heads = config.num_attention_heads
  60. assert self.hidden_size % self.total_num_heads == 0
  61. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  62. self.num_heads = (self.total_num_heads //
  63. tensor_model_parallel_world_size)
  64. self.head_dim = self.hidden_size // self.total_num_heads
  65. self.max_position_embeddings = config.max_position_embeddings
  66. self.rope_theta = config.rope_theta
  67. self.clip_qkv = config.clip_qkv
  68. # Attention input projection. Projects x -> (q, k, v)
  69. self.qkv_proj = QKVParallelLinear(
  70. self.hidden_size,
  71. self.head_dim,
  72. self.total_num_heads,
  73. bias=config.attention_bias,
  74. quant_config=quant_config,
  75. )
  76. # Rotary embeddings.
  77. self.rotary_emb = get_rope(
  78. self.head_dim,
  79. rotary_dim=self.head_dim,
  80. max_position=self.max_position_embeddings,
  81. base=self.rope_theta,
  82. )
  83. self.scaling = self.head_dim**-0.5
  84. self.attn = Attention(self.num_heads,
  85. self.head_dim,
  86. scale=self.scaling)
  87. # Attention output projection.
  88. self.o_proj = RowParallelLinear(
  89. self.hidden_size,
  90. self.hidden_size,
  91. bias=config.attention_bias,
  92. quant_config=quant_config,
  93. )
  94. def forward(
  95. self,
  96. positions: torch.Tensor,
  97. hidden_states: torch.Tensor,
  98. kv_cache: torch.Tensor,
  99. attn_metadata: AttentionMetadata,
  100. ) -> torch.Tensor:
  101. qkv, _ = self.qkv_proj(hidden_states)
  102. if self.clip_qkv is not None:
  103. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  104. q, k, v = qkv.chunk(chunks=3, dim=-1)
  105. q, k = self.rotary_emb(positions, q, k)
  106. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  107. output, _ = self.o_proj(attn_output)
  108. return output
  109. class OlmoMLP(nn.Module):
  110. """
  111. This is the MLP block where the output is computed as
  112. ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
  113. (plus another skip connection).
  114. """
  115. def __init__(
  116. self,
  117. config: OlmoConfig,
  118. quant_config: Optional[QuantizationConfig] = None,
  119. ):
  120. super().__init__()
  121. self.config = config
  122. self.hidden_size = config.hidden_size
  123. self.intermediate_size = config.intermediate_size
  124. # Feed-forward input projection.
  125. self.gate_up_proj = MergedColumnParallelLinear(
  126. self.hidden_size,
  127. [self.intermediate_size] * 2,
  128. bias=False,
  129. quant_config=quant_config,
  130. )
  131. # Activation function.
  132. self.act_fn = SiluAndMul()
  133. # Feed-forward output projection.
  134. self.down_proj = RowParallelLinear(
  135. self.intermediate_size,
  136. self.hidden_size,
  137. bias=False,
  138. quant_config=quant_config,
  139. )
  140. def forward(
  141. self,
  142. x: torch.Tensor,
  143. ) -> torch.Tensor:
  144. gate_up, _ = self.gate_up_proj(x)
  145. x = self.act_fn(gate_up)
  146. x, _ = self.down_proj(x)
  147. return x
  148. class OlmoDecoderLayer(nn.Module):
  149. """
  150. This is a typical transformer block where the output is
  151. computed as ``MLP(LN(x + Attention(LN(x))))``
  152. (plus another skip connection).
  153. """
  154. def __init__(self,
  155. config: OlmoConfig,
  156. quant_config: Optional[QuantizationConfig] = None):
  157. super().__init__()
  158. # Attention block.
  159. self.self_attn = OlmoAttention(config, quant_config)
  160. # MLP block.
  161. self.mlp = OlmoMLP(config, quant_config)
  162. # LayerNorm
  163. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  164. elementwise_affine=False,
  165. bias=False)
  166. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  167. elementwise_affine=False,
  168. bias=False)
  169. def forward(
  170. self,
  171. positions: torch.Tensor,
  172. hidden_states: torch.Tensor,
  173. kv_cache: torch.Tensor,
  174. attn_metadata: AttentionMetadata,
  175. ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
  176. # Attention block.
  177. residual = hidden_states
  178. hidden_states = self.input_layernorm(hidden_states)
  179. hidden_states = self.self_attn(positions, hidden_states, kv_cache,
  180. attn_metadata)
  181. hidden_states = hidden_states + residual
  182. # MLP block.
  183. residual = hidden_states
  184. hidden_states = self.post_attention_layernorm(hidden_states)
  185. hidden_states = self.mlp(hidden_states)
  186. hidden_states = residual + hidden_states
  187. return hidden_states
  188. class OlmoModel(nn.Module):
  189. def __init__(self,
  190. config: OlmoConfig,
  191. quant_config: Optional[QuantizationConfig] = None):
  192. super().__init__()
  193. self.config = config
  194. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  195. config.hidden_size)
  196. self.layers = nn.ModuleList([
  197. OlmoDecoderLayer(config, quant_config)
  198. for layer_idx in range(config.num_hidden_layers)
  199. ])
  200. self.norm = nn.LayerNorm(config.hidden_size,
  201. elementwise_affine=False,
  202. bias=False)
  203. def forward(
  204. self,
  205. input_ids: torch.Tensor,
  206. positions: torch.Tensor,
  207. kv_caches: List[torch.Tensor],
  208. attn_metadata: AttentionMetadata,
  209. ) -> torch.Tensor:
  210. """
  211. :param input_ids: A tensor of shape `(batch_size, seq_len)`.
  212. """
  213. # Get embeddings of input.
  214. # shape: (batch_size, seq_len, d_model)
  215. inputs_embeds = self.embed_tokens(input_ids)
  216. # embed positions
  217. hidden_states = inputs_embeds
  218. # Apply blocks one-by-one.
  219. for layer_idx, decoder_layer in enumerate(self.layers):
  220. # shape: (batch_size, seq_len, d_model)
  221. hidden_states = decoder_layer(
  222. positions,
  223. hidden_states,
  224. kv_caches[layer_idx],
  225. attn_metadata,
  226. )
  227. # Apply final layer norm.
  228. # shape: (batch_size, seq_len or 1, d_model)
  229. hidden_states = self.norm(hidden_states)
  230. return hidden_states
  231. class OlmoForCausalLM(nn.Module):
  232. """
  233. Extremely barebones HF model wrapper.
  234. """
  235. def __init__(self,
  236. config: OlmoConfig,
  237. quant_config: Optional[QuantizationConfig] = None):
  238. super().__init__()
  239. self.config = config
  240. self.model = OlmoModel(config, quant_config)
  241. if config.tie_word_embeddings:
  242. self.lm_head_weight = self.model.embed_tokens.weight
  243. else:
  244. self.unpadded_vocab_size = config.vocab_size
  245. self.lm_head = ParallelLMHead(
  246. self.unpadded_vocab_size,
  247. config.hidden_size,
  248. org_num_embeddings=config.vocab_size,
  249. )
  250. self.lm_head_weight = self.lm_head.weight
  251. self.logits_processor = LogitsProcessor(config.vocab_size)
  252. self.sampler = Sampler()
  253. def forward(
  254. self,
  255. input_ids: torch.Tensor,
  256. positions: torch.Tensor,
  257. kv_caches: List[torch.Tensor],
  258. attn_metadata: AttentionMetadata,
  259. ) -> torch.Tensor:
  260. hidden_states = self.model(
  261. input_ids=input_ids,
  262. positions=positions,
  263. kv_caches=kv_caches,
  264. attn_metadata=attn_metadata,
  265. )
  266. return hidden_states
  267. def compute_logits(self, hidden_states: torch.Tensor,
  268. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  269. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  270. sampling_metadata)
  271. return logits
  272. def sample(
  273. self,
  274. logits: torch.Tensor,
  275. sampling_metadata: SamplingMetadata,
  276. ) -> Optional[SamplerOutput]:
  277. next_tokens = self.sampler(logits, sampling_metadata)
  278. return next_tokens
  279. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  280. stacked_params_mapping = [
  281. # (param_name, shard_name, shard_id)
  282. ("qkv_proj", "q_proj", "q"),
  283. ("qkv_proj", "k_proj", "k"),
  284. ("qkv_proj", "v_proj", "v"),
  285. ("gate_up_proj", "gate_proj", 0),
  286. ("gate_up_proj", "up_proj", 1),
  287. ]
  288. params_dict = dict(self.named_parameters(remove_duplicate=False))
  289. for name, loaded_weight in weights:
  290. if "rotary_emb.inv_freq" in name:
  291. continue
  292. if ("rotary_emb.cos_cached" in name
  293. or "rotary_emb.sin_cached" in name):
  294. # Models trained using ColossalAI may include these tensors in
  295. # the checkpoint. Skip them.
  296. continue
  297. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  298. if weight_name not in name:
  299. continue
  300. name = name.replace(weight_name, param_name)
  301. # Skip loading extra bias for GPTQ models.
  302. if name.endswith(".bias") and name not in params_dict:
  303. continue
  304. param = params_dict[name]
  305. weight_loader = param.weight_loader
  306. weight_loader(param, loaded_weight, shard_id)
  307. break
  308. else:
  309. # Skip loading extra bias for GPTQ models.
  310. if name.endswith(".bias") and name not in params_dict:
  311. continue
  312. param = params_dict[name]
  313. weight_loader = getattr(param, "weight_loader",
  314. default_weight_loader)
  315. weight_loader(param, loaded_weight)