123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503 |
- /*
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
- * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
- * Copyright (c) 2023, The PygmalionAI team.
- * Copyright (c) 2023, The vLLM team.
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #pragma once
- #include "attention_generic.cuh"
- #include "dtype_float32.cuh"
- #ifdef USE_ROCM
- #include <hip/hip_fp16.h>
- #endif
- #include <stdint.h>
- namespace aphrodite {
- // FP16 vector types for Q, K, V.
- template<>
- struct Vec<uint16_t, 1> {
- using Type = uint16_t;
- };
- template<>
- struct Vec<uint16_t, 2> {
- using Type = uint32_t;
- };
- template<>
- struct Vec<uint16_t, 4> {
- using Type = uint2;
- };
- template<>
- struct Vec<uint16_t, 8> {
- using Type = uint4;
- };
- // FP32 accumulator vector types corresponding to Vec.
- template<>
- struct FloatVec<uint16_t> {
- using Type = float;
- };
- template<>
- struct FloatVec<uint32_t> {
- using Type = float2;
- };
- template<>
- struct FloatVec<uint2> {
- using Type = Float4_;
- };
- template<>
- struct FloatVec<uint4> {
- using Type = Float8_;
- };
- // Utility functions for type conversions.
- inline __device__ uint32_t h0_h0(uint16_t a) {
- #ifndef USE_ROCM
- uint32_t b;
- asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
- return b;
- #else
- union {
- uint32_t u32;
- uint16_t u16[2];
- } tmp;
- tmp.u16[0] = a;
- tmp.u16[1] = a;
- return tmp.u32;
- #endif
- }
- inline __device__ float half_to_float(uint16_t h) {
- float f;
- #ifndef USE_ROCM
- asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
- #else
- asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
- #endif
- return f;
- }
- inline __device__ float2 half2_to_float2(uint32_t v) {
- #ifndef USE_ROCM
- uint16_t lo, hi;
- asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
- return make_float2(half_to_float(lo), half_to_float(hi));
- #else
- union {
- uint32_t u32;
- uint16_t u16[2];
- } tmp;
- tmp.u32 = v;
- float2 ret;
- ret.x = half_to_float(tmp.u16[0]);
- ret.y = half_to_float(tmp.u16[1]);
- return ret;
- #endif
- }
- inline __device__ uint16_t float_to_half(float f) {
- union {
- uint32_t u32;
- uint16_t u16[2];
- } tmp;
- #ifndef USE_ROCM
- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
- #else
- asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
- #endif
- return tmp.u16[0];
- }
- inline __device__ uint32_t float2_to_half2(float2 f) {
- union {
- uint32_t u32;
- uint16_t u16[2];
- } tmp;
- #ifndef USE_ROCM
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
- #else
- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
- #endif
- #else
- tmp.u16[0] = float_to_half(f.x);
- tmp.u16[1] = float_to_half(f.y);
- #endif
- return tmp.u32;
- }
- // Vector addition.
- inline __device__ uint16_t add(uint16_t a, uint16_t b) {
- uint16_t c;
- #ifndef USE_ROCM
- asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
- #else
- asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
- #endif
- return c;
- }
- inline __device__ uint32_t add(uint32_t a, uint32_t b) {
- uint32_t c;
- #ifndef USE_ROCM
- asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
- #else
- asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
- #endif
- return c;
- }
- inline __device__ uint2 add(uint2 a, uint2 b) {
- uint2 c;
- c.x = add(a.x, b.x);
- c.y = add(a.y, b.y);
- return c;
- }
- inline __device__ uint4 add(uint4 a, uint4 b) {
- uint4 c;
- c.x = add(a.x, b.x);
- c.y = add(a.y, b.y);
- c.z = add(a.z, b.z);
- c.w = add(a.w, b.w);
- return c;
- }
- inline __device__ float2 add(uint32_t a, float2 fb) {
- float2 fa = half2_to_float2(a);
- return add(fa, fb);
- }
- inline __device__ Float4_ add(uint2 a, Float4_ fb) {
- Float4_ fc;
- fc.x = add(a.x, fb.x);
- fc.y = add(a.y, fb.y);
- return fc;
- }
- inline __device__ Float8_ add(uint4 a, Float8_ fb) {
- Float8_ fc;
- fc.x = add(a.x, fb.x);
- fc.y = add(a.y, fb.y);
- fc.z = add(a.z, fb.z);
- fc.w = add(a.w, fb.w);
- return fc;
- }
- // Vector multiplication.
- template<>
- inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
- uint16_t c;
- #ifndef USE_ROCM
- asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
- #else
- asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
- #endif
- return c;
- }
- template<>
- inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
- uint32_t c;
- #ifndef USE_ROCM
- asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
- #else
- asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
- #endif
- return c;
- }
- template<>
- inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
- return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
- }
- template<>
- inline __device__ uint2 mul(uint2 a, uint2 b) {
- uint2 c;
- c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
- c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
- return c;
- }
- template<>
- inline __device__ uint2 mul(uint16_t a, uint2 b) {
- uint32_t s = h0_h0(a);
- uint2 c;
- c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
- c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
- return c;
- }
- template<>
- inline __device__ uint4 mul(uint4 a, uint4 b) {
- uint4 c;
- c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
- c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
- c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
- c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
- return c;
- }
- template<>
- inline __device__ uint4 mul(uint16_t a, uint4 b) {
- uint32_t s = h0_h0(a);
- uint4 c;
- c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
- c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
- c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
- c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
- return c;
- }
- template<>
- inline __device__ float mul(uint16_t a, uint16_t b) {
- float fa = half_to_float(a);
- float fb = half_to_float(b);
- return fa * fb;
- }
- template<>
- inline __device__ float2 mul(uint32_t a, uint32_t b) {
- float2 fa = half2_to_float2(a);
- float2 fb = half2_to_float2(b);
- return mul<float2, float2, float2>(fa, fb);
- }
- template<>
- inline __device__ float2 mul(uint16_t a, uint32_t b) {
- return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
- }
- template<>
- inline __device__ Float4_ mul(uint2 a, uint2 b) {
- Float4_ fc;
- fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
- fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
- return fc;
- }
- template<>
- inline __device__ Float4_ mul(uint16_t a, uint2 b) {
- uint32_t s = h0_h0(a);
- Float4_ fc;
- fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
- fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
- return fc;
- }
- template<>
- inline __device__ Float8_ mul(uint4 a, uint4 b) {
- Float8_ fc;
- fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
- fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
- fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
- fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
- return fc;
- }
- template<>
- inline __device__ Float8_ mul(uint16_t a, uint4 b) {
- uint32_t s = h0_h0(a);
- Float8_ fc;
- fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
- fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
- fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
- fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
- return fc;
- }
- // Vector fused multiply-add.
- inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
- uint32_t d;
- #ifndef USE_ROCM
- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
- #else
- asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
- #endif
- return d;
- }
- inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
- return fma(h0_h0(a), b, c);
- }
- inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
- uint2 d;
- d.x = fma(a.x, b.x, c.x);
- d.y = fma(a.y, b.y, c.y);
- return d;
- }
- inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
- uint32_t s = h0_h0(a);
- uint2 d;
- d.x = fma(s, b.x, c.x);
- d.y = fma(s, b.y, c.y);
- return d;
- }
- inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
- uint4 d;
- d.x = fma(a.x, b.x, c.x);
- d.y = fma(a.y, b.y, c.y);
- d.z = fma(a.z, b.z, c.z);
- d.w = fma(a.w, b.w, c.w);
- return d;
- }
- inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
- uint32_t s = h0_h0(a);
- uint4 d;
- d.x = fma(s, b.x, c.x);
- d.y = fma(s, b.y, c.y);
- d.z = fma(s, b.z, c.z);
- d.w = fma(s, b.w, c.w);
- return d;
- }
- inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
- float fa = half_to_float(a);
- float fb = half_to_float(b);
- return fa * fb + fc;
- }
- inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
- float2 fa = half2_to_float2(a);
- float2 fb = half2_to_float2(b);
- return fma(fa, fb, fc);
- }
- inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
- return fma(h0_h0(a), b, fc);
- }
- inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
- Float4_ fd;
- fd.x = fma(a.x, b.x, fc.x);
- fd.y = fma(a.y, b.y, fc.y);
- return fd;
- }
- inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
- uint32_t s = h0_h0(a);
- Float4_ fd;
- fd.x = fma(s, b.x, fc.x);
- fd.y = fma(s, b.y, fc.y);
- return fd;
- }
- inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
- Float8_ fd;
- fd.x = fma(a.x, b.x, fc.x);
- fd.y = fma(a.y, b.y, fc.y);
- fd.z = fma(a.z, b.z, fc.z);
- fd.w = fma(a.w, b.w, fc.w);
- return fd;
- }
- inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
- uint32_t s = h0_h0(a);
- Float8_ fd;
- fd.x = fma(s, b.x, fc.x);
- fd.y = fma(s, b.y, fc.y);
- fd.z = fma(s, b.z, fc.z);
- fd.w = fma(s, b.w, fc.w);
- return fd;
- }
- // Vector sum.
- template<>
- inline __device__ float sum(uint16_t v) {
- return half_to_float(v);
- }
- template<>
- inline __device__ float sum(uint32_t v) {
- float2 tmp = half2_to_float2(v);
- return tmp.x + tmp.y;
- }
- template<>
- inline __device__ float sum(uint2 v) {
- uint32_t c = add(v.x, v.y);
- return sum(c);
- }
- template<>
- inline __device__ float sum(uint4 v) {
- uint32_t c = add(v.x, v.y);
- c = add(c, v.z);
- c = add(c, v.w);
- return sum(c);
- }
- // From float32 to float16.
- inline __device__ void from_float(uint16_t& dst, float src) {
- dst = float_to_half(src);
- }
- inline __device__ void from_float(uint32_t& dst, float2 src) {
- dst = float2_to_half2(src);
- }
- inline __device__ void from_float(uint2& dst, Float4_ src) {
- dst.x = float2_to_half2(src.x);
- dst.y = float2_to_half2(src.y);
- }
- inline __device__ void from_float(uint4& dst, Float8_ src) {
- dst.x = float2_to_half2(src.x);
- dst.y = float2_to_half2(src.y);
- dst.z = float2_to_half2(src.z);
- dst.w = float2_to_half2(src.w);
- }
- // From float16 to float32.
- inline __device__ float to_float(uint16_t u) {
- return half_to_float(u);
- }
- inline __device__ float2 to_float(uint32_t u) {
- return half2_to_float2(u);
- }
- inline __device__ Float4_ to_float(uint2 u) {
- Float4_ tmp;
- tmp.x = half2_to_float2(u.x);
- tmp.y = half2_to_float2(u.y);
- return tmp;
- }
- inline __device__ Float8_ to_float(uint4 u) {
- Float8_ tmp;
- tmp.x = half2_to_float2(u.x);
- tmp.y = half2_to_float2(u.y);
- tmp.z = half2_to_float2(u.z);
- tmp.w = half2_to_float2(u.w);
- return tmp;
- }
- // Zero-out a variable.
- inline __device__ void zero(uint16_t& dst) {
- dst = uint16_t(0);
- }
- } // namespace aphrodite
|