1234567891011121314151617181920212223242526272829303132333435363738394041 |
- import torch
- from aphrodite.multimodal import BatchedTensors
- def merge_vision_embeddings(input_ids: torch.Tensor,
- inputs_embeds: torch.Tensor,
- vision_embeddings: BatchedTensors,
- image_token_id: int) -> torch.Tensor:
- """
- Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
- in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
- Note:
- This updates `inputs_embeds` in place.
- """
- mask = (input_ids == image_token_id)
- num_expected_tokens = mask.sum()
- if isinstance(vision_embeddings, torch.Tensor):
- batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
- total_tokens = batch_size * batch_tokens
- if num_expected_tokens != total_tokens:
- expr = f"{batch_size} x {batch_tokens}"
- raise ValueError(
- f"Attempted to assign {expr} = {total_tokens} "
- f"image tokens to {num_expected_tokens} placeholders")
- inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
- else:
- size_per_batch = [t.shape[0] for t in vision_embeddings]
- total_tokens = sum(size_per_batch)
- if num_expected_tokens != total_tokens:
- expr = ' + '.join(map(str, size_per_batch))
- raise ValueError(
- f"Attempted to assign {expr} = {total_tokens} "
- f"image tokens to {num_expected_tokens} placeholders")
- inputs_embeds[mask] = torch.cat(vision_embeddings)
- return inputs_embeds
|