Explorar o código

xpu: bump IPEX to 2.3, support GQA (#1042)

AlpinDale hai 2 meses
pai
achega
6951928522

+ 9 - 2
Dockerfile.xpu

@@ -1,9 +1,8 @@
-FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04
+FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04
 
 RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
     echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
     chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
-    rm /etc/apt/sources.list.d/intel-graphics.list && \
     wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
     echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
     chmod 644 /usr/share/keyrings/intel-graphics.gpg
@@ -11,6 +10,14 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
 RUN apt-get update  -y \
 && apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1
 
+RUN git clone https://github.com/intel/pti-gpu && \
+    cd pti-gpu/sdk && \
+    mkdir build && \
+    cd build && \
+    cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \
+    make -j && \
+    cmake --install . --config Release --prefix "/usr/local"
+
 COPY ./ /workspace/aphrodite-engine
 
 WORKDIR /workspace/aphrodite-engine

+ 37 - 68
aphrodite/_ipex_ops.py

@@ -22,23 +22,29 @@ class ipex_ops:
         x2 = x2.reshape(num, d)
         return x1, x2
 
+    @staticmethod
     def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
-        x1, x2 = ipex_ops._reshape_activation_tensor(x)
-        ipex.llm.functional.silu_mul(x1, x2, out)
+        ipex.llm.functional.silu_and_mul(x, out)
 
+    @staticmethod
     def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
-        x1, x2 = ipex_ops._reshape_activation_tensor(x)
-        ipex.llm.functional.gelu_mul(x1, x2, out, "none")
+        ipex.llm.functional.gelu_and_mul(x, out)
 
+    @staticmethod
     def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
-        x1, x2 = ipex_ops._reshape_activation_tensor(x)
-        ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
+        ipex.llm.functional.gelu_and_mul(x, out)
 
-    def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
-        out.copy_(torch.nn.functional.gelu(x))
+    @staticmethod
+    def gelu_fast(x: torch.Tensor) -> torch.Tensor:
+        return torch.nn.functional.gelu(x)
 
-    def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
-        out.copy_(torch.nn.functional.gelu(x))
+    @staticmethod
+    def gelu_new(x: torch.Tensor) -> torch.Tensor:
+        return torch.nn.functional.gelu(x)
+
+    @staticmethod
+    def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
+        ipex.llm.functional.gelu_quick(x, out)
 
     def paged_attention_v1(
         out: torch.Tensor,
@@ -128,65 +134,25 @@ class ipex_ops:
         cos_sin_cache: torch.Tensor,  # [cos_sin_dim, rot_dim]
         is_neox: bool,
     ) -> None:
-        if positions.dim() == 1:
-            positions = positions.unsqueeze(0)
-            query = query.unsqueeze(0)
-            key = key.unsqueeze(0)
-
-        rotary_dim = cos_sin_cache.size(1)
-        query = query.view(*query.shape[:-1], -1, head_size)
-        key = key.view(*key.shape[:-1], -1, head_size)
-
-        query_rot = query[..., :rotary_dim]
-        key_rot = key[..., :rotary_dim]
-
-        cos_sin = cos_sin_cache[positions.long()]
-        cos, sin = cos_sin.chunk(2, dim=-1)
-
-        if is_neox:
-            cos = cos.repeat(1, 1, 2).unsqueeze(-2)
-            sin = sin.repeat(1, 1, 2).unsqueeze(-2)
-        else:
-            cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
-            sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
-        ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
-                                             rotary_dim, is_neox, positions)
+        rot_dim = cos_sin_cache.size(1)
+        ipex.llm.functional.rotary_embedding_batched(positions, query, key,
+                                                     head_size, cos_sin_cache,
+                                                     is_neox, rot_dim)
 
     def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                                  key: torch.Tensor, head_size: int,
                                  cos_sin_cache: torch.Tensor, is_neox: bool,
                                  rot_dim: int,
                                  cos_sin_cache_offsets: torch.Tensor) -> None:
-        if positions.dim() == 1:
-            positions = positions.unsqueeze(0)
-            query = query.unsqueeze(0)
-            key = key.unsqueeze(0)
-        cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
-        rotary_dim = cos_sin_cache.size(1)
-        query = query.view(*query.shape[:-1], -1, head_size)
-        key = key.view(*key.shape[:-1], -1, head_size)
-
-        query_rot = query[..., :rotary_dim]
-        key_rot = key[..., :rotary_dim]
-
-        cos_sin = cos_sin_cache[torch.add(positions,
-                                          cos_sin_cache_offsets).long()]
-        cos, sin = cos_sin.chunk(2, dim=-1)
-
-        if is_neox:
-            cos = cos.repeat(1, 1, 2).unsqueeze(-2)
-            sin = sin.repeat(1, 1, 2).unsqueeze(-2)
-        else:
-            cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
-            sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
-
-        ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
-                                             rotary_dim, is_neox, positions)
-
-    def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
-                 epsilon: float) -> None:
-        tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
-        out.copy_(tmp)
+        ipex.llm.functional.rotary_embedding_batched(positions, query, key,
+                                                     head_size, cos_sin_cache,
+                                                     is_neox, rot_dim,
+                                                     cos_sin_cache_offsets)
+
+    @staticmethod
+    def rms_norm(input: torch.Tensor, weight: torch.Tensor,
+                 epsilon: float) -> torch.Tensor:
+        return ipex.llm.functional.rms_norm(input, weight, epsilon)
 
     def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                            weight: torch.Tensor, epsilon: float) -> None:
@@ -210,11 +176,14 @@ class ipex_ops:
         return_softmax: bool,
         gen_: torch.Generator,
     ) -> None:
-        ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
-                                             seqlen_k, max_seqlen_q,
-                                             max_seqlen_k, pdropout,
-                                             softmax_scale, zero_tensors,
-                                             is_causal, return_softmax, gen_)
+        ipex.llm.functional.varlen_attention(query.contiguous(),
+                                             key.contiguous(),
+                                             value.contiguous(), out,
+                                             seqlen_q.int(), seqlen_k.int(),
+                                             max_seqlen_q, max_seqlen_k,
+                                             pdropout, softmax_scale,
+                                             zero_tensors, is_causal,
+                                             return_softmax, gen_)
 
     def reshape_and_cache(
         key: torch.Tensor,

+ 7 - 2
aphrodite/attention/backends/ipex_attn.py

@@ -56,14 +56,19 @@ class IpexAttnBackend(AttentionBackend):
         dst_kv_cache: torch.Tensor,
         src_to_dst: torch.Tensor,
     ) -> None:
-        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+        from aphrodite._ipex_ops import ipex_ops as ops
+        ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
 
     @staticmethod
     def copy_blocks(
         kv_caches: List[torch.Tensor],
         src_to_dists: torch.Tensor,
     ) -> None:
-        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+        from aphrodite._ipex_ops import ipex_ops as ops
+
+        key_caches = [kv_cache[0] for kv_cache in kv_caches]
+        value_caches = [kv_cache[1] for kv_cache in kv_caches]
+        ops.copy_blocks(key_caches, value_caches, src_to_dists)
 
 
 @dataclass

+ 8 - 6
aphrodite/modeling/layers/activation.py

@@ -111,9 +111,7 @@ class NewGELU(CustomOp):
     def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
         from aphrodite._ipex_ops import ipex_ops as ops
 
-        out = torch.empty_like(x)
-        ops.gelu_new(out, x)
-        return out
+        return ops.gelu_new(x)
 
 
 class FastGELU(CustomOp):
@@ -132,9 +130,7 @@ class FastGELU(CustomOp):
     def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
         from aphrodite._ipex_ops import ipex_ops as ops
 
-        out = torch.empty_like(x)
-        ops.gelu_fast(out, x)
-        return out
+        return ops.gelu_fast(x)
 
 
 class QuickGELU(CustomOp):
@@ -151,6 +147,12 @@ class QuickGELU(CustomOp):
         ops.gelu_quick(out, x)
         return out
 
+    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
+        from aphrodite._ipex_ops import ipex_ops as ops
+        out = torch.empty_like(x)
+        ops.gelu_quick(out, x)
+        return out
+
 
 class ReLUSquaredActivation(CustomOp):
     """

+ 1 - 4
aphrodite/modeling/layers/layernorm.py

@@ -106,14 +106,11 @@ class RMSNorm(CustomOp):
                 self.variance_epsilon,
             )
             return x, residual
-        out = torch.empty_like(x)
-        ops.rms_norm(
-            out,
+        return ops.rms_norm(
             x,
             self.weight.data,
             self.variance_epsilon,
         )
-        return out
 
     def extra_repr(self) -> str:
         s = f"hidden_size={self.weight.data.size(0)}"

+ 7 - 4
requirements-xpu.txt

@@ -3,8 +3,11 @@
 
 setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed.
 
-torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl
-intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
-oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl
 
-triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
+torch == 2.3.1+cxx11.abi
+intel-extension-for-pytorch == 2.3.110+xpu
+oneccl_bind_pt == 2.3.100+xpu
+
+triton-xpu == 3.0.0b2
+
+--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/