ultravox.py 4.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py
  2. from typing import Any, Dict, Optional
  3. import transformers
  4. class UltravoxConfig(transformers.PretrainedConfig):
  5. r"""
  6. This is the configuration class to store the configuration of a
  7. [`UltravoxForConditionalGeneration`]. It is used to instantiate an
  8. Ultravox model according to the specified arguments, defining the model
  9. architecture.
  10. Configuration objects inherit from [`PretrainedConfig`] and can be used to
  11. control the model outputs. Read the documentation from [`PretrainedConfig`]
  12. for more information.
  13. Args:
  14. audio_config (`Union[AutoConfig, dict]`, *optional*):
  15. Custom audio config or dict
  16. text_config (`Union[AutoConfig, dict]`, *optional*):
  17. The config object of the text backbone. Can be any of `LlamaConfig`
  18. or `MistralConfig`.
  19. ignore_index (`int`, *optional*, defaults to -100):
  20. The ignore index for the loss function.
  21. audio_token_index (`int`, *optional*, defaults to 32000):
  22. The audio token index to encode the audio prompt.
  23. stack_factor (`int`, *optional*, defaults to 8):
  24. Audio downsampling factor for the multimodal projector.
  25. norm_init (`float`, *optional*, defaults to 0.4):
  26. The initialization value for the layer normalization.
  27. projector_act (`str`, *optional*, defaults to `"swiglu"`):
  28. The activation function used by the multimodal projector.
  29. text_model_lora_config (`LoraConfigSimplified`, *optional*):
  30. The LoRA configuration for finetuning the text model.
  31. audio_model_lora_config (`LoraConfigSimplified`, *optional*):
  32. The LoRA configuration for finetuning the audio model.
  33. """
  34. model_type = "ultravox"
  35. is_composition = False
  36. def __init__(
  37. self,
  38. audio_config: Optional[Dict[str, Any]] = None,
  39. text_config: Optional[Dict[str, Any]] = None,
  40. audio_model_id: Optional[str] = None,
  41. text_model_id: Optional[str] = None,
  42. ignore_index: int = -100,
  43. audio_token_index: int = 32000,
  44. hidden_size: int = 4096,
  45. stack_factor: int = 8,
  46. norm_init: float = 0.4,
  47. projector_act: str = "swiglu",
  48. text_model_lora_config: Optional[Dict[str, Any]] = None,
  49. audio_model_lora_config: Optional[Dict[str, Any]] = None,
  50. **kwargs,
  51. ):
  52. self.ignore_index = ignore_index
  53. self.audio_model_id = audio_model_id
  54. self.text_model_id = text_model_id
  55. self.audio_token_index = audio_token_index
  56. self.hidden_size = hidden_size
  57. self.stack_factor = stack_factor
  58. self.norm_init = norm_init
  59. self.projector_act = projector_act
  60. if text_model_id is not None:
  61. # Avoid circular import
  62. from aphrodite.transformers_utils.config import get_config
  63. self.text_config = get_config(text_model_id,
  64. trust_remote_code=False)
  65. else:
  66. text_config = text_config or {}
  67. self.text_config = transformers.CONFIG_MAPPING[text_config.get(
  68. "model_type", "llama")](**text_config)
  69. if audio_model_id is not None:
  70. # Avoid circular import
  71. from aphrodite.transformers_utils.config import get_config
  72. self.audio_config = get_config(audio_model_id,
  73. trust_remote_code=False)
  74. else:
  75. audio_config = audio_config or {}
  76. self.audio_config = transformers.CONFIG_MAPPING[audio_config.get(
  77. "model_type", "whisper")](**audio_config)
  78. self.text_model_lora_config = text_model_lora_config or {}
  79. self.audio_model_lora_config = audio_model_lora_config or {}
  80. self.vocab_size = self.text_config.vocab_size
  81. self.initializer_range = self.text_config.initializer_range
  82. super().__init__(**kwargs)