Browse Source

feat: add TP support for bitsandbytes

AlpinDale 4 months ago
parent
commit
6154031a85

+ 0 - 6
aphrodite/common/config.py

@@ -456,12 +456,6 @@ class ModelConfig:
                 "Pipeline parallelism is only supported for the following "
                 f" architectures: {_PP_SUPPORTED_MODELS}.")
 
-        if self.quantization == "bitsandbytes" and (
-                parallel_config.tensor_parallel_size > 1
-                or parallel_config.pipeline_parallel_size > 1):
-            raise ValueError(
-                "BitsAndBytes quantization with TP/PP is not supported yet.")
-
         if self.quantization == "bitsandbytes" and self.enforce_eager is False:
             raise ValueError(
                 "BitsAndBytes with enforce_eager=False is not supported yet.")

+ 18 - 5
aphrodite/modeling/layers/linear.py

@@ -523,6 +523,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                     param, shard_size, shard_offset)
 
             use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
+            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
+                                            False)
             if use_bitsandbytes:
                 shard_size = loaded_weight.shape[output_dim]
                 shard_offset = loaded_weight.shape[output_dim] * \
@@ -547,8 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                     loaded_weight.shape[output_dim], tp_rank, tp_size)
             else:
                 start_idx = tp_rank * shard_size
-            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
-                                                 shard_size)
+            # bitsandbytes loads the weights of the specific portion
+            # no need to narrow here
+            if not use_bitsandbytes_4bit:
+                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
+                                                     shard_size)
         # Special case for AQLM codebooks.
         elif is_metadata:
             # metadata indicates fixed size concatenated along dim 0
@@ -894,6 +899,8 @@ class QKVParallelLinear(ColumnParallelLinear):
                     param, shard_size, shard_offset)
 
             use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
+            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
+                                            False)
             if use_bitsandbytes:
                 orig_qkv_offsets = {
                     "q": (0, self.num_heads * self.head_size),
@@ -934,8 +941,11 @@ class QKVParallelLinear(ColumnParallelLinear):
                 else:
                     shard_id = tp_rank // self.num_kv_head_replicas
                 start_idx = shard_id * shard_size
-            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
-                                                 shard_size)
+            # bitsandbytes loads the weights of the specific portion
+            # no need to narrow here
+            if not use_bitsandbytes_4bit:
+                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
+                                                     shard_size)
         # Special case for for AQLM codebooks.
         elif is_metadata:
             # metadata indicates fixed size concatenated along dim 0
@@ -1044,6 +1054,7 @@ class RowParallelLinear(LinearBase):
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
         tp_size = get_tensor_model_parallel_world_size()
         input_dim = getattr(param, "input_dim", None)
+        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
         # Special case for GGUF
         is_gguf_weight = getattr(param, "is_gguf_weight", False)
         is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
@@ -1058,7 +1069,9 @@ class RowParallelLinear(LinearBase):
             param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
 
         param_data = param.data
-        if input_dim is not None:
+        # bitsandbytes loads the weights of the specific portion
+        # no need to narrow here
+        if input_dim is not None and not use_bitsandbytes_4bit:
             shard_size = param_data.shape[input_dim]
             if self.quant_config is None:
                 start_idx = get_current_tp_rank_partition_offset(

+ 37 - 1
aphrodite/modeling/model_loader/loader.py

@@ -25,6 +25,8 @@ from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
                                      LoRAConfig, ModelConfig, MultiModalConfig,
                                      ParallelConfig, SchedulerConfig)
 from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
 from aphrodite.modeling.model_loader.tensorizer import (
     TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
     serialize_aphrodite_model, tensorizer_weights_iterator)
@@ -661,6 +663,8 @@ class ShardedStateLoader(BaseModelLoader):
 class BitsAndBytesModelLoader(BaseModelLoader):
     """Model loader to load model weights with BitAndBytes quantization."""
 
+    # TODO: these module names are for Llama only,
+    # change so that it works with other models as well
     default_target_modules = [
         "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
         "o_proj"
@@ -846,13 +850,39 @@ class BitsAndBytesModelLoader(BaseModelLoader):
                     yield weight_name, weight_tensor
 
         def generator() -> Generator:
+            tp_size = get_tensor_model_parallel_world_size()
+            tp_rank = get_tensor_model_parallel_rank()
             for weight_name, weight_tensor in self._hf_weight_iter(
                     hf_weights_files, use_safetensors):
                 if any(target_module in weight_name
                        for target_module in self.target_modules):
                     weight_name = weight_name.replace(".weight", ".qweight")
+                    # weight partitions of different modules occur at
+                    # different dimensions
+                    # TODO: these module names are for Llama only,
+                    # change so that it works with other models as well
+                    if 'down_proj' in weight_name or 'o_proj' in weight_name:
+                        total_size = weight_tensor.size(-1)
+                        start_index = total_size // tp_size * tp_rank
+                        end_index = total_size // tp_size * (tp_rank + 1)
+                        weight_sub_tensor = weight_tensor[...,
+                                                         start_index:end_index]
+                    else:
+                        total_size = weight_tensor.size(0)
+                        start_index = total_size // tp_size * tp_rank
+                        end_index = total_size // tp_size * (tp_rank + 1)
+                        weight_sub_tensor = weight_tensor[ \
+                            start_index:end_index, ...]
                     #  bitsandbytes requires data in GPU
-                    loaded_weight = weight_tensor.cuda().data
+                    if weight_sub_tensor.is_cuda:
+                        loaded_weight = weight_sub_tensor
+                    else:
+                        loaded_weight = weight_sub_tensor.cuda()
+                    # remove the following after the issue is fixed:
+                    # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
+                    if loaded_weight.is_contiguous() is False:
+                        loaded_weight = loaded_weight.contiguous()
+
                     with set_default_torch_dtype(torch.float32):
                         processed_weight, quant_state = quantize_4bit(
                             loaded_weight,
@@ -867,6 +897,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
 
         if pre_quant:
             return quantized_checkpoint(), quant_state_dict
+
+        if pre_quant and get_tensor_model_parallel_world_size() > 1:
+            raise ValueError(
+                "Prequanted Bitsandbytes models are not supported with "
+                "Tensor Parallel. Please try Pipeline Parallel instead.")
+
         return generator(), quant_state_dict
 
     def _load_weights(self, model_config: ModelConfig,