12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- import torch
- from aphrodite.multimodal.base import MultiModalInputs, NestedTensors
- def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors):
- assert type(expected) == type(actual)
- if isinstance(expected, torch.Tensor):
- assert torch.equal(expected, actual)
- else:
- for expected_item, actual_item in zip(expected, actual):
- assert_nested_tensors_equal(expected_item, actual_item)
- def assert_multimodal_inputs_equal(
- expected: MultiModalInputs, actual: MultiModalInputs
- ):
- assert set(expected.keys()) == set(actual.keys())
- for key in expected:
- assert_nested_tensors_equal(expected[key], actual[key])
- def test_multimodal_input_batch_single_tensor():
- t = torch.rand([1, 2])
- result = MultiModalInputs.batch([{"image": t}])
- assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
- def test_multimodal_input_batch_multiple_tensors():
- a = torch.rand([1, 1, 2])
- b = torch.rand([1, 1, 2])
- c = torch.rand([1, 1, 2])
- result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
- assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
- def test_multimodal_input_batch_multiple_heterogeneous_tensors():
- a = torch.rand([1, 2, 2])
- b = torch.rand([1, 3, 2])
- c = torch.rand([1, 4, 2])
- result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
- assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
- def test_multimodal_input_batch_nested_tensors():
- a = torch.rand([2, 3])
- b = torch.rand([2, 3])
- c = torch.rand([2, 3])
- result = MultiModalInputs.batch(
- [{"image": [a]}, {"image": [b]}, {"image": [c]}]
- )
- assert_multimodal_inputs_equal(
- result,
- {
- "image": torch.stack(
- [a.unsqueeze(0), b.unsqueeze(0), c.unsqueeze(0)]
- )
- },
- )
- def test_multimodal_input_batch_heterogeneous_lists():
- a = torch.rand([1, 2, 3])
- b = torch.rand([1, 2, 3])
- c = torch.rand([1, 2, 3])
- result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
- assert_multimodal_inputs_equal(
- result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]}
- )
- def test_multimodal_input_batch_multiple_batchable_lists():
- a = torch.rand([1, 2, 3])
- b = torch.rand([1, 2, 3])
- c = torch.rand([1, 2, 3])
- d = torch.rand([1, 2, 3])
- result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
- assert_multimodal_inputs_equal(
- result,
- {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])},
- )
- def test_multimodal_input_batch_mixed_stacking_depths():
- a = torch.rand([1, 2, 3])
- b = torch.rand([1, 3, 3])
- c = torch.rand([1, 4, 3])
- result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
- assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})
- result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}])
- assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})
|