|
@@ -480,16 +480,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
|
|
// if (cute::thread(32, 0)) { print(scores); }
|
|
|
|
|
|
// Softcapping - calculating dTanh and scaling dS later with it
|
|
|
- auto dtanh = ([&]{
|
|
|
- if constexpr (Is_softcap) {
|
|
|
- Tensor _dtanh = make_tensor_like(scores);
|
|
|
- flash::calculate_dtanh(scores, _dtanh, params.softcap);
|
|
|
- return _dtanh;
|
|
|
- }
|
|
|
- else {
|
|
|
- return nullptr;
|
|
|
- }
|
|
|
- }());
|
|
|
+ Tensor dtanh = make_tensor_like(scores);
|
|
|
+ if constexpr (Is_softcap) {
|
|
|
+ flash::calculate_dtanh(scores, dtanh, params.softcap);
|
|
|
+ }
|
|
|
|
|
|
// Alibi
|
|
|
if (Has_alibi) {
|
|
@@ -591,13 +585,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
|
|
for (int mi = 0; mi < size<0>(dS); ++mi) {
|
|
|
#pragma unroll
|
|
|
for (int ni = 0; ni < size<1>(dS); ++ni) {
|
|
|
-
|
|
|
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
|
|
|
-
|
|
|
- if constexpr (Is_softcap) {
|
|
|
- scaled_ds *= dtanh(mi, ni);
|
|
|
- }
|
|
|
-
|
|
|
+ if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
|
|
|
dS(mi, ni) = scaled_ds;
|
|
|
}
|
|
|
}
|