|
@@ -16,8 +16,8 @@ __global__ void silu_and_mul_kernel(
|
|
|
scalar_t* __restrict__ out, // [..., d]
|
|
|
const scalar_t* __restrict__ input, // [..., 2, d]
|
|
|
const int d) {
|
|
|
- const int token_idx = blockIdx.x;
|
|
|
- for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
|
|
+ const int64_t token_idx = blockIdx.x;
|
|
|
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
|
|
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
|
|
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
|
|
out[token_idx * d + idx] = silu(x) * y;
|
|
@@ -30,7 +30,7 @@ void silu_and_mul(
|
|
|
torch::Tensor& out, // [..., d]
|
|
|
torch::Tensor& input) // [..., 2 * d]
|
|
|
{
|
|
|
- int num_tokens = input.numel() / input.size(-1);
|
|
|
+ int64_t num_tokens = input.numel() / input.size(-1);
|
|
|
int d = input.size(-1) / 2;
|
|
|
|
|
|
dim3 grid(num_tokens);
|
|
@@ -55,8 +55,8 @@ __global__ void activation_kernel(
|
|
|
scalar_t* __restrict__ out, // [..., d]
|
|
|
const scalar_t* __restrict__ input, // [..., d]
|
|
|
const int d) {
|
|
|
- const int token_idx = blockIdx.x;
|
|
|
- for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
|
|
+ const int64_t token_idx = blockIdx.x;
|
|
|
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
|
|
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
|
|
out[token_idx * d + idx] = ACT_FN(x);
|
|
|
}
|
|
@@ -67,7 +67,7 @@ __global__ void activation_kernel(
|
|
|
// Launch element-wise activation kernel.
|
|
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
|
|
int d = input.size(-1); \
|
|
|
- int num_tokens = input.numel() / d; \
|
|
|
+ int64_t num_tokens = input.numel() / d; \
|
|
|
dim3 grid(num_tokens); \
|
|
|
dim3 block(std::min(d, 1024)); \
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|