ultravox.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
  2. """PyTorch Ultravox model."""
  3. import itertools
  4. import math
  5. from array import array
  6. from functools import lru_cache
  7. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  8. TypedDict, Union, cast)
  9. import librosa
  10. import numpy as np
  11. import torch
  12. import torch.utils.checkpoint
  13. from torch import nn
  14. from torch.nn import functional as F
  15. from transformers.models.whisper import WhisperFeatureExtractor
  16. from transformers.models.whisper.modeling_whisper import WhisperEncoder
  17. from aphrodite.attention import AttentionMetadata
  18. from aphrodite.common.config import CacheConfig, MultiModalConfig
  19. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  20. SequenceData)
  21. from aphrodite.inputs import INPUT_REGISTRY
  22. from aphrodite.inputs.data import LLMInputs
  23. from aphrodite.inputs.registry import InputContext
  24. from aphrodite.modeling.layers.activation import SiluAndMul, get_act_fn
  25. from aphrodite.modeling.layers.layernorm import RMSNorm
  26. from aphrodite.modeling.layers.sampler import SamplerOutput
  27. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  28. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  29. from aphrodite.modeling.models.utils import (filter_weights, flatten_bn,
  30. init_aphrodite_registered_model,
  31. merge_multimodal_embeddings)
  32. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  33. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  34. from aphrodite.multimodal.base import MultiModalInputs, NestedTensors
  35. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  36. repeat_and_pad_placeholder_tokens)
  37. from aphrodite.quantization.base_config import QuantizationConfig
  38. from aphrodite.transformers_utils.configs.ultravox import UltravoxConfig
  39. _AUDIO_PLACEHOLDER_TOKEN = 128002
  40. _AUDIO_TOKENS_PER_SECOND = 6.25
  41. class UltravoxAudioFeatureInputs(TypedDict):
  42. type: Literal["audio_features"]
  43. data: NestedTensors
  44. """Shape: `(batch_size, num_audios, 80, M)"""
  45. class UltravoxAudioEmbeddingInputs(TypedDict):
  46. type: Literal["audio_embeds"]
  47. data: NestedTensors
  48. """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
  49. UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
  50. UltravoxAudioEmbeddingInputs]
  51. @lru_cache
  52. def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
  53. return WhisperFeatureExtractor.from_pretrained(model_id)
  54. def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
  55. return cached_feature_extractor(
  56. ctx.get_hf_config(UltravoxConfig).audio_model_id)
  57. def get_ultravox_max_audio_tokens(ctx: InputContext):
  58. feature_extractor = whisper_feature_extractor(ctx)
  59. return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
  60. def dummy_data_for_ultravox(
  61. ctx: InputContext,
  62. seq_len: int,
  63. mm_counts: Mapping[str, int],
  64. ):
  65. feature_extractor = whisper_feature_extractor(ctx)
  66. audio_count = mm_counts["audio"]
  67. audio_placeholder = array(
  68. APHRODITE_TOKEN_ID_ARRAY_TYPE,
  69. [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
  70. # Add a separator between each chunk.
  71. audio_token_ids = (audio_placeholder +
  72. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
  73. other_token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  74. [0]) * (seq_len - len(audio_token_ids))
  75. audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
  76. mm_dict = {"audio": [audio_and_sr] * audio_count}
  77. return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
  78. def input_mapper_for_ultravox(ctx: InputContext, data: object):
  79. if not isinstance(data, list):
  80. data = [data]
  81. audio_features = []
  82. for audio_input in data:
  83. if not isinstance(audio_input, tuple):
  84. raise NotImplementedError(
  85. f"Unsupported data type: {type(audio_input)}")
  86. (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
  87. feature_extractor = whisper_feature_extractor(ctx)
  88. if sr != feature_extractor.sampling_rate:
  89. audio = librosa.resample(audio,
  90. orig_sr=sr,
  91. target_sr=feature_extractor.sampling_rate)
  92. sr = feature_extractor.sampling_rate
  93. minimum_audio_length = feature_extractor.n_fft // 2 + 1
  94. if len(audio) < minimum_audio_length:
  95. # Not enough audio; pad it.
  96. audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
  97. single_audio_features = feature_extractor(
  98. audio, sampling_rate=sr, padding="longest",
  99. return_tensors="pt")["input_features"]
  100. # Remove the batch dimension because we're wrapping it in a list.
  101. audio_features.append(single_audio_features.squeeze(0))
  102. return MultiModalInputs({"audio_features": audio_features})
  103. def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
  104. multi_modal_data = llm_inputs.get("multi_modal_data")
  105. if multi_modal_data is None or "audio" not in multi_modal_data:
  106. return llm_inputs
  107. feature_extractor = whisper_feature_extractor(ctx)
  108. audios = multi_modal_data["audio"]
  109. if not isinstance(audios, list):
  110. audios = [audios]
  111. audio_token_counts = []
  112. for audio_data, sample_rate in audios:
  113. audio_length = audio_data.shape[0]
  114. if sample_rate != feature_extractor.sampling_rate:
  115. # Account for resampling.
  116. adjustment = feature_extractor.sampling_rate / sample_rate
  117. audio_length = math.ceil(adjustment * audio_length)
  118. feature_extractor_output_length = math.ceil(
  119. (audio_length - (feature_extractor.hop_length - 1)) /
  120. feature_extractor.hop_length)
  121. uv_config = ctx.get_hf_config(UltravoxConfig)
  122. audio_num_tokens = min(
  123. max(
  124. 1,
  125. math.ceil(feature_extractor_output_length /
  126. (uv_config.stack_factor * 2))),
  127. get_ultravox_max_audio_tokens(ctx))
  128. audio_token_counts.append(audio_num_tokens)
  129. tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
  130. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  131. tokenizer,
  132. llm_inputs.get("prompt"),
  133. llm_inputs["prompt_token_ids"],
  134. placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
  135. repeat_count=audio_token_counts,
  136. )
  137. # NOTE: Create a defensive copy of the original inputs
  138. return LLMInputs(prompt_token_ids=new_token_ids,
  139. prompt=new_prompt,
  140. multi_modal_data=multi_modal_data)
  141. class StackAudioFrames(nn.Module):
  142. """
  143. Stack the audio embedding frames to reduce the sequence length by a factor
  144. of `stack_factor`.
  145. """
  146. def __init__(self, stack_factor: int = 8):
  147. super().__init__()
  148. self.stack_factor = stack_factor
  149. def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
  150. B, T, C = audio_embeds.shape
  151. T_pad = (T + self.stack_factor -
  152. 1) // self.stack_factor * self.stack_factor
  153. audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
  154. B, T, C = audio_embeds.shape
  155. audio_embeds = audio_embeds.view(B, T // self.stack_factor,
  156. C * self.stack_factor)
  157. return audio_embeds
  158. class FlippedSiluAndMul(SiluAndMul):
  159. """Ultravox is trained with SwiGLU with flipped halves."""
  160. def forward(self, x: torch.Tensor):
  161. a, b = x.chunk(2, dim=-1)
  162. flipped = torch.cat((b, a), dim=-1)
  163. return super().forward(flipped)
  164. class UltravoxProjector(nn.Module):
  165. def __init__(self, config: UltravoxConfig):
  166. super().__init__()
  167. self.hidden_dim = config.hidden_size
  168. self._pad_and_stack = StackAudioFrames(config.stack_factor)
  169. dim = config.audio_config.hidden_size * config.stack_factor
  170. self.ln_pre = RMSNorm(dim)
  171. self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
  172. dim = self.hidden_dim
  173. if config.projector_act == "swiglu":
  174. self.act = FlippedSiluAndMul()
  175. dim = dim // 2
  176. else:
  177. self.act = get_act_fn(config.projector_act)
  178. self.linear_2 = nn.Linear(dim,
  179. config.text_config.hidden_size,
  180. bias=False)
  181. self.ln_post = RMSNorm(config.text_config.hidden_size)
  182. def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
  183. audio_features = self._pad_and_stack(audio_features)
  184. audio_features = self.ln_pre(audio_features)
  185. hidden_states = self.linear_1(audio_features)
  186. hidden_states = self.act(hidden_states)
  187. hidden_states = self.linear_2(hidden_states)
  188. hidden_states = self.ln_post(hidden_states)
  189. return hidden_states
  190. class ModifiedWhisperEncoder(WhisperEncoder):
  191. """
  192. Encoder portion of OpenAI's Whisper model.
  193. This implementation is a slightly modified version of HF Transformers'
  194. Whisper Encoder, with only a few fixes:
  195. 1. base_model_prefix updated to allow for doing `.from_pretrained`
  196. directly on the encoder
  197. 2. allow less than 30 second of audio padding to be passed in:
  198. - relaxed ValueError check for `input_features` length to be less
  199. than or equal to `expected_seq_length` instead of strictly equal
  200. - embed_pos is now sliced to match the length of `inputs_embeds`
  201. Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
  202. See commentary: https://github.com/huggingface/transformers/issues/25744
  203. """
  204. base_model_prefix = "model.encoder"
  205. def forward(
  206. self,
  207. input_features,
  208. ):
  209. expected_seq_length = (self.config.max_source_positions *
  210. self.conv1.stride[0] * self.conv2.stride[0])
  211. if input_features.shape[-1] > expected_seq_length:
  212. raise ValueError(
  213. f"Whisper expects the mel input features to be of length "
  214. f"{expected_seq_length} or less, but found "
  215. f"{input_features.shape[-1]}. Make sure to pad the input mel "
  216. f"features to {expected_seq_length}.")
  217. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  218. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  219. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  220. embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]
  221. hidden_states = inputs_embeds + embed_pos
  222. hidden_states = nn.functional.dropout(hidden_states,
  223. p=self.dropout,
  224. training=self.training)
  225. for encoder_layer in self.layers:
  226. layer_outputs = encoder_layer(
  227. hidden_states,
  228. None,
  229. layer_head_mask=None,
  230. )
  231. hidden_states = layer_outputs[0]
  232. hidden_states = self.layer_norm(hidden_states)
  233. return hidden_states
  234. @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
  235. @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
  236. "audio", get_ultravox_max_audio_tokens)
  237. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
  238. @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
  239. class UltravoxModel(nn.Module, SupportsMultiModal):
  240. def __init__(self,
  241. config: UltravoxConfig,
  242. multimodal_config: MultiModalConfig,
  243. cache_config: Optional[CacheConfig] = None,
  244. quant_config: Optional["QuantizationConfig"] = None):
  245. super().__init__()
  246. self.config = config
  247. self.multi_modal_config = multimodal_config
  248. assert self.multi_modal_config
  249. if config.audio_model_id is not None:
  250. self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
  251. config.audio_model_id)
  252. else:
  253. self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
  254. self.multi_modal_projector = UltravoxProjector(config)
  255. self.language_model = init_aphrodite_registered_model(
  256. config.text_config, cache_config, quant_config)
  257. def _audio_features_to_embeddings(
  258. self, input_features: torch.Tensor) -> torch.Tensor:
  259. audio_input = input_features.to(self.audio_tower.dtype)
  260. audio_features = self.audio_tower(audio_input)
  261. audio_features = audio_features.to(self.audio_tower.dtype)
  262. audio_embeddings = self.multi_modal_projector(audio_features)
  263. return audio_embeddings
  264. def _parse_and_validate_audio_input(
  265. self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
  266. audio_features = kwargs.pop("audio_features", None)
  267. audio_embeds = kwargs.pop("audio_embeds", None)
  268. if audio_features is None and audio_embeds is None:
  269. return None
  270. if audio_features is not None:
  271. if not isinstance(audio_features, (torch.Tensor, list)):
  272. raise ValueError("Incorrect type of audio features. "
  273. f"Got type: {type(audio_features)}")
  274. return UltravoxAudioFeatureInputs(type="audio_features",
  275. data=audio_features)
  276. if audio_embeds is not None:
  277. if not isinstance(audio_embeds, (torch.Tensor, list)):
  278. raise ValueError("Incorrect type of audio embeds. "
  279. f"Got type: {type(audio_embeds)}")
  280. return UltravoxAudioEmbeddingInputs(type="audio_embeds",
  281. data=audio_embeds)
  282. raise AssertionError("This line should be unreachable.")
  283. def _process_audio_input(
  284. self, audio_input: UltravoxAudioInputs) -> NestedTensors:
  285. if audio_input["type"] == "audio_embeds":
  286. return audio_input["data"]
  287. audio_features = audio_input["data"]
  288. if isinstance(audio_features, torch.Tensor):
  289. # Combine the B and N dimensions for the encoder/projector
  290. flattened = flatten_bn(audio_features)
  291. flattened_embeddings = self._audio_features_to_embeddings(
  292. flattened)
  293. # Restore the original dimensions
  294. embeddings = flattened_embeddings.unflatten(
  295. 0, audio_features.shape[:2])
  296. return embeddings
  297. result = []
  298. # TODO: Batch heterogeneous tensors through the encoder/projector
  299. for audio_features_item in audio_features:
  300. if isinstance(audio_features_item, torch.Tensor):
  301. result.append(
  302. self._audio_features_to_embeddings(audio_features_item))
  303. else:
  304. embeddings = [
  305. # Add a batch dimension to embed it, then remove it.
  306. self._audio_features_to_embeddings(tensor.unsqueeze(0)
  307. ).squeeze(0)
  308. for tensor in audio_features_item
  309. ]
  310. result.append(embeddings)
  311. return result
  312. def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
  313. kv_caches: List[torch.Tensor],
  314. attn_metadata: AttentionMetadata,
  315. intermediate_tensors: Optional[torch.Tensor],
  316. **kwargs) -> SamplerOutput:
  317. """Run forward pass for Ultravox
  318. One key thing to understand is the `input_ids` already accounts for the
  319. positions of the to-be-inserted audio embeddings. The to-be-inserted
  320. audio has a size that is essentially 6.25 tokens per second of audio.
  321. This way, the `positions` and `attn_metadata` are consistent
  322. with the `input_ids`.
  323. Args:
  324. audio_features: A batch of audio inputs [B, N, 80, M].
  325. """
  326. audio_input = self._parse_and_validate_audio_input(**kwargs)
  327. if audio_input is not None:
  328. audio_embeddings = self._process_audio_input(audio_input)
  329. inputs_embeds = self.language_model.model.get_input_embeddings(
  330. input_ids)
  331. inputs_embeds = merge_multimodal_embeddings(
  332. input_ids, inputs_embeds, audio_embeddings,
  333. _AUDIO_PLACEHOLDER_TOKEN)
  334. input_ids = None
  335. else:
  336. inputs_embeds = None
  337. hidden_states = self.language_model.model(
  338. input_ids=input_ids,
  339. positions=positions,
  340. kv_caches=kv_caches,
  341. attn_metadata=attn_metadata,
  342. intermediate_tensors=intermediate_tensors,
  343. inputs_embeds=inputs_embeds)
  344. return hidden_states
  345. def compute_logits(self, hidden_states: torch.Tensor,
  346. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  347. return self.language_model.compute_logits(hidden_states,
  348. sampling_metadata)
  349. def sample(
  350. self,
  351. logits: torch.Tensor,
  352. sampling_metadata: SamplingMetadata,
  353. ) -> Optional[SamplerOutput]:
  354. return self.language_model.sample(logits, sampling_metadata)
  355. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  356. # prepare weight iterators for components
  357. projector_weights, llm_weights = itertools.tee(weights, 2)
  358. # load projector weights
  359. projector_weights = filter_weights(projector_weights,
  360. "multi_modal_projector")
  361. projector_params_dict = dict(
  362. self.multi_modal_projector.named_parameters())
  363. for name, loaded_weight in projector_weights:
  364. param = projector_params_dict[name]
  365. weight_loader = getattr(param, "weight_loader",
  366. default_weight_loader)
  367. weight_loader(param, loaded_weight)
  368. # load llm backbone
  369. llm_weights = filter_weights(llm_weights, "language_model")
  370. self.language_model.load_weights(llm_weights)