vocab_parallel_embedding.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. from dataclasses import dataclass
  2. from typing import List, Optional, Sequence, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter, UninitializedParameter
  6. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  7. get_tensor_model_parallel_world_size,
  8. tensor_model_parallel_all_reduce)
  9. from aphrodite.modeling.utils import set_weight_attrs
  10. from aphrodite.quantization.base_config import (
  11. QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
  12. DEFAULT_VOCAB_PADDING_SIZE = 64
  13. class UnquantizedEmbeddingMethod(QuantizeMethodBase):
  14. """Unquantized method for embeddings."""
  15. def create_weights(self, layer: torch.nn.Module,
  16. input_size_per_partition: int,
  17. output_partition_sizes: List[int], input_size: int,
  18. output_size: int, params_dtype: torch.dtype,
  19. **extra_weight_attrs):
  20. """Create weights for embedding layer."""
  21. weight = Parameter(torch.empty(sum(output_partition_sizes),
  22. input_size_per_partition,
  23. dtype=params_dtype),
  24. requires_grad=False)
  25. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  26. layer.register_parameter("weight", weight)
  27. set_weight_attrs(weight, extra_weight_attrs)
  28. def apply(self,
  29. layer: torch.nn.Module,
  30. x: torch.Tensor,
  31. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  32. return F.linear(x, layer.weight, bias)
  33. def embedding(self, layer: torch.nn.Module,
  34. input_: torch.Tensor) -> torch.Tensor:
  35. return F.embedding(input_, layer.weight)
  36. def pad_vocab_size(vocab_size: int,
  37. pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
  38. """Pad the vocab size to the given value."""
  39. return ((vocab_size + pad_to - 1) // pad_to) * pad_to
  40. def vocab_range_from_per_partition_vocab_size(
  41. per_partition_vocab_size: int,
  42. rank: int,
  43. offset: int = 0) -> Sequence[int]:
  44. index_f = rank * per_partition_vocab_size
  45. index_l = index_f + per_partition_vocab_size
  46. return index_f + offset, index_l + offset
  47. def vocab_range_from_global_vocab_size(global_vocab_size: int,
  48. rank: int,
  49. world_size: int,
  50. offset: int = 0) -> Sequence[int]:
  51. per_partition_vocab_size = divide(global_vocab_size, world_size)
  52. return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
  53. rank,
  54. offset=offset)
  55. @dataclass
  56. class VocabParallelEmbeddingShardIndices:
  57. """Indices for a shard of a vocab parallel embedding."""
  58. padded_org_vocab_start_index: int
  59. padded_org_vocab_end_index: int
  60. padded_added_vocab_start_index: int
  61. padded_added_vocab_end_index: int
  62. org_vocab_start_index: int
  63. org_vocab_end_index: int
  64. added_vocab_start_index: int
  65. added_vocab_end_index: int
  66. @property
  67. def num_org_elements(self) -> int:
  68. return self.org_vocab_end_index - self.org_vocab_start_index
  69. @property
  70. def num_added_elements(self) -> int:
  71. return self.added_vocab_end_index - self.added_vocab_start_index
  72. @property
  73. def num_org_elements_padded(self) -> int:
  74. return (self.padded_org_vocab_end_index -
  75. self.padded_org_vocab_start_index)
  76. @property
  77. def num_added_elements_padded(self) -> int:
  78. return (self.padded_added_vocab_end_index -
  79. self.padded_added_vocab_start_index)
  80. @property
  81. def num_org_vocab_padding(self) -> int:
  82. return self.num_org_elements_padded - self.num_org_elements
  83. @property
  84. def num_added_vocab_padding(self) -> int:
  85. return self.num_added_elements_padded - self.num_added_elements
  86. @property
  87. def num_elements_padded(self) -> int:
  88. return self.num_org_elements_padded + self.num_added_elements_padded
  89. def __post_init__(self):
  90. # sanity checks
  91. assert (self.padded_org_vocab_start_index <=
  92. self.padded_org_vocab_end_index)
  93. assert (self.padded_added_vocab_start_index <=
  94. self.padded_added_vocab_end_index)
  95. assert self.org_vocab_start_index <= self.org_vocab_end_index
  96. assert self.added_vocab_start_index <= self.added_vocab_end_index
  97. assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
  98. assert (self.added_vocab_start_index <=
  99. self.padded_added_vocab_start_index)
  100. assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
  101. assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
  102. assert self.num_org_elements <= self.num_org_elements_padded
  103. assert self.num_added_elements <= self.num_added_elements_padded
  104. @torch.jit.script
  105. def get_masked_input_and_mask(
  106. input_: torch.Tensor, org_vocab_start_index: int,
  107. org_vocab_end_index: int, num_org_vocab_padding: int,
  108. added_vocab_start_index: int,
  109. added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
  110. # torch.jit.script will fuse all of the pointwise ops below
  111. # into a single kernel, making it very fast
  112. org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
  113. org_vocab_end_index)
  114. added_vocab_mask = (input_ >= added_vocab_start_index) & (
  115. input_ < added_vocab_end_index)
  116. added_offset = added_vocab_start_index - (
  117. org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
  118. valid_offset = (org_vocab_start_index *
  119. org_vocab_mask) + (added_offset * added_vocab_mask)
  120. vocab_mask = org_vocab_mask | added_vocab_mask
  121. input_ = vocab_mask * (input_ - valid_offset)
  122. return input_, ~vocab_mask
  123. class VocabParallelEmbedding(torch.nn.Module):
  124. """Embedding parallelized in the vocabulary dimension.
  125. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
  126. make sure it is divisible by the number of model parallel GPUs.
  127. In order to support various loading methods, we ensure that LoRA-added
  128. embeddings are always at the end of TP-sharded tensors. In other words,
  129. we shard base embeddings and LoRA embeddings separately (both padded),
  130. and place them in the same tensor.
  131. In this example, we will have the original vocab size = 1010,
  132. added vocab size = 16 and padding to 64. Therefore, the total
  133. vocab size with padding will be 1088 (because we first pad 1010 to
  134. 1024, add 16, and then pad to 1088).
  135. Therefore, the tensor format looks like the following:
  136. TP1, rank 0 (no sharding):
  137. |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
  138. corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
  139. index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
  140. TP2, rank 0:
  141. |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
  142. corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
  143. index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
  144. TP2, rank 1:
  145. |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
  146. corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
  147. index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
  148. Args:
  149. num_embeddings: vocabulary size.
  150. embedding_dim: size of hidden state.
  151. params_dtype: type of the parameters.
  152. org_num_embeddings: original vocabulary size (without LoRA).
  153. padding_size: padding size for the vocabulary.
  154. quant_config: quant config for the layer.
  155. prefix: full name of the layer in the state dict
  156. """ # noqa: E501
  157. def __init__(self,
  158. num_embeddings: int,
  159. embedding_dim: int,
  160. params_dtype: Optional[torch.dtype] = None,
  161. org_num_embeddings: Optional[int] = None,
  162. padding_size: Optional[int] = None,
  163. quant_config: Optional[QuantizationConfig] = None,
  164. prefix: str = ""):
  165. super().__init__()
  166. padding_size = padding_size or get_tensor_model_parallel_world_size()
  167. # Keep the input dimensions.
  168. tp_rank = get_tensor_model_parallel_rank()
  169. self.tp_size = get_tensor_model_parallel_world_size()
  170. self.num_embeddings = num_embeddings
  171. self.padding_size = padding_size
  172. self.org_vocab_size = org_num_embeddings or num_embeddings
  173. num_added_embeddings = num_embeddings - self.org_vocab_size
  174. self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
  175. self.padding_size)
  176. self.num_embeddings_padded = pad_vocab_size(
  177. self.org_vocab_size_padded + num_added_embeddings,
  178. self.padding_size)
  179. assert self.org_vocab_size_padded <= self.num_embeddings_padded
  180. self.shard_indices = self._get_indices(self.num_embeddings_padded,
  181. self.org_vocab_size_padded,
  182. self.num_embeddings,
  183. self.org_vocab_size, tp_rank,
  184. self.tp_size)
  185. self.embedding_dim = embedding_dim
  186. linear_method = None
  187. if quant_config is not None:
  188. linear_method = quant_config.get_quant_method(self, prefix=prefix)
  189. if linear_method is None:
  190. linear_method = UnquantizedEmbeddingMethod()
  191. # If we are making an embedding layer, then our quantization linear
  192. # method must implement the embedding operation. If we are another
  193. # layer type like ParallelLMHead, this is not important.
  194. is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
  195. linear_method_implements_embedding = method_has_implemented_embedding(
  196. type(linear_method))
  197. if is_embedding_layer and not linear_method_implements_embedding:
  198. raise NotImplementedError(
  199. f"The class {type(linear_method).__name__} must implement "
  200. "the 'embedding' method, see UnquantizedEmbeddingMethod.")
  201. self.linear_method: QuantizeMethodBase = linear_method
  202. if params_dtype is None:
  203. params_dtype = torch.get_default_dtype()
  204. # Divide the weight matrix along the vocaburaly dimension.
  205. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
  206. self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
  207. self.tp_size)
  208. assert (self.shard_indices.num_elements_padded ==
  209. self.num_embeddings_per_partition)
  210. self.num_org_embeddings_per_partition = (
  211. self.shard_indices.org_vocab_end_index -
  212. self.shard_indices.org_vocab_start_index)
  213. self.num_added_embeddings_per_partition = (
  214. self.shard_indices.added_vocab_end_index -
  215. self.shard_indices.added_vocab_start_index)
  216. self.linear_method.create_weights(self,
  217. self.embedding_dim,
  218. [self.num_embeddings_per_partition],
  219. self.embedding_dim,
  220. self.num_embeddings_padded,
  221. params_dtype=params_dtype,
  222. weight_loader=self.weight_loader)
  223. @classmethod
  224. def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
  225. vocab_size: int, org_vocab_size: int, tp_rank: int,
  226. tp_size: int) -> VocabParallelEmbeddingShardIndices:
  227. """Get start and end indices for vocab parallel embedding, following the
  228. layout outlined in the class docstring, based on the given tp_rank and
  229. tp_size."""
  230. num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
  231. padded_org_vocab_start_index, padded_org_vocab_end_index = (
  232. vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
  233. tp_size))
  234. padded_added_vocab_start_index, padded_added_vocab_end_index = (
  235. vocab_range_from_global_vocab_size(num_added_embeddings_padded,
  236. tp_rank,
  237. tp_size,
  238. offset=org_vocab_size))
  239. # remove padding
  240. org_vocab_start_index = min(padded_org_vocab_start_index,
  241. org_vocab_size)
  242. org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
  243. added_vocab_start_index = min(padded_added_vocab_start_index,
  244. vocab_size)
  245. added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
  246. return VocabParallelEmbeddingShardIndices(
  247. padded_org_vocab_start_index, padded_org_vocab_end_index,
  248. padded_added_vocab_start_index, padded_added_vocab_end_index,
  249. org_vocab_start_index, org_vocab_end_index,
  250. added_vocab_start_index, added_vocab_end_index)
  251. def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
  252. """Get a mapping that can be used to reindex the gathered
  253. logits for sampling.
  254. During sampling, we gather logits from all ranks. The relationship
  255. of index->token_id will follow the same format as outlined in the class
  256. docstring. However, after the gather, we want to reindex the final
  257. logits tensor to map index->token_id one-to-one (the index is always
  258. equal the token_id it corresponds to). The indices returned by this
  259. method allow us to do that.
  260. """
  261. if self.tp_size < 2:
  262. return None
  263. base_embeddings: List[int] = []
  264. added_embeddings: List[int] = []
  265. padding: List[int] = []
  266. for tp_rank in range(self.tp_size):
  267. shard_indices = self._get_indices(self.num_embeddings_padded,
  268. self.org_vocab_size_padded,
  269. self.num_embeddings,
  270. self.org_vocab_size, tp_rank,
  271. self.tp_size)
  272. range_start = self.num_embeddings_per_partition * tp_rank
  273. range_end = self.num_embeddings_per_partition * (tp_rank + 1)
  274. base_embeddings.extend(
  275. range(range_start,
  276. range_start + shard_indices.num_org_elements))
  277. padding.extend(
  278. range(range_start + shard_indices.num_org_elements,
  279. range_start + shard_indices.num_org_elements_padded))
  280. added_embeddings.extend(
  281. range(
  282. range_start + shard_indices.num_org_elements_padded,
  283. range_start + shard_indices.num_org_elements_padded +
  284. shard_indices.num_added_elements))
  285. padding.extend(
  286. range(
  287. range_start + shard_indices.num_org_elements_padded +
  288. shard_indices.num_added_elements,
  289. range_start + shard_indices.num_org_elements_padded +
  290. shard_indices.num_added_elements_padded))
  291. assert (range_start + shard_indices.num_org_elements_padded +
  292. shard_indices.num_added_elements_padded == range_end)
  293. ret = base_embeddings + added_embeddings + padding
  294. assert len(ret) == self.num_embeddings_padded
  295. return ret
  296. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  297. output_dim = getattr(param, "output_dim", None)
  298. packed_dim = getattr(param, "packed_dim", None)
  299. # If the parameter is a gguf weight, then load it directly.
  300. if getattr(param, "is_gguf_weight_type", None):
  301. param.data.copy_(loaded_weight)
  302. param.weight_type = loaded_weight.item()
  303. return
  304. elif isinstance(param, UninitializedParameter):
  305. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  306. # If parameter does not have output dim, then it should
  307. # be copied onto all gpus (e.g. g_idx for act_order gptq).
  308. if output_dim is None:
  309. assert param.data.shape == loaded_weight.shape
  310. param.data.copy_(loaded_weight)
  311. return
  312. # Shard indexes for loading the weight
  313. start_idx = self.shard_indices.org_vocab_start_index
  314. shard_size = self.shard_indices.org_vocab_end_index - start_idx
  315. # If param packed on the same dim we are sharding on, then
  316. # need to adjust offsets of loaded weight by pack_factor.
  317. if packed_dim is not None and packed_dim == output_dim:
  318. assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
  319. param.pack_factor)
  320. start_idx = start_idx // param.pack_factor
  321. shard_size = shard_size // param.pack_factor
  322. else:
  323. assert loaded_weight.shape[output_dim] == self.org_vocab_size
  324. # Copy the data.
  325. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
  326. param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
  327. param[loaded_weight.shape[0]:].data.fill_(0)
  328. def forward(self, input_):
  329. if self.tp_size > 1:
  330. # Build the mask.
  331. masked_input, input_mask = get_masked_input_and_mask(
  332. input_, self.shard_indices.org_vocab_start_index,
  333. self.shard_indices.org_vocab_end_index,
  334. self.shard_indices.num_org_vocab_padding,
  335. self.shard_indices.added_vocab_start_index,
  336. self.shard_indices.added_vocab_end_index)
  337. else:
  338. masked_input = input_
  339. # Get the embeddings.
  340. output_parallel = self.linear_method.embedding(self,
  341. masked_input.long())
  342. # Mask the output embedding.
  343. if self.tp_size > 1:
  344. output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
  345. # Reduce across all the model parallel GPUs.
  346. output = tensor_model_parallel_all_reduce(output_parallel)
  347. return output
  348. def extra_repr(self) -> str:
  349. s = f"num_embeddings={self.num_embeddings_per_partition}"
  350. s += f", embedding_dim={self.embedding_dim}"
  351. s += f", org_vocab_size={self.org_vocab_size}"
  352. s += f', num_embeddings_padded={self.num_embeddings_padded}'
  353. s += f', tp_size={self.tp_size}'
  354. return s
  355. class ParallelLMHead(VocabParallelEmbedding):
  356. """Parallelized LM head.
  357. Output logits weight matrices used in the Sampler. The weight and bias
  358. tensors are padded to make sure they are divisible by the number of
  359. model parallel GPUs.
  360. Args:
  361. num_embeddings: vocabulary size.
  362. embedding_dim: size of hidden state.
  363. bias: whether to use bias.
  364. params_dtype: type of the parameters.
  365. org_num_embeddings: original vocabulary size (without LoRA).
  366. padding_size: padding size for the vocabulary.
  367. """
  368. def __init__(self,
  369. num_embeddings: int,
  370. embedding_dim: int,
  371. bias: bool = False,
  372. params_dtype: Optional[torch.dtype] = None,
  373. org_num_embeddings: Optional[int] = None,
  374. padding_size: Optional[int] = None,
  375. quant_config: Optional[QuantizationConfig] = None,
  376. prefix: str = ""):
  377. super().__init__(num_embeddings, embedding_dim, params_dtype,
  378. org_num_embeddings, padding_size, quant_config,
  379. prefix)
  380. if bias:
  381. self.bias = Parameter(
  382. torch.empty(self.num_embeddings_per_partition,
  383. dtype=params_dtype))
  384. set_weight_attrs(self.bias, {
  385. "output_dim": 0,
  386. "weight_loader": self.weight_loader,
  387. })
  388. else:
  389. self.register_parameter("bias", None)
  390. def forward(self, input_):
  391. logits = self.linear_method.apply_weights(self.linear_weights, input_)
  392. if self.bias is not None:
  393. logits += self.bias
  394. return logits
  395. class ParallelTWEHead(torch.nn.Module):
  396. """Parallelized tie word embeddings head.
  397. Output logits weight matrices used in the Sampler. The weight and bias
  398. tensors are read from a VocabParallelEmbedding.
  399. Args:
  400. embeddings: the VocabParallelEmbedding to mirror
  401. """
  402. def __init__(self, embeddings: VocabParallelEmbedding):
  403. super().__init__()
  404. self.linear_method = embeddings.linear_method
  405. self.linear_weights = embeddings.linear_weights
  406. def forward(self, input_):
  407. logits = self.linear_method.apply_weights(self.linear_weights, input_)
  408. return logits