causal_conv1d.h 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. ////////////////////////////////////////////////////////////////////////////////////////////////////
  6. struct ConvParamsBase {
  7. using index_t = uint32_t;
  8. int batch, dim, seqlen, width;
  9. bool silu_activation;
  10. index_t x_batch_stride;
  11. index_t x_c_stride;
  12. index_t x_l_stride;
  13. index_t weight_c_stride;
  14. index_t weight_width_stride;
  15. index_t out_batch_stride;
  16. index_t out_c_stride;
  17. index_t out_l_stride;
  18. index_t conv_state_batch_stride;
  19. index_t conv_state_c_stride;
  20. index_t conv_state_l_stride;
  21. // Common data pointers.
  22. void *__restrict__ x_ptr;
  23. void *__restrict__ weight_ptr;
  24. void *__restrict__ bias_ptr;
  25. void *__restrict__ out_ptr;
  26. void *__restrict__ conv_state_ptr;
  27. void *__restrict__ seq_idx_ptr;
  28. // No __restrict__ since initial_states could be the same as final_states.
  29. void * initial_states_ptr;
  30. index_t initial_states_batch_stride;
  31. index_t initial_states_l_stride;
  32. index_t initial_states_c_stride;
  33. void * final_states_ptr;
  34. index_t final_states_batch_stride;
  35. index_t final_states_l_stride;
  36. index_t final_states_c_stride;
  37. };