vocab_parallel_embedding.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. from dataclasses import dataclass
  2. from typing import List, Optional, Sequence, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter, UninitializedParameter
  6. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  7. get_tensor_model_parallel_world_size,
  8. tensor_model_parallel_all_reduce)
  9. from aphrodite.modeling.parameter import BaseAphroditeParameter
  10. from aphrodite.modeling.utils import set_weight_attrs
  11. from aphrodite.quantization.base_config import (
  12. QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
  13. DEFAULT_VOCAB_PADDING_SIZE = 64
  14. class UnquantizedEmbeddingMethod(QuantizeMethodBase):
  15. """Unquantized method for embeddings."""
  16. def create_weights(self, layer: torch.nn.Module,
  17. input_size_per_partition: int,
  18. output_partition_sizes: List[int], input_size: int,
  19. output_size: int, params_dtype: torch.dtype,
  20. **extra_weight_attrs):
  21. """Create weights for embedding layer."""
  22. weight = Parameter(torch.empty(sum(output_partition_sizes),
  23. input_size_per_partition,
  24. dtype=params_dtype),
  25. requires_grad=False)
  26. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  27. layer.register_parameter("weight", weight)
  28. set_weight_attrs(weight, extra_weight_attrs)
  29. def apply(self,
  30. layer: torch.nn.Module,
  31. x: torch.Tensor,
  32. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  33. return F.linear(x, layer.weight, bias)
  34. def embedding(self, layer: torch.nn.Module,
  35. input_: torch.Tensor) -> torch.Tensor:
  36. return F.embedding(input_, layer.weight)
  37. def pad_vocab_size(vocab_size: int,
  38. pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
  39. """Pad the vocab size to the given value."""
  40. return ((vocab_size + pad_to - 1) // pad_to) * pad_to
  41. def vocab_range_from_per_partition_vocab_size(
  42. per_partition_vocab_size: int,
  43. rank: int,
  44. offset: int = 0) -> Sequence[int]:
  45. index_f = rank * per_partition_vocab_size
  46. index_l = index_f + per_partition_vocab_size
  47. return index_f + offset, index_l + offset
  48. def vocab_range_from_global_vocab_size(global_vocab_size: int,
  49. rank: int,
  50. world_size: int,
  51. offset: int = 0) -> Sequence[int]:
  52. per_partition_vocab_size = divide(global_vocab_size, world_size)
  53. return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
  54. rank,
  55. offset=offset)
  56. @dataclass
  57. class VocabParallelEmbeddingShardIndices:
  58. """Indices for a shard of a vocab parallel embedding."""
  59. padded_org_vocab_start_index: int
  60. padded_org_vocab_end_index: int
  61. padded_added_vocab_start_index: int
  62. padded_added_vocab_end_index: int
  63. org_vocab_start_index: int
  64. org_vocab_end_index: int
  65. added_vocab_start_index: int
  66. added_vocab_end_index: int
  67. @property
  68. def num_org_elements(self) -> int:
  69. return self.org_vocab_end_index - self.org_vocab_start_index
  70. @property
  71. def num_added_elements(self) -> int:
  72. return self.added_vocab_end_index - self.added_vocab_start_index
  73. @property
  74. def num_org_elements_padded(self) -> int:
  75. return (self.padded_org_vocab_end_index -
  76. self.padded_org_vocab_start_index)
  77. @property
  78. def num_added_elements_padded(self) -> int:
  79. return (self.padded_added_vocab_end_index -
  80. self.padded_added_vocab_start_index)
  81. @property
  82. def num_org_vocab_padding(self) -> int:
  83. return self.num_org_elements_padded - self.num_org_elements
  84. @property
  85. def num_added_vocab_padding(self) -> int:
  86. return self.num_added_elements_padded - self.num_added_elements
  87. @property
  88. def num_elements_padded(self) -> int:
  89. return self.num_org_elements_padded + self.num_added_elements_padded
  90. def __post_init__(self):
  91. # sanity checks
  92. assert (self.padded_org_vocab_start_index <=
  93. self.padded_org_vocab_end_index)
  94. assert (self.padded_added_vocab_start_index <=
  95. self.padded_added_vocab_end_index)
  96. assert self.org_vocab_start_index <= self.org_vocab_end_index
  97. assert self.added_vocab_start_index <= self.added_vocab_end_index
  98. assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
  99. assert (self.added_vocab_start_index <=
  100. self.padded_added_vocab_start_index)
  101. assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
  102. assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
  103. assert self.num_org_elements <= self.num_org_elements_padded
  104. assert self.num_added_elements <= self.num_added_elements_padded
  105. @torch.jit.script
  106. def get_masked_input_and_mask(
  107. input_: torch.Tensor, org_vocab_start_index: int,
  108. org_vocab_end_index: int, num_org_vocab_padding: int,
  109. added_vocab_start_index: int,
  110. added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
  111. # torch.jit.script will fuse all of the pointwise ops below
  112. # into a single kernel, making it very fast
  113. org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
  114. org_vocab_end_index)
  115. added_vocab_mask = (input_ >= added_vocab_start_index) & (
  116. input_ < added_vocab_end_index)
  117. added_offset = added_vocab_start_index - (
  118. org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
  119. valid_offset = (org_vocab_start_index *
  120. org_vocab_mask) + (added_offset * added_vocab_mask)
  121. vocab_mask = org_vocab_mask | added_vocab_mask
  122. input_ = vocab_mask * (input_ - valid_offset)
  123. return input_, ~vocab_mask
  124. class VocabParallelEmbedding(torch.nn.Module):
  125. """Embedding parallelized in the vocabulary dimension.
  126. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
  127. make sure it is divisible by the number of model parallel GPUs.
  128. In order to support various loading methods, we ensure that LoRA-added
  129. embeddings are always at the end of TP-sharded tensors. In other words,
  130. we shard base embeddings and LoRA embeddings separately (both padded),
  131. and place them in the same tensor.
  132. In this example, we will have the original vocab size = 1010,
  133. added vocab size = 16 and padding to 64. Therefore, the total
  134. vocab size with padding will be 1088 (because we first pad 1010 to
  135. 1024, add 16, and then pad to 1088).
  136. Therefore, the tensor format looks like the following:
  137. TP1, rank 0 (no sharding):
  138. |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
  139. corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
  140. index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
  141. TP2, rank 0:
  142. |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
  143. corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
  144. index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
  145. TP2, rank 1:
  146. |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
  147. corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
  148. index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
  149. Args:
  150. num_embeddings: vocabulary size.
  151. embedding_dim: size of hidden state.
  152. params_dtype: type of the parameters.
  153. org_num_embeddings: original vocabulary size (without LoRA).
  154. padding_size: padding size for the vocabulary.
  155. quant_config: quant config for the layer.
  156. prefix: full name of the layer in the state dict
  157. """ # noqa: E501
  158. def __init__(self,
  159. num_embeddings: int,
  160. embedding_dim: int,
  161. params_dtype: Optional[torch.dtype] = None,
  162. org_num_embeddings: Optional[int] = None,
  163. padding_size: Optional[int] = None,
  164. quant_config: Optional[QuantizationConfig] = None,
  165. prefix: str = ""):
  166. super().__init__()
  167. padding_size = padding_size or get_tensor_model_parallel_world_size()
  168. # Keep the input dimensions.
  169. tp_rank = get_tensor_model_parallel_rank()
  170. self.tp_size = get_tensor_model_parallel_world_size()
  171. self.num_embeddings = num_embeddings
  172. self.padding_size = padding_size
  173. self.org_vocab_size = org_num_embeddings or num_embeddings
  174. num_added_embeddings = num_embeddings - self.org_vocab_size
  175. self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
  176. self.padding_size)
  177. self.num_embeddings_padded = pad_vocab_size(
  178. self.org_vocab_size_padded + num_added_embeddings,
  179. self.padding_size)
  180. assert self.org_vocab_size_padded <= self.num_embeddings_padded
  181. self.shard_indices = self._get_indices(self.num_embeddings_padded,
  182. self.org_vocab_size_padded,
  183. self.num_embeddings,
  184. self.org_vocab_size, tp_rank,
  185. self.tp_size)
  186. self.embedding_dim = embedding_dim
  187. linear_method = None
  188. if quant_config is not None:
  189. linear_method = quant_config.get_quant_method(self, prefix=prefix)
  190. if linear_method is None:
  191. linear_method = UnquantizedEmbeddingMethod()
  192. # If we are making an embedding layer, then our quantization linear
  193. # method must implement the embedding operation. If we are another
  194. # layer type like ParallelLMHead, this is not important.
  195. is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
  196. linear_method_implements_embedding = method_has_implemented_embedding(
  197. type(linear_method))
  198. if is_embedding_layer and not linear_method_implements_embedding:
  199. raise NotImplementedError(
  200. f"The class {type(linear_method).__name__} must implement "
  201. "the 'embedding' method, see UnquantizedEmbeddingMethod.")
  202. self.linear_method: QuantizeMethodBase = linear_method
  203. if params_dtype is None:
  204. params_dtype = torch.get_default_dtype()
  205. # Divide the weight matrix along the vocaburaly dimension.
  206. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
  207. self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
  208. self.tp_size)
  209. assert (self.shard_indices.num_elements_padded ==
  210. self.num_embeddings_per_partition)
  211. self.num_org_embeddings_per_partition = (
  212. self.shard_indices.org_vocab_end_index -
  213. self.shard_indices.org_vocab_start_index)
  214. self.num_added_embeddings_per_partition = (
  215. self.shard_indices.added_vocab_end_index -
  216. self.shard_indices.added_vocab_start_index)
  217. self.linear_method.create_weights(self,
  218. self.embedding_dim,
  219. [self.num_embeddings_per_partition],
  220. self.embedding_dim,
  221. self.num_embeddings_padded,
  222. params_dtype=params_dtype,
  223. weight_loader=self.weight_loader)
  224. @classmethod
  225. def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
  226. vocab_size: int, org_vocab_size: int, tp_rank: int,
  227. tp_size: int) -> VocabParallelEmbeddingShardIndices:
  228. """Get start and end indices for vocab parallel embedding, following the
  229. layout outlined in the class docstring, based on the given tp_rank and
  230. tp_size."""
  231. num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
  232. padded_org_vocab_start_index, padded_org_vocab_end_index = (
  233. vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
  234. tp_size))
  235. padded_added_vocab_start_index, padded_added_vocab_end_index = (
  236. vocab_range_from_global_vocab_size(num_added_embeddings_padded,
  237. tp_rank,
  238. tp_size,
  239. offset=org_vocab_size))
  240. # remove padding
  241. org_vocab_start_index = min(padded_org_vocab_start_index,
  242. org_vocab_size)
  243. org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
  244. added_vocab_start_index = min(padded_added_vocab_start_index,
  245. vocab_size)
  246. added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
  247. return VocabParallelEmbeddingShardIndices(
  248. padded_org_vocab_start_index, padded_org_vocab_end_index,
  249. padded_added_vocab_start_index, padded_added_vocab_end_index,
  250. org_vocab_start_index, org_vocab_end_index,
  251. added_vocab_start_index, added_vocab_end_index)
  252. def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
  253. """Get a mapping that can be used to reindex the gathered
  254. logits for sampling.
  255. During sampling, we gather logits from all ranks. The relationship
  256. of index->token_id will follow the same format as outlined in the class
  257. docstring. However, after the gather, we want to reindex the final
  258. logits tensor to map index->token_id one-to-one (the index is always
  259. equal the token_id it corresponds to). The indices returned by this
  260. method allow us to do that.
  261. """
  262. if self.tp_size < 2:
  263. return None
  264. base_embeddings: List[int] = []
  265. added_embeddings: List[int] = []
  266. padding: List[int] = []
  267. for tp_rank in range(self.tp_size):
  268. shard_indices = self._get_indices(self.num_embeddings_padded,
  269. self.org_vocab_size_padded,
  270. self.num_embeddings,
  271. self.org_vocab_size, tp_rank,
  272. self.tp_size)
  273. range_start = self.num_embeddings_per_partition * tp_rank
  274. range_end = self.num_embeddings_per_partition * (tp_rank + 1)
  275. base_embeddings.extend(
  276. range(range_start,
  277. range_start + shard_indices.num_org_elements))
  278. padding.extend(
  279. range(range_start + shard_indices.num_org_elements,
  280. range_start + shard_indices.num_org_elements_padded))
  281. added_embeddings.extend(
  282. range(
  283. range_start + shard_indices.num_org_elements_padded,
  284. range_start + shard_indices.num_org_elements_padded +
  285. shard_indices.num_added_elements))
  286. padding.extend(
  287. range(
  288. range_start + shard_indices.num_org_elements_padded +
  289. shard_indices.num_added_elements,
  290. range_start + shard_indices.num_org_elements_padded +
  291. shard_indices.num_added_elements_padded))
  292. assert (range_start + shard_indices.num_org_elements_padded +
  293. shard_indices.num_added_elements_padded == range_end)
  294. ret = base_embeddings + added_embeddings + padding
  295. assert len(ret) == self.num_embeddings_padded
  296. return ret
  297. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  298. output_dim = getattr(param, "output_dim", None)
  299. packed_dim = getattr(param, "packed_dim", None)
  300. # If the parameter is a gguf weight, then load it directly.
  301. if getattr(param, "is_gguf_weight_type", None):
  302. param.data.copy_(loaded_weight)
  303. param.weight_type = loaded_weight.item()
  304. return
  305. elif isinstance(param, UninitializedParameter):
  306. shape = list(loaded_weight.shape)
  307. if output_dim is not None:
  308. shape[output_dim] = shape[output_dim] // self.tp_size
  309. param.materialize(tuple(shape), dtype=loaded_weight.dtype)
  310. # If parameter does not have output dim, then it should
  311. # be copied onto all gpus (e.g. g_idx for act_order gptq).
  312. if output_dim is None:
  313. assert param.data.shape == loaded_weight.shape
  314. param.data.copy_(loaded_weight)
  315. return
  316. # Shard indexes for loading the weight
  317. start_idx = self.shard_indices.org_vocab_start_index
  318. shard_size = self.shard_indices.org_vocab_end_index - start_idx
  319. # If param packed on the same dim we are sharding on, then
  320. # need to adjust offsets of loaded weight by pack_factor.
  321. if packed_dim is not None and packed_dim == output_dim:
  322. packed_factor = param.packed_factor if isinstance(
  323. param, BaseAphroditeParameter) else param.pack_factor
  324. assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
  325. param.packed_factor)
  326. start_idx = start_idx // packed_factor
  327. shard_size = shard_size // packed_factor
  328. else:
  329. assert loaded_weight.shape[output_dim] == self.org_vocab_size
  330. # Copy the data.
  331. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
  332. param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
  333. param[loaded_weight.shape[0]:].data.fill_(0)
  334. def forward(self, input_):
  335. if self.tp_size > 1:
  336. # Build the mask.
  337. masked_input, input_mask = get_masked_input_and_mask(
  338. input_, self.shard_indices.org_vocab_start_index,
  339. self.shard_indices.org_vocab_end_index,
  340. self.shard_indices.num_org_vocab_padding,
  341. self.shard_indices.added_vocab_start_index,
  342. self.shard_indices.added_vocab_end_index)
  343. else:
  344. masked_input = input_
  345. # Get the embeddings.
  346. output_parallel = self.linear_method.embedding(self,
  347. masked_input.long())
  348. # Mask the output embedding.
  349. if self.tp_size > 1:
  350. output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
  351. # Reduce across all the model parallel GPUs.
  352. output = tensor_model_parallel_all_reduce(output_parallel)
  353. return output
  354. def extra_repr(self) -> str:
  355. s = f"num_embeddings={self.num_embeddings_per_partition}"
  356. s += f", embedding_dim={self.embedding_dim}"
  357. s += f", org_vocab_size={self.org_vocab_size}"
  358. s += f', num_embeddings_padded={self.num_embeddings_padded}'
  359. s += f', tp_size={self.tp_size}'
  360. return s
  361. class ParallelLMHead(VocabParallelEmbedding):
  362. """Parallelized LM head.
  363. Output logits weight matrices used in the Sampler. The weight and bias
  364. tensors are padded to make sure they are divisible by the number of
  365. model parallel GPUs.
  366. Args:
  367. num_embeddings: vocabulary size.
  368. embedding_dim: size of hidden state.
  369. bias: whether to use bias.
  370. params_dtype: type of the parameters.
  371. org_num_embeddings: original vocabulary size (without LoRA).
  372. padding_size: padding size for the vocabulary.
  373. """
  374. def __init__(self,
  375. num_embeddings: int,
  376. embedding_dim: int,
  377. bias: bool = False,
  378. params_dtype: Optional[torch.dtype] = None,
  379. org_num_embeddings: Optional[int] = None,
  380. padding_size: Optional[int] = None,
  381. quant_config: Optional[QuantizationConfig] = None,
  382. prefix: str = ""):
  383. super().__init__(num_embeddings, embedding_dim, params_dtype,
  384. org_num_embeddings, padding_size, quant_config,
  385. prefix)
  386. if bias:
  387. self.bias = Parameter(
  388. torch.empty(self.num_embeddings_per_partition,
  389. dtype=params_dtype))
  390. set_weight_attrs(self.bias, {
  391. "output_dim": 0,
  392. "weight_loader": self.weight_loader,
  393. })
  394. else:
  395. self.register_parameter("bias", None)
  396. def forward(self, input_):
  397. logits = self.linear_method.apply_weights(self.linear_weights, input_)
  398. if self.bias is not None:
  399. logits += self.bias
  400. return logits
  401. class ParallelTWEHead(torch.nn.Module):
  402. """Parallelized tie word embeddings head.
  403. Output logits weight matrices used in the Sampler. The weight and bias
  404. tensors are read from a VocabParallelEmbedding.
  405. Args:
  406. embeddings: the VocabParallelEmbedding to mirror
  407. """
  408. def __init__(self, embeddings: VocabParallelEmbedding):
  409. super().__init__()
  410. self.linear_method = embeddings.linear_method
  411. self.linear_weights = embeddings.linear_weights
  412. def forward(self, input_):
  413. logits = self.linear_method.apply_weights(self.linear_weights, input_)
  414. return logits