linear.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115
  1. from abc import abstractmethod
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter, UninitializedParameter
  7. # yapf: disable
  8. from aphrodite.distributed import (divide,
  9. get_current_tp_rank_partition_offset,
  10. get_current_tp_rank_partition_size,
  11. get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce)
  16. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  17. PackedAphroditeParameter,
  18. PerTensorScaleParameter)
  19. # yapf: enable
  20. from aphrodite.modeling.utils import set_weight_attrs
  21. from aphrodite.quantization.base_config import (QuantizationConfig,
  22. QuantizeMethodBase)
  23. WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"]
  24. def adjust_marlin_shard(param, shard_size, shard_offset):
  25. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  26. if marlin_tile_size is None:
  27. return shard_size, shard_offset
  28. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  29. def adjust_bitsandbytes_shard(param: Parameter,
  30. qkv_offsets: Dict[str, Tuple[int, int]],
  31. loaded_shard_id: str) -> Tuple[int, int]:
  32. """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
  33. total, _ = qkv_offsets["total"]
  34. orig_offset, orig_size = qkv_offsets[loaded_shard_id]
  35. quantized_total = param.data.shape[0]
  36. quantized_offset = orig_offset * quantized_total // total
  37. quantized_size = orig_size * quantized_total // total
  38. return quantized_size, quantized_offset
  39. def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
  40. """For fused modules (QKV and MLP) we have an array of length
  41. N that holds 1 scale for each "logical" matrix. So the param
  42. is an array of length N. The loaded_weight corresponds to
  43. one of the shards on disk. Here, we slice the param based on
  44. the shard_id for loading.
  45. """
  46. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  47. if isinstance(shard_id, str):
  48. shard_id = qkv_idxs[shard_id]
  49. elif not isinstance(shard_id, int):
  50. raise ValueError(f"Unknown Shard Id {shard_id}")
  51. # AutoFP8 scales do not have a shape
  52. # compressed-tensors scales do have a shape
  53. if len(loaded_weight.shape) != 0:
  54. assert loaded_weight.shape[0] == 1
  55. loaded_weight = loaded_weight[0]
  56. return param[shard_id], loaded_weight
  57. class LinearMethodBase(QuantizeMethodBase):
  58. """Base class for different (maybe quantized) linear methods."""
  59. @abstractmethod
  60. def create_weights(self, layer: torch.nn.Module,
  61. input_size_per_partition: int,
  62. output_partition_sizes: List[int], input_size: int,
  63. output_size: int, params_dtype: torch.dtype,
  64. **extra_weight_attrs):
  65. """Create weights for a linear layer.
  66. The weights will be set as attributes of the layer.
  67. Args:
  68. layer: The layer that is using the LinearMethodBase factory.
  69. input_size_per_partition: Size of the weight input dim on rank X.
  70. output_partition_sizes: Sizes of the output dim of each logical
  71. weight on rank X. E.g., output_partition_sizes for QKVLinear
  72. is a list contains the width of Wq, Wk, Wv on rank X.
  73. input_size: Size of the input dim of the weight across all ranks.
  74. output_size: Size of the output dim of the weight across all ranks.
  75. params_dtype: Datatype of the parameters.
  76. """
  77. raise NotImplementedError
  78. @abstractmethod
  79. def apply(self,
  80. layer: torch.nn.Module,
  81. x: torch.Tensor,
  82. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  83. """Apply the weights in layer to the input tensor.
  84. Expects create_weights to have been called before on the layer."""
  85. raise NotImplementedError
  86. class UnquantizedLinearMethod(LinearMethodBase):
  87. """Linear method without quantization."""
  88. def create_weights(self, layer: torch.nn.Module,
  89. input_size_per_partition: int,
  90. output_partition_sizes: List[int], input_size: int,
  91. output_size: int, params_dtype: torch.dtype,
  92. **extra_weight_attrs):
  93. weight = Parameter(torch.empty(sum(output_partition_sizes),
  94. input_size_per_partition,
  95. dtype=params_dtype),
  96. requires_grad=False)
  97. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  98. layer.register_parameter("weight", weight)
  99. set_weight_attrs(weight, extra_weight_attrs)
  100. def apply(self,
  101. layer: torch.nn.Module,
  102. x: torch.Tensor,
  103. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  104. return F.linear(x, layer.weight, bias)
  105. class LinearBase(torch.nn.Module):
  106. """Base linear layer.
  107. Args:
  108. input_size: input dimension of the linear layer.
  109. output_size: output dimension of the linear layer.
  110. bias: If true, add bias.
  111. skip_bias_add: If true, skip adding bias but instead return it.
  112. params_dtype: Data type for the parameters.
  113. quant_config: Quantization configure.
  114. """
  115. def __init__(
  116. self,
  117. input_size: int,
  118. output_size: int,
  119. skip_bias_add: bool = False,
  120. params_dtype: Optional[torch.dtype] = None,
  121. quant_config: Optional[QuantizationConfig] = None,
  122. prefix: str = "",
  123. ):
  124. super().__init__()
  125. # Keep input parameters
  126. self.input_size = input_size
  127. self.output_size = output_size
  128. self.skip_bias_add = skip_bias_add
  129. if params_dtype is None:
  130. params_dtype = torch.get_default_dtype()
  131. self.params_dtype = params_dtype
  132. if quant_config is None:
  133. self.quant_method: Optional[
  134. QuantizeMethodBase] = UnquantizedLinearMethod()
  135. else:
  136. self.quant_method = quant_config.get_quant_method(self,
  137. prefix=prefix)
  138. def forward(self, x: torch.Tensor) -> torch.Tensor:
  139. raise NotImplementedError
  140. class ReplicatedLinear(LinearBase):
  141. """Replicated linear layer.
  142. Args:
  143. input_size: input dimension of the linear layer.
  144. output_size: output dimension of the linear layer.
  145. bias: If true, add bias.
  146. skip_bias_add: If true, skip adding bias but instead return it.
  147. params_dtype: Data type for the parameters.
  148. quant_config: Quantization configure.
  149. prefix: The name of the layer in the state dict, including all parents
  150. (e.g. model.layers.0.qkv_proj)
  151. """
  152. def __init__(self,
  153. input_size: int,
  154. output_size: int,
  155. bias: bool = True,
  156. skip_bias_add: bool = False,
  157. params_dtype: Optional[torch.dtype] = None,
  158. quant_config: Optional[QuantizationConfig] = None,
  159. prefix: str = ""):
  160. super().__init__(input_size,
  161. output_size,
  162. skip_bias_add,
  163. params_dtype,
  164. quant_config,
  165. prefix=prefix)
  166. # All the linear layer supports quant method.
  167. assert self.quant_method is not None
  168. self.quant_method.create_weights(self,
  169. self.input_size, [self.output_size],
  170. self.input_size,
  171. self.output_size,
  172. self.params_dtype,
  173. prefix=prefix)
  174. if bias:
  175. self.bias = Parameter(
  176. torch.empty(self.output_size, dtype=self.params_dtype))
  177. set_weight_attrs(self.bias, {"output_dim": 0})
  178. else:
  179. self.register_parameter("bias", None)
  180. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  181. # If the weight on disk does not have a shape, give it one
  182. # (such scales for AutoFp8).
  183. if len(loaded_weight.shape) == 0:
  184. loaded_weight = loaded_weight.reshape(1)
  185. assert param.size() == loaded_weight.size()
  186. param.data.copy_(loaded_weight)
  187. def forward(self, x: torch.Tensor) -> torch.Tensor:
  188. bias = self.bias if not self.skip_bias_add else None
  189. assert self.quant_method is not None
  190. output = self.quant_method.apply(self, x, bias)
  191. output_bias = self.bias if self.skip_bias_add else None
  192. return output, output_bias
  193. def extra_repr(self) -> str:
  194. s = f"in_features={self.input_size}"
  195. s += f", output_features={self.output_size}"
  196. s += f", bias={self.bias is not None}"
  197. return s
  198. class ColumnParallelLinear(LinearBase):
  199. """Linear layer with column parallelism.
  200. The linear layer is defined as Y = XA + b. A is parallelized along
  201. its second dimension as A = [A_1, ..., A_p].
  202. Args:
  203. input_size: first dimension of matrix A.
  204. output_size: second dimension of matrix A.
  205. bias: If true, add bias.
  206. gather_output: If true, call all-gather on output and make Y available
  207. to all GPUs, otherwise, every GPU will have its output
  208. which is Y_i = XA_i
  209. skip_bias_add: This was added to enable performance optimizations where
  210. bias can be fused with other element-wise operations. we
  211. skip adding bias but instead return it.
  212. params_dtype: Data type for the parameters.
  213. quant_config: Quantization configure.
  214. output_sizes: list of output sizes packed into one output, like for QKV
  215. the list would be size 3.
  216. prefix: The name of the layer in the state dict, including all parents
  217. (e.g. model.layers.0.qkv_proj)
  218. """
  219. def __init__(self,
  220. input_size: int,
  221. output_size: int,
  222. bias: bool = True,
  223. gather_output: bool = False,
  224. skip_bias_add: bool = False,
  225. params_dtype: Optional[torch.dtype] = None,
  226. quant_config: Optional[QuantizationConfig] = None,
  227. output_sizes: Optional[List[int]] = None,
  228. prefix: str = ""):
  229. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  230. quant_config, prefix)
  231. self.gather_output = gather_output
  232. # Divide the weight matrix along the last dimension.
  233. tp_rank = get_tensor_model_parallel_rank()
  234. tp_size = get_tensor_model_parallel_world_size()
  235. assert self.quant_method is not None
  236. if quant_config is None:
  237. self.output_size_per_partition = get_current_tp_rank_partition_size(
  238. output_size, tp_rank, tp_size)
  239. else:
  240. self.output_size_per_partition = divide(self.output_size, tp_size)
  241. self.output_partition_sizes = [self.output_size_per_partition]
  242. # If QKV or MergedColumn, use output size of each partition.
  243. if hasattr(self, "output_sizes"):
  244. if quant_config is None:
  245. self.output_partition_sizes = [
  246. get_current_tp_rank_partition_size(output_size, tp_rank,
  247. tp_size)
  248. for output_size in self.output_sizes
  249. ]
  250. else:
  251. self.output_partition_sizes = [
  252. divide(output_size, tp_size)
  253. for output_size in self.output_sizes
  254. ]
  255. if output_sizes is None:
  256. output_sizes = [output_size]
  257. self.quant_method.create_weights(
  258. layer=self,
  259. input_size_per_partition=self.input_size,
  260. output_partition_sizes=self.output_partition_sizes,
  261. input_size=self.input_size,
  262. output_size=self.output_size,
  263. params_dtype=self.params_dtype,
  264. weight_loader=(
  265. self.weight_loader_v2 if self.quant_method.__class__.__name__
  266. in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
  267. prefix=prefix)
  268. if bias:
  269. self.bias = Parameter(
  270. torch.empty(self.output_size_per_partition,
  271. dtype=params_dtype))
  272. set_weight_attrs(self.bias, {
  273. "output_dim": 0,
  274. "weight_loader": self.weight_loader,
  275. })
  276. else:
  277. self.register_parameter("bias", None)
  278. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  279. tp_rank = get_tensor_model_parallel_rank()
  280. output_dim = getattr(param, "output_dim", None)
  281. # Special case for GGUF
  282. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  283. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  284. if is_gguf_weight_type:
  285. param.weight_type = loaded_weight.item()
  286. # Materialize GGUF UninitializedParameter
  287. if is_gguf_weight and isinstance(param, UninitializedParameter):
  288. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  289. param_data = param.data
  290. if output_dim is not None:
  291. shard_size = param_data.shape[output_dim]
  292. start_idx = tp_rank * shard_size
  293. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  294. shard_size)
  295. # Special case for loading scales off disk, which often do not
  296. # have a shape (such as in the case of AutoFP8).
  297. if len(loaded_weight.shape) == 0:
  298. loaded_weight = loaded_weight.reshape(1)
  299. assert param_data.shape == loaded_weight.shape
  300. param_data.copy_(loaded_weight)
  301. def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
  302. param.load_column_parallel_weight(loaded_weight=loaded_weight)
  303. def forward(self, input_):
  304. bias = self.bias if not self.skip_bias_add else None
  305. # Matrix multiply.
  306. assert self.quant_method is not None
  307. output_parallel = self.quant_method.apply(self, input_, bias)
  308. if self.gather_output:
  309. # All-gather across the partitions.
  310. output = tensor_model_parallel_all_gather(output_parallel)
  311. else:
  312. output = output_parallel
  313. output_bias = self.bias if self.skip_bias_add else None
  314. return output, output_bias
  315. def extra_repr(self) -> str:
  316. s = f"in_features={self.input_size}"
  317. s += f", output_features={self.output_size_per_partition}"
  318. s += f", bias={self.bias is not None}"
  319. s += f", tp_size={get_tensor_model_parallel_world_size()}"
  320. s += f", gather_output={self.gather_output}"
  321. return s
  322. class MergedColumnParallelLinear(ColumnParallelLinear):
  323. """Packed linear layers with column parallelism.
  324. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  325. along the output dimension. When the weight matrix is loaded, the
  326. different partitions are sharded separately.
  327. Args:
  328. input_size: input dimension of the linear layer.
  329. output_sizes: list of output dimensions of the linear layer.
  330. bias: If true, add bias.
  331. gather_output: If true, call all-gather on output and make the output
  332. available to all GPUs, otherwise, every GPU will have
  333. its own output.
  334. skip_bias_add: This was added to enable performance optimizations where
  335. bias can be fused with other element-wise operations. we
  336. skip adding bias but instead return it.
  337. params_dtype: Data type for the parameters.
  338. quant_config: Quantization configure.
  339. prefix: The name of the layer in the state dict, including all parents
  340. (e.g. model.layers.0.qkv_proj)
  341. """
  342. def __init__(self,
  343. input_size: int,
  344. output_sizes: List[int],
  345. bias: bool = True,
  346. gather_output: bool = False,
  347. skip_bias_add: bool = False,
  348. params_dtype: Optional[torch.dtype] = None,
  349. quant_config: Optional[QuantizationConfig] = None,
  350. prefix: str = ""):
  351. self.output_sizes = output_sizes
  352. self.quant_config = quant_config
  353. if quant_config is not None:
  354. tp_size = get_tensor_model_parallel_world_size()
  355. assert all(output_size % tp_size == 0
  356. for output_size in output_sizes)
  357. super().__init__(input_size=input_size,
  358. output_size=sum(output_sizes),
  359. bias=bias,
  360. gather_output=gather_output,
  361. skip_bias_add=skip_bias_add,
  362. params_dtype=params_dtype,
  363. quant_config=quant_config,
  364. prefix=prefix)
  365. def weight_loader(self,
  366. param: Parameter,
  367. loaded_weight: torch.Tensor,
  368. loaded_shard_id: Optional[int] = None):
  369. # Special case for GGUF
  370. # initialize GGUF param after we know the quantize type
  371. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  372. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  373. if is_gguf_weight_type:
  374. param.data[loaded_shard_id].copy_(loaded_weight)
  375. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  376. return
  377. if is_gguf_weight and isinstance(param, UninitializedParameter):
  378. from gguf.constants import GGML_QUANT_SIZES
  379. ori_shape = param.tensor_shape
  380. weight_types = self.qweight_type.shard_weight_type.values()
  381. row_size = []
  382. for weight_type in weight_types:
  383. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  384. row_size.append(ori_shape[1] // block_size * type_size)
  385. q_shape = (ori_shape[0], max(row_size))
  386. param.materialize(q_shape, dtype=loaded_weight.dtype)
  387. param_data = param.data
  388. output_dim = getattr(param, "output_dim", None)
  389. # Special case for AQLM codebooks.
  390. is_metadata = getattr(param, "is_metadata", False)
  391. # Special case for per-tensor scale to load scalar into fused array.
  392. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  393. if loaded_shard_id is None:
  394. # Loaded weight is already fused on disk (qkv/mlp).
  395. if output_dim is None:
  396. if needs_scalar_to_array:
  397. param_data, loaded_weight = adjust_scalar_to_fused_array(
  398. param_data, loaded_weight, 0)
  399. assert param_data.shape == loaded_weight.shape
  400. param_data.copy_(loaded_weight)
  401. return
  402. current_shard_offset = 0
  403. shard_offsets: List[Tuple[int, int, int]] = []
  404. for i, output_size in enumerate(self.output_sizes):
  405. shard_offsets.append((i, current_shard_offset, output_size))
  406. current_shard_offset += output_size
  407. packed_dim = getattr(param, "packed_dim", None)
  408. for shard_id, shard_offset, shard_size in shard_offsets:
  409. # Special case for Quantization.
  410. # If quantized, we need to adjust the offset and size to account
  411. # for the packing.
  412. if packed_dim == output_dim:
  413. shard_size = shard_size // param.pack_factor
  414. shard_offset = shard_offset // param.pack_factor
  415. # Special case for Marlin.
  416. shard_size, shard_offset = adjust_marlin_shard(
  417. param, shard_size, shard_offset)
  418. loaded_weight_shard = loaded_weight.narrow(
  419. output_dim, shard_offset, shard_size)
  420. self.weight_loader(param, loaded_weight_shard, shard_id)
  421. return
  422. assert loaded_shard_id < len(self.output_sizes)
  423. tp_rank = get_tensor_model_parallel_rank()
  424. tp_size = get_tensor_model_parallel_world_size()
  425. if output_dim is not None:
  426. if self.quant_config is None:
  427. shard_offset = sum(
  428. get_current_tp_rank_partition_size(output_size, tp_rank,
  429. tp_size)
  430. for output_size in self.output_sizes[:loaded_shard_id])
  431. shard_size = get_current_tp_rank_partition_size(
  432. self.output_sizes[loaded_shard_id], tp_rank, tp_size)
  433. else:
  434. shard_offset = sum(
  435. self.output_sizes[:loaded_shard_id]) // tp_size
  436. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  437. # Special case for quantization.
  438. # If quantized, we need to adjust the offset and size to account
  439. # for the packing.
  440. packed_dim = getattr(param, "packed_dim", None)
  441. if packed_dim == output_dim:
  442. shard_size = shard_size // param.pack_factor
  443. shard_offset = shard_offset // param.pack_factor
  444. # Special case for Marlin.
  445. shard_size, shard_offset = adjust_marlin_shard(
  446. param, shard_size, shard_offset)
  447. use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
  448. if use_bitsandbytes:
  449. shard_size = loaded_weight.shape[output_dim]
  450. shard_offset = loaded_weight.shape[output_dim] * \
  451. loaded_shard_id
  452. if is_gguf_weight:
  453. tp_size = get_tensor_model_parallel_world_size()
  454. output_dim = getattr(param, "output_dim", None)
  455. shard_shape = list(loaded_weight.shape)
  456. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  457. param.shard_id.append(loaded_shard_id)
  458. param.shard_size[loaded_shard_id] = shard_shape
  459. input_dim = getattr(param, "input_dim", None)
  460. input_size = loaded_weight.shape[input_dim]
  461. param_data = param_data.narrow(input_dim, 0, input_size)
  462. param_data = param_data.narrow(output_dim, shard_offset,
  463. shard_size)
  464. if self.quant_config is None:
  465. start_idx = get_current_tp_rank_partition_offset(
  466. loaded_weight.shape[output_dim], tp_rank, tp_size)
  467. else:
  468. start_idx = tp_rank * shard_size
  469. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  470. shard_size)
  471. # Special case for AQLM codebooks.
  472. elif is_metadata:
  473. # metadata indicates fixed size concatenated along dim 0
  474. shard_size = loaded_weight.shape[0]
  475. shard_offset = loaded_shard_id * shard_size
  476. param_data = param_data.narrow(0, shard_offset, shard_size)
  477. # Special case for per-tensor scales in fused case.
  478. elif needs_scalar_to_array:
  479. param_data, loaded_weight = adjust_scalar_to_fused_array(
  480. param_data, loaded_weight, loaded_shard_id)
  481. else:
  482. ignore_warning = getattr(param, "ignore_warning", False)
  483. if not ignore_warning:
  484. logger.warning(
  485. "Loading a weight without `output_dim` attribute in "
  486. "MergedColumnParallelLinear, assume the weight is "
  487. "the same for all partitions.")
  488. assert param_data.shape == loaded_weight.shape
  489. param_data.copy_(loaded_weight)
  490. def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
  491. loaded_weight: torch.Tensor):
  492. """
  493. Handle special case for models where MLP layers are already
  494. fused on disk. In this case, we have no shard id. This function
  495. determines the shard id by splitting these layers and then calls
  496. the weight loader using the shard id.
  497. An example of a model with these fused layers:
  498. https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
  499. """
  500. current_shard_offset = 0
  501. shard_offsets: List[Tuple[int, int, int]] = []
  502. for i, output_size in enumerate(self.output_sizes):
  503. shard_offsets.append((i, current_shard_offset, output_size))
  504. current_shard_offset += output_size
  505. for shard_id, shard_offset, shard_size in shard_offsets:
  506. # Special case for Quantization.
  507. # If quantized, we need to adjust the offset and size to account
  508. # for the packing.
  509. if isinstance(param, PackedAphroditeParameter
  510. ) and param.packed_dim == param.output_dim:
  511. param.adjust_shard_indexes_for_packing(
  512. shard_size=shard_size, shard_offset=shard_offset)
  513. loaded_weight_shard = loaded_weight.narrow(param.output_dim,
  514. shard_offset,
  515. shard_size)
  516. self.weight_loader_v2(param, loaded_weight_shard, shard_id)
  517. def weight_loader_v2(self,
  518. param: BaseAphroditeParameter,
  519. loaded_weight: torch.Tensor,
  520. loaded_shard_id: Optional[int] = None):
  521. if loaded_shard_id is None:
  522. if isinstance(param, PerTensorScaleParameter):
  523. param.load_merged_column_weight(loaded_weight=loaded_weight,
  524. shard_id=0)
  525. return
  526. elif type(param) is BaseAphroditeParameter:
  527. param.load_merged_column_weight(loaded_weight=loaded_weight)
  528. return
  529. self._load_fused_module_from_checkpoint(param, loaded_weight)
  530. return
  531. assert loaded_shard_id < len(self.output_sizes)
  532. tp_size = get_tensor_model_parallel_world_size()
  533. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  534. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  535. param.load_merged_column_weight(loaded_weight=loaded_weight,
  536. shard_id=loaded_shard_id,
  537. shard_offset=shard_offset,
  538. shard_size=shard_size)
  539. class QKVParallelLinear(ColumnParallelLinear):
  540. """Linear layers for the attention's QKV transformation.
  541. Linear layers for the linear transformation of the query, key, and value
  542. vectors in the attention layer. The weight matrix is concatenated along
  543. the output dimension. The layer is parallelized along the head dimension.
  544. When the number of key/value heads is smaller than the number of query
  545. heads (e.g., multi-query/grouped-query attention), the key/value head may
  546. be replicated while the query heads are partitioned.
  547. Args:
  548. hidden_size: input hidden state size of the transformer.
  549. head_size: size of each attention head.
  550. total_num_heads: total number of attention query heads.
  551. total_num_kv_heads: total number of attention key/value heads. If
  552. None, assume total_num_kv_heads = total_num_heads.
  553. bias: If true, add bias.
  554. skip_bias_add: This was added to enable performance optimizations where
  555. bias can be fused with other element-wise operations. we
  556. skip adding bias but instead return it.
  557. params_dtype: Data type for the parameters.
  558. quant_config: Quantization configure.
  559. prefix: The name of the layer in the state dict, including all parents
  560. (e.g. model.layers.0.qkv_proj)
  561. """
  562. def __init__(self,
  563. hidden_size: int,
  564. head_size: int,
  565. total_num_heads: int,
  566. total_num_kv_heads: Optional[int] = None,
  567. bias: bool = True,
  568. skip_bias_add: bool = False,
  569. params_dtype: Optional[torch.dtype] = None,
  570. quant_config: Optional[QuantizationConfig] = None,
  571. prefix: str = ""):
  572. self.hidden_size = hidden_size
  573. self.head_size = head_size
  574. self.total_num_heads = total_num_heads
  575. self.quant_config = quant_config
  576. if total_num_kv_heads is None:
  577. total_num_kv_heads = total_num_heads
  578. self.total_num_kv_heads = total_num_kv_heads
  579. # Divide the weight matrix along the last dimension.
  580. tp_size = get_tensor_model_parallel_world_size()
  581. tp_rank = get_tensor_model_parallel_rank()
  582. if quant_config is None:
  583. self.num_heads_per_kv_head = (self.total_num_heads //
  584. self.total_num_kv_heads)
  585. self.num_kv_heads = get_current_tp_rank_partition_size(
  586. self.total_num_kv_heads, tp_rank, tp_size)
  587. self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head
  588. self.num_kv_head_replicas = 1
  589. else:
  590. self.num_heads = divide(self.total_num_heads, tp_size)
  591. if tp_size >= self.total_num_kv_heads:
  592. self.num_kv_heads = 1
  593. self.num_kv_head_replicas = divide(tp_size,
  594. self.total_num_kv_heads)
  595. elif tp_size < self.total_num_kv_heads and quant_config is not None:
  596. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  597. self.num_kv_head_replicas = 1
  598. input_size = self.hidden_size
  599. output_size = (self.num_heads +
  600. 2 * self.num_kv_heads) * tp_size * self.head_size
  601. self.output_sizes = [
  602. self.num_heads * self.head_size * tp_size, # q_proj
  603. self.num_kv_heads * self.head_size * tp_size, # k_proj
  604. self.num_kv_heads * self.head_size * tp_size, # v_proj
  605. ]
  606. super().__init__(input_size=input_size,
  607. output_size=output_size,
  608. bias=bias,
  609. gather_output=False,
  610. skip_bias_add=skip_bias_add,
  611. params_dtype=params_dtype,
  612. quant_config=quant_config,
  613. prefix=prefix)
  614. def _get_shard_offset_mapping(self, loaded_shard_id: str):
  615. shard_offset_mapping = {
  616. "q": 0,
  617. "k": self.num_heads * self.head_size,
  618. "v": (self.num_heads + self.num_kv_heads) * self.head_size,
  619. "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
  620. }
  621. return shard_offset_mapping.get(loaded_shard_id)
  622. def _get_shard_size_mapping(self, loaded_shard_id: str):
  623. shard_size_mapping = {
  624. "q": self.num_heads * self.head_size,
  625. "k": self.num_kv_heads * self.head_size,
  626. "v": self.num_kv_heads * self.head_size,
  627. }
  628. return shard_size_mapping.get(loaded_shard_id)
  629. def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
  630. loaded_weight: torch.Tensor):
  631. """
  632. Handle special case for models where QKV layers are already
  633. fused on disk. In this case, we have no shard id. This function
  634. determmines the shard id by splitting these layers and then calls
  635. the weight loader using the shard id.
  636. An example of a model with these fused layers:
  637. https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
  638. """
  639. shard_offsets = [
  640. # (shard_id, shard_offset, shard_size)
  641. ("q", 0, self.total_num_heads * self.head_size),
  642. ("k", self.total_num_heads * self.head_size,
  643. self.total_num_kv_heads * self.head_size),
  644. ("v",
  645. (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
  646. self.total_num_kv_heads * self.head_size),
  647. ]
  648. for shard_id, shard_offset, shard_size in shard_offsets:
  649. # Special case for Quantization.
  650. # If quantized, we need to adjust the offset and size to account
  651. # for the packing.
  652. if isinstance(param, PackedAphroditeParameter
  653. ) and param.packed_dim == param.output_dim:
  654. param.adjust_shard_indexes_for_packing(
  655. shard_size=shard_size, shard_offset=shard_offset)
  656. loaded_weight_shard = loaded_weight.narrow(param.output_dim,
  657. shard_offset,
  658. shard_size)
  659. self.weight_loader_v2(param, loaded_weight_shard, shard_id)
  660. def weight_loader_v2(self,
  661. param: BaseAphroditeParameter,
  662. loaded_weight: torch.Tensor,
  663. loaded_shard_id: Optional[str] = None):
  664. if loaded_shard_id is None: # special case for certain models
  665. if isinstance(param, PerTensorScaleParameter):
  666. param.load_merged_column_weight(loaded_weight=loaded_weight,
  667. shard_id=0)
  668. return
  669. elif type(param) is BaseAphroditeParameter:
  670. param.load_merged_column_weight(loaded_weight=loaded_weight)
  671. return
  672. self._load_fused_module_from_checkpoint(param, loaded_weight)
  673. return
  674. assert loaded_shard_id in ["q", "k", "v"]
  675. shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
  676. shard_size = self._get_shard_size_mapping(loaded_shard_id)
  677. param.load_qkv_weight(loaded_weight=loaded_weight,
  678. num_heads=self.num_kv_head_replicas,
  679. shard_id=loaded_shard_id,
  680. shard_offset=shard_offset,
  681. shard_size=shard_size)
  682. def weight_loader(self,
  683. param: Parameter,
  684. loaded_weight: torch.Tensor,
  685. loaded_shard_id: Optional[str] = None):
  686. # Special case for GGUF
  687. # initialize GGUF param after we know the quantize type
  688. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  689. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  690. if is_gguf_weight_type and loaded_shard_id is not None:
  691. idx_map = {"q": 0, "k": 1, "v": 2}
  692. param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
  693. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  694. return
  695. if is_gguf_weight and isinstance(param, UninitializedParameter):
  696. from gguf.constants import GGML_QUANT_SIZES
  697. ori_shape = param.tensor_shape
  698. weight_types = self.qweight_type.shard_weight_type.values()
  699. row_size = []
  700. for weight_type in weight_types:
  701. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  702. row_size.append(ori_shape[1] // block_size * type_size)
  703. q_shape = (ori_shape[0], max(row_size))
  704. param.materialize(q_shape, dtype=loaded_weight.dtype)
  705. param_data = param.data
  706. output_dim = getattr(param, "output_dim", None)
  707. # Special case for AQLM codebooks.
  708. is_metadata = getattr(param, "is_metadata", False)
  709. # Special case for per-tensor scales in fused case.
  710. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  711. if loaded_shard_id is None:
  712. # Loaded weight is already fused on disk (qkv/mlp).
  713. if output_dim is None:
  714. if needs_scalar_to_array:
  715. param_data, loaded_weight = adjust_scalar_to_fused_array(
  716. param_data, loaded_weight, 0)
  717. assert param_data.shape == loaded_weight.shape
  718. param_data.copy_(loaded_weight)
  719. return
  720. shard_offsets = [
  721. # (shard_id, shard_offset, shard_size)
  722. ("q", 0, self.total_num_heads * self.head_size),
  723. ("k", self.total_num_heads * self.head_size,
  724. self.total_num_kv_heads * self.head_size),
  725. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  726. self.head_size, self.total_num_kv_heads * self.head_size),
  727. ]
  728. packed_dim = getattr(param, "packed_dim", None)
  729. for shard_id, shard_offset, shard_size in shard_offsets:
  730. # Special case for Quantized Weights.
  731. # If quantized, we need to adjust the offset and size to account
  732. # for the packing.
  733. if packed_dim == output_dim:
  734. shard_size = shard_size // param.pack_factor
  735. shard_offset = shard_offset // param.pack_factor
  736. # Special case for Marlin.
  737. shard_size, shard_offset = adjust_marlin_shard(
  738. param, shard_size, shard_offset)
  739. loaded_weight_shard = loaded_weight.narrow(
  740. output_dim, shard_offset, shard_size)
  741. self.weight_loader(param, loaded_weight_shard, shard_id)
  742. return
  743. tp_rank = get_tensor_model_parallel_rank()
  744. assert loaded_shard_id in ["q", "k", "v"]
  745. # If output dim is defined, use the default loading process.
  746. if output_dim is not None:
  747. if loaded_shard_id == "q":
  748. shard_offset = 0
  749. shard_size = self.num_heads * self.head_size
  750. if self.quant_config is None:
  751. multiple_of = self.head_size * self.num_heads_per_kv_head
  752. elif loaded_shard_id == "k":
  753. shard_offset = self.num_heads * self.head_size
  754. shard_size = self.num_kv_heads * self.head_size
  755. if self.quant_config is None:
  756. multiple_of = self.head_size
  757. elif loaded_shard_id == "v":
  758. shard_offset = (self.num_heads +
  759. self.num_kv_heads) * self.head_size
  760. shard_size = self.num_kv_heads * self.head_size
  761. if self.quant_config is None:
  762. multiple_of = self.head_size
  763. # Special case for Quantized Weights.
  764. # If quantized, we need to adjust the offset and size to account
  765. # for the packing.
  766. packed_dim = getattr(param, "packed_dim", None)
  767. if packed_dim == output_dim:
  768. shard_size = shard_size // param.pack_factor
  769. shard_offset = shard_offset // param.pack_factor
  770. if self.quant_config is None:
  771. multiple_of = multiple_of // param.pack_factor
  772. # Special case for Marlin.
  773. shard_size, shard_offset = adjust_marlin_shard(
  774. param, shard_size, shard_offset)
  775. use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
  776. if use_bitsandbytes:
  777. orig_qkv_offsets = {
  778. "q": (0, self.num_heads * self.head_size),
  779. "k": (self.num_heads * self.head_size,
  780. self.num_kv_heads * self.head_size),
  781. "v":
  782. ((self.num_heads + self.num_kv_heads) * self.head_size,
  783. self.num_kv_heads * self.head_size),
  784. "total":
  785. ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
  786. 0)
  787. }
  788. shard_size, shard_offset = adjust_bitsandbytes_shard(
  789. param, orig_qkv_offsets, loaded_shard_id)
  790. if is_gguf_weight:
  791. tp_size = get_tensor_model_parallel_world_size()
  792. output_dim = getattr(param, "output_dim", None)
  793. shard_shape = list(loaded_weight.shape)
  794. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  795. param.shard_id.append(loaded_shard_id)
  796. param.shard_size[loaded_shard_id] = shard_shape
  797. input_dim = getattr(param, "input_dim", None)
  798. input_size = loaded_weight.shape[input_dim]
  799. param_data = param_data.narrow(input_dim, 0, input_size)
  800. param_data = param_data.narrow(output_dim, shard_offset,
  801. shard_size)
  802. if self.quant_config is None:
  803. tp_size = get_tensor_model_parallel_world_size()
  804. total_size = loaded_weight.shape[output_dim]
  805. start_idx = get_current_tp_rank_partition_offset(
  806. total_size, tp_rank, tp_size, multiple_of=multiple_of)
  807. else:
  808. if loaded_shard_id == "q":
  809. shard_id = tp_rank
  810. else:
  811. shard_id = tp_rank // self.num_kv_head_replicas
  812. start_idx = shard_id * shard_size
  813. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  814. shard_size)
  815. # Special case for for AQLM codebooks.
  816. elif is_metadata:
  817. # metadata indicates fixed size concatenated along dim 0
  818. shard_size = loaded_weight.shape[0]
  819. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  820. param_data = param_data.narrow(0, shard_index * shard_size,
  821. shard_size)
  822. # Special case for per-tensor scales in fused case.
  823. elif needs_scalar_to_array:
  824. param_data, loaded_weight = adjust_scalar_to_fused_array(
  825. param_data, loaded_weight, loaded_shard_id)
  826. else:
  827. ignore_warning = getattr(param, "ignore_warning", False)
  828. if not ignore_warning:
  829. logger.warning(
  830. "Loading a weight without `output_dim` attribute in "
  831. "QKVParallelLinear, assume the weight is the same "
  832. "for all partitions.")
  833. assert param_data.shape == loaded_weight.shape
  834. param_data.copy_(loaded_weight)
  835. class RowParallelLinear(LinearBase):
  836. """Linear layer with row parallelism.
  837. The linear layer is defined as Y = XA + b. A is parallelized along
  838. its first dimension and X along its second dimension as:
  839. - -
  840. | A_1 |
  841. | . |
  842. A = | . | X = [X_1, ..., X_p]
  843. | . |
  844. | A_p |
  845. - -
  846. Arguments:
  847. input_size: first dimension of matrix A.
  848. output_size: second dimension of matrix A.
  849. bias: If true, add bias. Note that bias is not parallelized.
  850. input_is_parallel: If true, we assume that the input is already
  851. split across the GPUs and we do not split
  852. again.
  853. skip_bias_add: This was added to enable performance optimization where
  854. bias can be fused with other element-wise operations.
  855. We skip adding bias but instead return it.
  856. params_dtype: Data type for the parameters.
  857. quant_config: Quantization configure.
  858. partition_multiple_of: Partitions will be divided,
  859. so each partition is a multiple of this number.
  860. """
  861. def __init__(self,
  862. input_size: int,
  863. output_size: int,
  864. bias: bool = True,
  865. input_is_parallel: bool = True,
  866. skip_bias_add: bool = False,
  867. params_dtype: Optional[torch.dtype] = None,
  868. reduce_results: bool = True,
  869. quant_config: Optional[QuantizationConfig] = None,
  870. partition_multiple_of: int = 1,
  871. prefix: str = ""):
  872. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  873. quant_config, prefix)
  874. self.input_is_parallel = input_is_parallel
  875. self.reduce_results = reduce_results
  876. self.quant_config = quant_config
  877. # Divide the weight matrix along the last dimension.
  878. self.tp_rank = get_tensor_model_parallel_rank()
  879. self.tp_size = get_tensor_model_parallel_world_size()
  880. self.tp_rank = get_tensor_model_parallel_rank()
  881. if quant_config is None:
  882. self.partition_multiple_of = partition_multiple_of
  883. self.input_size_per_partition = get_current_tp_rank_partition_size(
  884. input_size, self.tp_rank, self.tp_size, partition_multiple_of)
  885. else:
  886. self.input_size_per_partition = divide(input_size, self.tp_size)
  887. assert self.quant_method is not None
  888. self.quant_method.create_weights(
  889. layer=self,
  890. input_size_per_partition=self.input_size_per_partition,
  891. output_partition_sizes=[self.output_size],
  892. input_size=self.input_size,
  893. output_size=self.output_size,
  894. params_dtype=self.params_dtype,
  895. weight_loader=(
  896. self.weight_loader_v2 if self.quant_method.__class__.__name__
  897. in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
  898. prefix=prefix)
  899. if not reduce_results and (bias and not skip_bias_add):
  900. raise ValueError("When not reduce the results, adding bias to the "
  901. "results can lead to incorrect results")
  902. if bias:
  903. self.bias = Parameter(
  904. torch.empty(self.output_size, dtype=params_dtype))
  905. set_weight_attrs(self.bias, {
  906. "output_dim": 0,
  907. "weight_loader": self.weight_loader,
  908. })
  909. else:
  910. self.register_parameter("bias", None)
  911. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  912. tp_size = get_tensor_model_parallel_world_size()
  913. input_dim = getattr(param, "input_dim", None)
  914. # Special case for GGUF
  915. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  916. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  917. if is_gguf_weight_type:
  918. param.weight_type = loaded_weight.item()
  919. # Materialize GGUF UninitializedParameter
  920. if is_gguf_weight and isinstance(param, UninitializedParameter):
  921. weight_shape = list(loaded_weight.shape)
  922. if input_dim:
  923. weight_shape[input_dim] = weight_shape[input_dim] // tp_size
  924. param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
  925. param_data = param.data
  926. if input_dim is not None:
  927. shard_size = param_data.shape[input_dim]
  928. if self.quant_config is None:
  929. start_idx = get_current_tp_rank_partition_offset(
  930. self.input_size,
  931. self.tp_rank,
  932. self.tp_size,
  933. multiple_of=self.partition_multiple_of)
  934. else:
  935. start_idx = self.tp_rank * shard_size
  936. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  937. shard_size)
  938. # Special case for loading scales off disk, which often do not
  939. # have a shape (such as in the case of AutoFP8).
  940. if len(loaded_weight.shape) == 0:
  941. loaded_weight = loaded_weight.reshape(1)
  942. assert param_data.shape == loaded_weight.shape
  943. param_data.copy_(loaded_weight)
  944. def weight_loader_v2(self, param: BaseAphroditeParameter,
  945. loaded_weight: torch.Tensor):
  946. param.load_row_parallel_weight(loaded_weight=loaded_weight)
  947. def forward(self, input_):
  948. if self.input_is_parallel:
  949. input_parallel = input_
  950. else:
  951. tp_rank = get_tensor_model_parallel_rank()
  952. splitted_input = split_tensor_along_last_dim(
  953. input_, num_partitions=self.tp_size)
  954. input_parallel = splitted_input[tp_rank].contiguous()
  955. # Matrix multiply.
  956. assert self.quant_method is not None
  957. # Only fuse bias add into GEMM for rank 0 (this ensures that
  958. # bias will not get added more than once in TP>1 case)
  959. bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
  960. output_parallel = self.quant_method.apply(self,
  961. input_parallel,
  962. bias=bias_)
  963. if self.reduce_results and self.tp_size > 1:
  964. output = tensor_model_parallel_all_reduce(output_parallel)
  965. else:
  966. output = output_parallel
  967. output_bias = self.bias if self.skip_bias_add else None
  968. return output, output_bias
  969. def extra_repr(self) -> str:
  970. s = f"input_features={self.input_size_per_partition}"
  971. s += f", output_features={self.output_size}"
  972. s += f", bias={self.bias is not None}"
  973. s += f", tp_size={self.tp_size}"
  974. s += f", reduce_results={self.reduce_results}"
  975. return s