/* * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "attention_generic.cuh" #include namespace aphrodite { // Define custom FP32 vector data types. struct Float4_ { float2 x; float2 y; }; struct Float8_ { float2 x; float2 y; float2 z; float2 w; }; // FP32 vector types for Q, K, V. template<> struct Vec { using Type = float; }; template<> struct Vec { using Type = float2; }; template<> struct Vec { using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. template<> struct FloatVec { using Type = float; }; template<> struct FloatVec { using Type = float2; }; template<> struct FloatVec { using Type = float4; }; // Vector addition. inline __device__ float add(float a, float b) { return a + b; } inline __device__ float2 add(float2 a, float2 b) { float2 c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); return c; } inline __device__ float4 add(float4 a, float4 b) { float4 c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); c.z = add(a.z, b.z); c.w = add(a.w, b.w); return c; } // Vector multiplication. template<> inline __device__ float mul(float a, float b) { return a * b; } template<> inline __device__ float2 mul(float2 a, float2 b) { float2 c; c.x = a.x * b.x; c.y = a.y * b.y; return c; } template<> inline __device__ float2 mul(float a, float2 b) { float2 c; c.x = a * b.x; c.y = a * b.y; return c; } template<> inline __device__ float4 mul(float4 a, float4 b) { float4 c; c.x = a.x * b.x; c.y = a.y * b.y; c.z = a.z * b.z; c.w = a.w * b.w; return c; } template<> inline __device__ float4 mul(float a, float4 b) { float4 c; c.x = a * b.x; c.y = a * b.y; c.z = a * b.z; c.w = a * b.w; return c; } // Vector fused multiply-add. inline __device__ float fma(float a, float b, float c) { return a * b + c; } inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); return d; } inline __device__ float2 fma(float a, float2 b, float2 c) { float2 d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); return d; } inline __device__ float4 fma(float4 a, float4 b, float4 c) { float4 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); d.z = fma(a.z, b.z, c.z); d.w = fma(a.w, b.w, c.w); return d; } inline __device__ float4 fma(float a, float4 b, float4 c) { float4 d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); d.z = fma(a, b.z, c.z); d.w = fma(a, b.w, c.w); return d; } inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { Float4_ d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); return d; } inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { Float8_ d; d.x = fma(a, b.x, c.x); d.y = fma(a, b.y, c.y); d.z = fma(a, b.z, c.z); d.w = fma(a, b.w, c.w); return d; } // Vector sum. template<> inline __device__ float sum(float v) { return v; } template<> inline __device__ float sum(float2 v) { return v.x + v.y; } template<> inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } template<> inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } template<> inline __device__ float sum(Float8_ v) { return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; } // Vector dot product. inline __device__ float dot(float a, float b) { return a * b; } inline __device__ float dot(float2 a, float2 b) { float2 c = mul(a, b); return c.x + c.y; } inline __device__ float dot(Float4_ a, Float4_ b) { float2 acc = mul(a.x, b.x); acc = fma(a.y, b.y, acc); return acc.x + acc.y; } inline __device__ float dot(Float8_ a, Float8_ b) { float2 acc = mul(a.x, b.x); acc = fma(a.y, b.y, acc); acc = fma(a.z, b.z, acc); acc = fma(a.w, b.w, acc); return acc.x + acc.y; } // From float to float. inline __device__ void from_float(float& dst, float src) { dst = src; } inline __device__ void from_float(float2& dst, float2 src) { dst = src; } inline __device__ void from_float(float4& dst, float4 src) { dst = src; } // From float to float. inline __device__ float to_float(float u) { return u; } inline __device__ float2 to_float(float2 u) { return u; } inline __device__ float4 to_float(float4 u) { return u; } inline __device__ Float4_ to_float(Float4_ u) { return u; } inline __device__ Float8_ to_float(Float8_ u) { return u; } // Zero-out a variable. inline __device__ void zero(float& dst) { dst = 0.f; } } // namespace aphrodite