123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- #pragma once
- #include "cuda_bf16_wrapper.h"
- #include <cuda_fp16.h>
- #include <cuda_runtime_api.h>
- #include <stdint.h>
- #include <stdio.h>
- #include <stdlib.h>
- #define CHECK_CUDA(call) \
- do { \
- cudaError_t status_ = call; \
- if (status_ != cudaSuccess) { \
- fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
- exit(1); \
- } \
- } while (0)
- template<typename T>
- struct Multihead_attention_params_base {
-
- T* out = nullptr;
-
- const T *q = nullptr, *q_bias = nullptr;
-
- const T *k = nullptr, *k_bias = nullptr;
-
- const T *v = nullptr, *v_bias = nullptr;
-
- T* k_cache = nullptr;
-
- T* v_cache = nullptr;
-
- const int* cache_indir = nullptr;
-
- int stride_q = 0;
- int stride_k = 0;
- int stride_v = 0;
-
- int batch_size = 0;
-
- int beam_width = 0;
-
- int memory_max_len = 0;
-
- int num_heads = 0;
- int num_heads_kv = 0;
- int num_heads_q_kv_ratio = 0;
-
- int hidden_size_per_head = 0;
-
- int rotary_embedding_dim = 0;
- bool neox_rotary_style = false;
- float rotary_base = 0.0f;
-
- int max_input_length = 0;
-
- int timestep = 0;
-
-
- float inv_sqrt_dh = 0.0f;
-
- const int* total_padding_tokens = nullptr;
- const bool* masked_tokens = nullptr;
- const int* prefix_prompt_lengths = nullptr;
- int max_prefix_prompt_length = 0;
- const T* relative_attention_bias = nullptr;
- int relative_attention_bias_stride = 0;
-
- const T* linear_bias_slopes = nullptr;
- const T* ia3_key_weights = nullptr;
- const T* ia3_value_weights = nullptr;
- const int* ia3_tasks = nullptr;
- const float* qkv_scale_out = nullptr;
- const float* attention_out_scale = nullptr;
- int int8_mode = 0;
- const T *rotary_cos = nullptr;
- const T *rotary_sin = nullptr;
- const int *nnz_head_idx = nullptr;
- int nnz_heads = 0;
- };
- template<typename T, bool CROSS_ATTENTION>
- struct Multihead_attention_params: public Multihead_attention_params_base<T> {
-
- float* cross_attention_out = nullptr;
- int max_decoder_seq_len = 0;
- bool is_return_cross_attentions = false;
-
- bool* finished = nullptr;
-
-
- int* memory_length_per_sample = nullptr;
-
- const int* length_per_sample = nullptr;
- };
- template<typename T>
- struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
-
- float* cross_attention_out = nullptr;
- int max_decoder_seq_len = 0;
- bool is_return_cross_attentions = false;
-
- bool* finished = nullptr;
-
- int* memory_length_per_sample = nullptr;
-
- const int* length_per_sample = nullptr;
- };
- template<class T>
- using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
- template<class T>
- using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
- template<typename T>
- struct outputCrossAttentionParam {
-
- int max_decoder_seq_len = 0;
- T* cross_attention_out = nullptr;
- bool is_return_cross_attentions = false;
- };
- void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
- void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
- #ifdef ENABLE_BF16
- void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
- const cudaStream_t& stream);
- #endif
- void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
- void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
- #ifdef ENABLE_BF16
- void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
- const cudaStream_t& stream);
- #endif
|