utils.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import torch
  2. from aphrodite.multimodal import BatchedTensors
  3. def merge_vision_embeddings(input_ids: torch.Tensor,
  4. inputs_embeds: torch.Tensor,
  5. vision_embeddings: BatchedTensors,
  6. image_token_id: int) -> torch.Tensor:
  7. """
  8. Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
  9. in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
  10. Note:
  11. This updates `inputs_embeds` in place.
  12. """
  13. mask = (input_ids == image_token_id)
  14. num_expected_tokens = mask.sum()
  15. if isinstance(vision_embeddings, torch.Tensor):
  16. batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
  17. total_tokens = batch_size * batch_tokens
  18. if num_expected_tokens != total_tokens:
  19. expr = f"{batch_size} x {batch_tokens}"
  20. raise ValueError(
  21. f"Attempted to assign {expr} = {total_tokens} "
  22. f"image tokens to {num_expected_tokens} placeholders")
  23. inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
  24. else:
  25. size_per_batch = [t.shape[0] for t in vision_embeddings]
  26. total_tokens = sum(size_per_batch)
  27. if num_expected_tokens != total_tokens:
  28. expr = ' + '.join(map(str, size_per_batch))
  29. raise ValueError(
  30. f"Attempted to assign {expr} = {total_tokens} "
  31. f"image tokens to {num_expected_tokens} placeholders")
  32. inputs_embeds[mask] = torch.cat(vision_embeddings)
  33. return inputs_embeds