Browse Source

feat: fbgemm quantization support (#601)

* feat: fbgemm support

* missed this one

* register the quant
AlpinDale 7 months ago
parent
commit
9be43994fe

+ 2 - 1
aphrodite/attention/layer.py

@@ -34,6 +34,7 @@ class Attention(nn.Module):
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         blocksparse_params: Optional[Dict[str, Any]] = None,
+        prefix: str = "",
     ) -> None:
         super().__init__()
         if cache_config is not None:
@@ -56,7 +57,7 @@ class Attention(nn.Module):
         self._k_scale = 1.0
         self._v_scale = 1.0
         quant_method = quant_config.get_quant_method(
-            self) if quant_config else None
+            self, prefix=prefix) if quant_config else None
         if quant_method is not None:
             assert isinstance(quant_method, Fp8KVCacheMethod)
             # TODO: kv cache dtype should be specified in the FP8

+ 1 - 1
aphrodite/common/config.py

@@ -332,7 +332,7 @@ class ModelConfig:
                     "supported in ROCm.")
             if (self.quantization
                     not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
-                            "compressed_tensors")):
+                            "fbgemm_fp8", "compressed_tensors")):
                 logger.warning(
                     f"{self.quantization} quantization is not fully "
                     "optimized yet. The speed can be slower than "

+ 1 - 1
aphrodite/modeling/layers/fused_moe/layer.py

@@ -176,7 +176,7 @@ class FusedMoE(torch.nn.Module):
             self.quant_method: Optional[QuantizeMethodBase] = (
                 UnquantizedFusedMoEMethod())
         else:
-            self.quant_method = quant_config.get_quant_method(self)
+            self.quant_method = quant_config.get_quant_method(self, prefix)
         assert self.quant_method is not None
 
         self.quant_method.create_weights(

+ 16 - 10
aphrodite/modeling/layers/linear.py

@@ -139,6 +139,7 @@ class LinearBase(torch.nn.Module):
         skip_bias_add: bool = False,
         params_dtype: Optional[torch.dtype] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
     ):
         super().__init__()
 
@@ -153,7 +154,8 @@ class LinearBase(torch.nn.Module):
             self.quant_method: Optional[
                 QuantizeMethodBase] = UnquantizedLinearMethod()
         else:
-            self.quant_method = quant_config.get_quant_method(self)
+            self.quant_method = quant_config.get_quant_method(self,
+                                                              prefix=prefix)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         raise NotImplementedError
@@ -180,9 +182,13 @@ class ReplicatedLinear(LinearBase):
                  skip_bias_add: bool = False,
                  params_dtype: Optional[torch.dtype] = None,
                  quant_config: Optional[QuantizationConfig] = None,
-                 prefix: Optional[str] = None):
-        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
-                         quant_config)
+                 prefix: str = ""):
+        super().__init__(input_size,
+                         output_size,
+                         skip_bias_add,
+                         params_dtype,
+                         quant_config,
+                         prefix=prefix)
 
         # All the linear layer supports quant method.
         assert self.quant_method is not None
@@ -256,9 +262,9 @@ class ColumnParallelLinear(LinearBase):
                  params_dtype: Optional[torch.dtype] = None,
                  quant_config: Optional[QuantizationConfig] = None,
                  output_sizes: Optional[List[int]] = None,
-                 prefix: Optional[str] = None):
+                 prefix: str = ""):
         super().__init__(input_size, output_size, skip_bias_add, params_dtype,
-                         quant_config)
+                         quant_config, prefix)
 
         self.gather_output = gather_output
 
@@ -372,7 +378,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                  skip_bias_add: bool = False,
                  params_dtype: Optional[torch.dtype] = None,
                  quant_config: Optional[QuantizationConfig] = None,
-                 prefix: Optional[str] = None):
+                 prefix: str = ""):
         self.output_sizes = output_sizes
         super().__init__(input_size=input_size,
                          output_size=sum(output_sizes),
@@ -520,7 +526,7 @@ class QKVParallelLinear(ColumnParallelLinear):
                  skip_bias_add: bool = False,
                  params_dtype: Optional[torch.dtype] = None,
                  quant_config: Optional[QuantizationConfig] = None,
-                 prefix: Optional[str] = None):
+                 prefix: str = ""):
         self.hidden_size = hidden_size
         self.head_size = head_size
         self.total_num_heads = total_num_heads
@@ -722,10 +728,10 @@ class RowParallelLinear(LinearBase):
                  params_dtype: Optional[torch.dtype] = None,
                  reduce_results: bool = True,
                  quant_config: Optional[QuantizationConfig] = None,
-                 prefix: Optional[str] = None,
+                 prefix: str = "",
                  partition_multiple_of: int = 1):
         super().__init__(input_size, output_size, skip_bias_add, params_dtype,
-                         quant_config)
+                         quant_config, prefix)
 
         self.input_is_parallel = input_is_parallel
         self.reduce_results = reduce_results

+ 8 - 4
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -160,6 +160,7 @@ class VocabParallelEmbedding(torch.nn.Module):
         org_num_embeddings: original vocabulary size (without LoRA).
         padding_size: padding size for the vocabulary.
         quant_config: quant config for the layer.
+        prefix: full name of the layer in the state dict
     """  # noqa: E501
 
     def __init__(self,
@@ -168,7 +169,8 @@ class VocabParallelEmbedding(torch.nn.Module):
                  params_dtype: Optional[torch.dtype] = None,
                  org_num_embeddings: Optional[int] = None,
                  padding_size: Optional[int] = None,
-                 quant_config: Optional[QuantizationConfig] = None):
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
         super().__init__()
 
         padding_size = padding_size or get_tensor_model_parallel_world_size()
@@ -195,7 +197,7 @@ class VocabParallelEmbedding(torch.nn.Module):
 
         linear_method = None
         if quant_config is not None:
-            linear_method = quant_config.get_quant_method(self)
+            linear_method = quant_config.get_quant_method(self, prefix=prefix)
         if linear_method is None:
             linear_method = UnquantizedLinearMethod()
         self.linear_method: QuantizeMethodBase = linear_method
@@ -381,9 +383,11 @@ class ParallelLMHead(VocabParallelEmbedding):
                  params_dtype: Optional[torch.dtype] = None,
                  org_num_embeddings: Optional[int] = None,
                  padding_size: Optional[int] = None,
-                 quant_config: Optional[QuantizationConfig] = None):
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
         super().__init__(num_embeddings, embedding_dim, params_dtype,
-                         org_num_embeddings, padding_size, quant_config)
+                         org_num_embeddings, padding_size, quant_config,
+                         prefix)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.num_embeddings_per_partition,

+ 2 - 0
aphrodite/quantization/__init__.py

@@ -10,6 +10,7 @@ from aphrodite.quantization.compressed_tensors.compressed_tensors import \
 from aphrodite.quantization.deepspeedfp import DeepSpeedFPConfig
 from aphrodite.quantization.eetq import EETQConfig
 from aphrodite.quantization.exl2 import Exl2Config
+from aphrodite.quantization.fbgemm_fp8 import FBGEMMFp8Config
 from aphrodite.quantization.fp8 import Fp8Config
 from aphrodite.quantization.gguf import GGUFConfig
 from aphrodite.quantization.gptq import GPTQConfig
@@ -27,6 +28,7 @@ QUANTIZATION_METHODS = {
     "eetq": EETQConfig,
     "exl2": Exl2Config,
     "fp8": Fp8Config,
+    "fbgemm_fp8": FBGEMMFp8Config,
     "gguf": GGUFConfig,
     # The order of gptq methods is important for config.py iteration over
     # override_quantization_method(..)

+ 2 - 2
aphrodite/quantization/aqlm.py

@@ -206,8 +206,8 @@ class AQLMConfig(QuantizationConfig):
         return cls(in_group_size, nbits_per_codebook, num_code_books,
                    out_group_size)
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["AQLMLinearMethod"]:
         if isinstance(layer, LinearBase):
             return AQLMLinearMethod(self)
         return None

+ 2 - 2
aphrodite/quantization/autoquant.py

@@ -81,8 +81,8 @@ class AutoQuantConfig(QuantizationConfig):
             quant_mode = "weight_only"
         return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["AutoQuantLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["AutoQuantLinearMethod"]:
         if isinstance(layer, LinearBase):
             return AutoQuantLinearMethod(self)
         return None

+ 2 - 2
aphrodite/quantization/awq.py

@@ -60,8 +60,8 @@ class AWQConfig(QuantizationConfig):
         zero_point = cls.get_from_keys(config, ["zero_point"])
         return cls(weight_bits, group_size, zero_point)
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["AWQLinearMethod"]:
         if isinstance(layer, LinearBase):
             return AWQLinearMethod(self)
         return None

+ 2 - 1
aphrodite/quantization/base_config.py

@@ -94,7 +94,8 @@ class QuantizationConfig(ABC):
             return default
 
     @abstractmethod
-    def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> QuantizeMethodBase:
         """Get the quantize method to use for the quantized layer."""
         raise NotImplementedError
 

+ 2 - 3
aphrodite/quantization/bitsandbytes.py

@@ -58,9 +58,8 @@ class BitsAndBytesConfig(QuantizationConfig):
             target_modules = cls.get_from_keys(config, ["target_modules"])
         return cls(adapter_name, target_modules)
 
-    def get_quant_method(
-            self,
-            layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
         if isinstance(layer, LinearBase):
             return BitsAndBytesLinearMethod(self)
         return None

+ 5 - 1
aphrodite/quantization/compressed_tensors/compressed_tensors.py

@@ -43,8 +43,12 @@ class CompressedTensorsConfig(QuantizationConfig):
     def get_name(self) -> str:
         return "compressed_tensors"
 
+    # TODO: do layer skipping though here
+    # rather than though create_weights to match other methods
     def get_quant_method(
-            self, layer: torch.nn.Module
+        self,
+        layer: torch.nn.Module,
+        prefix: str,
     ) -> Optional["CompressedTensorsLinearMethod"]:
         if isinstance(layer, LinearBase):
             return CompressedTensorsLinearMethod(self)

+ 2 - 3
aphrodite/quantization/deepspeedfp.py

@@ -69,9 +69,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
             "quantize_config.json",
         ]
 
-    def get_quant_method(
-            self,
-            layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
         if isinstance(layer, LinearBase):
             return DeepSpeedFPLinearMethod(self)
         return None

+ 2 - 2
aphrodite/quantization/eetq.py

@@ -59,8 +59,8 @@ class EETQConfig(QuantizationConfig):
         zero_point = cls.get_from_keys(config, ["zero_point"])
         return cls(weight_bits, zero_point)
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["EETQLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["EETQLinearMethod"]:
         if isinstance(layer, LinearBase):
             return EETQLinearMethod(self)
         return None

+ 2 - 2
aphrodite/quantization/exl2.py

@@ -53,8 +53,8 @@ class Exl2Config(QuantizationConfig):
     def from_config(cls, config: Dict[str, Any]) -> "Exl2Config":
         return cls()
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["Exl2LinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["Exl2LinearMethod"]:
         if isinstance(layer, LinearBase):
             return Exl2LinearMethod(self)
         return None

+ 155 - 0
aphrodite/quantization/fbgemm_fp8.py

@@ -0,0 +1,155 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+from torch.nn import Module
+from torch.nn.parameter import Parameter
+
+from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
+                                              UnquantizedLinearMethod)
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import (QuantizationConfig,
+                                                QuantizeMethodBase)
+from aphrodite.quantization.utils.w8a8_utils import (
+    apply_fp8_linear, create_per_channel_scale_param)
+
+# Note: this is a hack. We should update each model to register the
+# stacked params and get it from there instead in a future PR.
+# fused_name: List[shard_name]
+_FUSED_LAYER_NAME_MAPPING = {
+    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
+    "gate_up_proj": ["gate_proj", "up_proj"]
+}
+
+
+class FBGEMMFp8Config(QuantizationConfig):
+    """Config class for FBGEMM Fp8."""
+
+    def __init__(self, ignore_list: List[str], input_scale_ub: float):
+        self.ignore_list = ignore_list
+        self.input_scale_ub = input_scale_ub
+
+    @classmethod
+    def get_name(cls) -> str:
+        return "fbgemm_fp8"
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        return [torch.bfloat16, torch.float16]
+
+    @classmethod
+    def get_min_capability(cls) -> int:
+        return 89
+
+    @classmethod
+    def get_config_filenames(cls) -> List[str]:
+        return []
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
+        ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
+        input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
+        return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
+
+    def _is_layer_skipped(self, prefix: str) -> bool:
+        # prefix: model.layers.0.self_attn.q_proj
+        # proj_name: q_proj
+        proj_name = prefix.split(".")[-1]
+        if proj_name in _FUSED_LAYER_NAME_MAPPING:
+            shard_prefixes = [
+                prefix.replace(proj_name, shard_proj_name)
+                for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
+            ]
+
+            is_skipped = None
+            for shard_prefix in shard_prefixes:
+                is_shard_skipped = shard_prefix in self.ignore_list
+
+                if is_skipped is None:
+                    is_skipped = is_shard_skipped
+                elif is_shard_skipped != is_skipped:
+                    raise ValueError(
+                        f"Detected some but not all shards of {prefix} "
+                        "are quantized. All shards of fused layers "
+                        "to have the same precision.")
+        else:
+            is_skipped = prefix in self.ignore_list
+
+        assert is_skipped is not None
+        return is_skipped
+
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["QuantizeMethodBase"]:
+        if isinstance(layer, LinearBase):
+            if self._is_layer_skipped(prefix):
+                return UnquantizedLinearMethod()
+            return FBGEMMFp8LinearMethod(self)
+        return None
+
+    def get_scaled_act_names(self) -> List[str]:
+        return []
+
+
+class FBGEMMFp8LinearMethod(LinearMethodBase):
+
+    def __init__(self, quant_config: FBGEMMFp8Config):
+        self.quant_config = quant_config
+
+    def create_weights(
+        self,
+        layer: torch.nn.Module,
+        input_size_per_partition: int,
+        output_partition_sizes: List[int],
+        input_size: int,
+        output_size: int,
+        params_dtype: torch.dtype,
+        **extra_weight_attrs,
+    ):
+        del input_size, output_size
+        output_size_per_partition = sum(output_partition_sizes)
+
+        layer.logical_widths = output_partition_sizes
+
+        layer.input_size_per_partition = input_size_per_partition
+        layer.output_size_per_partition = output_size_per_partition
+        layer.orig_dtype = params_dtype
+
+        # WEIGHT
+        weight = Parameter(torch.empty(output_size_per_partition,
+                                       input_size_per_partition,
+                                       dtype=torch.float8_e4m3fn),
+                           requires_grad=False)
+        layer.register_parameter("weight", weight)
+        set_weight_attrs(weight, {
+            "input_dim": 1,
+            "output_dim": 0,
+            **extra_weight_attrs,
+        })
+
+        # WEIGHT SCALE
+        weight_scale = create_per_channel_scale_param(output_partition_sizes,
+                                                      **extra_weight_attrs)
+        layer.register_parameter("weight_scale", weight_scale)
+
+        # INPUT SCALE UPPER BOUND
+        input_scale_ub = torch.nn.Parameter(torch.tensor(
+            (self.quant_config.input_scale_ub), dtype=torch.float32),
+                                            requires_grad=False)
+        layer.input_scale_ub = input_scale_ub
+
+    def process_weights_after_loading(self, layer: Module) -> None:
+        weight = layer.weight
+        layer.weight = Parameter(weight.t(), requires_grad=False)
+
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+
+        return apply_fp8_linear(input=x,
+                                weight=layer.weight,
+                                weight_scale=layer.weight_scale,
+                                input_scale=None,
+                                input_scale_ub=layer.input_scale_ub,
+                                bias=bias,
+                                cutlass_fp8_supported=True,
+                                use_per_token_if_dynamic=True)

+ 2 - 2
aphrodite/quantization/fp8.py

@@ -64,8 +64,8 @@ class Fp8Config(QuantizationConfig):
         return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
                    activation_scheme=activation_scheme)
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["QuantizeMethodBase"]:
         from aphrodite.attention.layer import \
             Attention  # Avoid circular import
 

+ 2 - 2
aphrodite/quantization/gguf.py

@@ -58,8 +58,8 @@ class GGUFConfig(QuantizationConfig):
     def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
         return cls()
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["GGUFLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["GGUFLinearMethod"]:
         if isinstance(layer, LinearBase):
             return GGUFLinearMethod(self)
         return None

+ 2 - 2
aphrodite/quantization/gptq.py

@@ -68,8 +68,8 @@ class GPTQConfig(QuantizationConfig):
                                                  default=False)
         return cls(weight_bits, group_size, desc_act, lm_head_quantized)
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["GPTQLinearMethod"]:
         if (isinstance(layer, LinearBase) or
             (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
             return GPTQLinearMethod(self)

+ 2 - 3
aphrodite/quantization/gptq_marlin.py

@@ -91,9 +91,8 @@ class GPTQMarlinConfig(QuantizationConfig):
                         " faster inference")
         return None
 
-    def get_quant_method(
-            self,
-            layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
         if (isinstance(layer, LinearBase) or
             (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
             return GPTQMarlinLinearMethod(self)

+ 2 - 3
aphrodite/quantization/gptq_marlin_24.py

@@ -106,9 +106,8 @@ class GPTQMarlin24Config(QuantizationConfig):
 
         return None
 
-    def get_quant_method(
-            self,
-            layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
         if isinstance(layer, LinearBase):
             return GPTQMarlin24LinearMethod(self)
         return None

+ 2 - 2
aphrodite/quantization/marlin.py

@@ -97,8 +97,8 @@ class MarlinConfig(QuantizationConfig):
 
         return None
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["MarlinLinearMethod"]:
         if (isinstance(layer, LinearBase) or
             (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
             return MarlinLinearMethod(self)

+ 2 - 2
aphrodite/quantization/quip.py

@@ -52,8 +52,8 @@ class QuipConfig(QuantizationConfig):
         use_rand = cls.get_from_keys(config, ["use_rand"])
         return cls(codebook, use_rand)
 
-    def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["QuipLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional["QuipLinearMethod"]:
         if isinstance(layer, LinearBase):
             return QuipLinearMethod(self)
         return None

+ 2 - 3
aphrodite/quantization/squeezellm.py

@@ -52,9 +52,8 @@ class SqueezeLLMConfig(QuantizationConfig):
         weight_bits = cls.get_from_keys(config, ["wbits"])
         return cls(weight_bits)
 
-    def get_quant_method(
-            self,
-            layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional[QuantizeMethodBase]:
         if isinstance(layer, LinearBase):
             return SqueezeLLMLinearMethod(self)
         return

+ 2 - 0
aphrodite/quantization/utils/w8a8_utils.py

@@ -105,6 +105,7 @@ def apply_fp8_linear(
     weight: torch.Tensor,
     weight_scale: torch.Tensor,
     input_scale: torch.Tensor,
+    input_scale_ub: Optional[torch.Tensor] = None,
     bias: Optional[torch.Tensor] = None,
     cutlass_fp8_supported: bool = True,
     use_per_token_if_dynamic: bool = False,
@@ -118,6 +119,7 @@ def apply_fp8_linear(
         qinput, x_scale = ops.scaled_fp8_quant(
             input,
             input_scale,
+            scale_ub=input_scale_ub,
             use_per_token_if_dynamic=use_per_token_if_dynamic)
 
         # Fused GEMM_DQ