1
0

olmo.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
  4. # https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
  5. # Copyright 2023 The PygmalionAI team.
  6. # Copyright 2023 The vLLM team.
  7. # Copyright (c) Microsoft Corporation.
  8. # Licensed under the MIT license.
  9. #
  10. # BSD 3-Clause License
  11. #
  12. # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
  13. # All rights reserved.
  14. #
  15. # Redistribution and use in source and binary forms, with or without
  16. # modification, are permitted provided that the following conditions are met:
  17. #
  18. # * Redistributions of source code must retain the above copyright notice, this
  19. # list of conditions and the following disclaimer.
  20. #
  21. # * Redistributions in binary form must reproduce the above copyright notice,
  22. # this list of conditions and the following disclaimer in the documentation
  23. # and/or other materials provided with the distribution.
  24. #
  25. # * Neither the name of the copyright holder nor the names of its
  26. # contributors may be used to endorse or promote products derived from
  27. # this software without specific prior written permission.
  28. #
  29. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  30. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  31. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  32. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  33. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  34. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  35. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  36. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  37. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  38. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  39. """Inference-only OLMo model compatible with HuggingFace weights."""
  40. from typing import List, Optional, Tuple
  41. import torch
  42. import torch.nn.functional as F
  43. from torch import nn
  44. from aphrodite.attention import Attention, AttentionMetadata
  45. from aphrodite.modeling.layers.linear import (
  46. ColumnParallelLinear,
  47. LinearMethodBase,
  48. QKVParallelLinear,
  49. RowParallelLinear,
  50. )
  51. from aphrodite.modeling.layers.rotary_embedding import get_rope
  52. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  53. from aphrodite.modeling.layers.sampler import Sampler
  54. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  55. VocabParallelEmbedding,
  56. ParallelLMHead,
  57. )
  58. from aphrodite.distributed import (
  59. get_tensor_model_parallel_world_size, )
  60. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  61. from aphrodite.modeling.hf_downloader import (
  62. default_weight_loader,
  63. hf_model_weights_iterator,
  64. )
  65. from aphrodite.common.sequence import SamplerOutput
  66. from aphrodite.transformers_utils.configs.olmo import OLMoConfig
  67. class SwiGLU(nn.Module):
  68. def forward(self, x: torch.Tensor) -> torch.Tensor:
  69. x, gate = x.chunk(2, dim=-1)
  70. return F.silu(gate) * x
  71. @property
  72. def output_multiplier(self) -> float:
  73. return 0.5
  74. class OlmoAttention(nn.Module):
  75. """
  76. This is the attention block where the output is computed as
  77. ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
  78. (plus another skip connection).
  79. """
  80. def __init__(
  81. self,
  82. config: OLMoConfig,
  83. linear_method: Optional[LinearMethodBase] = None,
  84. ):
  85. super().__init__()
  86. self.config = config
  87. self.hidden_size = config.d_model
  88. assert config.d_model % config.n_heads == 0
  89. tensor_model_parallel_world_size = (
  90. get_tensor_model_parallel_world_size())
  91. self.total_num_heads = self.config.n_heads
  92. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  93. self.num_heads = (self.total_num_heads //
  94. tensor_model_parallel_world_size)
  95. self.head_dim = self.hidden_size // self.total_num_heads
  96. # Layer norms.
  97. self.attn_norm = nn.LayerNorm(config.d_model,
  98. elementwise_affine=False,
  99. bias=False)
  100. # Attention input projection. Projects x -> (q, k, v)
  101. self.att_proj = QKVParallelLinear(
  102. config.d_model,
  103. self.head_dim,
  104. self.total_num_heads,
  105. bias=config.include_bias,
  106. linear_method=linear_method,
  107. )
  108. # Rotary embeddings.
  109. if self.config.rope:
  110. rope_theta = getattr(config, "rope_theta", 10000)
  111. max_position_embeddings = getattr(config,
  112. "max_position_embeddings", 8192)
  113. self.rotary_emb = get_rope(
  114. self.head_dim,
  115. rotary_dim=self.head_dim,
  116. max_position=max_position_embeddings,
  117. base=rope_theta,
  118. )
  119. self.scaling = self.head_dim**-0.5
  120. self.attn = Attention(self.num_heads,
  121. self.head_dim,
  122. scale=self.scaling)
  123. # Attention output projection.
  124. self.attn_out = RowParallelLinear(
  125. config.d_model,
  126. config.d_model,
  127. bias=config.include_bias,
  128. linear_method=linear_method,
  129. )
  130. def forward(
  131. self,
  132. positions: torch.Tensor,
  133. hidden_states: torch.Tensor,
  134. kv_cache: torch.Tensor,
  135. attn_metadata: AttentionMetadata,
  136. ) -> torch.Tensor:
  137. hidden_states = self.attn_norm(hidden_states)
  138. qkv, _ = self.att_proj(hidden_states)
  139. q, k, v = qkv.chunk(chunks=3, dim=-1)
  140. if self.config.rope:
  141. q, k = self.rotary_emb(positions, q, k)
  142. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  143. output, _ = self.attn_out(attn_output)
  144. return output
  145. class OlmoMLP(nn.Module):
  146. """
  147. This is the MLP block where the output is computed as
  148. ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
  149. (plus another skip connection).
  150. """
  151. def __init__(
  152. self,
  153. config: OLMoConfig,
  154. linear_method: Optional[LinearMethodBase] = None,
  155. ):
  156. super().__init__()
  157. self.config = config
  158. self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
  159. is not None else config.mlp_ratio * config.d_model)
  160. # Layer norms.
  161. self.ff_norm = nn.LayerNorm(config.d_model,
  162. elementwise_affine=False,
  163. bias=False)
  164. # Feed-forward input projection.
  165. self.ff_proj = ColumnParallelLinear(
  166. config.d_model,
  167. self.hidden_size,
  168. bias=config.include_bias,
  169. linear_method=linear_method,
  170. )
  171. # Activation function.
  172. # self.act = SiluAndMul()
  173. # self.act.output_multiplier = 0.5
  174. self.act = SwiGLU()
  175. assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
  176. # Feed-forward output projection.
  177. self.ff_out = RowParallelLinear(
  178. int(self.act.output_multiplier * self.hidden_size),
  179. config.d_model,
  180. bias=config.include_bias,
  181. linear_method=linear_method,
  182. )
  183. def forward(
  184. self,
  185. x: torch.Tensor,
  186. ) -> torch.Tensor:
  187. # Add feed-forward projection.
  188. # shape: (batch_size, seq_len, d_model)
  189. og_x = x
  190. x = self.ff_norm(x)
  191. x, _ = self.ff_proj(x)
  192. x = self.act(x)
  193. x, _ = self.ff_out(x)
  194. x = og_x + x
  195. return x
  196. class OlmoBlock(nn.Module):
  197. """
  198. This is a typical transformer block where the output is computed as
  199. ``MLP(LN(x + Attention(LN(x))))``
  200. (plus another skip connection).
  201. """
  202. def __init__(
  203. self,
  204. config: OLMoConfig,
  205. linear_method: Optional[LinearMethodBase] = None,
  206. ):
  207. super().__init__()
  208. # Attention block.
  209. self.attn = OlmoAttention(config, linear_method)
  210. # MLP block.
  211. self.mlp = OlmoMLP(config, linear_method)
  212. def forward(
  213. self,
  214. positions: torch.Tensor,
  215. hidden_states: torch.Tensor,
  216. kv_cache: torch.Tensor,
  217. attn_metadata: AttentionMetadata,
  218. ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
  219. # Attention block.
  220. og_x = hidden_states
  221. x = self.attn(positions, hidden_states, kv_cache, attn_metadata)
  222. x = x + og_x
  223. # MLP block.
  224. hidden_states = self.mlp(x)
  225. return hidden_states
  226. class OlmoModel(nn.Module):
  227. def __init__(
  228. self,
  229. config: OLMoConfig,
  230. linear_method: Optional[LinearMethodBase] = None,
  231. ):
  232. super().__init__()
  233. self.config = config
  234. self.transformer = nn.ModuleDict(
  235. dict(
  236. wte=VocabParallelEmbedding(
  237. config.embedding_size or config.vocab_size,
  238. config.d_model,
  239. linear_method=linear_method,
  240. ),
  241. ln_f=nn.LayerNorm(config.d_model,
  242. elementwise_affine=False,
  243. bias=False),
  244. ff_out=ParallelLMHead(
  245. config.embedding_size or config.vocab_size,
  246. config.d_model,
  247. bias=config.include_bias,
  248. linear_method=linear_method,
  249. ),
  250. ))
  251. blocks = [
  252. OlmoBlock(config, linear_method) for i in range(config.n_layers)
  253. ]
  254. if self.config.block_group_size > 1:
  255. raise NotImplementedError("Block group size > 1 not supported yet")
  256. else:
  257. self.transformer.update({"blocks": nn.ModuleList(blocks)})
  258. def forward(
  259. self,
  260. input_ids: torch.Tensor,
  261. positions: torch.Tensor,
  262. kv_caches: List[torch.Tensor],
  263. attn_metadata: AttentionMetadata,
  264. ) -> torch.Tensor:
  265. """
  266. :param input_ids: A tensor of shape `(batch_size, seq_len)`.
  267. """
  268. # Get embeddings of input.
  269. # shape: (batch_size, seq_len, d_model)
  270. x = self.transformer.wte(input_ids) # type: ignore
  271. # Apply blocks one-by-one.
  272. for block_idx, block in enumerate(self.transformer.blocks):
  273. # shape: (batch_size, seq_len, d_model)
  274. x = block(
  275. positions,
  276. x,
  277. kv_caches[block_idx],
  278. attn_metadata,
  279. )
  280. # Apply final layer norm.
  281. # shape: (batch_size, seq_len or 1, d_model)
  282. x = self.transformer.ln_f(x) # type: ignore
  283. return x
  284. class OLMoForCausalLM(nn.Module):
  285. """
  286. Extremely barebones HF model wrapper.
  287. """
  288. def __init__(
  289. self,
  290. config: OLMoConfig,
  291. linear_method: Optional[LinearMethodBase] = None,
  292. ):
  293. super().__init__()
  294. self.config = config
  295. self.linear_method = linear_method
  296. self.model = OlmoModel(config, linear_method)
  297. self.logits_processor = LogitsProcessor(config.vocab_size)
  298. self.sampler = Sampler()
  299. def forward(
  300. self,
  301. input_ids: torch.Tensor,
  302. positions: torch.Tensor,
  303. kv_caches: List[torch.Tensor],
  304. attn_metadata: AttentionMetadata,
  305. ) -> torch.Tensor:
  306. hidden_states = self.model(
  307. input_ids=input_ids,
  308. positions=positions,
  309. kv_caches=kv_caches,
  310. attn_metadata=attn_metadata,
  311. )
  312. return hidden_states
  313. def compute_logits(self, hidden_states: torch.Tensor,
  314. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  315. logits = self.logits_processor(self.model.transformer.ff_out,
  316. hidden_states, sampling_metadata)
  317. return logits
  318. def sample(
  319. self,
  320. logits: torch.Tensor,
  321. sampling_metadata: SamplingMetadata,
  322. ) -> Optional[SamplerOutput]:
  323. next_tokens = self.sampler(logits, sampling_metadata)
  324. return next_tokens
  325. def load_weights(
  326. self,
  327. model_name_or_path: str,
  328. cache_dir: Optional[str] = None,
  329. load_format: str = "auto",
  330. revision: Optional[str] = None,
  331. ):
  332. params_dict = dict(self.named_parameters(remove_duplicate=False))
  333. for name, loaded_weight in hf_model_weights_iterator(
  334. model_name_or_path, cache_dir, load_format, revision):
  335. if "wte" in name and self.config.weight_tying:
  336. # Copy word embedding to lm_head
  337. head_name = name.replace("model.transformer.wte",
  338. "model.transformer.ff_out")
  339. if head_name in params_dict:
  340. lm_head_param = params_dict[head_name]
  341. weight_loader = getattr(lm_head_param, "weight_loader",
  342. default_weight_loader)
  343. weight_loader(lm_head_param, loaded_weight)
  344. # attention
  345. if ".att" in name:
  346. name = name.replace(".att", ".attn.att")
  347. # mlp
  348. if ".ff" in name and "transformer.ff_out" not in name:
  349. name = name.replace(".ff", ".mlp.ff")
  350. # there is no bias in olmo
  351. param = params_dict[name]
  352. weight_loader = getattr(param, "weight_loader",
  353. default_weight_loader)
  354. weight_loader(param, loaded_weight)