|
@@ -339,11 +339,12 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|
|
# be sliced from a tensor whose length is a multiple of 8.
|
|
|
padded_len = (prompt_len + 7) // 8 * 8
|
|
|
bias = torch.empty(
|
|
|
+ 1, # batch_size
|
|
|
self.num_heads,
|
|
|
- padded_len,
|
|
|
+ prompt_len,
|
|
|
padded_len,
|
|
|
device=self.alibi_slopes.device,
|
|
|
- )[:, :prompt_len, :prompt_len].copy_(bias)
|
|
|
+ )[:, :, :, :prompt_len].copy_(bias)
|
|
|
bias.mul_(self.alibi_slopes[:, None, None])
|
|
|
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
|
|
input_metadata.attn_bias.append(attn_bias)
|