|
@@ -474,9 +474,18 @@ class FlashAttentionMetadataBuilder(
|
|
|
# The shape of graph_block_tables is
|
|
|
# [max batch size, max context len // block size].
|
|
|
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
|
|
+ max_blocks = input_block_tables.shape[1]
|
|
|
for i, block_table in enumerate(self.block_tables):
|
|
|
if block_table:
|
|
|
- input_block_tables[i, :len(block_table)] = block_table
|
|
|
+ num_blocks = len(block_table)
|
|
|
+ if num_blocks <= max_blocks:
|
|
|
+ input_block_tables[i, :num_blocks] = block_table
|
|
|
+ else:
|
|
|
+ # It may be possible to have more blocks allocated due
|
|
|
+ # to lookahead slots of multi-step, however, they are
|
|
|
+ # not used anyway, so can be safely ignored.
|
|
|
+ input_block_tables[
|
|
|
+ i, :max_blocks] = block_table[:max_blocks]
|
|
|
block_tables = torch.from_numpy(input_block_tables).to(
|
|
|
device=device, non_blocking=True)
|
|
|
else:
|