1
0

vocab_parallel_embedding.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. from dataclasses import dataclass
  2. from typing import List, Optional, Sequence, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  7. get_tensor_model_parallel_world_size,
  8. tensor_model_parallel_all_reduce)
  9. from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
  10. from aphrodite.modeling.utils import set_weight_attrs
  11. from aphrodite.quantization.base_config import (QuantizationConfig,
  12. QuantizeMethodBase)
  13. DEFAULT_VOCAB_PADDING_SIZE = 64
  14. def pad_vocab_size(vocab_size: int,
  15. pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
  16. """Pad the vocab size to the given value."""
  17. return ((vocab_size + pad_to - 1) // pad_to) * pad_to
  18. def vocab_range_from_per_partition_vocab_size(
  19. per_partition_vocab_size: int,
  20. rank: int,
  21. offset: int = 0) -> Sequence[int]:
  22. index_f = rank * per_partition_vocab_size
  23. index_l = index_f + per_partition_vocab_size
  24. return index_f + offset, index_l + offset
  25. def vocab_range_from_global_vocab_size(global_vocab_size: int,
  26. rank: int,
  27. world_size: int,
  28. offset: int = 0) -> Sequence[int]:
  29. per_partition_vocab_size = divide(global_vocab_size, world_size)
  30. return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
  31. rank,
  32. offset=offset)
  33. @dataclass
  34. class VocabParallelEmbeddingShardIndices:
  35. """Indices for a shard of a vocab parallel embedding."""
  36. padded_org_vocab_start_index: int
  37. padded_org_vocab_end_index: int
  38. padded_added_vocab_start_index: int
  39. padded_added_vocab_end_index: int
  40. org_vocab_start_index: int
  41. org_vocab_end_index: int
  42. added_vocab_start_index: int
  43. added_vocab_end_index: int
  44. @property
  45. def num_org_elements(self) -> int:
  46. return self.org_vocab_end_index - self.org_vocab_start_index
  47. @property
  48. def num_added_elements(self) -> int:
  49. return self.added_vocab_end_index - self.added_vocab_start_index
  50. @property
  51. def num_org_elements_padded(self) -> int:
  52. return (self.padded_org_vocab_end_index -
  53. self.padded_org_vocab_start_index)
  54. @property
  55. def num_added_elements_padded(self) -> int:
  56. return (self.padded_added_vocab_end_index -
  57. self.padded_added_vocab_start_index)
  58. @property
  59. def num_org_vocab_padding(self) -> int:
  60. return self.num_org_elements_padded - self.num_org_elements
  61. @property
  62. def num_added_vocab_padding(self) -> int:
  63. return self.num_added_elements_padded - self.num_added_elements
  64. @property
  65. def num_elements_padded(self) -> int:
  66. return self.num_org_elements_padded + self.num_added_elements_padded
  67. def __post_init__(self):
  68. # sanity checks
  69. assert (self.padded_org_vocab_start_index <=
  70. self.padded_org_vocab_end_index)
  71. assert (self.padded_added_vocab_start_index <=
  72. self.padded_added_vocab_end_index)
  73. assert self.org_vocab_start_index <= self.org_vocab_end_index
  74. assert self.added_vocab_start_index <= self.added_vocab_end_index
  75. assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
  76. assert (self.added_vocab_start_index <=
  77. self.padded_added_vocab_start_index)
  78. assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
  79. assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
  80. assert self.num_org_elements <= self.num_org_elements_padded
  81. assert self.num_added_elements <= self.num_added_elements_padded
  82. @torch.jit.script
  83. def get_masked_input_and_mask(
  84. input_: torch.Tensor, org_vocab_start_index: int,
  85. org_vocab_end_index: int, num_org_vocab_padding: int,
  86. added_vocab_start_index: int,
  87. added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
  88. # torch.jit.script will fuse all of the pointwise ops below
  89. # into a single kernel, making it very fast
  90. org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
  91. org_vocab_end_index)
  92. added_vocab_mask = (input_ >= added_vocab_start_index) & (
  93. input_ < added_vocab_end_index)
  94. added_offset = added_vocab_start_index - (
  95. org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
  96. valid_offset = (org_vocab_start_index *
  97. org_vocab_mask) + (added_offset * added_vocab_mask)
  98. vocab_mask = org_vocab_mask | added_vocab_mask
  99. input_ = vocab_mask * (input_ - valid_offset)
  100. return input_, ~vocab_mask
  101. class VocabParallelEmbedding(torch.nn.Module):
  102. """Embedding parallelized in the vocabulary dimension.
  103. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
  104. make sure it is divisible by the number of model parallel GPUs.
  105. In order to support various loading methods, we ensure that LoRA-added
  106. embeddings are always at the end of TP-sharded tensors. In other words,
  107. we shard base embeddings and LoRA embeddings separately (both padded),
  108. and place them in the same tensor.
  109. In this example, we will have the original vocab size = 1010,
  110. added vocab size = 16 and padding to 64. Therefore, the total
  111. vocab size with padding will be 1088 (because we first pad 1010 to
  112. 1024, add 16, and then pad to 1088).
  113. Therefore, the tensor format looks like the following:
  114. TP1, rank 0 (no sharding):
  115. |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
  116. corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
  117. index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
  118. TP2, rank 0:
  119. |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
  120. corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
  121. index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
  122. TP2, rank 1:
  123. |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
  124. corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
  125. index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
  126. Args:
  127. num_embeddings: vocabulary size.
  128. embedding_dim: size of hidden state.
  129. params_dtype: type of the parameters.
  130. org_num_embeddings: original vocabulary size (without LoRA).
  131. padding_size: padding size for the vocabulary.
  132. quant_config: quant config for the layer.
  133. prefix: full name of the layer in the state dict
  134. """ # noqa: E501
  135. def __init__(self,
  136. num_embeddings: int,
  137. embedding_dim: int,
  138. params_dtype: Optional[torch.dtype] = None,
  139. org_num_embeddings: Optional[int] = None,
  140. padding_size: Optional[int] = None,
  141. quant_config: Optional[QuantizationConfig] = None,
  142. prefix: str = ""):
  143. super().__init__()
  144. padding_size = padding_size or get_tensor_model_parallel_world_size()
  145. # Keep the input dimensions.
  146. tp_rank = get_tensor_model_parallel_rank()
  147. self.tp_size = get_tensor_model_parallel_world_size()
  148. self.num_embeddings = num_embeddings
  149. self.padding_size = padding_size
  150. self.org_vocab_size = org_num_embeddings or num_embeddings
  151. num_added_embeddings = num_embeddings - self.org_vocab_size
  152. self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
  153. self.padding_size)
  154. self.num_embeddings_padded = pad_vocab_size(
  155. self.org_vocab_size_padded + num_added_embeddings,
  156. self.padding_size)
  157. assert self.org_vocab_size_padded <= self.num_embeddings_padded
  158. self.shard_indices = self._get_indices(self.num_embeddings_padded,
  159. self.org_vocab_size_padded,
  160. self.num_embeddings,
  161. self.org_vocab_size, tp_rank,
  162. self.tp_size)
  163. self.embedding_dim = embedding_dim
  164. linear_method = None
  165. if quant_config is not None:
  166. linear_method = quant_config.get_quant_method(self, prefix=prefix)
  167. if linear_method is None:
  168. linear_method = UnquantizedLinearMethod()
  169. self.linear_method: QuantizeMethodBase = linear_method
  170. if params_dtype is None:
  171. params_dtype = torch.get_default_dtype()
  172. # Divide the weight matrix along the vocaburaly dimension.
  173. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
  174. self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
  175. self.tp_size)
  176. assert (self.shard_indices.num_elements_padded ==
  177. self.num_embeddings_per_partition)
  178. self.num_org_embeddings_per_partition = (
  179. self.shard_indices.org_vocab_end_index -
  180. self.shard_indices.org_vocab_start_index)
  181. self.num_added_embeddings_per_partition = (
  182. self.shard_indices.added_vocab_end_index -
  183. self.shard_indices.added_vocab_start_index)
  184. self.linear_method.create_weights(self,
  185. self.embedding_dim,
  186. [self.num_embeddings_per_partition],
  187. self.embedding_dim,
  188. self.num_embeddings_padded,
  189. params_dtype=params_dtype,
  190. weight_loader=self.weight_loader)
  191. @classmethod
  192. def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
  193. vocab_size: int, org_vocab_size: int, tp_rank: int,
  194. tp_size: int) -> VocabParallelEmbeddingShardIndices:
  195. """Get start and end indices for vocab parallel embedding, following the
  196. layout outlined in the class docstring, based on the given tp_rank and
  197. tp_size."""
  198. num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
  199. padded_org_vocab_start_index, padded_org_vocab_end_index = (
  200. vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
  201. tp_size))
  202. padded_added_vocab_start_index, padded_added_vocab_end_index = (
  203. vocab_range_from_global_vocab_size(num_added_embeddings_padded,
  204. tp_rank,
  205. tp_size,
  206. offset=org_vocab_size))
  207. # remove padding
  208. org_vocab_start_index = min(padded_org_vocab_start_index,
  209. org_vocab_size)
  210. org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
  211. added_vocab_start_index = min(padded_added_vocab_start_index,
  212. vocab_size)
  213. added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
  214. return VocabParallelEmbeddingShardIndices(
  215. padded_org_vocab_start_index, padded_org_vocab_end_index,
  216. padded_added_vocab_start_index, padded_added_vocab_end_index,
  217. org_vocab_start_index, org_vocab_end_index,
  218. added_vocab_start_index, added_vocab_end_index)
  219. def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
  220. """Get a mapping that can be used to reindex the gathered
  221. logits for sampling.
  222. During sampling, we gather logits from all ranks. The relationship
  223. of index->token_id will follow the same format as outlined in the class
  224. docstring. However, after the gather, we want to reindex the final
  225. logits tensor to map index->token_id one-to-one (the index is always
  226. equal the token_id it corresponds to). The indices returned by this
  227. method allow us to do that.
  228. """
  229. if self.tp_size < 2:
  230. return None
  231. base_embeddings: List[int] = []
  232. added_embeddings: List[int] = []
  233. padding: List[int] = []
  234. for tp_rank in range(self.tp_size):
  235. shard_indices = self._get_indices(self.num_embeddings_padded,
  236. self.org_vocab_size_padded,
  237. self.num_embeddings,
  238. self.org_vocab_size, tp_rank,
  239. self.tp_size)
  240. range_start = self.num_embeddings_per_partition * tp_rank
  241. range_end = self.num_embeddings_per_partition * (tp_rank + 1)
  242. base_embeddings.extend(
  243. range(range_start,
  244. range_start + shard_indices.num_org_elements))
  245. padding.extend(
  246. range(range_start + shard_indices.num_org_elements,
  247. range_start + shard_indices.num_org_elements_padded))
  248. added_embeddings.extend(
  249. range(
  250. range_start + shard_indices.num_org_elements_padded,
  251. range_start + shard_indices.num_org_elements_padded +
  252. shard_indices.num_added_elements))
  253. padding.extend(
  254. range(
  255. range_start + shard_indices.num_org_elements_padded +
  256. shard_indices.num_added_elements,
  257. range_start + shard_indices.num_org_elements_padded +
  258. shard_indices.num_added_elements_padded))
  259. assert (range_start + shard_indices.num_org_elements_padded +
  260. shard_indices.num_added_elements_padded == range_end)
  261. ret = base_embeddings + added_embeddings + padding
  262. assert len(ret) == self.num_embeddings_padded
  263. return ret
  264. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  265. output_dim = getattr(param, "output_dim", None)
  266. packed_dim = getattr(param, "packed_dim", None)
  267. # If parameter does not have output dim, then it should
  268. # be copied onto all gpus (e.g. g_idx for act_order gptq).
  269. if output_dim is None:
  270. assert param.data.shape == loaded_weight.shape
  271. param.data.copy_(loaded_weight)
  272. return
  273. # Shard indexes for loading the weight
  274. start_idx = self.shard_indices.org_vocab_start_index
  275. shard_size = self.shard_indices.org_vocab_end_index - start_idx
  276. # If param packed on the same dim we are sharding on, then
  277. # need to adjust offsets of loaded weight by pack_factor.
  278. if packed_dim is not None and packed_dim == output_dim:
  279. assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
  280. param.pack_factor)
  281. start_idx = start_idx // param.pack_factor
  282. shard_size = shard_size // param.pack_factor
  283. else:
  284. assert loaded_weight.shape[output_dim] == self.org_vocab_size
  285. # Copy the data.
  286. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
  287. param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
  288. param[loaded_weight.shape[0]:].data.fill_(0)
  289. def forward(self, input_):
  290. if self.tp_size > 1:
  291. # Build the mask.
  292. masked_input, input_mask = get_masked_input_and_mask(
  293. input_, self.shard_indices.org_vocab_start_index,
  294. self.shard_indices.org_vocab_end_index,
  295. self.shard_indices.num_org_vocab_padding,
  296. self.shard_indices.added_vocab_start_index,
  297. self.shard_indices.added_vocab_end_index)
  298. else:
  299. masked_input = input_
  300. # Get the embeddings.
  301. output_parallel = F.embedding(masked_input.long(), self.weight)
  302. # Mask the output embedding.
  303. if self.tp_size > 1:
  304. output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
  305. # Reduce across all the model parallel GPUs.
  306. output = tensor_model_parallel_all_reduce(output_parallel)
  307. return output
  308. def extra_repr(self) -> str:
  309. s = f"num_embeddings={self.num_embeddings_per_partition}"
  310. s += f", embedding_dim={self.embedding_dim}"
  311. s += f", org_vocab_size={self.org_vocab_size}"
  312. s += f', num_embeddings_padded={self.num_embeddings_padded}'
  313. s += f', tp_size={self.tp_size}'
  314. return s
  315. class ParallelLMHead(VocabParallelEmbedding):
  316. """Parallelized LM head.
  317. Output logits weight matrices used in the Sampler. The weight and bias
  318. tensors are padded to make sure they are divisible by the number of
  319. model parallel GPUs.
  320. Args:
  321. num_embeddings: vocabulary size.
  322. embedding_dim: size of hidden state.
  323. bias: whether to use bias.
  324. params_dtype: type of the parameters.
  325. org_num_embeddings: original vocabulary size (without LoRA).
  326. padding_size: padding size for the vocabulary.
  327. """
  328. def __init__(self,
  329. num_embeddings: int,
  330. embedding_dim: int,
  331. bias: bool = False,
  332. params_dtype: Optional[torch.dtype] = None,
  333. org_num_embeddings: Optional[int] = None,
  334. padding_size: Optional[int] = None,
  335. quant_config: Optional[QuantizationConfig] = None,
  336. prefix: str = ""):
  337. super().__init__(num_embeddings, embedding_dim, params_dtype,
  338. org_num_embeddings, padding_size, quant_config,
  339. prefix)
  340. if bias:
  341. self.bias = Parameter(
  342. torch.empty(self.num_embeddings_per_partition,
  343. dtype=params_dtype))
  344. set_weight_attrs(self.bias, {
  345. "output_dim": 0,
  346. "weight_loader": self.weight_loader,
  347. })
  348. else:
  349. self.register_parameter("bias", None)
  350. def forward(self, input_):
  351. logits = self.linear_method.apply_weights(self.linear_weights, input_)
  352. if self.bias is not None:
  353. logits += self.bias
  354. return logits
  355. class ParallelTWEHead(torch.nn.Module):
  356. """Parallelized tie word embeddings head.
  357. Output logits weight matrices used in the Sampler. The weight and bias
  358. tensors are read from a VocabParallelEmbedding.
  359. Args:
  360. embeddings: the VocabParallelEmbedding to mirror
  361. """
  362. def __init__(self, embeddings: VocabParallelEmbedding):
  363. super().__init__()
  364. self.linear_method = embeddings.linear_method
  365. self.linear_weights = embeddings.linear_weights
  366. def forward(self, input_):
  367. logits = self.linear_method.apply_weights(self.linear_weights, input_)
  368. return logits