|
@@ -68,6 +68,33 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
|
|
projection_dim=768)
|
|
|
|
|
|
|
|
|
+class Phi3VImagePixelInputs(TypedDict):
|
|
|
+ type: Literal["pixel_values"]
|
|
|
+ data: Union[torch.Tensor, List[torch.Tensor]]
|
|
|
+ """
|
|
|
+ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
|
|
+ Note that `num_patches` may be different for each batch, in which case
|
|
|
+ the data is passed as a list instead of a batched tensor.
|
|
|
+ """
|
|
|
+
|
|
|
+ image_sizes: torch.Tensor
|
|
|
+ """
|
|
|
+ Shape: `(batch_size, 2)`
|
|
|
+ This should be in `(height, width)` format.
|
|
|
+ """
|
|
|
+
|
|
|
+
|
|
|
+class Phi3VImageEmbeddingInputs(TypedDict):
|
|
|
+ type: Literal["image_embeds"]
|
|
|
+ data: Union[torch.Tensor, List[torch.Tensor]]
|
|
|
+ """Shape: `(batch_size, image_feature_size, hidden_size)`
|
|
|
+ `hidden_size` must match the hidden size of language model backbone.
|
|
|
+ """
|
|
|
+
|
|
|
+
|
|
|
+Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
|
|
|
+
|
|
|
+
|
|
|
class Phi3ImageEmbeddingBase(nn.Module):
|
|
|
|
|
|
def __init__(self) -> None:
|
|
@@ -254,23 +281,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
|
|
return image_features_hd_newline
|
|
|
|
|
|
|
|
|
-class Phi3VImagePixelInputs(TypedDict):
|
|
|
- type: Literal["pixel_values"]
|
|
|
- data: Union[torch.Tensor, List[torch.Tensor]]
|
|
|
- """
|
|
|
- Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
|
|
-
|
|
|
- Note that `num_patches` may be different for each batch, in which case
|
|
|
- the data is passed as a list instead of a batched tensor.
|
|
|
- """
|
|
|
-
|
|
|
- image_sizes: torch.Tensor
|
|
|
- """
|
|
|
- Shape: `(batch_size, 2)`
|
|
|
-
|
|
|
- This should be in `(height, width)` format.
|
|
|
- """
|
|
|
-
|
|
|
|
|
|
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
|
|
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
|
@@ -386,7 +396,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
|
|
input_width=w,
|
|
|
input_height=h)
|
|
|
elif isinstance(image_data, torch.Tensor):
|
|
|
- raise NotImplementedError("Embeddings input is not supported yet")
|
|
|
+ image_feature_size = image_data.shape[0]
|
|
|
else:
|
|
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
|
|
|
|
@@ -490,25 +500,55 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
|
|
return data
|
|
|
|
|
|
def _parse_and_validate_image_input(
|
|
|
- self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
|
|
|
+ self, **kwargs: object) -> Optional[Phi3VImageInputs]:
|
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
|
image_sizes = kwargs.pop("image_sizes", None)
|
|
|
+ image_embeds = kwargs.pop("image_embeds", None)
|
|
|
|
|
|
if pixel_values is None:
|
|
|
return None
|
|
|
|
|
|
- if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
|
- raise ValueError("Incorrect type of pixel values. "
|
|
|
- f"Got type: {type(pixel_values)}")
|
|
|
+ if pixel_values is None and image_embeds is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ if pixel_values is not None:
|
|
|
+ if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
|
+ raise ValueError("Incorrect type of pixel values. "
|
|
|
+ f"Got type: {type(pixel_values)}")
|
|
|
+
|
|
|
+ if not isinstance(image_sizes, torch.Tensor):
|
|
|
+ raise ValueError("Incorrect type of image sizes. "
|
|
|
+ f"Got type: {type(image_sizes)}")
|
|
|
+
|
|
|
+ return Phi3VImagePixelInputs(
|
|
|
+ type="pixel_values",
|
|
|
+ data=self._validate_pixel_values(pixel_values),
|
|
|
+ image_sizes=self._validate_image_sizes(image_sizes))
|
|
|
+
|
|
|
+ if image_embeds is not None:
|
|
|
+ if not isinstance(image_embeds, torch.Tensor):
|
|
|
+ raise ValueError("Incorrect type of image embeddings. "
|
|
|
+ f"Got type: {type(image_embeds)}")
|
|
|
+ return Phi3VImageEmbeddingInputs(
|
|
|
+ type="image_embeds",
|
|
|
+ data=image_embeds,
|
|
|
+ )
|
|
|
+
|
|
|
+ raise AssertionError("This line should be unreachable.")
|
|
|
+
|
|
|
+ def _process_image_input(
|
|
|
+ self,
|
|
|
+ image_input: Phi3VImageInputs,
|
|
|
+ ) -> torch.Tensor:
|
|
|
+
|
|
|
+ if image_input["type"] == "image_embeds":
|
|
|
+ return image_input["data"]
|
|
|
|
|
|
- if not isinstance(image_sizes, torch.Tensor):
|
|
|
- raise ValueError("Incorrect type of image sizes. "
|
|
|
- f"Got type: {type(image_sizes)}")
|
|
|
+ assert self.vision_embed_tokens is not None
|
|
|
+ image_embeds = self.vision_embed_tokens(image_input["data"],
|
|
|
+ image_input["image_sizes"])
|
|
|
|
|
|
- return Phi3VImagePixelInputs(
|
|
|
- type="pixel_values",
|
|
|
- data=self._validate_pixel_values(pixel_values),
|
|
|
- image_sizes=self._validate_image_sizes(image_sizes))
|
|
|
+ return image_embeds
|
|
|
|
|
|
def forward(self,
|
|
|
input_ids: torch.Tensor,
|
|
@@ -520,8 +560,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
|
|
|
|
if image_input is not None:
|
|
|
- vision_embeddings = self.vision_embed_tokens(
|
|
|
- image_input["data"], image_input["image_sizes"])
|
|
|
+ vision_embeddings = self._process_image_input(image_input)
|
|
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
|
|
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
|
|
vision_embeddings,
|