olmo.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
  4. # https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
  5. # Copyright 2023 The PygmalionAI team.
  6. # Copyright 2023 The vLLM team.
  7. # Copyright (c) Microsoft Corporation.
  8. # Licensed under the MIT license.
  9. #
  10. # BSD 3-Clause License
  11. #
  12. # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
  13. # All rights reserved.
  14. #
  15. # Redistribution and use in source and binary forms, with or without
  16. # modification, are permitted provided that the following conditions are met:
  17. #
  18. # * Redistributions of source code must retain the above copyright notice, this
  19. # list of conditions and the following disclaimer.
  20. #
  21. # * Redistributions in binary form must reproduce the above copyright notice,
  22. # this list of conditions and the following disclaimer in the documentation
  23. # and/or other materials provided with the distribution.
  24. #
  25. # * Neither the name of the copyright holder nor the names of its
  26. # contributors may be used to endorse or promote products derived from
  27. # this software without specific prior written permission.
  28. #
  29. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  30. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  31. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  32. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  33. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  34. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  35. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  36. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  37. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  38. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  39. """Inference-only OLMo model compatible with HuggingFace weights."""
  40. from typing import List, Optional, Tuple
  41. import torch
  42. import torch.nn.functional as F
  43. from torch import nn
  44. from aphrodite.modeling.metadata import InputMetadata
  45. from aphrodite.modeling.layers.attention import PagedAttention
  46. from aphrodite.modeling.layers.linear import (
  47. ColumnParallelLinear,
  48. LinearMethodBase,
  49. QKVParallelLinear,
  50. RowParallelLinear,
  51. )
  52. from aphrodite.modeling.layers.rotary_embedding import get_rope
  53. from aphrodite.modeling.layers.sampler import Sampler
  54. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  55. VocabParallelEmbedding,
  56. ParallelLMHead,
  57. )
  58. from aphrodite.modeling.megatron.parallel_state import (
  59. get_tensor_model_parallel_world_size, )
  60. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  61. from aphrodite.modeling.hf_downloader import (
  62. default_weight_loader,
  63. hf_model_weights_iterator,
  64. )
  65. from aphrodite.common.sequence import SamplerOutput
  66. from aphrodite.transformers_utils.configs.olmo import OLMoConfig
  67. KVCache = Tuple[torch.Tensor, torch.Tensor]
  68. class SwiGLU(nn.Module):
  69. def forward(self, x: torch.Tensor) -> torch.Tensor:
  70. x, gate = x.chunk(2, dim=-1)
  71. return F.silu(gate) * x
  72. @property
  73. def output_multiplier(self) -> float:
  74. return 0.5
  75. class OlmoAttention(nn.Module):
  76. """
  77. This is the attention block where the output is computed as
  78. ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
  79. (plus another skip connection).
  80. """
  81. def __init__(
  82. self,
  83. config: OLMoConfig,
  84. linear_method: Optional[LinearMethodBase] = None,
  85. ):
  86. super().__init__()
  87. self.config = config
  88. self.hidden_size = config.d_model
  89. assert config.d_model % config.n_heads == 0
  90. tensor_model_parallel_world_size = (
  91. get_tensor_model_parallel_world_size())
  92. self.total_num_heads = self.config.n_heads
  93. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  94. self.num_heads = (self.total_num_heads //
  95. tensor_model_parallel_world_size)
  96. self.head_dim = self.hidden_size // self.total_num_heads
  97. # Layer norms.
  98. self.attn_norm = nn.LayerNorm(config.d_model,
  99. elementwise_affine=False,
  100. bias=False)
  101. # Attention input projection. Projects x -> (q, k, v)
  102. self.att_proj = QKVParallelLinear(
  103. config.d_model,
  104. self.head_dim,
  105. self.total_num_heads,
  106. bias=config.include_bias,
  107. linear_method=linear_method,
  108. )
  109. # Rotary embeddings.
  110. if self.config.rope:
  111. rope_theta = getattr(config, "rope_theta", 10000)
  112. max_position_embeddings = getattr(config,
  113. "max_position_embeddings", 8192)
  114. self.rotary_emb = get_rope(
  115. self.head_dim,
  116. rotary_dim=self.head_dim,
  117. max_position=max_position_embeddings,
  118. base=rope_theta,
  119. )
  120. self.scaling = self.head_dim**-0.5
  121. self.attn = PagedAttention(self.num_heads,
  122. self.head_dim,
  123. scale=self.scaling)
  124. # Attention output projection.
  125. self.attn_out = RowParallelLinear(
  126. config.d_model,
  127. config.d_model,
  128. bias=config.include_bias,
  129. linear_method=linear_method,
  130. )
  131. def forward(
  132. self,
  133. positions: torch.Tensor,
  134. hidden_states: torch.Tensor,
  135. kv_cache: KVCache,
  136. input_metadata: InputMetadata,
  137. ) -> torch.Tensor:
  138. hidden_states = self.attn_norm(hidden_states)
  139. qkv, _ = self.att_proj(hidden_states)
  140. q, k, v = qkv.chunk(chunks=3, dim=-1)
  141. if self.config.rope:
  142. q, k = self.rotary_emb(positions, q, k)
  143. k_cache, v_cache = kv_cache
  144. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  145. output, _ = self.attn_out(attn_output)
  146. return output
  147. class OlmoMLP(nn.Module):
  148. """
  149. This is the MLP block where the output is computed as
  150. ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
  151. (plus another skip connection).
  152. """
  153. def __init__(
  154. self,
  155. config: OLMoConfig,
  156. linear_method: Optional[LinearMethodBase] = None,
  157. ):
  158. super().__init__()
  159. self.config = config
  160. self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
  161. is not None else config.mlp_ratio * config.d_model)
  162. # Layer norms.
  163. self.ff_norm = nn.LayerNorm(config.d_model,
  164. elementwise_affine=False,
  165. bias=False)
  166. # Feed-forward input projection.
  167. self.ff_proj = ColumnParallelLinear(
  168. config.d_model,
  169. self.hidden_size,
  170. bias=config.include_bias,
  171. linear_method=linear_method,
  172. )
  173. # Activation function.
  174. # self.act = SiluAndMul()
  175. # self.act.output_multiplier = 0.5
  176. self.act = SwiGLU()
  177. assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
  178. # Feed-forward output projection.
  179. self.ff_out = RowParallelLinear(
  180. int(self.act.output_multiplier * self.hidden_size),
  181. config.d_model,
  182. bias=config.include_bias,
  183. linear_method=linear_method,
  184. )
  185. def forward(
  186. self,
  187. x: torch.Tensor,
  188. ) -> torch.Tensor:
  189. # Add feed-forward projection.
  190. # shape: (batch_size, seq_len, d_model)
  191. og_x = x
  192. x = self.ff_norm(x)
  193. x, _ = self.ff_proj(x)
  194. x = self.act(x)
  195. x, _ = self.ff_out(x)
  196. x = og_x + x
  197. return x
  198. class OlmoBlock(nn.Module):
  199. """
  200. This is a typical transformer block where the output is computed as
  201. ``MLP(LN(x + Attention(LN(x))))``
  202. (plus another skip connection).
  203. """
  204. def __init__(
  205. self,
  206. config: OLMoConfig,
  207. linear_method: Optional[LinearMethodBase] = None,
  208. ):
  209. super().__init__()
  210. # Attention block.
  211. self.attn = OlmoAttention(config, linear_method)
  212. # MLP block.
  213. self.mlp = OlmoMLP(config, linear_method)
  214. def forward(
  215. self,
  216. positions: torch.Tensor,
  217. hidden_states: torch.Tensor,
  218. kv_cache: KVCache,
  219. input_metadata: InputMetadata,
  220. ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
  221. # Attention block.
  222. og_x = hidden_states
  223. x = self.attn(positions, hidden_states, kv_cache, input_metadata)
  224. x = x + og_x
  225. # MLP block.
  226. hidden_states = self.mlp(x)
  227. return hidden_states
  228. class OlmoModel(nn.Module):
  229. def __init__(
  230. self,
  231. config: OLMoConfig,
  232. linear_method: Optional[LinearMethodBase] = None,
  233. ):
  234. super().__init__()
  235. self.config = config
  236. self.transformer = nn.ModuleDict(
  237. dict(
  238. wte=VocabParallelEmbedding(
  239. config.embedding_size or config.vocab_size,
  240. config.d_model,
  241. linear_method=linear_method,
  242. ),
  243. ln_f=nn.LayerNorm(config.d_model,
  244. elementwise_affine=False,
  245. bias=False),
  246. ff_out=ParallelLMHead(
  247. config.embedding_size or config.vocab_size,
  248. config.d_model,
  249. bias=config.include_bias,
  250. linear_method=linear_method,
  251. ),
  252. ))
  253. blocks = [
  254. OlmoBlock(config, linear_method) for i in range(config.n_layers)
  255. ]
  256. if self.config.block_group_size > 1:
  257. raise NotImplementedError("Block group size > 1 not supported yet")
  258. else:
  259. self.transformer.update({"blocks": nn.ModuleList(blocks)})
  260. def forward(
  261. self,
  262. input_ids: torch.Tensor,
  263. positions: torch.Tensor,
  264. kv_caches: List[KVCache],
  265. input_metadata: InputMetadata,
  266. ) -> torch.Tensor:
  267. """
  268. :param input_ids: A tensor of shape `(batch_size, seq_len)`.
  269. """
  270. # Get embeddings of input.
  271. # shape: (batch_size, seq_len, d_model)
  272. x = self.transformer.wte(input_ids) # type: ignore
  273. # Apply blocks one-by-one.
  274. for block_idx, block in enumerate(self.transformer.blocks):
  275. # shape: (batch_size, seq_len, d_model)
  276. x = block(
  277. positions,
  278. x,
  279. kv_caches[block_idx],
  280. input_metadata,
  281. )
  282. # Apply final layer norm.
  283. # shape: (batch_size, seq_len or 1, d_model)
  284. x = self.transformer.ln_f(x) # type: ignore
  285. return x
  286. class OLMoForCausalLM(nn.Module):
  287. """
  288. Extremely barebones HF model wrapper.
  289. """
  290. def __init__(
  291. self,
  292. config: OLMoConfig,
  293. linear_method: Optional[LinearMethodBase] = None,
  294. ):
  295. super().__init__()
  296. self.config = config
  297. self.linear_method = linear_method
  298. self.model = OlmoModel(config, linear_method)
  299. self.sampler = Sampler(config.vocab_size)
  300. def forward(
  301. self,
  302. input_ids: torch.Tensor,
  303. positions: torch.Tensor,
  304. kv_caches: List[KVCache],
  305. input_metadata: InputMetadata,
  306. ) -> torch.Tensor:
  307. hidden_states = self.model(
  308. input_ids=input_ids,
  309. positions=positions,
  310. kv_caches=kv_caches,
  311. input_metadata=input_metadata,
  312. )
  313. return hidden_states
  314. def sample(
  315. self,
  316. hidden_states: torch.Tensor,
  317. sampling_metadata: SamplingMetadata,
  318. ) -> Optional[SamplerOutput]:
  319. next_tokens = self.sampler(
  320. self.model.transformer.ff_out(hidden_states), sampling_metadata)
  321. return next_tokens
  322. def load_weights(
  323. self,
  324. model_name_or_path: str,
  325. cache_dir: Optional[str] = None,
  326. load_format: str = "auto",
  327. revision: Optional[str] = None,
  328. ):
  329. params_dict = dict(self.named_parameters(remove_duplicate=False))
  330. for name, loaded_weight in hf_model_weights_iterator(
  331. model_name_or_path, cache_dir, load_format, revision):
  332. if "wte" in name and self.config.weight_tying:
  333. # Copy word embedding to lm_head
  334. head_name = name.replace("model.transformer.wte",
  335. "model.transformer.ff_out")
  336. if head_name in params_dict:
  337. lm_head_param = params_dict[head_name]
  338. weight_loader = getattr(lm_head_param, "weight_loader",
  339. default_weight_loader)
  340. weight_loader(lm_head_param, loaded_weight)
  341. # attention
  342. if ".att" in name:
  343. name = name.replace(".att", ".attn.att")
  344. # mlp
  345. if ".ff" in name and "transformer.ff_out" not in name:
  346. name = name.replace(".ff", ".mlp.ff")
  347. # there is no bias in olmo
  348. param = params_dict[name]
  349. weight_loader = getattr(param, "weight_loader",
  350. default_weight_loader)
  351. weight_loader(param, loaded_weight)