test_base.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import torch
  2. from aphrodite.multimodal.base import MultiModalInputs, NestedTensors
  3. def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors):
  4. assert type(expected) == type(actual)
  5. if isinstance(expected, torch.Tensor):
  6. assert torch.equal(expected, actual)
  7. else:
  8. for expected_item, actual_item in zip(expected, actual):
  9. assert_nested_tensors_equal(expected_item, actual_item)
  10. def assert_multimodal_inputs_equal(
  11. expected: MultiModalInputs, actual: MultiModalInputs
  12. ):
  13. assert set(expected.keys()) == set(actual.keys())
  14. for key in expected:
  15. assert_nested_tensors_equal(expected[key], actual[key])
  16. def test_multimodal_input_batch_single_tensor():
  17. t = torch.rand([1, 2])
  18. result = MultiModalInputs.batch([{"image": t}])
  19. assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
  20. def test_multimodal_input_batch_multiple_tensors():
  21. a = torch.rand([1, 1, 2])
  22. b = torch.rand([1, 1, 2])
  23. c = torch.rand([1, 1, 2])
  24. result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
  25. assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
  26. def test_multimodal_input_batch_multiple_heterogeneous_tensors():
  27. a = torch.rand([1, 2, 2])
  28. b = torch.rand([1, 3, 2])
  29. c = torch.rand([1, 4, 2])
  30. result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
  31. assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
  32. def test_multimodal_input_batch_nested_tensors():
  33. a = torch.rand([2, 3])
  34. b = torch.rand([2, 3])
  35. c = torch.rand([2, 3])
  36. result = MultiModalInputs.batch(
  37. [{"image": [a]}, {"image": [b]}, {"image": [c]}]
  38. )
  39. assert_multimodal_inputs_equal(
  40. result,
  41. {
  42. "image": torch.stack(
  43. [a.unsqueeze(0), b.unsqueeze(0), c.unsqueeze(0)]
  44. )
  45. },
  46. )
  47. def test_multimodal_input_batch_heterogeneous_lists():
  48. a = torch.rand([1, 2, 3])
  49. b = torch.rand([1, 2, 3])
  50. c = torch.rand([1, 2, 3])
  51. result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
  52. assert_multimodal_inputs_equal(
  53. result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]}
  54. )
  55. def test_multimodal_input_batch_multiple_batchable_lists():
  56. a = torch.rand([1, 2, 3])
  57. b = torch.rand([1, 2, 3])
  58. c = torch.rand([1, 2, 3])
  59. d = torch.rand([1, 2, 3])
  60. result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
  61. assert_multimodal_inputs_equal(
  62. result,
  63. {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])},
  64. )
  65. def test_multimodal_input_batch_mixed_stacking_depths():
  66. a = torch.rand([1, 2, 3])
  67. b = torch.rand([1, 3, 3])
  68. c = torch.rand([1, 4, 3])
  69. result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
  70. assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})
  71. result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}])
  72. assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})