linear.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135
  1. from abc import abstractmethod
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter, UninitializedParameter
  7. # yapf: disable
  8. from aphrodite.distributed import (divide,
  9. get_current_tp_rank_partition_offset,
  10. get_current_tp_rank_partition_size,
  11. get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce)
  16. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  17. PackedAphroditeParameter,
  18. PerTensorScaleParameter)
  19. # yapf: enable
  20. from aphrodite.modeling.utils import set_weight_attrs
  21. from aphrodite.quantization.base_config import (QuantizationConfig,
  22. QuantizeMethodBase)
  23. WEIGHT_LOADER_V2_SUPPORTED = [
  24. "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod",
  25. "AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod",
  26. "Fp8LinearMethod",
  27. ]
  28. def adjust_marlin_shard(param, shard_size, shard_offset):
  29. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  30. if marlin_tile_size is None:
  31. return shard_size, shard_offset
  32. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  33. def adjust_bitsandbytes_shard(param: Parameter,
  34. qkv_offsets: Dict[str, Tuple[int, int]],
  35. loaded_shard_id: str) -> Tuple[int, int]:
  36. """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
  37. total, _ = qkv_offsets["total"]
  38. orig_offset, orig_size = qkv_offsets[loaded_shard_id]
  39. quantized_total = param.data.shape[0]
  40. quantized_offset = orig_offset * quantized_total // total
  41. quantized_size = orig_size * quantized_total // total
  42. return quantized_size, quantized_offset
  43. def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
  44. """For fused modules (QKV and MLP) we have an array of length
  45. N that holds 1 scale for each "logical" matrix. So the param
  46. is an array of length N. The loaded_weight corresponds to
  47. one of the shards on disk. Here, we slice the param based on
  48. the shard_id for loading.
  49. """
  50. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  51. if isinstance(shard_id, str):
  52. shard_id = qkv_idxs[shard_id]
  53. elif not isinstance(shard_id, int):
  54. raise ValueError(f"Unknown Shard Id {shard_id}")
  55. # AutoFP8 scales do not have a shape
  56. # compressed-tensors scales do have a shape
  57. if len(loaded_weight.shape) != 0:
  58. assert loaded_weight.shape[0] == 1
  59. loaded_weight = loaded_weight[0]
  60. return param[shard_id], loaded_weight
  61. class LinearMethodBase(QuantizeMethodBase):
  62. """Base class for different (maybe quantized) linear methods."""
  63. @abstractmethod
  64. def create_weights(self, layer: torch.nn.Module,
  65. input_size_per_partition: int,
  66. output_partition_sizes: List[int], input_size: int,
  67. output_size: int, params_dtype: torch.dtype,
  68. **extra_weight_attrs):
  69. """Create weights for a linear layer.
  70. The weights will be set as attributes of the layer.
  71. Args:
  72. layer: The layer that is using the LinearMethodBase factory.
  73. input_size_per_partition: Size of the weight input dim on rank X.
  74. output_partition_sizes: Sizes of the output dim of each logical
  75. weight on rank X. E.g., output_partition_sizes for QKVLinear
  76. is a list contains the width of Wq, Wk, Wv on rank X.
  77. input_size: Size of the input dim of the weight across all ranks.
  78. output_size: Size of the output dim of the weight across all ranks.
  79. params_dtype: Datatype of the parameters.
  80. """
  81. raise NotImplementedError
  82. @abstractmethod
  83. def apply(self,
  84. layer: torch.nn.Module,
  85. x: torch.Tensor,
  86. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  87. """Apply the weights in layer to the input tensor.
  88. Expects create_weights to have been called before on the layer."""
  89. raise NotImplementedError
  90. class UnquantizedLinearMethod(LinearMethodBase):
  91. """Linear method without quantization."""
  92. def create_weights(self, layer: torch.nn.Module,
  93. input_size_per_partition: int,
  94. output_partition_sizes: List[int], input_size: int,
  95. output_size: int, params_dtype: torch.dtype,
  96. **extra_weight_attrs):
  97. weight = Parameter(torch.empty(sum(output_partition_sizes),
  98. input_size_per_partition,
  99. dtype=params_dtype),
  100. requires_grad=False)
  101. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  102. layer.register_parameter("weight", weight)
  103. set_weight_attrs(weight, extra_weight_attrs)
  104. def apply(self,
  105. layer: torch.nn.Module,
  106. x: torch.Tensor,
  107. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  108. return F.linear(x, layer.weight, bias)
  109. class LinearBase(torch.nn.Module):
  110. """Base linear layer.
  111. Args:
  112. input_size: input dimension of the linear layer.
  113. output_size: output dimension of the linear layer.
  114. bias: If true, add bias.
  115. skip_bias_add: If true, skip adding bias but instead return it.
  116. params_dtype: Data type for the parameters.
  117. quant_config: Quantization configure.
  118. """
  119. def __init__(
  120. self,
  121. input_size: int,
  122. output_size: int,
  123. skip_bias_add: bool = False,
  124. params_dtype: Optional[torch.dtype] = None,
  125. quant_config: Optional[QuantizationConfig] = None,
  126. prefix: str = "",
  127. ):
  128. super().__init__()
  129. # Keep input parameters
  130. self.input_size = input_size
  131. self.output_size = output_size
  132. self.skip_bias_add = skip_bias_add
  133. if params_dtype is None:
  134. params_dtype = torch.get_default_dtype()
  135. self.params_dtype = params_dtype
  136. if quant_config is None:
  137. self.quant_method: Optional[
  138. QuantizeMethodBase] = UnquantizedLinearMethod()
  139. else:
  140. self.quant_method = quant_config.get_quant_method(self,
  141. prefix=prefix)
  142. def forward(self, x: torch.Tensor) -> torch.Tensor:
  143. raise NotImplementedError
  144. class ReplicatedLinear(LinearBase):
  145. """Replicated linear layer.
  146. Args:
  147. input_size: input dimension of the linear layer.
  148. output_size: output dimension of the linear layer.
  149. bias: If true, add bias.
  150. skip_bias_add: If true, skip adding bias but instead return it.
  151. params_dtype: Data type for the parameters.
  152. quant_config: Quantization configure.
  153. prefix: The name of the layer in the state dict, including all parents
  154. (e.g. model.layers.0.qkv_proj)
  155. """
  156. def __init__(self,
  157. input_size: int,
  158. output_size: int,
  159. bias: bool = True,
  160. skip_bias_add: bool = False,
  161. params_dtype: Optional[torch.dtype] = None,
  162. quant_config: Optional[QuantizationConfig] = None,
  163. prefix: str = ""):
  164. super().__init__(input_size,
  165. output_size,
  166. skip_bias_add,
  167. params_dtype,
  168. quant_config,
  169. prefix=prefix)
  170. # All the linear layer supports quant method.
  171. assert self.quant_method is not None
  172. self.quant_method.create_weights(self,
  173. self.input_size, [self.output_size],
  174. self.input_size,
  175. self.output_size,
  176. self.params_dtype,
  177. prefix=prefix)
  178. if bias:
  179. self.bias = Parameter(
  180. torch.empty(self.output_size, dtype=self.params_dtype))
  181. set_weight_attrs(self.bias, {"output_dim": 0})
  182. else:
  183. self.register_parameter("bias", None)
  184. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  185. # If the weight on disk does not have a shape, give it one
  186. # (such scales for AutoFp8).
  187. if len(loaded_weight.shape) == 0:
  188. loaded_weight = loaded_weight.reshape(1)
  189. assert param.size() == loaded_weight.size()
  190. param.data.copy_(loaded_weight)
  191. def forward(self, x: torch.Tensor) -> torch.Tensor:
  192. bias = self.bias if not self.skip_bias_add else None
  193. assert self.quant_method is not None
  194. output = self.quant_method.apply(self, x, bias)
  195. output_bias = self.bias if self.skip_bias_add else None
  196. return output, output_bias
  197. def extra_repr(self) -> str:
  198. s = f"in_features={self.input_size}"
  199. s += f", output_features={self.output_size}"
  200. s += f", bias={self.bias is not None}"
  201. return s
  202. class ColumnParallelLinear(LinearBase):
  203. """Linear layer with column parallelism.
  204. The linear layer is defined as Y = XA + b. A is parallelized along
  205. its second dimension as A = [A_1, ..., A_p].
  206. Args:
  207. input_size: first dimension of matrix A.
  208. output_size: second dimension of matrix A.
  209. bias: If true, add bias.
  210. gather_output: If true, call all-gather on output and make Y available
  211. to all GPUs, otherwise, every GPU will have its output
  212. which is Y_i = XA_i
  213. skip_bias_add: This was added to enable performance optimizations where
  214. bias can be fused with other element-wise operations. we
  215. skip adding bias but instead return it.
  216. params_dtype: Data type for the parameters.
  217. quant_config: Quantization configure.
  218. output_sizes: list of output sizes packed into one output, like for QKV
  219. the list would be size 3.
  220. prefix: The name of the layer in the state dict, including all parents
  221. (e.g. model.layers.0.qkv_proj)
  222. """
  223. def __init__(self,
  224. input_size: int,
  225. output_size: int,
  226. bias: bool = True,
  227. gather_output: bool = False,
  228. skip_bias_add: bool = False,
  229. params_dtype: Optional[torch.dtype] = None,
  230. quant_config: Optional[QuantizationConfig] = None,
  231. output_sizes: Optional[List[int]] = None,
  232. prefix: str = ""):
  233. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  234. quant_config, prefix)
  235. self.gather_output = gather_output
  236. # Divide the weight matrix along the last dimension.
  237. tp_rank = get_tensor_model_parallel_rank()
  238. tp_size = get_tensor_model_parallel_world_size()
  239. assert self.quant_method is not None
  240. if quant_config is None:
  241. self.output_size_per_partition = get_current_tp_rank_partition_size(
  242. output_size, tp_rank, tp_size)
  243. else:
  244. self.output_size_per_partition = divide(self.output_size, tp_size)
  245. self.output_partition_sizes = [self.output_size_per_partition]
  246. # If QKV or MergedColumn, use output size of each partition.
  247. if hasattr(self, "output_sizes"):
  248. if quant_config is None:
  249. self.output_partition_sizes = [
  250. get_current_tp_rank_partition_size(output_size, tp_rank,
  251. tp_size)
  252. for output_size in self.output_sizes
  253. ]
  254. else:
  255. self.output_partition_sizes = [
  256. divide(output_size, tp_size)
  257. for output_size in self.output_sizes
  258. ]
  259. if output_sizes is None:
  260. output_sizes = [output_size]
  261. self.quant_method.create_weights(
  262. layer=self,
  263. input_size_per_partition=self.input_size,
  264. output_partition_sizes=self.output_partition_sizes,
  265. input_size=self.input_size,
  266. output_size=self.output_size,
  267. params_dtype=self.params_dtype,
  268. weight_loader=(
  269. self.weight_loader_v2 if self.quant_method.__class__.__name__
  270. in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
  271. prefix=prefix)
  272. if bias:
  273. self.bias = Parameter(
  274. torch.empty(self.output_size_per_partition,
  275. dtype=params_dtype))
  276. set_weight_attrs(self.bias, {
  277. "output_dim": 0,
  278. "weight_loader": self.weight_loader,
  279. })
  280. else:
  281. self.register_parameter("bias", None)
  282. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  283. tp_rank = get_tensor_model_parallel_rank()
  284. output_dim = getattr(param, "output_dim", None)
  285. # Special case for GGUF
  286. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  287. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  288. if is_gguf_weight_type:
  289. param.weight_type = loaded_weight.item()
  290. # Materialize GGUF UninitializedParameter
  291. if is_gguf_weight and isinstance(param, UninitializedParameter):
  292. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  293. param_data = param.data
  294. if output_dim is not None:
  295. shard_size = param_data.shape[output_dim]
  296. start_idx = tp_rank * shard_size
  297. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  298. shard_size)
  299. # Special case for loading scales off disk, which often do not
  300. # have a shape (such as in the case of AutoFP8).
  301. if len(loaded_weight.shape) == 0:
  302. loaded_weight = loaded_weight.reshape(1)
  303. assert param_data.shape == loaded_weight.shape
  304. param_data.copy_(loaded_weight)
  305. def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
  306. # Special case for loading scales off disk, which often do not
  307. # have a shape (such as in the case of AutoFP8).
  308. if len(loaded_weight.shape) == 0:
  309. assert loaded_weight.numel() == 1
  310. loaded_weight = loaded_weight.reshape(1)
  311. param.load_column_parallel_weight(loaded_weight=loaded_weight)
  312. def forward(self, input_):
  313. bias = self.bias if not self.skip_bias_add else None
  314. # Matrix multiply.
  315. assert self.quant_method is not None
  316. output_parallel = self.quant_method.apply(self, input_, bias)
  317. if self.gather_output:
  318. # All-gather across the partitions.
  319. output = tensor_model_parallel_all_gather(output_parallel)
  320. else:
  321. output = output_parallel
  322. output_bias = self.bias if self.skip_bias_add else None
  323. return output, output_bias
  324. def extra_repr(self) -> str:
  325. s = f"in_features={self.input_size}"
  326. s += f", output_features={self.output_size_per_partition}"
  327. s += f", bias={self.bias is not None}"
  328. s += f", tp_size={get_tensor_model_parallel_world_size()}"
  329. s += f", gather_output={self.gather_output}"
  330. return s
  331. class MergedColumnParallelLinear(ColumnParallelLinear):
  332. """Packed linear layers with column parallelism.
  333. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  334. along the output dimension. When the weight matrix is loaded, the
  335. different partitions are sharded separately.
  336. Args:
  337. input_size: input dimension of the linear layer.
  338. output_sizes: list of output dimensions of the linear layer.
  339. bias: If true, add bias.
  340. gather_output: If true, call all-gather on output and make the output
  341. available to all GPUs, otherwise, every GPU will have
  342. its own output.
  343. skip_bias_add: This was added to enable performance optimizations where
  344. bias can be fused with other element-wise operations. we
  345. skip adding bias but instead return it.
  346. params_dtype: Data type for the parameters.
  347. quant_config: Quantization configure.
  348. prefix: The name of the layer in the state dict, including all parents
  349. (e.g. model.layers.0.qkv_proj)
  350. """
  351. def __init__(self,
  352. input_size: int,
  353. output_sizes: List[int],
  354. bias: bool = True,
  355. gather_output: bool = False,
  356. skip_bias_add: bool = False,
  357. params_dtype: Optional[torch.dtype] = None,
  358. quant_config: Optional[QuantizationConfig] = None,
  359. prefix: str = ""):
  360. self.output_sizes = output_sizes
  361. self.quant_config = quant_config
  362. if quant_config is not None:
  363. tp_size = get_tensor_model_parallel_world_size()
  364. assert all(output_size % tp_size == 0
  365. for output_size in output_sizes)
  366. super().__init__(input_size=input_size,
  367. output_size=sum(output_sizes),
  368. bias=bias,
  369. gather_output=gather_output,
  370. skip_bias_add=skip_bias_add,
  371. params_dtype=params_dtype,
  372. quant_config=quant_config,
  373. prefix=prefix)
  374. def weight_loader(self,
  375. param: Parameter,
  376. loaded_weight: torch.Tensor,
  377. loaded_shard_id: Optional[int] = None):
  378. # Special case for GGUF
  379. # initialize GGUF param after we know the quantize type
  380. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  381. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  382. if is_gguf_weight_type:
  383. param.data[loaded_shard_id].copy_(loaded_weight)
  384. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  385. return
  386. if is_gguf_weight and isinstance(param, UninitializedParameter):
  387. from gguf.constants import GGML_QUANT_SIZES
  388. ori_shape = param.tensor_shape
  389. weight_types = self.qweight_type.shard_weight_type.values()
  390. row_size = []
  391. for weight_type in weight_types:
  392. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  393. row_size.append(ori_shape[1] // block_size * type_size)
  394. q_shape = (ori_shape[0], max(row_size))
  395. param.materialize(q_shape, dtype=loaded_weight.dtype)
  396. param_data = param.data
  397. output_dim = getattr(param, "output_dim", None)
  398. # Special case for AQLM codebooks.
  399. is_metadata = getattr(param, "is_metadata", False)
  400. # Special case for per-tensor scale to load scalar into fused array.
  401. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  402. if loaded_shard_id is None:
  403. # Loaded weight is already fused on disk (qkv/mlp).
  404. if output_dim is None:
  405. if needs_scalar_to_array:
  406. param_data, loaded_weight = adjust_scalar_to_fused_array(
  407. param_data, loaded_weight, 0)
  408. assert param_data.shape == loaded_weight.shape
  409. param_data.copy_(loaded_weight)
  410. return
  411. current_shard_offset = 0
  412. shard_offsets: List[Tuple[int, int, int]] = []
  413. for i, output_size in enumerate(self.output_sizes):
  414. shard_offsets.append((i, current_shard_offset, output_size))
  415. current_shard_offset += output_size
  416. packed_dim = getattr(param, "packed_dim", None)
  417. for shard_id, shard_offset, shard_size in shard_offsets:
  418. # Special case for Quantization.
  419. # If quantized, we need to adjust the offset and size to account
  420. # for the packing.
  421. if packed_dim == output_dim:
  422. shard_size = shard_size // param.pack_factor
  423. shard_offset = shard_offset // param.pack_factor
  424. # Special case for Marlin.
  425. shard_size, shard_offset = adjust_marlin_shard(
  426. param, shard_size, shard_offset)
  427. loaded_weight_shard = loaded_weight.narrow(
  428. output_dim, shard_offset, shard_size)
  429. self.weight_loader(param, loaded_weight_shard, shard_id)
  430. return
  431. assert loaded_shard_id < len(self.output_sizes)
  432. tp_rank = get_tensor_model_parallel_rank()
  433. tp_size = get_tensor_model_parallel_world_size()
  434. if output_dim is not None:
  435. if self.quant_config is None:
  436. shard_offset = sum(
  437. get_current_tp_rank_partition_size(output_size, tp_rank,
  438. tp_size)
  439. for output_size in self.output_sizes[:loaded_shard_id])
  440. shard_size = get_current_tp_rank_partition_size(
  441. self.output_sizes[loaded_shard_id], tp_rank, tp_size)
  442. else:
  443. shard_offset = sum(
  444. self.output_sizes[:loaded_shard_id]) // tp_size
  445. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  446. # Special case for quantization.
  447. # If quantized, we need to adjust the offset and size to account
  448. # for the packing.
  449. packed_dim = getattr(param, "packed_dim", None)
  450. if packed_dim == output_dim:
  451. shard_size = shard_size // param.pack_factor
  452. shard_offset = shard_offset // param.pack_factor
  453. # Special case for Marlin.
  454. shard_size, shard_offset = adjust_marlin_shard(
  455. param, shard_size, shard_offset)
  456. use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
  457. if use_bitsandbytes:
  458. shard_size = loaded_weight.shape[output_dim]
  459. shard_offset = loaded_weight.shape[output_dim] * \
  460. loaded_shard_id
  461. if is_gguf_weight:
  462. tp_size = get_tensor_model_parallel_world_size()
  463. output_dim = getattr(param, "output_dim", None)
  464. shard_shape = list(loaded_weight.shape)
  465. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  466. param.shard_id.append(loaded_shard_id)
  467. param.shard_size[loaded_shard_id] = shard_shape
  468. input_dim = getattr(param, "input_dim", None)
  469. input_size = loaded_weight.shape[input_dim]
  470. param_data = param_data.narrow(input_dim, 0, input_size)
  471. param_data = param_data.narrow(output_dim, shard_offset,
  472. shard_size)
  473. if self.quant_config is None:
  474. start_idx = get_current_tp_rank_partition_offset(
  475. loaded_weight.shape[output_dim], tp_rank, tp_size)
  476. else:
  477. start_idx = tp_rank * shard_size
  478. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  479. shard_size)
  480. # Special case for AQLM codebooks.
  481. elif is_metadata:
  482. # metadata indicates fixed size concatenated along dim 0
  483. shard_size = loaded_weight.shape[0]
  484. shard_offset = loaded_shard_id * shard_size
  485. param_data = param_data.narrow(0, shard_offset, shard_size)
  486. # Special case for per-tensor scales in fused case.
  487. elif needs_scalar_to_array:
  488. param_data, loaded_weight = adjust_scalar_to_fused_array(
  489. param_data, loaded_weight, loaded_shard_id)
  490. else:
  491. ignore_warning = getattr(param, "ignore_warning", False)
  492. if not ignore_warning:
  493. logger.warning(
  494. "Loading a weight without `output_dim` attribute in "
  495. "MergedColumnParallelLinear, assume the weight is "
  496. "the same for all partitions.")
  497. assert param_data.shape == loaded_weight.shape
  498. param_data.copy_(loaded_weight)
  499. def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
  500. loaded_weight: torch.Tensor):
  501. """
  502. Handle special case for models where MLP layers are already
  503. fused on disk. In this case, we have no shard id. This function
  504. determines the shard id by splitting these layers and then calls
  505. the weight loader using the shard id.
  506. An example of a model with these fused layers:
  507. https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
  508. """
  509. current_shard_offset = 0
  510. shard_offsets: List[Tuple[int, int, int]] = []
  511. for i, output_size in enumerate(self.output_sizes):
  512. shard_offsets.append((i, current_shard_offset, output_size))
  513. current_shard_offset += output_size
  514. for shard_id, shard_offset, shard_size in shard_offsets:
  515. # Special case for Quantization.
  516. # If quantized, we need to adjust the offset and size to account
  517. # for the packing.
  518. if isinstance(param, PackedAphroditeParameter
  519. ) and param.packed_dim == param.output_dim:
  520. shard_size, shard_offset = \
  521. param.adjust_shard_indexes_for_packing(
  522. shard_size=shard_size, shard_offset=shard_offset)
  523. loaded_weight_shard = loaded_weight.narrow(param.output_dim,
  524. shard_offset,
  525. shard_size)
  526. self.weight_loader_v2(param, loaded_weight_shard, shard_id)
  527. def weight_loader_v2(self,
  528. param: BaseAphroditeParameter,
  529. loaded_weight: torch.Tensor,
  530. loaded_shard_id: Optional[int] = None):
  531. if loaded_shard_id is None:
  532. if isinstance(param, PerTensorScaleParameter):
  533. param.load_merged_column_weight(loaded_weight=loaded_weight,
  534. shard_id=0)
  535. return
  536. elif type(param) is BaseAphroditeParameter:
  537. param.load_merged_column_weight(loaded_weight=loaded_weight)
  538. return
  539. self._load_fused_module_from_checkpoint(param, loaded_weight)
  540. return
  541. assert loaded_shard_id < len(self.output_sizes)
  542. tp_size = get_tensor_model_parallel_world_size()
  543. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  544. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  545. param.load_merged_column_weight(loaded_weight=loaded_weight,
  546. shard_id=loaded_shard_id,
  547. shard_offset=shard_offset,
  548. shard_size=shard_size)
  549. class QKVParallelLinear(ColumnParallelLinear):
  550. """Linear layers for the attention's QKV transformation.
  551. Linear layers for the linear transformation of the query, key, and value
  552. vectors in the attention layer. The weight matrix is concatenated along
  553. the output dimension. The layer is parallelized along the head dimension.
  554. When the number of key/value heads is smaller than the number of query
  555. heads (e.g., multi-query/grouped-query attention), the key/value head may
  556. be replicated while the query heads are partitioned.
  557. Args:
  558. hidden_size: input hidden state size of the transformer.
  559. head_size: size of each attention head.
  560. total_num_heads: total number of attention query heads.
  561. total_num_kv_heads: total number of attention key/value heads. If
  562. None, assume total_num_kv_heads = total_num_heads.
  563. bias: If true, add bias.
  564. skip_bias_add: This was added to enable performance optimizations where
  565. bias can be fused with other element-wise operations. we
  566. skip adding bias but instead return it.
  567. params_dtype: Data type for the parameters.
  568. quant_config: Quantization configure.
  569. prefix: The name of the layer in the state dict, including all parents
  570. (e.g. model.layers.0.qkv_proj)
  571. """
  572. def __init__(self,
  573. hidden_size: int,
  574. head_size: int,
  575. total_num_heads: int,
  576. total_num_kv_heads: Optional[int] = None,
  577. bias: bool = True,
  578. skip_bias_add: bool = False,
  579. params_dtype: Optional[torch.dtype] = None,
  580. quant_config: Optional[QuantizationConfig] = None,
  581. prefix: str = ""):
  582. self.hidden_size = hidden_size
  583. self.head_size = head_size
  584. self.total_num_heads = total_num_heads
  585. self.quant_config = quant_config
  586. if total_num_kv_heads is None:
  587. total_num_kv_heads = total_num_heads
  588. self.total_num_kv_heads = total_num_kv_heads
  589. # Divide the weight matrix along the last dimension.
  590. tp_size = get_tensor_model_parallel_world_size()
  591. tp_rank = get_tensor_model_parallel_rank()
  592. if quant_config is None:
  593. self.num_heads_per_kv_head = (self.total_num_heads //
  594. self.total_num_kv_heads)
  595. self.num_kv_heads = get_current_tp_rank_partition_size(
  596. self.total_num_kv_heads, tp_rank, tp_size)
  597. self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head
  598. self.num_kv_head_replicas = 1
  599. else:
  600. self.num_heads = divide(self.total_num_heads, tp_size)
  601. if tp_size >= self.total_num_kv_heads:
  602. self.num_kv_heads = 1
  603. self.num_kv_head_replicas = divide(tp_size,
  604. self.total_num_kv_heads)
  605. elif tp_size < self.total_num_kv_heads and quant_config is not None:
  606. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  607. self.num_kv_head_replicas = 1
  608. input_size = self.hidden_size
  609. output_size = (self.num_heads +
  610. 2 * self.num_kv_heads) * tp_size * self.head_size
  611. self.output_sizes = [
  612. self.num_heads * self.head_size * tp_size, # q_proj
  613. self.num_kv_heads * self.head_size * tp_size, # k_proj
  614. self.num_kv_heads * self.head_size * tp_size, # v_proj
  615. ]
  616. super().__init__(input_size=input_size,
  617. output_size=output_size,
  618. bias=bias,
  619. gather_output=False,
  620. skip_bias_add=skip_bias_add,
  621. params_dtype=params_dtype,
  622. quant_config=quant_config,
  623. prefix=prefix)
  624. def _get_shard_offset_mapping(self, loaded_shard_id: str):
  625. shard_offset_mapping = {
  626. "q": 0,
  627. "k": self.num_heads * self.head_size,
  628. "v": (self.num_heads + self.num_kv_heads) * self.head_size,
  629. "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
  630. }
  631. return shard_offset_mapping.get(loaded_shard_id)
  632. def _get_shard_size_mapping(self, loaded_shard_id: str):
  633. shard_size_mapping = {
  634. "q": self.num_heads * self.head_size,
  635. "k": self.num_kv_heads * self.head_size,
  636. "v": self.num_kv_heads * self.head_size,
  637. }
  638. return shard_size_mapping.get(loaded_shard_id)
  639. def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
  640. loaded_weight: torch.Tensor):
  641. """
  642. Handle special case for models where QKV layers are already
  643. fused on disk. In this case, we have no shard id. This function
  644. determmines the shard id by splitting these layers and then calls
  645. the weight loader using the shard id.
  646. An example of a model with these fused layers:
  647. https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
  648. """
  649. shard_offsets = [
  650. # (shard_id, shard_offset, shard_size)
  651. ("q", 0, self.total_num_heads * self.head_size),
  652. ("k", self.total_num_heads * self.head_size,
  653. self.total_num_kv_heads * self.head_size),
  654. ("v",
  655. (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
  656. self.total_num_kv_heads * self.head_size),
  657. ]
  658. for shard_id, shard_offset, shard_size in shard_offsets:
  659. # Special case for Quantization.
  660. # If quantized, we need to adjust the offset and size to account
  661. # for the packing.
  662. if isinstance(param, PackedAphroditeParameter
  663. ) and param.packed_dim == param.output_dim:
  664. shard_size, shard_offset = \
  665. param.adjust_shard_indexes_for_packing(
  666. shard_size=shard_size, shard_offset=shard_offset)
  667. loaded_weight_shard = loaded_weight.narrow(param.output_dim,
  668. shard_offset,
  669. shard_size)
  670. self.weight_loader_v2(param, loaded_weight_shard, shard_id)
  671. def weight_loader_v2(self,
  672. param: BaseAphroditeParameter,
  673. loaded_weight: torch.Tensor,
  674. loaded_shard_id: Optional[str] = None):
  675. if loaded_shard_id is None: # special case for certain models
  676. if isinstance(param, PerTensorScaleParameter):
  677. param.load_merged_column_weight(loaded_weight=loaded_weight,
  678. shard_id=0)
  679. return
  680. elif type(param) is BaseAphroditeParameter:
  681. param.load_merged_column_weight(loaded_weight=loaded_weight)
  682. return
  683. self._load_fused_module_from_checkpoint(param, loaded_weight)
  684. return
  685. assert loaded_shard_id in ["q", "k", "v"]
  686. shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
  687. shard_size = self._get_shard_size_mapping(loaded_shard_id)
  688. param.load_qkv_weight(loaded_weight=loaded_weight,
  689. num_heads=self.num_kv_head_replicas,
  690. shard_id=loaded_shard_id,
  691. shard_offset=shard_offset,
  692. shard_size=shard_size)
  693. def weight_loader(self,
  694. param: Parameter,
  695. loaded_weight: torch.Tensor,
  696. loaded_shard_id: Optional[str] = None):
  697. # Special case for GGUF
  698. # initialize GGUF param after we know the quantize type
  699. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  700. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  701. if is_gguf_weight_type and loaded_shard_id is not None:
  702. idx_map = {"q": 0, "k": 1, "v": 2}
  703. param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
  704. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  705. return
  706. if is_gguf_weight and isinstance(param, UninitializedParameter):
  707. from gguf.constants import GGML_QUANT_SIZES
  708. ori_shape = param.tensor_shape
  709. weight_types = self.qweight_type.shard_weight_type.values()
  710. row_size = []
  711. for weight_type in weight_types:
  712. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  713. row_size.append(ori_shape[1] // block_size * type_size)
  714. q_shape = (ori_shape[0], max(row_size))
  715. param.materialize(q_shape, dtype=loaded_weight.dtype)
  716. param_data = param.data
  717. output_dim = getattr(param, "output_dim", None)
  718. # Special case for AQLM codebooks.
  719. is_metadata = getattr(param, "is_metadata", False)
  720. # Special case for per-tensor scales in fused case.
  721. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  722. if loaded_shard_id is None:
  723. # Loaded weight is already fused on disk (qkv/mlp).
  724. if output_dim is None:
  725. if needs_scalar_to_array:
  726. param_data, loaded_weight = adjust_scalar_to_fused_array(
  727. param_data, loaded_weight, 0)
  728. assert param_data.shape == loaded_weight.shape
  729. param_data.copy_(loaded_weight)
  730. return
  731. shard_offsets = [
  732. # (shard_id, shard_offset, shard_size)
  733. ("q", 0, self.total_num_heads * self.head_size),
  734. ("k", self.total_num_heads * self.head_size,
  735. self.total_num_kv_heads * self.head_size),
  736. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  737. self.head_size, self.total_num_kv_heads * self.head_size),
  738. ]
  739. packed_dim = getattr(param, "packed_dim", None)
  740. for shard_id, shard_offset, shard_size in shard_offsets:
  741. # Special case for Quantized Weights.
  742. # If quantized, we need to adjust the offset and size to account
  743. # for the packing.
  744. if packed_dim == output_dim:
  745. shard_size = shard_size // param.pack_factor
  746. shard_offset = shard_offset // param.pack_factor
  747. # Special case for Marlin.
  748. shard_size, shard_offset = adjust_marlin_shard(
  749. param, shard_size, shard_offset)
  750. loaded_weight_shard = loaded_weight.narrow(
  751. output_dim, shard_offset, shard_size)
  752. self.weight_loader(param, loaded_weight_shard, shard_id)
  753. return
  754. tp_rank = get_tensor_model_parallel_rank()
  755. assert loaded_shard_id in ["q", "k", "v"]
  756. # If output dim is defined, use the default loading process.
  757. if output_dim is not None:
  758. if loaded_shard_id == "q":
  759. shard_offset = 0
  760. shard_size = self.num_heads * self.head_size
  761. if self.quant_config is None:
  762. multiple_of = self.head_size * self.num_heads_per_kv_head
  763. elif loaded_shard_id == "k":
  764. shard_offset = self.num_heads * self.head_size
  765. shard_size = self.num_kv_heads * self.head_size
  766. if self.quant_config is None:
  767. multiple_of = self.head_size
  768. elif loaded_shard_id == "v":
  769. shard_offset = (self.num_heads +
  770. self.num_kv_heads) * self.head_size
  771. shard_size = self.num_kv_heads * self.head_size
  772. if self.quant_config is None:
  773. multiple_of = self.head_size
  774. # Special case for Quantized Weights.
  775. # If quantized, we need to adjust the offset and size to account
  776. # for the packing.
  777. packed_dim = getattr(param, "packed_dim", None)
  778. if packed_dim == output_dim:
  779. shard_size = shard_size // param.pack_factor
  780. shard_offset = shard_offset // param.pack_factor
  781. if self.quant_config is None:
  782. multiple_of = multiple_of // param.pack_factor
  783. # Special case for Marlin.
  784. shard_size, shard_offset = adjust_marlin_shard(
  785. param, shard_size, shard_offset)
  786. use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
  787. if use_bitsandbytes:
  788. orig_qkv_offsets = {
  789. "q": (0, self.num_heads * self.head_size),
  790. "k": (self.num_heads * self.head_size,
  791. self.num_kv_heads * self.head_size),
  792. "v":
  793. ((self.num_heads + self.num_kv_heads) * self.head_size,
  794. self.num_kv_heads * self.head_size),
  795. "total":
  796. ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
  797. 0)
  798. }
  799. shard_size, shard_offset = adjust_bitsandbytes_shard(
  800. param, orig_qkv_offsets, loaded_shard_id)
  801. if is_gguf_weight:
  802. tp_size = get_tensor_model_parallel_world_size()
  803. output_dim = getattr(param, "output_dim", None)
  804. shard_shape = list(loaded_weight.shape)
  805. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  806. param.shard_id.append(loaded_shard_id)
  807. param.shard_size[loaded_shard_id] = shard_shape
  808. input_dim = getattr(param, "input_dim", None)
  809. input_size = loaded_weight.shape[input_dim]
  810. param_data = param_data.narrow(input_dim, 0, input_size)
  811. param_data = param_data.narrow(output_dim, shard_offset,
  812. shard_size)
  813. if self.quant_config is None:
  814. tp_size = get_tensor_model_parallel_world_size()
  815. total_size = loaded_weight.shape[output_dim]
  816. start_idx = get_current_tp_rank_partition_offset(
  817. total_size, tp_rank, tp_size, multiple_of=multiple_of)
  818. else:
  819. if loaded_shard_id == "q":
  820. shard_id = tp_rank
  821. else:
  822. shard_id = tp_rank // self.num_kv_head_replicas
  823. start_idx = shard_id * shard_size
  824. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  825. shard_size)
  826. # Special case for for AQLM codebooks.
  827. elif is_metadata:
  828. # metadata indicates fixed size concatenated along dim 0
  829. shard_size = loaded_weight.shape[0]
  830. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  831. param_data = param_data.narrow(0, shard_index * shard_size,
  832. shard_size)
  833. # Special case for per-tensor scales in fused case.
  834. elif needs_scalar_to_array:
  835. param_data, loaded_weight = adjust_scalar_to_fused_array(
  836. param_data, loaded_weight, loaded_shard_id)
  837. else:
  838. ignore_warning = getattr(param, "ignore_warning", False)
  839. if not ignore_warning:
  840. logger.warning(
  841. "Loading a weight without `output_dim` attribute in "
  842. "QKVParallelLinear, assume the weight is the same "
  843. "for all partitions.")
  844. assert param_data.shape == loaded_weight.shape
  845. param_data.copy_(loaded_weight)
  846. class RowParallelLinear(LinearBase):
  847. """Linear layer with row parallelism.
  848. The linear layer is defined as Y = XA + b. A is parallelized along
  849. its first dimension and X along its second dimension as:
  850. - -
  851. | A_1 |
  852. | . |
  853. A = | . | X = [X_1, ..., X_p]
  854. | . |
  855. | A_p |
  856. - -
  857. Arguments:
  858. input_size: first dimension of matrix A.
  859. output_size: second dimension of matrix A.
  860. bias: If true, add bias. Note that bias is not parallelized.
  861. input_is_parallel: If true, we assume that the input is already
  862. split across the GPUs and we do not split
  863. again.
  864. skip_bias_add: This was added to enable performance optimization where
  865. bias can be fused with other element-wise operations.
  866. We skip adding bias but instead return it.
  867. params_dtype: Data type for the parameters.
  868. quant_config: Quantization configure.
  869. partition_multiple_of: Partitions will be divided,
  870. so each partition is a multiple of this number.
  871. """
  872. def __init__(self,
  873. input_size: int,
  874. output_size: int,
  875. bias: bool = True,
  876. input_is_parallel: bool = True,
  877. skip_bias_add: bool = False,
  878. params_dtype: Optional[torch.dtype] = None,
  879. reduce_results: bool = True,
  880. quant_config: Optional[QuantizationConfig] = None,
  881. partition_multiple_of: int = 1,
  882. prefix: str = ""):
  883. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  884. quant_config, prefix)
  885. self.input_is_parallel = input_is_parallel
  886. self.reduce_results = reduce_results
  887. self.quant_config = quant_config
  888. # Divide the weight matrix along the last dimension.
  889. self.tp_rank = get_tensor_model_parallel_rank()
  890. self.tp_size = get_tensor_model_parallel_world_size()
  891. self.tp_rank = get_tensor_model_parallel_rank()
  892. if quant_config is None:
  893. self.partition_multiple_of = partition_multiple_of
  894. self.input_size_per_partition = get_current_tp_rank_partition_size(
  895. input_size, self.tp_rank, self.tp_size, partition_multiple_of)
  896. else:
  897. self.input_size_per_partition = divide(input_size, self.tp_size)
  898. assert self.quant_method is not None
  899. self.quant_method.create_weights(
  900. layer=self,
  901. input_size_per_partition=self.input_size_per_partition,
  902. output_partition_sizes=[self.output_size],
  903. input_size=self.input_size,
  904. output_size=self.output_size,
  905. params_dtype=self.params_dtype,
  906. weight_loader=(
  907. self.weight_loader_v2 if self.quant_method.__class__.__name__
  908. in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
  909. prefix=prefix)
  910. if not reduce_results and (bias and not skip_bias_add):
  911. raise ValueError("When not reduce the results, adding bias to the "
  912. "results can lead to incorrect results")
  913. if bias:
  914. self.bias = Parameter(
  915. torch.empty(self.output_size, dtype=params_dtype))
  916. set_weight_attrs(self.bias, {
  917. "output_dim": 0,
  918. "weight_loader": self.weight_loader,
  919. })
  920. else:
  921. self.register_parameter("bias", None)
  922. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  923. tp_size = get_tensor_model_parallel_world_size()
  924. input_dim = getattr(param, "input_dim", None)
  925. # Special case for GGUF
  926. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  927. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  928. if is_gguf_weight_type:
  929. param.weight_type = loaded_weight.item()
  930. # Materialize GGUF UninitializedParameter
  931. if is_gguf_weight and isinstance(param, UninitializedParameter):
  932. weight_shape = list(loaded_weight.shape)
  933. if input_dim:
  934. weight_shape[input_dim] = weight_shape[input_dim] // tp_size
  935. param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
  936. param_data = param.data
  937. if input_dim is not None:
  938. shard_size = param_data.shape[input_dim]
  939. if self.quant_config is None:
  940. start_idx = get_current_tp_rank_partition_offset(
  941. self.input_size,
  942. self.tp_rank,
  943. self.tp_size,
  944. multiple_of=self.partition_multiple_of)
  945. else:
  946. start_idx = self.tp_rank * shard_size
  947. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  948. shard_size)
  949. # Special case for loading scales off disk, which often do not
  950. # have a shape (such as in the case of AutoFP8).
  951. if len(loaded_weight.shape) == 0:
  952. loaded_weight = loaded_weight.reshape(1)
  953. assert param_data.shape == loaded_weight.shape
  954. param_data.copy_(loaded_weight)
  955. def weight_loader_v2(self, param: BaseAphroditeParameter,
  956. loaded_weight: torch.Tensor):
  957. # Special case for loading scales off disk, which often do not
  958. # have a shape (such as in the case of AutoFP8).
  959. if len(loaded_weight.shape) == 0:
  960. assert loaded_weight.numel() == 1
  961. loaded_weight = loaded_weight.reshape(1)
  962. param.load_row_parallel_weight(loaded_weight=loaded_weight)
  963. def forward(self, input_):
  964. if self.input_is_parallel:
  965. input_parallel = input_
  966. else:
  967. tp_rank = get_tensor_model_parallel_rank()
  968. splitted_input = split_tensor_along_last_dim(
  969. input_, num_partitions=self.tp_size)
  970. input_parallel = splitted_input[tp_rank].contiguous()
  971. # Matrix multiply.
  972. assert self.quant_method is not None
  973. # Only fuse bias add into GEMM for rank 0 (this ensures that
  974. # bias will not get added more than once in TP>1 case)
  975. bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
  976. output_parallel = self.quant_method.apply(self,
  977. input_parallel,
  978. bias=bias_)
  979. if self.reduce_results and self.tp_size > 1:
  980. output = tensor_model_parallel_all_reduce(output_parallel)
  981. else:
  982. output = output_parallel
  983. output_bias = self.bias if self.skip_bias_add else None
  984. return output, output_bias
  985. def extra_repr(self) -> str:
  986. s = f"input_features={self.input_size_per_partition}"
  987. s += f", output_features={self.output_size}"
  988. s += f", bias={self.bias is not None}"
  989. s += f", tp_size={self.tp_size}"
  990. s += f", reduce_results={self.reduce_results}"
  991. return s