utils.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. from typing import Dict, List, Optional
  2. import torch
  3. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  4. class DummyLoRAManager:
  5. def __init__(self):
  6. super().__init__()
  7. self._loras: Dict[str, LoRALayerWeights] = {}
  8. def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
  9. self._loras[module_name] = lora
  10. def get_module_lora(self, module_name: str) -> LoRALayerWeights:
  11. return self._loras[module_name]
  12. def init_random_lora(self,
  13. module_name: str,
  14. weight: torch.Tensor,
  15. rank: int = 8,
  16. generate_embeddings_tensor: int = 0):
  17. lora = LoRALayerWeights(
  18. module_name,
  19. rank=rank,
  20. lora_alpha=1,
  21. lora_a=torch.rand([weight.shape[1], rank],
  22. dtype=weight.dtype,
  23. device="cuda"),
  24. lora_b=torch.rand([rank, weight.shape[0]],
  25. dtype=weight.dtype,
  26. device="cuda"),
  27. )
  28. if generate_embeddings_tensor:
  29. lora.embeddings_tensor = torch.rand(5,
  30. generate_embeddings_tensor,
  31. dtype=weight.dtype,
  32. device="cuda")
  33. self.set_module_lora(module_name, lora)
  34. return lora
  35. def init_lora(self,
  36. module_name: str,
  37. input_dim: int,
  38. output_dim: int,
  39. rank=8,
  40. noop=False,
  41. embeddings_tensor=None):
  42. lora = LoRALayerWeights(
  43. module_name,
  44. rank=rank,
  45. lora_alpha=1,
  46. lora_a=torch.rand([input_dim, rank], device="cuda"),
  47. lora_b=torch.rand([rank, output_dim], device="cuda"),
  48. embeddings_tensor=embeddings_tensor,
  49. )
  50. self.set_module_lora(module_name, lora)
  51. return lora
  52. def reset_lora(self):
  53. self._loras = {}
  54. def init_packed_lora(
  55. self,
  56. module_name: str,
  57. input_dim: int,
  58. output_dims: List[int],
  59. noop_lora_index: Optional[List[int]] = None,
  60. rank: int = 8,
  61. ):
  62. base_loras: List[LoRALayerWeights] = []
  63. noop_lora_index_set = set(noop_lora_index or [])
  64. for i, out_dim in enumerate(output_dims):
  65. base_lora = self.init_lora(
  66. module_name + "_000_" + str(i),
  67. input_dim,
  68. out_dim,
  69. rank=rank,
  70. noop=i in noop_lora_index_set,
  71. )
  72. base_loras.append(base_lora)
  73. packed_lora = PackedLoRALayerWeights.pack(base_loras)
  74. self.set_module_lora(module_name, packed_lora)
  75. return packed_lora
  76. def assert_close(a, b):
  77. rtol, atol = {
  78. torch.float16: (6e-2, 6e-2),
  79. torch.bfloat16: (6e-2, 6e-2),
  80. torch.float32: (1e-2, 1e-2),
  81. }[a.dtype]
  82. torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
  83. def ref_torch_groupgemm(
  84. out_tensor,
  85. inputs,
  86. lora_weights,
  87. lora_indices_tensor,
  88. seq_len_tensor,
  89. batches,
  90. scaling,
  91. op_type,
  92. ) -> torch.Tensor:
  93. out_list = []
  94. current_offset = 0
  95. for lora_index, b_length in zip(range(batches), seq_len_tensor):
  96. input_weight = inputs[current_offset:b_length + current_offset, :]
  97. current_offset += b_length
  98. lora_weight = lora_weights[lora_indices_tensor[lora_index]]
  99. result = torch.nn.functional.linear(input_weight, lora_weight)
  100. result *= scaling
  101. out_list.append(result)
  102. cat_result = torch.cat(out_list, dim=0)
  103. if op_type == "expand":
  104. out_tensor += cat_result
  105. else:
  106. out_tensor.copy_(cat_result)
  107. return
  108. def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype,
  109. op_type, device):
  110. seq_len_tensor = torch.randint(seq_length, seq_length + 1,
  111. (batches, )).to(device)
  112. b_seq_start_loc = torch.cumsum(
  113. torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
  114. dim=0,
  115. ).to(device)
  116. total_tokens = seq_len_tensor.sum()
  117. if op_type == "shrink":
  118. inputs_tensor = torch.rand((total_tokens, hidden_size),
  119. dtype=dtype).to(device)
  120. lora_weights = torch.rand(
  121. (lora_nums, max_rank, hidden_size), # col-major
  122. dtype=dtype,
  123. ).to(device)
  124. # shrink op need atomic_add, so output is initinized by 0
  125. ref_out_tensor = torch.zeros((total_tokens, max_rank),
  126. dtype=dtype,
  127. device=inputs_tensor.device)
  128. # NOTE shrink kernel using torch.float32 as output type
  129. our_out_tensor = torch.zeros((total_tokens, max_rank),
  130. dtype=torch.float32).to(device)
  131. else:
  132. inputs_tensor = torch.rand(
  133. (total_tokens, max_rank),
  134. dtype=dtype,
  135. ).to(device)
  136. lora_weights = torch.rand(
  137. (lora_nums, hidden_size, max_rank), # col-major
  138. dtype=dtype,
  139. ).to(device)
  140. # expand op needs to complete y+=a@lora_b, so output is
  141. # initinized randomly
  142. ref_out_tensor = torch.rand(
  143. (total_tokens, hidden_size),
  144. dtype=dtype,
  145. ).to(device)
  146. # Ensure the same input.
  147. our_out_tensor = ref_out_tensor.clone()
  148. lora_indices_tensor = torch.randint(0,
  149. lora_nums - 1 if lora_nums > 1 else 1,
  150. (batches, )).to(device)
  151. indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
  152. current_offset = 0
  153. for b_id in range(batches):
  154. lora_index = lora_indices_tensor[b_id]
  155. indices[current_offset:current_offset +
  156. seq_len_tensor[b_id]].copy_(lora_index)
  157. current_offset += seq_len_tensor[b_id].item()
  158. return (
  159. inputs_tensor,
  160. lora_weights,
  161. our_out_tensor,
  162. ref_out_tensor,
  163. b_seq_start_loc,
  164. lora_indices_tensor,
  165. seq_len_tensor,
  166. indices,
  167. )
  168. def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank,
  169. seq_length, dtype, nslices, device):
  170. seq_len_tensor = torch.randint(seq_length, seq_length + 1,
  171. (batches, )).to(device)
  172. b_seq_start_loc = torch.cumsum(
  173. torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
  174. dim=0,
  175. ).to(device)
  176. total_tokens = seq_len_tensor.sum()
  177. inputs_tensor = torch.rand(
  178. (total_tokens, max_rank),
  179. dtype=dtype,
  180. ).to(device)
  181. lora_weights_lst = []
  182. for _ in range(nslices):
  183. lora_weights_lst.append(
  184. torch.rand(
  185. (lora_nums, hidden_size, max_rank), # col-major
  186. dtype=dtype,
  187. ).to(device))
  188. # expand op needs to complete y+=a@lora_b, so output is
  189. # initinized randomly
  190. ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices),
  191. dtype=dtype).to(device)
  192. # Ensure the same input.
  193. our_out_tensor = ref_out_tensor.clone()
  194. lora_indices_tensor = torch.randint(0,
  195. lora_nums - 1 if lora_nums > 1 else 1,
  196. (batches, ))
  197. indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
  198. current_offset = 0
  199. for b_id in range(batches):
  200. lora_index = lora_indices_tensor[b_id]
  201. indices[current_offset:current_offset +
  202. seq_len_tensor[b_id]] = lora_index.item()
  203. current_offset += seq_len_tensor[b_id].item()
  204. lora_indices_tensor = lora_indices_tensor.to(device)
  205. return (
  206. inputs_tensor,
  207. lora_weights_lst,
  208. our_out_tensor,
  209. ref_out_tensor,
  210. b_seq_start_loc,
  211. lora_indices_tensor,
  212. seq_len_tensor,
  213. indices,
  214. )