123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- /*
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
- * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
- * Copyright (c) 2023, The vLLM team.
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #pragma once
- #include "attention_generic.cuh"
- #include <stdint.h>
- namespace aphrodite {
- // Define custom FP32 vector data types.
- struct Float4_ {
- float2 x;
- float2 y;
- };
- struct Float8_ {
- float2 x;
- float2 y;
- float2 z;
- float2 w;
- };
- // FP32 vector types for Q, K, V.
- template<>
- struct Vec<float, 1> {
- using Type = float;
- };
- template<>
- struct Vec<float, 2> {
- using Type = float2;
- };
- template<>
- struct Vec<float, 4> {
- using Type = float4;
- };
- // FP32 accumulator vector types corresponding to Vec.
- template<>
- struct FloatVec<float> {
- using Type = float;
- };
- template<>
- struct FloatVec<float2> {
- using Type = float2;
- };
- template<>
- struct FloatVec<float4> {
- using Type = float4;
- };
- // Vector addition.
- inline __device__ float add(float a, float b) {
- return a + b;
- }
- inline __device__ float2 add(float2 a, float2 b) {
- float2 c;
- c.x = add(a.x, b.x);
- c.y = add(a.y, b.y);
- return c;
- }
- inline __device__ float4 add(float4 a, float4 b) {
- float4 c;
- c.x = add(a.x, b.x);
- c.y = add(a.y, b.y);
- c.z = add(a.z, b.z);
- c.w = add(a.w, b.w);
- return c;
- }
- inline __device__ Float4_ add(Float4_ a, Float4_ b) {
- Float4_ c;
- c.x = add(a.x, b.x);
- c.y = add(a.y, b.y);
- return c;
- }
- // Vector multiplication.
- template<>
- inline __device__ float mul<float, float>(float a, float b) {
- return a * b;
- }
- template<>
- inline __device__ float2 mul(float2 a, float2 b) {
- float2 c;
- c.x = a.x * b.x;
- c.y = a.y * b.y;
- return c;
- }
- template<>
- inline __device__ float2 mul(float a, float2 b) {
- float2 c;
- c.x = a * b.x;
- c.y = a * b.y;
- return c;
- }
- template<>
- inline __device__ float4 mul(float4 a, float4 b) {
- float4 c;
- c.x = a.x * b.x;
- c.y = a.y * b.y;
- c.z = a.z * b.z;
- c.w = a.w * b.w;
- return c;
- }
- template<>
- inline __device__ float4 mul(float a, float4 b) {
- float4 c;
- c.x = a * b.x;
- c.y = a * b.y;
- c.z = a * b.z;
- c.w = a * b.w;
- return c;
- }
- // Vector fused multiply-add.
- inline __device__ float fma(float a, float b, float c) {
- return a * b + c;
- }
- inline __device__ float2 fma(float2 a, float2 b, float2 c) {
- float2 d;
- d.x = fma(a.x, b.x, c.x);
- d.y = fma(a.y, b.y, c.y);
- return d;
- }
- inline __device__ float2 fma(float a, float2 b, float2 c) {
- float2 d;
- d.x = fma(a, b.x, c.x);
- d.y = fma(a, b.y, c.y);
- return d;
- }
- inline __device__ float4 fma(float4 a, float4 b, float4 c) {
- float4 d;
- d.x = fma(a.x, b.x, c.x);
- d.y = fma(a.y, b.y, c.y);
- d.z = fma(a.z, b.z, c.z);
- d.w = fma(a.w, b.w, c.w);
- return d;
- }
- inline __device__ float4 fma(float a, float4 b, float4 c) {
- float4 d;
- d.x = fma(a, b.x, c.x);
- d.y = fma(a, b.y, c.y);
- d.z = fma(a, b.z, c.z);
- d.w = fma(a, b.w, c.w);
- return d;
- }
- inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
- Float4_ d;
- d.x = fma(a, b.x, c.x);
- d.y = fma(a, b.y, c.y);
- return d;
- }
- inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
- Float8_ d;
- d.x = fma(a, b.x, c.x);
- d.y = fma(a, b.y, c.y);
- d.z = fma(a, b.z, c.z);
- d.w = fma(a, b.w, c.w);
- return d;
- }
- // Vector sum.
- template<>
- inline __device__ float sum(float v) {
- return v;
- }
- template<>
- inline __device__ float sum(float2 v) {
- return v.x + v.y;
- }
- template<>
- inline __device__ float sum(float4 v) {
- return v.x + v.y + v.z + v.w;
- }
- template<>
- inline __device__ float sum(Float4_ v) {
- return v.x.x + v.x.y + v.y.x + v.y.y;
- }
- template<>
- inline __device__ float sum(Float8_ v) {
- return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
- }
- // Vector dot product.
- inline __device__ float dot(float a, float b) {
- return a * b;
- }
- inline __device__ float dot(float2 a, float2 b) {
- float2 c = mul<float2, float2, float2>(a, b);
- return c.x + c.y;
- }
- inline __device__ float dot(Float4_ a, Float4_ b) {
- float2 acc = mul<float2, float2, float2>(a.x, b.x);
- acc = fma(a.y, b.y, acc);
- return acc.x + acc.y;
- }
- inline __device__ float dot(Float8_ a, Float8_ b) {
- float2 acc = mul<float2, float2, float2>(a.x, b.x);
- acc = fma(a.y, b.y, acc);
- acc = fma(a.z, b.z, acc);
- acc = fma(a.w, b.w, acc);
- return acc.x + acc.y;
- }
- // From float to float.
- inline __device__ void from_float(float& dst, float src) {
- dst = src;
- }
- inline __device__ void from_float(float2& dst, float2 src) {
- dst = src;
- }
- inline __device__ void from_float(float4& dst, float4 src) {
- dst = src;
- }
- // From float to float.
- inline __device__ float to_float(float u) {
- return u;
- }
- inline __device__ float2 to_float(float2 u) {
- return u;
- }
- inline __device__ float4 to_float(float4 u) {
- return u;
- }
- inline __device__ Float4_ to_float(Float4_ u) {
- return u;
- }
- inline __device__ Float8_ to_float(Float8_ u) {
- return u;
- }
- // Zero-out a variable.
- inline __device__ void zero(float& dst) {
- dst = 0.f;
- }
- } // namespace aphrodite
|