olmo.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
  4. # Copyright 2024 The vLLM team.
  5. # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only OLMo model compatible with HuggingFace weights."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import OlmoConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import progress_bar
  32. from aphrodite.distributed import get_tensor_model_parallel_world_size
  33. from aphrodite.modeling.layers.activation import SiluAndMul
  34. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear)
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. ParallelLMHead, VocabParallelEmbedding)
  42. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.quantization.base_config import QuantizationConfig
  45. class OlmoAttention(nn.Module):
  46. """
  47. This is the attention block where the output is computed as
  48. ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
  49. (plus another skip connection).
  50. """
  51. def __init__(
  52. self,
  53. config: OlmoConfig,
  54. cache_config: Optional[CacheConfig] = None,
  55. quant_config: Optional[QuantizationConfig] = None,
  56. ):
  57. super().__init__()
  58. self.config = config
  59. self.hidden_size = config.hidden_size
  60. tensor_model_parallel_world_size = (
  61. get_tensor_model_parallel_world_size())
  62. self.total_num_heads = config.num_attention_heads
  63. assert self.hidden_size % self.total_num_heads == 0
  64. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  65. self.num_heads = (self.total_num_heads //
  66. tensor_model_parallel_world_size)
  67. self.head_dim = self.hidden_size // self.total_num_heads
  68. self.max_position_embeddings = config.max_position_embeddings
  69. self.rope_theta = config.rope_theta
  70. self.clip_qkv = config.clip_qkv
  71. # Attention input projection. Projects x -> (q, k, v)
  72. self.qkv_proj = QKVParallelLinear(
  73. self.hidden_size,
  74. self.head_dim,
  75. self.total_num_heads,
  76. bias=config.attention_bias,
  77. quant_config=quant_config,
  78. )
  79. # Rotary embeddings.
  80. self.rotary_emb = get_rope(
  81. self.head_dim,
  82. rotary_dim=self.head_dim,
  83. max_position=self.max_position_embeddings,
  84. base=self.rope_theta,
  85. )
  86. self.scaling = self.head_dim**-0.5
  87. self.attn = Attention(self.num_heads,
  88. self.head_dim,
  89. scale=self.scaling,
  90. cache_config=cache_config,
  91. quant_config=quant_config)
  92. # Attention output projection.
  93. self.o_proj = RowParallelLinear(
  94. self.hidden_size,
  95. self.hidden_size,
  96. bias=config.attention_bias,
  97. quant_config=quant_config,
  98. )
  99. def forward(
  100. self,
  101. positions: torch.Tensor,
  102. hidden_states: torch.Tensor,
  103. kv_cache: torch.Tensor,
  104. attn_metadata: AttentionMetadata,
  105. ) -> torch.Tensor:
  106. qkv, _ = self.qkv_proj(hidden_states)
  107. if self.clip_qkv is not None:
  108. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  109. q, k, v = qkv.chunk(chunks=3, dim=-1)
  110. q, k = self.rotary_emb(positions, q, k)
  111. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  112. output, _ = self.o_proj(attn_output)
  113. return output
  114. class OlmoMLP(nn.Module):
  115. """
  116. This is the MLP block where the output is computed as
  117. ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
  118. (plus another skip connection).
  119. """
  120. def __init__(
  121. self,
  122. config: OlmoConfig,
  123. quant_config: Optional[QuantizationConfig] = None,
  124. ):
  125. super().__init__()
  126. self.config = config
  127. self.hidden_size = config.hidden_size
  128. self.intermediate_size = config.intermediate_size
  129. # Feed-forward input projection.
  130. self.gate_up_proj = MergedColumnParallelLinear(
  131. self.hidden_size,
  132. [self.intermediate_size] * 2,
  133. bias=False,
  134. quant_config=quant_config,
  135. )
  136. # Activation function.
  137. self.act_fn = SiluAndMul()
  138. # Feed-forward output projection.
  139. self.down_proj = RowParallelLinear(
  140. self.intermediate_size,
  141. self.hidden_size,
  142. bias=False,
  143. quant_config=quant_config,
  144. )
  145. def forward(
  146. self,
  147. x: torch.Tensor,
  148. ) -> torch.Tensor:
  149. gate_up, _ = self.gate_up_proj(x)
  150. x = self.act_fn(gate_up)
  151. x, _ = self.down_proj(x)
  152. return x
  153. class OlmoDecoderLayer(nn.Module):
  154. """
  155. This is a typical transformer block where the output is
  156. computed as ``MLP(LN(x + Attention(LN(x))))``
  157. (plus another skip connection).
  158. """
  159. def __init__(self,
  160. config: OlmoConfig,
  161. cache_config: Optional[CacheConfig] = None,
  162. quant_config: Optional[QuantizationConfig] = None):
  163. super().__init__()
  164. # Attention block.
  165. self.self_attn = OlmoAttention(config, cache_config, quant_config)
  166. # MLP block.
  167. self.mlp = OlmoMLP(config, quant_config)
  168. # LayerNorm
  169. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  170. elementwise_affine=False,
  171. bias=False)
  172. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  173. elementwise_affine=False,
  174. bias=False)
  175. def forward(
  176. self,
  177. positions: torch.Tensor,
  178. hidden_states: torch.Tensor,
  179. kv_cache: torch.Tensor,
  180. attn_metadata: AttentionMetadata,
  181. ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
  182. # Attention block.
  183. residual = hidden_states
  184. hidden_states = self.input_layernorm(hidden_states)
  185. hidden_states = self.self_attn(positions, hidden_states, kv_cache,
  186. attn_metadata)
  187. hidden_states = hidden_states + residual
  188. # MLP block.
  189. residual = hidden_states
  190. hidden_states = self.post_attention_layernorm(hidden_states)
  191. hidden_states = self.mlp(hidden_states)
  192. hidden_states = residual + hidden_states
  193. return hidden_states
  194. class OlmoModel(nn.Module):
  195. def __init__(self,
  196. config: OlmoConfig,
  197. cache_config: Optional[CacheConfig] = None,
  198. quant_config: Optional[QuantizationConfig] = None):
  199. super().__init__()
  200. self.config = config
  201. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  202. config.hidden_size)
  203. self.layers = nn.ModuleList([
  204. OlmoDecoderLayer(config, cache_config, quant_config)
  205. for layer_idx in range(config.num_hidden_layers)
  206. ])
  207. self.norm = nn.LayerNorm(config.hidden_size,
  208. elementwise_affine=False,
  209. bias=False)
  210. def forward(
  211. self,
  212. input_ids: torch.Tensor,
  213. positions: torch.Tensor,
  214. kv_caches: List[torch.Tensor],
  215. attn_metadata: AttentionMetadata,
  216. ) -> torch.Tensor:
  217. """
  218. :param input_ids: A tensor of shape `(batch_size, seq_len)`.
  219. """
  220. # Get embeddings of input.
  221. # shape: (batch_size, seq_len, d_model)
  222. inputs_embeds = self.embed_tokens(input_ids)
  223. # embed positions
  224. hidden_states = inputs_embeds
  225. # Apply blocks one-by-one.
  226. for layer_idx, decoder_layer in enumerate(self.layers):
  227. # shape: (batch_size, seq_len, d_model)
  228. hidden_states = decoder_layer(
  229. positions,
  230. hidden_states,
  231. kv_caches[layer_idx],
  232. attn_metadata,
  233. )
  234. # Apply final layer norm.
  235. # shape: (batch_size, seq_len or 1, d_model)
  236. hidden_states = self.norm(hidden_states)
  237. return hidden_states
  238. class OlmoForCausalLM(nn.Module):
  239. """
  240. Extremely barebones HF model wrapper.
  241. """
  242. def __init__(self,
  243. config: OlmoConfig,
  244. cache_config: Optional[CacheConfig] = None,
  245. quant_config: Optional[QuantizationConfig] = None):
  246. super().__init__()
  247. self.config = config
  248. self.model = OlmoModel(config, cache_config, quant_config)
  249. if config.tie_word_embeddings:
  250. self.lm_head = self.model.embed_tokens
  251. else:
  252. self.unpadded_vocab_size = config.vocab_size
  253. self.lm_head = ParallelLMHead(
  254. self.unpadded_vocab_size,
  255. config.hidden_size,
  256. org_num_embeddings=config.vocab_size,
  257. quant_config=quant_config,
  258. )
  259. self.logits_processor = LogitsProcessor(config.vocab_size)
  260. self.sampler = Sampler()
  261. def forward(
  262. self,
  263. input_ids: torch.Tensor,
  264. positions: torch.Tensor,
  265. kv_caches: List[torch.Tensor],
  266. attn_metadata: AttentionMetadata,
  267. intermediate_tensors: Optional[IntermediateTensors] = None,
  268. ) -> torch.Tensor:
  269. hidden_states = self.model(
  270. input_ids=input_ids,
  271. positions=positions,
  272. kv_caches=kv_caches,
  273. attn_metadata=attn_metadata,
  274. )
  275. return hidden_states
  276. def compute_logits(
  277. self,
  278. hidden_states: torch.Tensor,
  279. sampling_metadata: SamplingMetadata,
  280. ) -> Optional[torch.Tensor]:
  281. logits = self.logits_processor(self.lm_head, hidden_states,
  282. sampling_metadata)
  283. return logits
  284. def sample(
  285. self,
  286. logits: torch.Tensor,
  287. sampling_metadata: SamplingMetadata,
  288. ) -> Optional[SamplerOutput]:
  289. next_tokens = self.sampler(logits, sampling_metadata)
  290. return next_tokens
  291. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  292. stacked_params_mapping = [
  293. # (param_name, shard_name, shard_id)
  294. ("qkv_proj", "q_proj", "q"),
  295. ("qkv_proj", "k_proj", "k"),
  296. ("qkv_proj", "v_proj", "v"),
  297. ("gate_up_proj", "gate_proj", 0),
  298. ("gate_up_proj", "up_proj", 1),
  299. ]
  300. params_dict = dict(self.named_parameters(remove_duplicate=False))
  301. weights_list = list(weights)
  302. for name, loaded_weight in progress_bar(weights_list,
  303. desc="Loading modules..."):
  304. if "rotary_emb.inv_freq" in name:
  305. continue
  306. if ("rotary_emb.cos_cached" in name
  307. or "rotary_emb.sin_cached" in name):
  308. # Models trained using ColossalAI may include these tensors in
  309. # the checkpoint. Skip them.
  310. continue
  311. # With tie_word_embeddings, we can skip lm_head.weight
  312. # The weight might appear unnecessarily in the files if the model is
  313. # processed with quantization, LoRA, fine-tuning, etc.
  314. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  315. continue
  316. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  317. if weight_name not in name:
  318. continue
  319. name = name.replace(weight_name, param_name)
  320. # Skip loading extra bias for GPTQ models.
  321. if name.endswith(".bias") and name not in params_dict:
  322. continue
  323. param = params_dict[name]
  324. weight_loader = param.weight_loader
  325. weight_loader(param, loaded_weight, shard_id)
  326. break
  327. else:
  328. # Skip loading extra bias for GPTQ models.
  329. if name.endswith(".bias") and name not in params_dict:
  330. continue
  331. param = params_dict[name]
  332. weight_loader = getattr(param, "weight_loader",
  333. default_weight_loader)
  334. weight_loader(param, loaded_weight)