attention_generic.cuh 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. /*
  2. * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
  3. * Copyright (c) 2023, The PygmalionAI team.
  4. * Copyright (c) 2023, The vLLM team.
  5. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #pragma once
  20. #include <stdint.h>
  21. namespace aphrodite {
  22. // A vector type to store Q, K, V elements.
  23. template<typename T, int VEC_SIZE>
  24. struct Vec {};
  25. // A vector type to store FP32 accumulators.
  26. template<typename T>
  27. struct FloatVec {};
  28. // Template vector operations.
  29. template<typename Acc, typename A, typename B>
  30. inline __device__ Acc mul(A a, B b);
  31. template<typename T>
  32. inline __device__ float sum(T v);
  33. template<typename T>
  34. inline __device__ float dot(T a, T b) {
  35. return sum(mul<T, T, T>(a, b));
  36. }
  37. template<typename A, typename T>
  38. inline __device__ float dot(T a, T b) {
  39. return sum(mul<A, T, T>(a, b));
  40. }
  41. template<typename T>
  42. inline __device__ void zero(T& dst) {
  43. constexpr int WORDS = sizeof(T) / 4;
  44. union {
  45. T raw;
  46. uint32_t words[WORDS];
  47. } tmp;
  48. #pragma unroll
  49. for (int ii = 0; ii < WORDS; ++ii) {
  50. tmp.words[ii] = 0u;
  51. }
  52. dst = tmp.raw;
  53. }
  54. } // namespace aphrodite