/* * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The PygmalionAI team. * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "attention_generic.cuh" #include "dtype_float32.cuh" #ifdef USE_ROCM #include #endif #include namespace aphrodite { // FP16 vector types for Q, K, V. template<> struct Vec { using Type = uint16_t; }; template<> struct Vec { using Type = uint32_t; }; template<> struct Vec { using Type = uint2; }; template<> struct Vec { using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. template<> struct FloatVec { using Type = float; }; template<> struct FloatVec { using Type = float2; }; template<> struct FloatVec { using Type = Float4_; }; template<> struct FloatVec { using Type = Float8_; }; // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { #ifndef USE_ROCM uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); return b; #else union { uint32_t u32; uint16_t u16[2]; } tmp; tmp.u16[0] = a; tmp.u16[1] = a; return tmp.u32; #endif } inline __device__ float half_to_float(uint16_t h) { float f; #ifndef USE_ROCM asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); #else asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); #endif return f; } inline __device__ float2 half2_to_float2(uint32_t v) { #ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); return make_float2(half_to_float(lo), half_to_float(hi)); #else union { uint32_t u32; uint16_t u16[2]; } tmp; tmp.u32 = v; float2 ret; ret.x = half_to_float(tmp.u16[0]); ret.y = half_to_float(tmp.u16[1]); return ret; #endif } inline __device__ uint16_t float_to_half(float f) { union { uint32_t u32; uint16_t u16[2]; } tmp; #ifndef USE_ROCM asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); #else asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); #endif return tmp.u16[0]; } inline __device__ uint32_t float2_to_half2(float2 f) { union { uint32_t u32; uint16_t u16[2]; } tmp; #ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); #else asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif #else tmp.u16[0] = float_to_half(f.x); tmp.u16[1] = float_to_half(f.y); #endif return tmp.u32; } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { uint16_t c; #ifndef USE_ROCM asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); #else asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif return c; } inline __device__ uint32_t add(uint32_t a, uint32_t b) { uint32_t c; #ifndef USE_ROCM asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); #else asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif return c; } inline __device__ uint2 add(uint2 a, uint2 b) { uint2 c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); return c; } inline __device__ uint4 add(uint4 a, uint4 b) { uint4 c; c.x = add(a.x, b.x); c.y = add(a.y, b.y); c.z = add(a.z, b.z); c.w = add(a.w, b.w); return c; } inline __device__ float2 add(uint32_t a, float2 fb) { float2 fa = half2_to_float2(a); return add(fa, fb); } inline __device__ Float4_ add(uint2 a, Float4_ fb) { Float4_ fc; fc.x = add(a.x, fb.x); fc.y = add(a.y, fb.y); return fc; } inline __device__ Float8_ add(uint4 a, Float8_ fb) { Float8_ fc; fc.x = add(a.x, fb.x); fc.y = add(a.y, fb.y); fc.z = add(a.z, fb.z); fc.w = add(a.w, fb.w); return fc; } // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; #ifndef USE_ROCM asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); #else asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif return c; } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; #ifndef USE_ROCM asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); #else asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif return c; } template<> inline __device__ uint32_t mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } template<> inline __device__ uint2 mul(uint2 a, uint2 b) { uint2 c; c.x = mul(a.x, b.x); c.y = mul(a.y, b.y); return c; } template<> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); uint2 c; c.x = mul(s, b.x); c.y = mul(s, b.y); return c; } template<> inline __device__ uint4 mul(uint4 a, uint4 b) { uint4 c; c.x = mul(a.x, b.x); c.y = mul(a.y, b.y); c.z = mul(a.z, b.z); c.w = mul(a.w, b.w); return c; } template<> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); uint4 c; c.x = mul(s, b.x); c.y = mul(s, b.y); c.z = mul(s, b.z); c.w = mul(s, b.w); return c; } template<> inline __device__ float mul(uint16_t a, uint16_t b) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb; } template<> inline __device__ float2 mul(uint32_t a, uint32_t b) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return mul(fa, fb); } template<> inline __device__ float2 mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } template<> inline __device__ Float4_ mul(uint2 a, uint2 b) { Float4_ fc; fc.x = mul(a.x, b.x); fc.y = mul(a.y, b.y); return fc; } template<> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); Float4_ fc; fc.x = mul(s, b.x); fc.y = mul(s, b.y); return fc; } template<> inline __device__ Float8_ mul(uint4 a, uint4 b) { Float8_ fc; fc.x = mul(a.x, b.x); fc.y = mul(a.y, b.y); fc.z = mul(a.z, b.z); fc.w = mul(a.w, b.w); return fc; } template<> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); Float8_ fc; fc.x = mul(s, b.x); fc.y = mul(s, b.y); fc.z = mul(s, b.z); fc.w = mul(s, b.w); return fc; } // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #ifndef USE_ROCM asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); #else asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); #endif return d; } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { return fma(h0_h0(a), b, c); } inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { uint2 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); return d; } inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { uint32_t s = h0_h0(a); uint2 d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); return d; } inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { uint4 d; d.x = fma(a.x, b.x, c.x); d.y = fma(a.y, b.y, c.y); d.z = fma(a.z, b.z, c.z); d.w = fma(a.w, b.w, c.w); return d; } inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { uint32_t s = h0_h0(a); uint4 d; d.x = fma(s, b.x, c.x); d.y = fma(s, b.y, c.y); d.z = fma(s, b.z, c.z); d.w = fma(s, b.w, c.w); return d; } inline __device__ float fma(uint16_t a, uint16_t b, float fc) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb + fc; } inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return fma(fa, fb, fc); } inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { return fma(h0_h0(a), b, fc); } inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { Float4_ fd; fd.x = fma(a.x, b.x, fc.x); fd.y = fma(a.y, b.y, fc.y); return fd; } inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { uint32_t s = h0_h0(a); Float4_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); return fd; } inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { Float8_ fd; fd.x = fma(a.x, b.x, fc.x); fd.y = fma(a.y, b.y, fc.y); fd.z = fma(a.z, b.z, fc.z); fd.w = fma(a.w, b.w, fc.w); return fd; } inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { uint32_t s = h0_h0(a); Float8_ fd; fd.x = fma(s, b.x, fc.x); fd.y = fma(s, b.y, fc.y); fd.z = fma(s, b.z, fc.z); fd.w = fma(s, b.w, fc.w); return fd; } // Vector sum. template<> inline __device__ float sum(uint16_t v) { return half_to_float(v); } template<> inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } template<> inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } template<> inline __device__ float sum(uint4 v) { uint32_t c = add(v.x, v.y); c = add(c, v.z); c = add(c, v.w); return sum(c); } // From float32 to float16. inline __device__ void from_float(uint16_t& dst, float src) { dst = float_to_half(src); } inline __device__ void from_float(uint32_t& dst, float2 src) { dst = float2_to_half2(src); } inline __device__ void from_float(uint2& dst, Float4_ src) { dst.x = float2_to_half2(src.x); dst.y = float2_to_half2(src.y); } inline __device__ void from_float(uint4& dst, Float8_ src) { dst.x = float2_to_half2(src.x); dst.y = float2_to_half2(src.y); dst.z = float2_to_half2(src.z); dst.w = float2_to_half2(src.w); } // From float16 to float32. inline __device__ float to_float(uint16_t u) { return half_to_float(u); } inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } inline __device__ Float4_ to_float(uint2 u) { Float4_ tmp; tmp.x = half2_to_float2(u.x); tmp.y = half2_to_float2(u.y); return tmp; } inline __device__ Float8_ to_float(uint4 u) { Float8_ tmp; tmp.x = half2_to_float2(u.x); tmp.y = half2_to_float2(u.y); tmp.z = half2_to_float2(u.z); tmp.w = half2_to_float2(u.w); return tmp; } // Zero-out a variable. inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } } // namespace aphrodite