|
@@ -25,6 +25,8 @@ from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
|
|
|
LoRAConfig, ModelConfig, MultiModalConfig,
|
|
|
ParallelConfig, SchedulerConfig)
|
|
|
from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
|
|
|
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
|
|
|
+ get_tensor_model_parallel_world_size)
|
|
|
from aphrodite.modeling.model_loader.tensorizer import (
|
|
|
TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
|
|
|
serialize_aphrodite_model, tensorizer_weights_iterator)
|
|
@@ -661,6 +663,8 @@ class ShardedStateLoader(BaseModelLoader):
|
|
|
class BitsAndBytesModelLoader(BaseModelLoader):
|
|
|
"""Model loader to load model weights with BitAndBytes quantization."""
|
|
|
|
|
|
+ # TODO: these module names are for Llama only,
|
|
|
+ # change so that it works with other models as well
|
|
|
default_target_modules = [
|
|
|
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
|
|
|
"o_proj"
|
|
@@ -846,13 +850,39 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
|
yield weight_name, weight_tensor
|
|
|
|
|
|
def generator() -> Generator:
|
|
|
+ tp_size = get_tensor_model_parallel_world_size()
|
|
|
+ tp_rank = get_tensor_model_parallel_rank()
|
|
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
|
|
hf_weights_files, use_safetensors):
|
|
|
if any(target_module in weight_name
|
|
|
for target_module in self.target_modules):
|
|
|
weight_name = weight_name.replace(".weight", ".qweight")
|
|
|
+ # weight partitions of different modules occur at
|
|
|
+ # different dimensions
|
|
|
+ # TODO: these module names are for Llama only,
|
|
|
+ # change so that it works with other models as well
|
|
|
+ if 'down_proj' in weight_name or 'o_proj' in weight_name:
|
|
|
+ total_size = weight_tensor.size(-1)
|
|
|
+ start_index = total_size // tp_size * tp_rank
|
|
|
+ end_index = total_size // tp_size * (tp_rank + 1)
|
|
|
+ weight_sub_tensor = weight_tensor[...,
|
|
|
+ start_index:end_index]
|
|
|
+ else:
|
|
|
+ total_size = weight_tensor.size(0)
|
|
|
+ start_index = total_size // tp_size * tp_rank
|
|
|
+ end_index = total_size // tp_size * (tp_rank + 1)
|
|
|
+ weight_sub_tensor = weight_tensor[ \
|
|
|
+ start_index:end_index, ...]
|
|
|
# bitsandbytes requires data in GPU
|
|
|
- loaded_weight = weight_tensor.cuda().data
|
|
|
+ if weight_sub_tensor.is_cuda:
|
|
|
+ loaded_weight = weight_sub_tensor
|
|
|
+ else:
|
|
|
+ loaded_weight = weight_sub_tensor.cuda()
|
|
|
+ # remove the following after the issue is fixed:
|
|
|
+ # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
|
|
+ if loaded_weight.is_contiguous() is False:
|
|
|
+ loaded_weight = loaded_weight.contiguous()
|
|
|
+
|
|
|
with set_default_torch_dtype(torch.float32):
|
|
|
processed_weight, quant_state = quantize_4bit(
|
|
|
loaded_weight,
|
|
@@ -867,6 +897,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|
|
|
|
|
if pre_quant:
|
|
|
return quantized_checkpoint(), quant_state_dict
|
|
|
+
|
|
|
+ if pre_quant and get_tensor_model_parallel_world_size() > 1:
|
|
|
+ raise ValueError(
|
|
|
+ "Prequanted Bitsandbytes models are not supported with "
|
|
|
+ "Tensor Parallel. Please try Pipeline Parallel instead.")
|
|
|
+
|
|
|
return generator(), quant_state_dict
|
|
|
|
|
|
def _load_weights(self, model_config: ModelConfig,
|