test_layers.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216
  1. import random
  2. from copy import deepcopy
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple
  5. from unittest.mock import patch
  6. import pytest
  7. import torch
  8. import torch.nn.functional as F
  9. from aphrodite.common.config import LoRAConfig
  10. from aphrodite.lora.fully_sharded_layers import (
  11. ColumnParallelLinearWithShardedLoRA,
  12. MergedColumnParallelLinearWithShardedLoRA,
  13. MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
  14. RowParallelLinearWithShardedLoRA)
  15. # yapf conflicts with isort for this block
  16. # yapf: disable
  17. from aphrodite.lora.layers import (BaseLayerWithLoRA,
  18. ColumnParallelLinearWithLoRA,
  19. LinearScalingRotaryEmbeddingWithLora,
  20. LogitsProcessorWithLoRA, LoRAMapping,
  21. MergedColumnParallelLinearWithLoRA,
  22. MergedQKVParallelLinearWithLora,
  23. QKVParallelLinearWithLora,
  24. ReplicatedLinearWithLoRA,
  25. RowParallelLinearWithLoRA,
  26. VocabParallelEmbeddingWithLoRA)
  27. # yapf: enable
  28. from aphrodite.lora.models import (LongContextLoRAContext, LoRALayerWeights,
  29. PackedLoRALayerWeights)
  30. from aphrodite.lora.punica import PunicaWrapper
  31. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  32. MergedColumnParallelLinear,
  33. QKVParallelLinear,
  34. ReplicatedLinear,
  35. RowParallelLinear)
  36. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  39. ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
  40. from aphrodite.modeling.utils import set_random_seed
  41. from .utils import DummyLoRAManager
  42. TOLERANCES = {
  43. torch.float16: (5e-3, 5e-3),
  44. torch.float32: (5e-3, 5e-3),
  45. torch.bfloat16: (3e-2, 2e-2),
  46. }
  47. CUDA_DEVICES = [
  48. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  49. ]
  50. # We will launch different triton kernels between the prefill and decode
  51. # stages, so we need to verify this. prefill stage(True) or decode stage(False)
  52. STAGES = [True, False]
  53. def get_random_id_to_index(num_loras: int,
  54. num_slots: int,
  55. log: bool = True) -> List[Optional[int]]:
  56. """Creates a random lora_id_to_index mapping.
  57. Args:
  58. num_loras: The number of active loras in the mapping.
  59. num_slots: The number of slots in the mapping. Must be larger
  60. than num_loras.
  61. log: Whether to log the output.
  62. """
  63. if num_loras > num_slots:
  64. raise ValueError(
  65. f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
  66. "num_loras must be less than or equal to num_slots.")
  67. slots: List[Optional[int]] = [None] * num_slots
  68. random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
  69. for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
  70. slots[slot_idx] = lora_id
  71. if log:
  72. print(f"Created lora_id_to_index mapping: {slots}.")
  73. return slots
  74. def populate_loras(
  75. id_to_index: List[Optional[int]],
  76. layer: BaseLayerWithLoRA,
  77. layer_weights: torch.Tensor,
  78. generate_embeddings_tensor: int = 0,
  79. repeats: int = 1,
  80. ) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]:
  81. """This method populates the lora layers with lora weights.
  82. Args:
  83. id_to_index: a list of lora ids. The index of the lora id
  84. represents which memory slot the lora matrices are
  85. stored in. A None value indicates a free slot.
  86. layer: the LoRAlayer to populate.
  87. layer_weights: the PyTorch tensor containing the layer's
  88. weights.
  89. generate_embeddings_tensor: whether to generate an
  90. embeddings tensor for each LoRA.
  91. repeats: must only be set for column parallel packed
  92. layers. Indicates the number of loras to compose
  93. together to create a single lora layer.
  94. """
  95. # Dictionary that maps the lora ID to the
  96. # corresponding lora weights.
  97. lora_dict: Dict[int, LoRALayerWeights] = dict()
  98. # Dictionary that maps the lora ID to the
  99. # corresponding subloras.
  100. sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()
  101. for slot_idx, lora_id in enumerate(id_to_index):
  102. if lora_id is not None:
  103. subloras: List[LoRALayerWeights] = []
  104. sublora_len = layer_weights.shape[0] // repeats
  105. for i in range(repeats):
  106. sublora = DummyLoRAManager().init_random_lora(
  107. module_name=f"fake_{i}",
  108. weight=layer_weights,
  109. generate_embeddings_tensor=generate_embeddings_tensor,
  110. )
  111. sublora.lora_b = sublora.lora_b[:, (sublora_len *
  112. i):(sublora_len * (i + 1))]
  113. sublora.optimize()
  114. subloras.append(sublora)
  115. lora = PackedLoRALayerWeights.pack(
  116. subloras) if repeats > 1 else subloras[0]
  117. layer.set_lora(
  118. slot_idx,
  119. lora_a=lora.lora_a,
  120. lora_b=lora.lora_b,
  121. embeddings_tensor=lora.embeddings_tensor,
  122. )
  123. lora_dict[lora_id] = lora
  124. sublora_dict[lora_id] = subloras
  125. return lora_dict, sublora_dict
  126. def create_random_inputs(
  127. active_lora_ids: List[int],
  128. num_inputs: int,
  129. input_size: Tuple[int, ...],
  130. input_range: Tuple[float, float],
  131. input_type: torch.dtype = torch.int,
  132. ) -> Tuple[List[torch.Tensor], List[int], List[int]]:
  133. """Creates random inputs.
  134. Args:
  135. active_lora_ids: lora IDs of active lora weights.
  136. num_inputs: the number of inputs to create.
  137. input_size: the size of each individual input.
  138. input_range: the range of values to include in the input.
  139. input_range[0] <= possible input values < input_range[1]
  140. input_type: the type of values in the input.
  141. """
  142. low, high = input_range
  143. inputs: List[torch.Tensor] = []
  144. index_mapping: List[int] = []
  145. prompt_mapping: List[int] = []
  146. for _ in range(num_inputs):
  147. if input_type == torch.int:
  148. inputs.append(
  149. torch.randint(low=int(low), high=int(high), size=input_size))
  150. else:
  151. inputs.append(
  152. torch.rand(size=input_size, dtype=input_type) * high + low)
  153. lora_id = random.choice(active_lora_ids)
  154. index_mapping += [lora_id] * input_size[0]
  155. prompt_mapping += [lora_id]
  156. return inputs, index_mapping, prompt_mapping
  157. @torch.inference_mode()
  158. @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
  159. @pytest.mark.parametrize("device", CUDA_DEVICES)
  160. @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
  161. @pytest.mark.parametrize("stage", STAGES)
  162. def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
  163. torch.set_default_device(device)
  164. max_loras = 8
  165. punica_wrapper = PunicaWrapper(8192, 256, device)
  166. lora_config = LoRAConfig(max_loras=max_loras,
  167. max_lora_rank=8,
  168. lora_dtype=torch.float16)
  169. def create_random_embedding_layer():
  170. embedding = VocabParallelEmbedding(vocab_size, 256)
  171. embedding.weight.data = torch.rand_like(embedding.weight.data)
  172. embedding.weight.data[vocab_size:, :] = 0
  173. lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
  174. lora_embedding.create_lora_weights(max_loras, lora_config)
  175. return embedding, lora_embedding
  176. for i in range(10):
  177. set_random_seed(i)
  178. id_to_index = get_random_id_to_index(num_loras, max_loras)
  179. embedding, lora_embedding = create_random_embedding_layer()
  180. lora_embedding.set_mapping(punica_wrapper)
  181. lora_dict, _ = populate_loras(
  182. id_to_index,
  183. layer=lora_embedding,
  184. layer_weights=embedding.weight.T,
  185. )
  186. inputs, index_mapping, prompt_mapping = create_random_inputs(
  187. active_lora_ids=list(lora_dict.keys()),
  188. num_inputs=num_loras * 3,
  189. input_size=(200, ),
  190. input_range=(1, vocab_size),
  191. )
  192. lora_mapping = LoRAMapping(index_mapping,
  193. prompt_mapping,
  194. is_prefill=stage)
  195. punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
  196. vocab_size,
  197. lora_config.lora_extra_vocab_size)
  198. lora_result = lora_embedding(torch.cat(inputs))
  199. expected_results: List[torch.Tensor] = []
  200. for input_, lora_id in zip(inputs, prompt_mapping):
  201. lora = lora_dict[lora_id]
  202. result = embedding(input_)
  203. after_a = F.embedding(
  204. input_,
  205. lora.lora_a,
  206. )
  207. result += (after_a @ lora.lora_b)
  208. expected_results.append(result)
  209. expected_result = torch.cat(expected_results)
  210. rtol, atol = TOLERANCES[lora_result.dtype]
  211. torch.testing.assert_close(lora_result,
  212. expected_result,
  213. rtol=rtol,
  214. atol=atol)
  215. # Check that resetting the lora weights succeeds
  216. for slot_idx in range(max_loras):
  217. lora_embedding.reset_lora(slot_idx)
  218. inputs, index_mapping, prompt_mapping = create_random_inputs(
  219. active_lora_ids=[0],
  220. num_inputs=num_loras * 3,
  221. input_size=(200, ),
  222. input_range=(1, vocab_size),
  223. )
  224. lora_mapping = LoRAMapping(index_mapping,
  225. prompt_mapping,
  226. is_prefill=stage)
  227. punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
  228. vocab_size,
  229. lora_config.lora_extra_vocab_size)
  230. lora_result = lora_embedding(torch.cat(inputs))
  231. expected_result = embedding(torch.cat(inputs))
  232. rtol, atol = TOLERANCES[lora_result.dtype]
  233. torch.testing.assert_close(lora_result,
  234. expected_result,
  235. rtol=rtol,
  236. atol=atol)
  237. @torch.inference_mode()
  238. # @pytest.mark.skip(
  239. # reason="Fails when loras are in any slot other than the first.")
  240. @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
  241. @pytest.mark.parametrize("device", CUDA_DEVICES)
  242. @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
  243. @pytest.mark.parametrize("stage", STAGES)
  244. def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
  245. vocab_size, stage) -> None:
  246. torch.set_default_device(device)
  247. max_loras = 8
  248. punica_wrapper = PunicaWrapper(8192, 256, device)
  249. lora_config = LoRAConfig(max_loras=max_loras,
  250. max_lora_rank=8,
  251. lora_dtype=torch.float16)
  252. def create_random_embedding_layer():
  253. embedding = VocabParallelEmbedding(vocab_size, 256)
  254. embedding_data = torch.rand_like(embedding.weight.data)
  255. embedding.weight.data = embedding_data
  256. embedding.weight.data[vocab_size:, :] = 0
  257. expanded_embedding = VocabParallelEmbedding(
  258. vocab_size + lora_config.lora_extra_vocab_size * max_loras,
  259. 256,
  260. org_num_embeddings=vocab_size)
  261. expanded_embedding.weight.data[:vocab_size, :] = embedding_data
  262. # We need to deepcopy the embedding as it will be modified
  263. # in place
  264. lora_embedding = VocabParallelEmbeddingWithLoRA(
  265. deepcopy(expanded_embedding))
  266. lora_embedding.create_lora_weights(max_loras, lora_config)
  267. return expanded_embedding, lora_embedding
  268. for i in range(10):
  269. set_random_seed(i)
  270. id_to_index = get_random_id_to_index(num_loras, max_loras)
  271. expanded_embedding, lora_embedding = create_random_embedding_layer()
  272. lora_dict, _ = populate_loras(
  273. id_to_index,
  274. layer=lora_embedding,
  275. layer_weights=torch.zeros(
  276. (256, vocab_size + lora_config.lora_extra_vocab_size)),
  277. generate_embeddings_tensor=256,
  278. )
  279. lora_embedding.set_mapping(punica_wrapper)
  280. # All embeddings tensors have the same shape.
  281. embeddings_tensors = [
  282. lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
  283. ]
  284. embeddings_tensor_len = embeddings_tensors[0].shape[0]
  285. # Add empty embeddings_tensors for unoccupied lora slots.
  286. for _ in range(max_loras - len(embeddings_tensors)):
  287. embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
  288. inputs, index_mapping, prompt_mapping = create_random_inputs(
  289. active_lora_ids=list(lora_dict.keys()),
  290. num_inputs=num_loras * 3,
  291. input_size=(200, ),
  292. input_range=(1, vocab_size),
  293. )
  294. lora_mapping = LoRAMapping(index_mapping,
  295. prompt_mapping,
  296. is_prefill=stage)
  297. punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
  298. vocab_size,
  299. lora_config.lora_extra_vocab_size)
  300. original_inputs = deepcopy(inputs)
  301. # Force some of the inputs to be in the extended embeddings range
  302. # to guarantee that their behavior is tested.
  303. for input_, original_input_, lora_id in zip(inputs, original_inputs,
  304. prompt_mapping):
  305. embedding_id = lora_id - 1
  306. input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
  307. original_input_[-1] = vocab_size
  308. input_[-2] = vocab_size + (
  309. (embedding_id + 1) * embeddings_tensor_len - 1)
  310. original_input_[-2] = vocab_size + embeddings_tensor_len - 1
  311. expanded_embedding.weight[vocab_size:vocab_size +
  312. (embeddings_tensor_len *
  313. max_loras)] = torch.cat(embeddings_tensors)
  314. lora_result = lora_embedding(torch.cat(original_inputs))
  315. expected_results: List[torch.Tensor] = []
  316. for input_, original_input_, lora_id in zip(inputs, original_inputs,
  317. prompt_mapping):
  318. lora = lora_dict[lora_id]
  319. result = expanded_embedding(input_)
  320. after_a = F.embedding(
  321. original_input_,
  322. lora.lora_a,
  323. )
  324. result += (after_a @ lora.lora_b)
  325. expected_results.append(result)
  326. expected_result = torch.cat(expected_results)
  327. rtol, atol = TOLERANCES[lora_result.dtype]
  328. torch.testing.assert_close(lora_result,
  329. expected_result,
  330. rtol=rtol,
  331. atol=atol)
  332. # Check that resetting the lora weights succeeds
  333. for slot_idx in range(max_loras):
  334. lora_embedding.reset_lora(slot_idx)
  335. inputs, index_mapping, prompt_mapping = create_random_inputs(
  336. active_lora_ids=[0],
  337. num_inputs=num_loras * 3,
  338. input_size=(200, ),
  339. input_range=(1, vocab_size),
  340. )
  341. original_inputs = deepcopy(inputs)
  342. lora_mapping = LoRAMapping(index_mapping,
  343. prompt_mapping,
  344. is_prefill=stage)
  345. punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
  346. vocab_size,
  347. lora_config.lora_extra_vocab_size)
  348. lora_result = lora_embedding(torch.cat(original_inputs))
  349. expected_result = expanded_embedding(torch.cat(inputs))
  350. rtol, atol = TOLERANCES[lora_result.dtype]
  351. torch.testing.assert_close(lora_result,
  352. expected_result,
  353. rtol=rtol,
  354. atol=atol)
  355. @torch.inference_mode()
  356. @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
  357. @pytest.mark.parametrize("device", CUDA_DEVICES)
  358. @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
  359. @pytest.mark.parametrize("stage", STAGES)
  360. def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
  361. stage) -> None:
  362. torch.set_default_device(device)
  363. max_loras = 8
  364. punica_wrapper = PunicaWrapper(8192, 256, device)
  365. lora_config = LoRAConfig(max_loras=max_loras,
  366. max_lora_rank=8,
  367. lora_dtype=torch.float16)
  368. def _pretest():
  369. linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
  370. 1024,
  371. vocab_size,
  372. params_dtype=torch.float16)
  373. linear.weight.data = torch.rand_like(linear.weight.data)
  374. linear.weight.data[:, vocab_size:] = 0
  375. logits_processor = LogitsProcessor(
  376. vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
  377. lora_logits_processor = LogitsProcessorWithLoRA(
  378. logits_processor, 1024, linear.weight.dtype, linear.weight.device,
  379. None)
  380. lora_logits_processor.create_lora_weights(max_loras, lora_config)
  381. return linear, logits_processor, lora_logits_processor
  382. for i in range(10):
  383. set_random_seed(i)
  384. id_to_index = get_random_id_to_index(num_loras, max_loras)
  385. linear, logits_processor, lora_logits_processor = _pretest()
  386. lora_logits_processor.set_mapping(punica_wrapper)
  387. # NOTE: all the generated loras share the same embeddings tensor.
  388. lora_dict, _ = populate_loras(
  389. id_to_index,
  390. layer=lora_logits_processor,
  391. layer_weights=linear.weight,
  392. generate_embeddings_tensor=1024,
  393. )
  394. embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
  395. embeddings_tensor_len = embeddings_tensor.shape[0]
  396. inputs, index_mapping, prompt_mapping = create_random_inputs(
  397. active_lora_ids=list(lora_dict.keys()),
  398. num_inputs=8 * num_loras, # * 3,
  399. input_size=(1, 1024),
  400. input_range=(0, 1),
  401. input_type=torch.float16,
  402. )
  403. lora_mapping = LoRAMapping(index_mapping,
  404. prompt_mapping,
  405. is_prefill=stage)
  406. punica_wrapper.update_metadata(
  407. lora_mapping,
  408. id_to_index,
  409. max_loras,
  410. vocab_size,
  411. lora_config.lora_extra_vocab_size,
  412. )
  413. input_ = torch.rand(20, 1024)
  414. lora_result = lora_logits_processor._get_logits(
  415. hidden_states=torch.cat(inputs),
  416. lm_head=linear,
  417. embedding_bias=None)
  418. original_lm_head = deepcopy(linear)
  419. linear.weight[logits_processor.
  420. org_vocab_size:logits_processor.org_vocab_size +
  421. embeddings_tensor_len] = embeddings_tensor
  422. logits_processor.org_vocab_size = (vocab_size +
  423. lora_config.lora_extra_vocab_size)
  424. expected_results: List[torch.Tensor] = []
  425. for input_, lora_id in zip(inputs, prompt_mapping):
  426. lora = lora_dict[lora_id]
  427. result = logits_processor._get_logits(hidden_states=input_,
  428. lm_head=linear,
  429. embedding_bias=None)
  430. result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
  431. result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
  432. expected_results.append(result)
  433. expected_result = torch.cat(expected_results)
  434. logits_processor.org_vocab_size = vocab_size
  435. # Check that resetting the lora weights succeeds
  436. for slot_idx in range(max_loras):
  437. lora_logits_processor.reset_lora(slot_idx)
  438. inputs, index_mapping, prompt_mapping = create_random_inputs(
  439. active_lora_ids=[0],
  440. num_inputs=8 * num_loras * 3,
  441. input_size=(1, 1024),
  442. input_range=(0, 1),
  443. input_type=torch.float16,
  444. )
  445. lora_mapping = LoRAMapping(index_mapping,
  446. prompt_mapping,
  447. is_prefill=stage)
  448. punica_wrapper.update_metadata(
  449. lora_mapping,
  450. id_to_index,
  451. max_loras,
  452. vocab_size,
  453. lora_config.lora_extra_vocab_size,
  454. )
  455. lora_result = lora_logits_processor._get_logits(
  456. hidden_states=torch.cat(inputs),
  457. lm_head=original_lm_head,
  458. embedding_bias=None)[:, :vocab_size]
  459. expected_result = logits_processor._get_logits(
  460. hidden_states=torch.cat(inputs),
  461. lm_head=original_lm_head,
  462. embedding_bias=None)
  463. rtol, atol = TOLERANCES[lora_result.dtype]
  464. torch.testing.assert_close(lora_result,
  465. expected_result,
  466. rtol=rtol,
  467. atol=atol)
  468. @torch.inference_mode()
  469. @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
  470. @pytest.mark.parametrize("device", CUDA_DEVICES)
  471. @pytest.mark.parametrize("stage", STAGES)
  472. def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
  473. torch.set_default_device(device)
  474. punica_wrapper = PunicaWrapper(8192, 256, device)
  475. max_loras = 8
  476. lora_config = LoRAConfig(max_loras=max_loras,
  477. max_lora_rank=8,
  478. lora_dtype=torch.float16)
  479. def create_random_linear_replicated_layer():
  480. linear = ReplicatedLinear(4096,
  481. 4096,
  482. bias=False,
  483. params_dtype=torch.float16)
  484. linear.weight.data = torch.rand_like(linear.weight.data)
  485. lora_linear = ReplicatedLinearWithLoRA(linear)
  486. lora_linear.create_lora_weights(max_loras, lora_config)
  487. return linear, lora_linear
  488. for i in range(10):
  489. set_random_seed(i)
  490. id_to_index = get_random_id_to_index(num_loras, max_loras)
  491. linear, lora_linear = create_random_linear_replicated_layer()
  492. lora_linear.set_mapping(punica_wrapper)
  493. lora_dict, _ = populate_loras(
  494. id_to_index,
  495. layer=lora_linear,
  496. layer_weights=linear.weight,
  497. )
  498. inputs, index_mapping, prompt_mapping = create_random_inputs(
  499. active_lora_ids=list(lora_dict.keys()),
  500. num_inputs=32 * num_loras,
  501. input_size=(1, 4096),
  502. input_range=(0, 1),
  503. input_type=torch.float16,
  504. )
  505. lora_mapping = LoRAMapping(index_mapping,
  506. prompt_mapping,
  507. is_prefill=stage)
  508. punica_wrapper.update_metadata(
  509. lora_mapping,
  510. id_to_index,
  511. max_loras,
  512. 512,
  513. lora_config.lora_extra_vocab_size,
  514. )
  515. lora_result = lora_linear(torch.cat(inputs))[0]
  516. expected_results: List[torch.Tensor] = []
  517. for input_, lora_id in zip(inputs, prompt_mapping):
  518. lora = lora_dict[lora_id]
  519. result = linear(input_)[0]
  520. result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
  521. expected_results.append(result)
  522. expected_result = torch.cat(expected_results)
  523. rtol, atol = TOLERANCES[lora_result.dtype]
  524. torch.testing.assert_close(lora_result,
  525. expected_result,
  526. rtol=rtol,
  527. atol=atol)
  528. # Check that resetting the lora weights succeeds
  529. for slot_idx in range(max_loras):
  530. lora_linear.reset_lora(slot_idx)
  531. inputs, index_mapping, prompt_mapping = create_random_inputs(
  532. active_lora_ids=[0],
  533. num_inputs=32 * num_loras,
  534. input_size=(1, 4096),
  535. input_range=(0, 1),
  536. input_type=torch.float16,
  537. )
  538. lora_mapping = LoRAMapping(index_mapping,
  539. prompt_mapping,
  540. is_prefill=stage)
  541. punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
  542. 512, lora_config.lora_extra_vocab_size)
  543. lora_result = lora_linear(torch.cat(inputs))[0]
  544. expected_result = linear(torch.cat(inputs))[0]
  545. rtol, atol = TOLERANCES[lora_result.dtype]
  546. torch.testing.assert_close(lora_result,
  547. expected_result,
  548. rtol=rtol,
  549. atol=atol)
  550. @torch.inference_mode()
  551. @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
  552. @pytest.mark.parametrize("orientation", ["row", "column"])
  553. @pytest.mark.parametrize("fully_shard", [True, False])
  554. @pytest.mark.parametrize("device", CUDA_DEVICES)
  555. @pytest.mark.parametrize("stage", STAGES)
  556. def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
  557. device, stage) -> None:
  558. torch.set_default_device(device)
  559. punica_wrapper = PunicaWrapper(8192, 256, device)
  560. max_loras = 8
  561. lora_config = LoRAConfig(max_loras=max_loras,
  562. max_lora_rank=8,
  563. fully_sharded_loras=fully_shard,
  564. lora_dtype=torch.float16)
  565. def create_random_linear_parallel_layer():
  566. if orientation == "row":
  567. linear = RowParallelLinear(4096,
  568. 4096,
  569. bias=False,
  570. params_dtype=torch.float16)
  571. linear.weight.data = torch.rand_like(linear.weight.data)
  572. lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
  573. else RowParallelLinearWithShardedLoRA(linear))
  574. else:
  575. linear = ColumnParallelLinear(4096,
  576. 4096,
  577. bias=False,
  578. params_dtype=torch.float16)
  579. linear.weight.data = torch.rand_like(linear.weight.data)
  580. lora_linear = (ColumnParallelLinearWithLoRA(linear)
  581. if not fully_shard else
  582. ColumnParallelLinearWithShardedLoRA(linear))
  583. lora_linear.create_lora_weights(max_loras, lora_config)
  584. return linear, lora_linear
  585. for i in range(10):
  586. set_random_seed(i)
  587. id_to_index = get_random_id_to_index(num_loras, max_loras)
  588. linear, lora_linear = create_random_linear_parallel_layer()
  589. lora_linear.set_mapping(punica_wrapper)
  590. lora_dict, _ = populate_loras(
  591. id_to_index,
  592. layer=lora_linear,
  593. layer_weights=linear.weight,
  594. )
  595. inputs, index_mapping, prompt_mapping = create_random_inputs(
  596. active_lora_ids=list(lora_dict.keys()),
  597. num_inputs=32 * num_loras,
  598. input_size=(1, 4096),
  599. input_range=(0, 1),
  600. input_type=torch.float16,
  601. )
  602. lora_mapping = LoRAMapping(index_mapping,
  603. prompt_mapping,
  604. is_prefill=stage)
  605. punica_wrapper.update_metadata(
  606. lora_mapping,
  607. id_to_index,
  608. max_loras,
  609. 512,
  610. lora_config.lora_extra_vocab_size,
  611. )
  612. lora_result = lora_linear(torch.cat(inputs))[0]
  613. expected_results: List[torch.Tensor] = []
  614. for input_, lora_id in zip(inputs, prompt_mapping):
  615. lora = lora_dict[lora_id]
  616. result = linear(input_)[0]
  617. result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
  618. expected_results.append(result)
  619. expected_result = torch.cat(expected_results)
  620. rtol, atol = TOLERANCES[lora_result.dtype]
  621. torch.testing.assert_close(lora_result,
  622. expected_result,
  623. rtol=rtol,
  624. atol=atol)
  625. # Check that resetting the lora weights succeeds
  626. for slot_idx in range(max_loras):
  627. lora_linear.reset_lora(slot_idx)
  628. inputs, index_mapping, prompt_mapping = create_random_inputs(
  629. active_lora_ids=[0],
  630. num_inputs=32 * num_loras,
  631. input_size=(1, 4096),
  632. input_range=(0, 1),
  633. input_type=torch.float16,
  634. )
  635. lora_mapping = LoRAMapping(index_mapping,
  636. prompt_mapping,
  637. is_prefill=stage)
  638. punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
  639. 512, lora_config.lora_extra_vocab_size)
  640. lora_result = lora_linear(torch.cat(inputs))[0]
  641. expected_result = linear(torch.cat(inputs))[0]
  642. rtol, atol = TOLERANCES[lora_result.dtype]
  643. torch.testing.assert_close(lora_result,
  644. expected_result,
  645. rtol=rtol,
  646. atol=atol)
  647. @torch.inference_mode()
  648. @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
  649. @pytest.mark.parametrize("repeats", [1, 2, 3])
  650. @pytest.mark.parametrize("fully_shard", [True, False])
  651. @pytest.mark.parametrize("device", CUDA_DEVICES)
  652. @pytest.mark.parametrize("stage", STAGES)
  653. def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
  654. device, stage) -> None:
  655. torch.set_default_device(device)
  656. punica_wrapper = PunicaWrapper(8192, 256, device)
  657. max_loras = 8
  658. lora_config = LoRAConfig(max_loras=max_loras,
  659. max_lora_rank=8,
  660. fully_sharded_loras=fully_shard,
  661. lora_dtype=torch.float16)
  662. def create_column_parallel_packed_layer():
  663. if repeats == 2:
  664. linear = MergedColumnParallelLinear(4096, [4096] * repeats,
  665. bias=False,
  666. params_dtype=torch.float16)
  667. linear.weight.data = torch.rand_like(linear.weight.data)
  668. lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
  669. if not fully_shard else
  670. MergedColumnParallelLinearWithShardedLoRA(linear))
  671. elif repeats == 3:
  672. linear = QKVParallelLinear(4096,
  673. 64,
  674. 32,
  675. bias=False,
  676. params_dtype=torch.float16)
  677. linear.weight.data = torch.rand_like(linear.weight.data)
  678. lora_linear = (MergedQKVParallelLinearWithLora(linear)
  679. if not fully_shard else
  680. MergedQKVParallelLinearWithShardedLora(linear))
  681. else:
  682. linear = QKVParallelLinear(4096,
  683. 64,
  684. 32,
  685. bias=False,
  686. params_dtype=torch.float16)
  687. linear.weight.data = torch.rand_like(linear.weight.data)
  688. lora_linear = QKVParallelLinearWithLora(
  689. linear
  690. ) if not fully_shard else QKVParallelLinearWithShardedLora(linear)
  691. @dataclass
  692. class FakeConfig:
  693. hidden_size = 4096
  694. num_key_value_heads = 32
  695. num_attention_heads = 32
  696. lora_linear.create_lora_weights(max_loras,
  697. lora_config,
  698. model_config=FakeConfig())
  699. return linear, lora_linear
  700. for i in range(10):
  701. set_random_seed(i)
  702. id_to_index = get_random_id_to_index(num_loras, max_loras)
  703. linear, lora_linear = create_column_parallel_packed_layer()
  704. lora_linear.set_mapping(punica_wrapper)
  705. lora_dict, sublora_dict = populate_loras(
  706. id_to_index,
  707. layer=lora_linear,
  708. layer_weights=linear.weight,
  709. repeats=repeats,
  710. )
  711. inputs, index_mapping, prompt_mapping = create_random_inputs(
  712. active_lora_ids=list(lora_dict.keys()),
  713. num_inputs=32 * num_loras,
  714. input_size=(1, 4096),
  715. input_range=(0, 1),
  716. input_type=torch.float16,
  717. )
  718. lora_mapping = LoRAMapping(index_mapping,
  719. prompt_mapping,
  720. is_prefill=stage)
  721. punica_wrapper.update_metadata(
  722. lora_mapping,
  723. id_to_index,
  724. max_loras,
  725. 512,
  726. lora_config.lora_extra_vocab_size,
  727. )
  728. lora_result = lora_linear(torch.cat(inputs))[0]
  729. expected_results: List[torch.Tensor] = []
  730. for input_, lora_id in zip(inputs, prompt_mapping):
  731. result = linear(input_)[0]
  732. subloras = sublora_dict[lora_id]
  733. for i, sublora in enumerate(subloras):
  734. result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
  735. (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
  736. sublora.scaling)
  737. expected_results.append(result)
  738. expected_result = torch.cat(expected_results)
  739. rtol, atol = TOLERANCES[lora_result.dtype]
  740. torch.testing.assert_close(lora_result,
  741. expected_result,
  742. rtol=rtol,
  743. atol=atol)
  744. for slot_idx in range(max_loras):
  745. lora_linear.reset_lora(slot_idx)
  746. inputs, index_mapping, prompt_mapping = create_random_inputs(
  747. active_lora_ids=[0],
  748. num_inputs=32 * num_loras,
  749. input_size=(1, 4096),
  750. input_range=(0, 1),
  751. input_type=torch.float16,
  752. )
  753. lora_mapping = LoRAMapping(index_mapping,
  754. prompt_mapping,
  755. is_prefill=stage)
  756. punica_wrapper.update_metadata(
  757. lora_mapping,
  758. id_to_index,
  759. max_loras,
  760. 512,
  761. lora_config.lora_extra_vocab_size,
  762. )
  763. # lora_linear.set_mapping(*mapping_info)
  764. lora_result = lora_linear(torch.cat(inputs))[0]
  765. expected_result = linear(torch.cat(inputs))[0]
  766. rtol, atol = TOLERANCES[lora_result.dtype]
  767. torch.testing.assert_close(lora_result,
  768. expected_result,
  769. rtol=rtol,
  770. atol=atol)
  771. @torch.inference_mode()
  772. @pytest.mark.parametrize("num_loras", [1, 8])
  773. @pytest.mark.parametrize("device", ["cuda"])
  774. @pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
  775. (6.0, 1.0)])
  776. @pytest.mark.parametrize("max_position", [11, 4096, 32768])
  777. @pytest.mark.parametrize("is_neox_style", [True, False])
  778. @pytest.mark.parametrize("rotary_dim", [None, 32])
  779. @pytest.mark.parametrize("head_size", [32, 108])
  780. @pytest.mark.parametrize("seq_len", [11, 1024])
  781. def test_rotary_embedding_long_context(dist_init, num_loras, device,
  782. scaling_factors, max_position,
  783. is_neox_style, rotary_dim, head_size,
  784. seq_len) -> None:
  785. dtype = torch.float16
  786. seed = 0
  787. torch.random.manual_seed(seed)
  788. if torch.cuda.is_available():
  789. torch.cuda.manual_seed(seed)
  790. torch.set_default_device(device)
  791. punica_wrapper = PunicaWrapper(8192, 256, device)
  792. max_loras = 8
  793. lora_config = LoRAConfig(max_loras=max_loras,
  794. max_lora_rank=8,
  795. long_lora_scaling_factors=scaling_factors,
  796. lora_dtype=dtype)
  797. if rotary_dim is None:
  798. rotary_dim = head_size
  799. base = 10000
  800. batch_size = 5 * num_loras
  801. num_heads = 7
  802. # Verify lora is equivalent to linear scaling rotary embedding.
  803. rope = get_rope(
  804. head_size,
  805. rotary_dim,
  806. max_position,
  807. base,
  808. is_neox_style,
  809. )
  810. lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
  811. lora_rope.set_mapping(punica_wrapper)
  812. lora_rope.create_lora_weights(max_loras, lora_config)
  813. linear_rope = get_rope(head_size, rotary_dim, max_position, base,
  814. is_neox_style, {
  815. "type": "linear",
  816. "factor": scaling_factors
  817. })
  818. linear_rope = linear_rope.to(dtype=dtype)
  819. id_to_index = get_random_id_to_index(num_loras, max_loras)
  820. _, index_mapping, prompt_mapping = create_random_inputs(
  821. active_lora_ids=[0],
  822. num_inputs=batch_size,
  823. input_size=(1, max_position),
  824. input_range=(0, lora_config.lora_extra_vocab_size),
  825. input_type=torch.float16,
  826. )
  827. lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
  828. long_lora_context = LongContextLoRAContext(list(scaling_factors),
  829. rotary_dim)
  830. next_expected_offset = 0
  831. # Make sure the offset is correct.
  832. scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
  833. for scaling_factor, offset in scaling_factor_to_offset.items():
  834. assert offset == next_expected_offset
  835. next_expected_offset += scaling_factor * max_position
  836. for i in range(len(scaling_factors)):
  837. long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
  838. scaling_factors[i], 0)
  839. punica_wrapper.update_metadata(
  840. lora_mapping,
  841. id_to_index,
  842. max_loras,
  843. 512,
  844. lora_config.lora_extra_vocab_size,
  845. long_lora_context=long_lora_context,
  846. )
  847. # lora_rope.set_mapping(*mapping_info)
  848. positions = torch.randint(0, max_position, (batch_size, seq_len))
  849. query = torch.randn(batch_size,
  850. seq_len,
  851. num_heads * head_size,
  852. dtype=dtype)
  853. key = torch.randn_like(query)
  854. ref_q, ref_k = linear_rope(positions, query, key)
  855. actual_q, actual_k = lora_rope(positions, query, key)
  856. torch.allclose(ref_q, actual_q)
  857. torch.allclose(ref_k, actual_k)
  858. @pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
  859. @pytest.mark.parametrize("seed", list(range(256)))
  860. def test_vocab_parallel_embedding_indices(tp_size, seed):
  861. random.seed(seed)
  862. vocab_size = random.randint(4000, 64000)
  863. added_vocab_size = random.randint(0, 1024)
  864. org_vocab_size = vocab_size - added_vocab_size
  865. last_org_vocab_end_index = 0
  866. last_added_vocab_end_index = org_vocab_size
  867. computed_vocab_size = 0
  868. computed_org_vocab_size = 0
  869. computed_added_vocab_size = 0
  870. vocab_size_padded = -1
  871. all_org_tokens: List[int] = []
  872. all_added_tokens: List[int] = []
  873. token_ids: List[int] = []
  874. for tp_rank in range(tp_size):
  875. with patch(
  876. "aphrodite.modeling.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
  877. return_value=tp_rank
  878. ), patch(
  879. "aphrodite.modeling.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
  880. return_value=tp_size):
  881. vocab_embedding = VocabParallelEmbedding(
  882. vocab_size, 1, org_num_embeddings=org_vocab_size)
  883. vocab_size_padded = vocab_embedding.num_embeddings_padded
  884. shard_indices = vocab_embedding.shard_indices
  885. # Assert that the ranges are contiguous
  886. assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
  887. assert (shard_indices.added_vocab_start_index ==
  888. last_added_vocab_end_index)
  889. # Ensure that we are not exceeding the vocab size
  890. computed_vocab_size += shard_indices.num_elements_padded
  891. computed_org_vocab_size += shard_indices.num_org_elements
  892. computed_added_vocab_size += shard_indices.num_added_elements
  893. # Ensure that the ranges are not overlapping
  894. all_org_tokens.extend(
  895. range(shard_indices.org_vocab_start_index,
  896. shard_indices.org_vocab_end_index))
  897. all_added_tokens.extend(
  898. range(shard_indices.added_vocab_start_index,
  899. shard_indices.added_vocab_end_index))
  900. token_ids.extend(
  901. range(shard_indices.org_vocab_start_index,
  902. shard_indices.org_vocab_end_index))
  903. token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
  904. shard_indices.num_org_elements))
  905. token_ids.extend(
  906. range(shard_indices.added_vocab_start_index,
  907. shard_indices.added_vocab_end_index))
  908. token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
  909. shard_indices.num_added_elements))
  910. last_org_vocab_end_index = shard_indices.org_vocab_end_index
  911. last_added_vocab_end_index = shard_indices.added_vocab_end_index
  912. assert computed_vocab_size == vocab_size_padded
  913. assert computed_org_vocab_size == org_vocab_size
  914. assert computed_added_vocab_size == added_vocab_size
  915. # Ensure that the ranges are not overlapping
  916. assert len(all_org_tokens) == len(set(all_org_tokens))
  917. assert len(all_added_tokens) == len(set(all_added_tokens))
  918. assert not set(all_org_tokens).intersection(set(all_added_tokens))
  919. token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
  920. reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
  921. assert reindex_mapping is not None or tp_size == 1
  922. if reindex_mapping is not None:
  923. reindexed_token_ids = token_ids_tensor[reindex_mapping]
  924. expected = torch.tensor(list(range(0, vocab_size)))
  925. assert reindexed_token_ids[:vocab_size].equal(expected)
  926. assert torch.all(reindexed_token_ids[vocab_size:] == -1)
  927. def test_get_masked_input_and_mask():
  928. x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
  929. # base tp 1 case, no padding
  930. modified_x, _ = get_masked_input_and_mask(x,
  931. org_vocab_start_index=0,
  932. org_vocab_end_index=8,
  933. added_vocab_start_index=8,
  934. added_vocab_end_index=12,
  935. num_org_vocab_padding=0)
  936. assert torch.equal(x, modified_x)
  937. # tp 2 case, no padding
  938. modified_x_rank_0, _ = get_masked_input_and_mask(x,
  939. org_vocab_start_index=0,
  940. org_vocab_end_index=4,
  941. added_vocab_start_index=8,
  942. added_vocab_end_index=10,
  943. num_org_vocab_padding=0)
  944. modified_x_rank_1, _ = get_masked_input_and_mask(
  945. x,
  946. org_vocab_start_index=4,
  947. org_vocab_end_index=8,
  948. added_vocab_start_index=10,
  949. added_vocab_end_index=12,
  950. num_org_vocab_padding=0)
  951. assert torch.equal(modified_x_rank_0,
  952. torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
  953. assert torch.equal(modified_x_rank_1,
  954. torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
  955. # tp 4 case, no padding
  956. modified_x_rank_0, _ = get_masked_input_and_mask(x,
  957. org_vocab_start_index=0,
  958. org_vocab_end_index=2,
  959. added_vocab_start_index=8,
  960. added_vocab_end_index=9,
  961. num_org_vocab_padding=0)
  962. modified_x_rank_1, _ = get_masked_input_and_mask(x,
  963. org_vocab_start_index=2,
  964. org_vocab_end_index=4,
  965. added_vocab_start_index=9,
  966. added_vocab_end_index=10,
  967. num_org_vocab_padding=0)
  968. modified_x_rank_2, _ = get_masked_input_and_mask(
  969. x,
  970. org_vocab_start_index=4,
  971. org_vocab_end_index=6,
  972. added_vocab_start_index=10,
  973. added_vocab_end_index=11,
  974. num_org_vocab_padding=0)
  975. modified_x_rank_3, _ = get_masked_input_and_mask(
  976. x,
  977. org_vocab_start_index=6,
  978. org_vocab_end_index=8,
  979. added_vocab_start_index=11,
  980. added_vocab_end_index=12,
  981. num_org_vocab_padding=0)
  982. assert torch.equal(modified_x_rank_0,
  983. torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
  984. assert torch.equal(modified_x_rank_1,
  985. torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
  986. assert torch.equal(modified_x_rank_2,
  987. torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
  988. assert torch.equal(modified_x_rank_3,
  989. torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
  990. # base tp 1 case, with padding
  991. modified_x, _ = get_masked_input_and_mask(x,
  992. org_vocab_start_index=0,
  993. org_vocab_end_index=8,
  994. added_vocab_start_index=8,
  995. added_vocab_end_index=12,
  996. num_org_vocab_padding=2)
  997. assert torch.equal(modified_x,
  998. torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
  999. # tp 2 case, with padding
  1000. modified_x_rank_0, _ = get_masked_input_and_mask(x,
  1001. org_vocab_start_index=0,
  1002. org_vocab_end_index=4,
  1003. added_vocab_start_index=8,
  1004. added_vocab_end_index=10,
  1005. num_org_vocab_padding=2)
  1006. modified_x_rank_1, _ = get_masked_input_and_mask(
  1007. x,
  1008. org_vocab_start_index=4,
  1009. org_vocab_end_index=8,
  1010. added_vocab_start_index=10,
  1011. added_vocab_end_index=12,
  1012. num_org_vocab_padding=2)
  1013. assert torch.equal(modified_x_rank_0,
  1014. torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
  1015. assert torch.equal(modified_x_rank_1,
  1016. torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
  1017. # tp 4 case, with padding
  1018. modified_x_rank_0, _ = get_masked_input_and_mask(x,
  1019. org_vocab_start_index=0,
  1020. org_vocab_end_index=2,
  1021. added_vocab_start_index=8,
  1022. added_vocab_end_index=9,
  1023. num_org_vocab_padding=2)
  1024. modified_x_rank_1, _ = get_masked_input_and_mask(x,
  1025. org_vocab_start_index=2,
  1026. org_vocab_end_index=4,
  1027. added_vocab_start_index=9,
  1028. added_vocab_end_index=10,
  1029. num_org_vocab_padding=2)
  1030. modified_x_rank_2, _ = get_masked_input_and_mask(
  1031. x,
  1032. org_vocab_start_index=4,
  1033. org_vocab_end_index=6,
  1034. added_vocab_start_index=10,
  1035. added_vocab_end_index=11,
  1036. num_org_vocab_padding=2)
  1037. modified_x_rank_3, _ = get_masked_input_and_mask(
  1038. x,
  1039. org_vocab_start_index=6,
  1040. org_vocab_end_index=8,
  1041. added_vocab_start_index=11,
  1042. added_vocab_end_index=12,
  1043. num_org_vocab_padding=2)
  1044. assert torch.equal(modified_x_rank_0,
  1045. torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
  1046. assert torch.equal(modified_x_rank_1,
  1047. torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
  1048. assert torch.equal(modified_x_rank_2,
  1049. torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
  1050. assert torch.equal(modified_x_rank_3,
  1051. torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))