1
0

ultravox.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
  2. """PyTorch Ultravox model."""
  3. import math
  4. from array import array
  5. from functools import lru_cache
  6. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  7. TypedDict, Union, cast)
  8. import librosa
  9. import numpy as np
  10. import torch
  11. import torch.utils.checkpoint
  12. from torch import nn
  13. from torch.nn import functional as F
  14. from transformers.models.whisper import WhisperFeatureExtractor
  15. from transformers.models.whisper.modeling_whisper import WhisperEncoder
  16. from aphrodite.attention import AttentionMetadata
  17. from aphrodite.common.config import CacheConfig, MultiModalConfig
  18. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  19. SequenceData)
  20. from aphrodite.inputs import INPUT_REGISTRY
  21. from aphrodite.inputs.data import LLMInputs
  22. from aphrodite.inputs.registry import InputContext
  23. from aphrodite.modeling.layers.activation import SiluAndMul, get_act_fn
  24. from aphrodite.modeling.layers.layernorm import RMSNorm
  25. from aphrodite.modeling.layers.sampler import SamplerOutput
  26. from aphrodite.modeling.model_loader.loader import DefaultModelLoader
  27. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  28. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  29. from aphrodite.modeling.models.utils import (flatten_bn,
  30. group_weights_with_prefix,
  31. init_aphrodite_registered_model,
  32. merge_multimodal_embeddings)
  33. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  34. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  35. from aphrodite.multimodal.base import MultiModalInputs, NestedTensors
  36. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  37. repeat_and_pad_placeholder_tokens)
  38. from aphrodite.quantization.base_config import QuantizationConfig
  39. from aphrodite.transformers_utils.configs.ultravox import UltravoxConfig
  40. _AUDIO_PLACEHOLDER_TOKEN = 128002
  41. _AUDIO_TOKENS_PER_SECOND = 6.25
  42. class UltravoxAudioFeatureInputs(TypedDict):
  43. type: Literal["audio_features"]
  44. data: NestedTensors
  45. """Shape: `(batch_size, num_audios, 80, M)"""
  46. class UltravoxAudioEmbeddingInputs(TypedDict):
  47. type: Literal["audio_embeds"]
  48. data: NestedTensors
  49. """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
  50. UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
  51. UltravoxAudioEmbeddingInputs]
  52. @lru_cache
  53. def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
  54. return WhisperFeatureExtractor.from_pretrained(model_id)
  55. def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
  56. return cached_feature_extractor(
  57. ctx.get_hf_config(UltravoxConfig).audio_model_id)
  58. def get_ultravox_max_audio_tokens(ctx: InputContext):
  59. feature_extractor = whisper_feature_extractor(ctx)
  60. return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
  61. def dummy_seq_data_for_ultravox(
  62. ctx: InputContext,
  63. seq_len: int,
  64. audio_count: int,
  65. ):
  66. audio_placeholder = array(
  67. APHRODITE_TOKEN_ID_ARRAY_TYPE,
  68. [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
  69. # Add a separator between each chunk.
  70. audio_token_ids = (audio_placeholder +
  71. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
  72. other_token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  73. [0]) * (seq_len - len(audio_token_ids))
  74. return SequenceData(audio_token_ids + other_token_ids)
  75. def dummy_audio_for_ultravox(
  76. ctx: InputContext,
  77. audio_count: int,
  78. ):
  79. feature_extractor = whisper_feature_extractor(ctx)
  80. audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
  81. return {"audio": [audio_and_sr] * audio_count}
  82. def dummy_data_for_ultravox(
  83. ctx: InputContext,
  84. seq_len: int,
  85. mm_counts: Mapping[str, int],
  86. ):
  87. audio_count = mm_counts["audio"]
  88. seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
  89. mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
  90. return (seq_data, mm_dict)
  91. def input_mapper_for_ultravox(ctx: InputContext, data: object):
  92. if not isinstance(data, list):
  93. data = [data]
  94. audio_features = []
  95. for audio_input in data:
  96. if not isinstance(audio_input, tuple):
  97. raise NotImplementedError(
  98. f"Unsupported data type: {type(audio_input)}")
  99. (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
  100. feature_extractor = whisper_feature_extractor(ctx)
  101. if sr != feature_extractor.sampling_rate:
  102. audio = librosa.resample(audio,
  103. orig_sr=sr,
  104. target_sr=feature_extractor.sampling_rate)
  105. sr = feature_extractor.sampling_rate
  106. minimum_audio_length = feature_extractor.n_fft // 2 + 1
  107. if len(audio) < minimum_audio_length:
  108. # Not enough audio; pad it.
  109. audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
  110. single_audio_features = feature_extractor(
  111. audio, sampling_rate=sr, padding="longest",
  112. return_tensors="pt")["input_features"]
  113. # Remove the batch dimension because we're wrapping it in a list.
  114. audio_features.append(single_audio_features.squeeze(0))
  115. return MultiModalInputs({"audio_features": audio_features})
  116. def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
  117. multi_modal_data = llm_inputs.get("multi_modal_data")
  118. if multi_modal_data is None or "audio" not in multi_modal_data:
  119. return llm_inputs
  120. feature_extractor = whisper_feature_extractor(ctx)
  121. audios = multi_modal_data["audio"]
  122. if not isinstance(audios, list):
  123. audios = [audios]
  124. audio_token_counts = []
  125. for audio_data, sample_rate in audios:
  126. audio_length = audio_data.shape[0]
  127. if sample_rate != feature_extractor.sampling_rate:
  128. # Account for resampling.
  129. adjustment = feature_extractor.sampling_rate / sample_rate
  130. audio_length = math.ceil(adjustment * audio_length)
  131. feature_extractor_output_length = math.ceil(
  132. (audio_length - (feature_extractor.hop_length - 1)) /
  133. feature_extractor.hop_length)
  134. uv_config = ctx.get_hf_config(UltravoxConfig)
  135. audio_num_tokens = min(
  136. max(
  137. 1,
  138. math.ceil(feature_extractor_output_length /
  139. (uv_config.stack_factor * 2))),
  140. get_ultravox_max_audio_tokens(ctx))
  141. audio_token_counts.append(audio_num_tokens)
  142. tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
  143. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  144. tokenizer,
  145. llm_inputs.get("prompt"),
  146. llm_inputs["prompt_token_ids"],
  147. placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
  148. repeat_count=audio_token_counts,
  149. )
  150. # NOTE: Create a defensive copy of the original inputs
  151. return LLMInputs(prompt_token_ids=new_token_ids,
  152. prompt=new_prompt,
  153. multi_modal_data=multi_modal_data)
  154. class StackAudioFrames(nn.Module):
  155. """
  156. Stack the audio embedding frames to reduce the sequence length by a factor
  157. of `stack_factor`.
  158. """
  159. def __init__(self, stack_factor: int = 8):
  160. super().__init__()
  161. self.stack_factor = stack_factor
  162. def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
  163. B, T, C = audio_embeds.shape
  164. T_pad = (T + self.stack_factor -
  165. 1) // self.stack_factor * self.stack_factor
  166. audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
  167. B, T, C = audio_embeds.shape
  168. audio_embeds = audio_embeds.view(B, T // self.stack_factor,
  169. C * self.stack_factor)
  170. return audio_embeds
  171. class FlippedSiluAndMul(SiluAndMul):
  172. """Ultravox is trained with SwiGLU with flipped halves."""
  173. def forward(self, x: torch.Tensor):
  174. a, b = x.chunk(2, dim=-1)
  175. flipped = torch.cat((b, a), dim=-1)
  176. return super().forward(flipped)
  177. class UltravoxProjector(nn.Module):
  178. def __init__(self, config: UltravoxConfig):
  179. super().__init__()
  180. self.hidden_dim = config.hidden_size
  181. self._pad_and_stack = StackAudioFrames(config.stack_factor)
  182. dim = config.audio_config.hidden_size * config.stack_factor
  183. self.ln_pre = RMSNorm(dim)
  184. self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
  185. dim = self.hidden_dim
  186. if config.projector_act == "swiglu":
  187. self.act = FlippedSiluAndMul()
  188. dim = dim // 2
  189. else:
  190. self.act = get_act_fn(config.projector_act)
  191. self.linear_2 = nn.Linear(dim,
  192. config.text_config.hidden_size,
  193. bias=False)
  194. self.ln_post = RMSNorm(config.text_config.hidden_size)
  195. def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
  196. audio_features = self._pad_and_stack(audio_features)
  197. audio_features = self.ln_pre(audio_features)
  198. hidden_states = self.linear_1(audio_features)
  199. hidden_states = self.act(hidden_states)
  200. hidden_states = self.linear_2(hidden_states)
  201. hidden_states = self.ln_post(hidden_states)
  202. return hidden_states
  203. class ModifiedWhisperEncoder(WhisperEncoder):
  204. """
  205. Encoder portion of OpenAI's Whisper model.
  206. This implementation is a slightly modified version of HF Transformers'
  207. Whisper Encoder, with only a few fixes:
  208. 1. base_model_prefix updated to allow for doing `.from_pretrained`
  209. directly on the encoder
  210. 2. allow less than 30 second of audio padding to be passed in:
  211. - relaxed ValueError check for `input_features` length to be less
  212. than or equal to `expected_seq_length` instead of strictly equal
  213. - embed_pos is now sliced to match the length of `inputs_embeds`
  214. Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
  215. See commentary: https://github.com/huggingface/transformers/issues/25744
  216. """
  217. base_model_prefix = "model.encoder"
  218. def forward(
  219. self,
  220. input_features,
  221. ):
  222. expected_seq_length = (self.config.max_source_positions *
  223. self.conv1.stride[0] * self.conv2.stride[0])
  224. if input_features.shape[-1] > expected_seq_length:
  225. raise ValueError(
  226. f"Whisper expects the mel input features to be of length "
  227. f"{expected_seq_length} or less, but found "
  228. f"{input_features.shape[-1]}. Make sure to pad the input mel "
  229. f"features to {expected_seq_length}.")
  230. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  231. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  232. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  233. embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]
  234. hidden_states = inputs_embeds + embed_pos
  235. hidden_states = nn.functional.dropout(hidden_states,
  236. p=self.dropout,
  237. training=self.training)
  238. for encoder_layer in self.layers:
  239. layer_outputs = encoder_layer(
  240. hidden_states,
  241. None,
  242. layer_head_mask=None,
  243. )
  244. hidden_states = layer_outputs[0]
  245. hidden_states = self.layer_norm(hidden_states)
  246. return hidden_states
  247. @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
  248. @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
  249. "audio", get_ultravox_max_audio_tokens)
  250. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
  251. @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
  252. class UltravoxModel(nn.Module, SupportsMultiModal):
  253. def __init__(self,
  254. config: UltravoxConfig,
  255. multimodal_config: MultiModalConfig,
  256. cache_config: Optional[CacheConfig] = None,
  257. quant_config: Optional["QuantizationConfig"] = None):
  258. super().__init__()
  259. self.config = config
  260. self.multi_modal_config = multimodal_config
  261. assert self.multi_modal_config
  262. self.secondary_weights = []
  263. self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
  264. if config.audio_model_id is not None:
  265. self.secondary_weights.append(
  266. DefaultModelLoader.Source(
  267. model_or_path=config.audio_model_id,
  268. revision=None,
  269. prefix="audio_tower.",
  270. ))
  271. self.multi_modal_projector = UltravoxProjector(config)
  272. self.language_model = init_aphrodite_registered_model(
  273. config.text_config, cache_config, quant_config)
  274. if config.text_model_id is not None:
  275. self.secondary_weights.append(
  276. DefaultModelLoader.Source(model_or_path=config.text_model_id,
  277. revision=None,
  278. prefix="language_model."))
  279. def _audio_features_to_embeddings(
  280. self, input_features: torch.Tensor) -> torch.Tensor:
  281. audio_input = input_features.to(self.audio_tower.dtype)
  282. audio_features = self.audio_tower(audio_input)
  283. audio_features = audio_features.to(self.audio_tower.dtype)
  284. audio_embeddings = self.multi_modal_projector(audio_features)
  285. return audio_embeddings
  286. def _parse_and_validate_audio_input(
  287. self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
  288. audio_features = kwargs.pop("audio_features", None)
  289. audio_embeds = kwargs.pop("audio_embeds", None)
  290. if audio_features is None and audio_embeds is None:
  291. return None
  292. if audio_features is not None:
  293. if not isinstance(audio_features, (torch.Tensor, list)):
  294. raise ValueError("Incorrect type of audio features. "
  295. f"Got type: {type(audio_features)}")
  296. return UltravoxAudioFeatureInputs(type="audio_features",
  297. data=audio_features)
  298. if audio_embeds is not None:
  299. if not isinstance(audio_embeds, (torch.Tensor, list)):
  300. raise ValueError("Incorrect type of audio embeds. "
  301. f"Got type: {type(audio_embeds)}")
  302. return UltravoxAudioEmbeddingInputs(type="audio_embeds",
  303. data=audio_embeds)
  304. raise AssertionError("This line should be unreachable.")
  305. def _process_audio_input(
  306. self, audio_input: UltravoxAudioInputs) -> NestedTensors:
  307. if audio_input["type"] == "audio_embeds":
  308. return audio_input["data"]
  309. audio_features = audio_input["data"]
  310. if isinstance(audio_features, torch.Tensor):
  311. # Combine the B and N dimensions for the encoder/projector
  312. flattened = flatten_bn(audio_features)
  313. flattened_embeddings = self._audio_features_to_embeddings(
  314. flattened)
  315. # Restore the original dimensions
  316. embeddings = flattened_embeddings.unflatten(
  317. 0, audio_features.shape[:2])
  318. return embeddings
  319. result = []
  320. # TODO: Batch heterogeneous tensors through the encoder/projector
  321. for audio_features_item in audio_features:
  322. if isinstance(audio_features_item, torch.Tensor):
  323. result.append(
  324. self._audio_features_to_embeddings(audio_features_item))
  325. else:
  326. embeddings = [
  327. # Add a batch dimension to embed it, then remove it.
  328. self._audio_features_to_embeddings(tensor.unsqueeze(0)
  329. ).squeeze(0)
  330. for tensor in audio_features_item
  331. ]
  332. result.append(embeddings)
  333. return result
  334. def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
  335. kv_caches: List[torch.Tensor],
  336. attn_metadata: AttentionMetadata,
  337. intermediate_tensors: Optional[torch.Tensor],
  338. **kwargs) -> SamplerOutput:
  339. """Run forward pass for Ultravox
  340. One key thing to understand is the `input_ids` already accounts for the
  341. positions of the to-be-inserted audio embeddings. The to-be-inserted
  342. audio has a size that is essentially 6.25 tokens per second of audio.
  343. This way, the `positions` and `attn_metadata` are consistent
  344. with the `input_ids`.
  345. Args:
  346. audio_features: A batch of audio inputs [B, N, 80, M].
  347. """
  348. audio_input = self._parse_and_validate_audio_input(**kwargs)
  349. if audio_input is not None:
  350. audio_embeddings = self._process_audio_input(audio_input)
  351. inputs_embeds = self.language_model.model.get_input_embeddings(
  352. input_ids)
  353. inputs_embeds = merge_multimodal_embeddings(
  354. input_ids, inputs_embeds, audio_embeddings,
  355. _AUDIO_PLACEHOLDER_TOKEN)
  356. input_ids = None
  357. else:
  358. inputs_embeds = None
  359. hidden_states = self.language_model.model(
  360. input_ids=input_ids,
  361. positions=positions,
  362. kv_caches=kv_caches,
  363. attn_metadata=attn_metadata,
  364. intermediate_tensors=intermediate_tensors,
  365. inputs_embeds=inputs_embeds)
  366. return hidden_states
  367. def compute_logits(self, hidden_states: torch.Tensor,
  368. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  369. return self.language_model.compute_logits(hidden_states,
  370. sampling_metadata)
  371. def sample(
  372. self,
  373. logits: torch.Tensor,
  374. sampling_metadata: SamplingMetadata,
  375. ) -> Optional[SamplerOutput]:
  376. return self.language_model.sample(logits, sampling_metadata)
  377. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  378. # prepare weight iterators for components
  379. weights_group = group_weights_with_prefix(weights)
  380. # load audio tower weights
  381. audio_tower_weights = weights_group["audio_tower"]
  382. audio_tower_params_dict = dict(
  383. self.audio_tower.named_parameters(
  384. prefix=self.audio_tower.base_model_prefix))
  385. for name, loaded_weight in audio_tower_weights:
  386. if name in audio_tower_params_dict:
  387. param = audio_tower_params_dict[name]
  388. weight_loader = getattr(param, "weight_loader",
  389. default_weight_loader)
  390. weight_loader(param, loaded_weight)
  391. # load projector weights
  392. projector_weights = weights_group["multi_modal_projector"]
  393. projector_params_dict = dict(
  394. self.multi_modal_projector.named_parameters())
  395. for name, loaded_weight in projector_weights:
  396. param = projector_params_dict[name]
  397. weight_loader = getattr(param, "weight_loader",
  398. default_weight_loader)
  399. weight_loader(param, loaded_weight)
  400. # load llm backbone
  401. self.language_model.load_weights(weights_group["language_model"])