model_runner.py 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633
  1. import dataclasses
  2. import gc
  3. import itertools
  4. import os
  5. import time
  6. import warnings
  7. import weakref
  8. from dataclasses import dataclass
  9. from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
  10. TypeVar, Union)
  11. import numpy as np
  12. import torch
  13. import torch.distributed
  14. import torch.nn as nn
  15. from loguru import logger
  16. from aphrodite.attention import AttentionMetadata, get_attn_backend
  17. from aphrodite.attention.backends.abstract import AttentionState
  18. from aphrodite.attention.backends.utils import CommonAttentionState
  19. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  20. LoRAConfig, ModelConfig, ParallelConfig,
  21. PromptAdapterConfig, SchedulerConfig)
  22. from aphrodite.common.sampling_params import SamplingParams
  23. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  24. SequenceGroupMetadata)
  25. from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
  26. async_tensor_h2d, flatten_2d_lists, is_hip,
  27. is_pin_memory_available)
  28. from aphrodite.distributed import get_pp_group
  29. from aphrodite.distributed.parallel_state import (
  30. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
  31. graph_capture)
  32. from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
  33. from aphrodite.lora.layers import LoRAMapping
  34. from aphrodite.lora.request import LoRARequest
  35. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  36. from aphrodite.modeling import SamplingMetadata, SamplingMetadataCache
  37. from aphrodite.modeling.model_loader import get_model
  38. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  39. from aphrodite.modeling.models.interfaces import (supports_lora,
  40. supports_multimodal)
  41. from aphrodite.modeling.models.utils import set_cpu_offload_max_bytes
  42. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  43. MultiModalInputs, MultiModalRegistry)
  44. from aphrodite.prompt_adapter.layers import PromptAdapterMapping
  45. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  46. from aphrodite.prompt_adapter.worker_manager import (
  47. LRUCacheWorkerPromptAdapterManager)
  48. from aphrodite.task_handler.model_runner_base import (
  49. ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
  50. _add_attn_metadata_broadcastable_dict,
  51. _add_sampling_metadata_broadcastable_dict,
  52. _init_attn_metadata_from_tensor_dict,
  53. _init_sampling_metadata_from_tensor_dict)
  54. if TYPE_CHECKING:
  55. from aphrodite.attention.backends.abstract import AttentionBackend
  56. LORA_WARMUP_RANK = 8
  57. _BATCH_SIZE_ALIGNMENT = 8
  58. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  59. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  60. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  61. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  62. ]
  63. _NUM_WARMUP_ITERS = 2
  64. APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE = int(
  65. os.environ.get("APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE", "0"))
  66. TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
  67. @dataclass(frozen=True)
  68. class ModelInputForGPU(ModelRunnerInputBase):
  69. """
  70. This base class contains metadata needed for the base model forward pass
  71. but not metadata for possible additional steps, e.g., sampling. Model
  72. runners that run additional steps should subclass this method to add
  73. additional fields.
  74. """
  75. input_tokens: Optional[torch.Tensor] = None
  76. input_positions: Optional[torch.Tensor] = None
  77. seq_lens: Optional[List[int]] = None
  78. query_lens: Optional[List[int]] = None
  79. lora_mapping: Optional["LoRAMapping"] = None
  80. lora_requests: Optional[Set[LoRARequest]] = None
  81. attn_metadata: Optional["AttentionMetadata"] = None
  82. prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
  83. prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
  84. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  85. request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
  86. finished_requests_ids: Optional[List[str]] = None
  87. virtual_engine: int = 0
  88. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  89. tensor_dict = {
  90. "input_tokens": self.input_tokens,
  91. "input_positions": self.input_positions,
  92. "lora_requests": self.lora_requests,
  93. "lora_mapping": self.lora_mapping,
  94. "multi_modal_kwargs": self.multi_modal_kwargs,
  95. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  96. "prompt_adapter_requests": self.prompt_adapter_requests,
  97. "virtual_engine": self.virtual_engine,
  98. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  99. "finished_requests_ids": self.finished_requests_ids,
  100. }
  101. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  102. return tensor_dict
  103. @classmethod
  104. def from_broadcasted_tensor_dict(
  105. cls: Type[TModelInputForGPU],
  106. tensor_dict: Dict[str, Any],
  107. attn_backend: Optional["AttentionBackend"] = None,
  108. ) -> TModelInputForGPU:
  109. if attn_backend is not None:
  110. tensor_dict = _init_attn_metadata_from_tensor_dict(
  111. attn_backend, tensor_dict)
  112. return cls(**tensor_dict)
  113. @dataclass(frozen=True)
  114. class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
  115. """
  116. Used by the ModelRunner.
  117. """
  118. sampling_metadata: Optional["SamplingMetadata"] = None
  119. # Used for speculative decoding. We do not broadcast it because it is only
  120. # used by the driver worker.
  121. is_prompt: Optional[bool] = None
  122. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  123. tensor_dict = {
  124. "input_tokens": self.input_tokens,
  125. "input_positions": self.input_positions,
  126. "lora_requests": self.lora_requests,
  127. "lora_mapping": self.lora_mapping,
  128. "multi_modal_kwargs": self.multi_modal_kwargs,
  129. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  130. "prompt_adapter_requests": self.prompt_adapter_requests,
  131. "virtual_engine": self.virtual_engine,
  132. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  133. "finished_requests_ids": self.finished_requests_ids,
  134. }
  135. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  136. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  137. self.sampling_metadata)
  138. return tensor_dict
  139. @classmethod
  140. def from_broadcasted_tensor_dict(
  141. cls,
  142. tensor_dict: Dict[str, Any],
  143. attn_backend: Optional["AttentionBackend"] = None,
  144. ) -> "ModelInputForGPUWithSamplingMetadata":
  145. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  146. if attn_backend is not None:
  147. tensor_dict = _init_attn_metadata_from_tensor_dict(
  148. attn_backend, tensor_dict)
  149. return cls(**tensor_dict)
  150. class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
  151. """Build ModelInputForGPU from SequenceGroupMetadata."""
  152. # NOTE: ideally, we dould be using a dataclass(kw_only=True)
  153. # here, so that this can be subclassed easily, but kw_only
  154. # is not supported in python<3.10.
  155. class InterDataForSeqGroup:
  156. """Intermediate data for the current sequence group."""
  157. def simple_reinit(self):
  158. self.input_tokens[0].clear() # type: ignore
  159. self.input_positions[0].clear() # type: ignore
  160. self.seq_lens[0] = 0 # type: ignore
  161. self.orig_seq_lens[0] = 0 # type: ignore
  162. self.query_lens[0] = 0 # type: ignore
  163. self.context_lens[0] = 0 # type: ignore
  164. self.curr_sliding_window_blocks[0] = 0 # type: ignore
  165. self.lora_index_mapping.clear() # type: ignore
  166. self.lora_prompt_mapping.clear() # type: ignore
  167. self.lora_requests.clear() # type: ignore
  168. self.prompt_adapter_index_mapping.clear() # type: ignore
  169. self.prompt_adapter_prompt_mapping.clear() # type: ignore
  170. def __init__(
  171. self,
  172. *,
  173. # From sequence group metadata.
  174. request_id: str,
  175. seq_ids: List[int],
  176. is_prompt: bool,
  177. block_tables: Optional[Dict[int, List[int]]],
  178. computed_block_nums: List[int],
  179. n_seqs: int = 0,
  180. # Input tokens and positions.
  181. input_tokens: Optional[List[List[int]]] = None,
  182. input_positions: Optional[List[List[int]]] = None,
  183. # The sequence length (may be capped to the sliding window).
  184. seq_lens: Optional[List[int]] = None,
  185. # The original sequence length (before applying sliding window).
  186. # This is used to compute slot mapping.
  187. orig_seq_lens: Optional[List[int]] = None,
  188. # The query length.
  189. query_lens: Optional[List[int]] = None,
  190. # The number of tokens that are already computed.
  191. context_lens: Optional[List[int]] = None,
  192. # The current sliding window block.
  193. curr_sliding_window_blocks: Optional[List[int]] = None,
  194. # LoRA inputs.
  195. lora_index_mapping: Optional[List[List[int]]] = None,
  196. lora_prompt_mapping: Optional[List[List[int]]] = None,
  197. lora_requests: Optional[Set[LoRARequest]] = None,
  198. # Prompt adapter inputs.
  199. prompt_adapter_index_mapping: Optional[List[int]] = None,
  200. prompt_adapter_prompt_mapping: Optional[List[int]] = None,
  201. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  202. # Multi-modal inputs.
  203. multi_modal_inputs: Optional[MultiModalInputs] = None,
  204. # Whether the prefix cache is hit (prefill only).
  205. prefix_cache_hit: bool = False,
  206. reinit: bool = False,
  207. reinit_use_defaults: bool = False,
  208. ):
  209. if reinit:
  210. assert len(self.seq_ids) == len(seq_ids) # type: ignore
  211. for i, seq_id in enumerate(seq_ids):
  212. self.seq_ids[i] = seq_id # type: ignore
  213. else:
  214. self.seq_ids = seq_ids
  215. self.request_id = request_id
  216. self.is_prompt = is_prompt
  217. self.block_tables = block_tables
  218. self.computed_block_nums = computed_block_nums
  219. self.n_seqs = n_seqs
  220. if reinit:
  221. if len(self.seq_ids) == 1 and reinit_use_defaults:
  222. self.simple_reinit()
  223. else:
  224. if input_tokens:
  225. self.input_tokens = input_tokens
  226. else:
  227. for seq_id in range(len(self.seq_ids)):
  228. self.input_tokens[seq_id].clear()
  229. if input_positions:
  230. self.input_positions = input_positions
  231. else:
  232. for seq_id in range(len(self.seq_ids)):
  233. self.input_positions[seq_id].clear()
  234. if seq_lens:
  235. self.seq_lens = seq_lens
  236. else:
  237. for seq_id in range(len(self.seq_ids)):
  238. self.seq_lens[seq_id] = 0
  239. if orig_seq_lens:
  240. self.orig_seq_lens = orig_seq_lens
  241. else:
  242. for seq_id in range(len(self.seq_ids)):
  243. self.orig_seq_lens[seq_id] = 0
  244. if query_lens:
  245. self.query_lens = query_lens
  246. else:
  247. for seq_id in range(len(self.seq_ids)):
  248. self.query_lens[seq_id] = 0
  249. if context_lens:
  250. self.context_lens = context_lens
  251. else:
  252. for seq_id in range(len(self.seq_ids)):
  253. self.context_lens[seq_id] = 0
  254. if curr_sliding_window_blocks:
  255. self.curr_sliding_window_blocks = \
  256. curr_sliding_window_blocks
  257. else:
  258. for seq_id in range(len(self.seq_ids)):
  259. self.curr_sliding_window_blocks[seq_id] = 0
  260. if lora_index_mapping:
  261. self.lora_index_mapping = lora_index_mapping
  262. else:
  263. self.lora_index_mapping.clear()
  264. if lora_prompt_mapping:
  265. self.lora_prompt_mapping = lora_prompt_mapping
  266. else:
  267. self.lora_prompt_mapping.clear()
  268. if lora_requests:
  269. self.lora_requests = lora_requests
  270. else:
  271. self.lora_requests.clear()
  272. if prompt_adapter_index_mapping:
  273. self.prompt_adapter_index_mapping = \
  274. prompt_adapter_index_mapping
  275. else:
  276. self.prompt_adapter_index_mapping.clear()
  277. if prompt_adapter_prompt_mapping:
  278. self.prompt_adapter_prompt_mapping = \
  279. prompt_adapter_prompt_mapping
  280. else:
  281. self.prompt_adapter_prompt_mapping.clear()
  282. else:
  283. self.input_tokens = input_tokens or []
  284. self.input_positions = input_positions or []
  285. self.seq_lens = seq_lens or []
  286. self.orig_seq_lens = orig_seq_lens or []
  287. self.query_lens = query_lens or []
  288. self.context_lens = context_lens or []
  289. self.curr_sliding_window_blocks = \
  290. curr_sliding_window_blocks or []
  291. self.lora_index_mapping = lora_index_mapping or []
  292. self.lora_prompt_mapping = lora_prompt_mapping or []
  293. self.lora_requests = lora_requests or set()
  294. self.prompt_adapter_index_mapping = (
  295. prompt_adapter_index_mapping or [])
  296. self.prompt_adapter_prompt_mapping = (
  297. prompt_adapter_prompt_mapping or [])
  298. self.prompt_adapter_request = prompt_adapter_request
  299. self.multi_modal_inputs = multi_modal_inputs
  300. self.prefix_cache_hit = prefix_cache_hit
  301. self.n_seqs = len(self.seq_ids)
  302. if not reinit:
  303. self.__post_init__()
  304. def __post_init__(self):
  305. self.n_seqs = len(self.seq_ids)
  306. self.input_tokens = [[] for _ in range(self.n_seqs)]
  307. self.input_positions = [[] for _ in range(self.n_seqs)]
  308. self.seq_lens = [0] * self.n_seqs
  309. self.orig_seq_lens = [0] * self.n_seqs
  310. self.query_lens = [0] * self.n_seqs
  311. self.context_lens = [0] * self.n_seqs
  312. self.curr_sliding_window_blocks = [0] * self.n_seqs
  313. self.lora_index_mapping = []
  314. self.lora_prompt_mapping = []
  315. def gen_inter_data_builder(self, num_seqs: int):
  316. return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
  317. request_id="",
  318. seq_ids=[0] * num_seqs,
  319. is_prompt=True,
  320. block_tables=None,
  321. computed_block_nums=[])
  322. def init_cached_inter_data(self, *args, **kwargs):
  323. assert len(args) == 0
  324. assert "seq_ids" in kwargs
  325. seq_ids = kwargs["seq_ids"]
  326. num_seqs = len(seq_ids)
  327. # The inter-data cache is per model_runner
  328. inter_data_cache = self.runner.inter_data_cache
  329. if num_seqs not in inter_data_cache:
  330. inter_data_cache[num_seqs] = PyObjectCache(
  331. self.gen_inter_data_builder(num_seqs))
  332. obj = inter_data_cache[num_seqs].get_object()
  333. obj.__init__(*args, **kwargs)
  334. return obj
  335. def reset_cached_inter_data(self):
  336. for cache in self.runner.inter_data_cache.values():
  337. cache.reset()
  338. def __init__(self,
  339. runner: "GPUModelRunnerBase",
  340. finished_requests_ids: Optional[List[str]] = None):
  341. super().__init__()
  342. # Compute functions for each sequence in a sequence group.
  343. # WARNING: The order of the functions matters!
  344. self.per_seq_compute_fns = [
  345. self._compute_lens,
  346. self._compute_for_prefix_cache_hit,
  347. self._compute_for_sliding_window,
  348. self._compute_lora_input,
  349. ]
  350. # Compute functions for each sequence group.
  351. # WARNING: The order of the functions matters!
  352. self.per_seq_group_compute_fns = [
  353. self._compute_prompt_adapter_input,
  354. self._compute_multi_modal_input,
  355. ]
  356. self.runner = runner
  357. self.model_input_cls = self.runner._model_input_cls
  358. self.attn_backend = self.runner.attn_backend
  359. self.scheduler_config = self.runner.scheduler_config
  360. self.sliding_window = self.runner.sliding_window
  361. self.block_size = self.runner.block_size
  362. self.enable_lora = self.runner.lora_config is not None
  363. self.enable_prompt_adapter = (self.runner.prompt_adapter_config
  364. is not None)
  365. self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
  366. self.finished_requests_ids = finished_requests_ids
  367. self.decode_only = True
  368. # Intermediate data (data in CPU before going to GPU) for
  369. # the current sequence group.
  370. self.inter_data_list: List[
  371. ModelInputForGPUBuilder.InterDataForSeqGroup] = []
  372. # Attention metadata inputs.
  373. self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
  374. weakref.proxy(self))
  375. # Engine/Model configurations.
  376. self.chunked_prefill_enabled = (
  377. self.scheduler_config is not None
  378. and self.scheduler_config.chunked_prefill_enabled)
  379. if self.sliding_window is not None:
  380. self.sliding_window_blocks = (
  381. self.sliding_window + self.block_size - 1) // self.block_size
  382. self.block_aligned_sliding_window = \
  383. self.sliding_window_blocks * self.block_size
  384. def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
  385. seq_group_metadata: SequenceGroupMetadata):
  386. """Compute context length, sequence length and tokens
  387. for the given sequence data.
  388. """
  389. seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
  390. token_chunk_size = seq_group_metadata.token_chunk_size
  391. # Compute context length (the number of tokens that are
  392. # already computed) and sequence length (total number of tokens).
  393. seq_len = seq_data.get_len()
  394. if inter_data.is_prompt:
  395. context_len = seq_data.get_num_computed_tokens()
  396. else:
  397. # get_num_computed_tokens is incorrect for spec decoding.
  398. # So, we should have a special logic here.
  399. # TODO: Fix it.
  400. context_len = seq_len - 1
  401. seq_len = min(seq_len, context_len + token_chunk_size)
  402. # Compute tokens.
  403. if inter_data.is_prompt:
  404. tokens = seq_data.get_token_ids()
  405. if context_len != 0 or seq_len < len(tokens):
  406. tokens = tokens[context_len:seq_len]
  407. else:
  408. # Optimization. get_token_ids requires the entire copy of
  409. # tokens.
  410. tokens = seq_data.get_last_token_id()
  411. inter_data.seq_lens[seq_idx] = seq_len
  412. inter_data.orig_seq_lens[seq_idx] = seq_len
  413. inter_data.context_lens[seq_idx] = context_len
  414. if isinstance(tokens, list):
  415. inter_data.input_tokens[seq_idx].extend(tokens)
  416. else:
  417. inter_data.input_tokens[seq_idx].append(tokens)
  418. if (seq_len - context_len) == 1:
  419. inter_data.input_positions[seq_idx].append(seq_len - 1)
  420. else:
  421. inter_data.input_positions[seq_idx].extend(
  422. range(context_len, seq_len))
  423. inter_data.query_lens[
  424. seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
  425. def _compute_for_prefix_cache_hit(
  426. self, inter_data: InterDataForSeqGroup, seq_idx: int,
  427. seq_group_metadata: SequenceGroupMetadata):
  428. """Check if hit prefix cache (i.e., some blocks are already computed).
  429. If hit, update input tokens and positions to only compute the
  430. remaining blocks.
  431. """
  432. computed_block_nums = inter_data.computed_block_nums
  433. # Note that prefix caching does not support sliding window.
  434. prefix_cache_hit = (computed_block_nums is not None
  435. and len(computed_block_nums) > 0
  436. and self.sliding_window is None
  437. and inter_data.is_prompt)
  438. inter_data.prefix_cache_hit = prefix_cache_hit
  439. if not prefix_cache_hit:
  440. return
  441. assert computed_block_nums is not None
  442. # The cache hit prompt tokens in this sequence. Note that
  443. # this may be larger than the sequence length if chunked
  444. # prefill is enabled.
  445. prefix_cache_len = len(computed_block_nums) * self.block_size
  446. # The number of so far computed prompt tokens in this sequence.
  447. context_len = inter_data.context_lens[seq_idx]
  448. # The total number of prompt tokens in this sequence.
  449. # When chunked prefill is enabled, this is the token number of
  450. # computed chunks + current chunk.
  451. seq_len = inter_data.seq_lens[seq_idx]
  452. if prefix_cache_len <= context_len:
  453. # We already passed the cache hit region,
  454. # so do normal computation.
  455. pass
  456. elif context_len < prefix_cache_len < seq_len:
  457. # Partial hit. Compute the missing part.
  458. uncomputed_start = prefix_cache_len - context_len
  459. inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
  460. seq_idx][uncomputed_start:]
  461. inter_data.input_positions[seq_idx] = inter_data.input_positions[
  462. seq_idx][uncomputed_start:]
  463. context_len = prefix_cache_len
  464. inter_data.context_lens[seq_idx] = context_len
  465. inter_data.query_lens[
  466. seq_idx] = inter_data.seq_lens[seq_idx] - context_len
  467. elif seq_len <= prefix_cache_len:
  468. # Full hit. Only compute the last token to avoid
  469. # erroneous behavior. FIXME: Ideally we should directly
  470. # mark all tokens as computed in the scheduler and do not
  471. # schedule this sequence, so this case should not happen.
  472. inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
  473. seq_idx][-1:]
  474. inter_data.input_positions[seq_idx] = inter_data.input_positions[
  475. seq_idx][-1:]
  476. inter_data.query_lens[seq_idx] = 1
  477. inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
  478. def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
  479. seq_idx: int,
  480. seq_group_metadata: SequenceGroupMetadata):
  481. """Update seq_len and curr_sliding_window_block for the given
  482. sequence data (only required by decoding) if sliding window is enabled.
  483. """
  484. curr_sliding_window_block = 0
  485. sliding_seq_len = inter_data.seq_lens[seq_idx]
  486. if not inter_data.is_prompt and self.sliding_window is not None:
  487. # TODO: This is a hack to make sliding window work with
  488. # paged attn. We can remove it if we make paged attn kernel
  489. # to properly handle slinding window attn.
  490. curr_sliding_window_block = self.sliding_window_blocks
  491. if self.scheduler_config.use_v2_block_manager:
  492. # number of elements in last block
  493. suff_len = inter_data.seq_lens[seq_idx] % self.block_size
  494. sliding_seq_len = min(
  495. inter_data.seq_lens[seq_idx],
  496. self.block_aligned_sliding_window + suff_len)
  497. if suff_len > 0:
  498. curr_sliding_window_block += 1
  499. else:
  500. sliding_seq_len = min(inter_data.seq_lens[seq_idx],
  501. self.sliding_window)
  502. inter_data.curr_sliding_window_blocks[
  503. seq_idx] = curr_sliding_window_block
  504. inter_data.seq_lens[seq_idx] = sliding_seq_len
  505. def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
  506. seq_idx: int,
  507. seq_group_metadata: SequenceGroupMetadata):
  508. """If LoRA is enabled, compute LoRA index and prompt mapping."""
  509. if not self.enable_lora:
  510. return
  511. lora_id = seq_group_metadata.lora_int_id
  512. if lora_id > 0:
  513. inter_data.lora_requests.add(seq_group_metadata.lora_request)
  514. query_len = inter_data.query_lens[seq_idx]
  515. inter_data.lora_index_mapping.append([lora_id] * query_len)
  516. sampling_params = seq_group_metadata.sampling_params
  517. if sampling_params and sampling_params.prompt_logprobs is not None:
  518. inter_data.lora_prompt_mapping.append([lora_id] * query_len)
  519. elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
  520. inter_data.lora_prompt_mapping.append([lora_id])
  521. else:
  522. inter_data.lora_prompt_mapping.append([])
  523. def _compute_prompt_adapter_input(
  524. self, inter_data: InterDataForSeqGroup,
  525. seq_group_metadata: SequenceGroupMetadata):
  526. """If prompt adapter is enabled, compute index and prompt mapping.
  527. """
  528. # Note that when is_prompt=True, we expect only one sequence
  529. # in the group.
  530. if not self.enable_prompt_adapter:
  531. return
  532. prompt_adapter_id = seq_group_metadata.prompt_adapter_id
  533. if prompt_adapter_id <= 0 or not inter_data.is_prompt:
  534. return
  535. # We expect only one sequence in the group when is_prompt=True.
  536. assert inter_data.n_seqs == 1
  537. query_len = inter_data.query_lens[0]
  538. inter_data.prompt_adapter_request = (
  539. seq_group_metadata.prompt_adapter_request)
  540. num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
  541. inter_data.prompt_adapter_index_mapping = [
  542. prompt_adapter_id
  543. ] * num_tokens + [0] * (query_len - num_tokens)
  544. inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
  545. query_len if seq_group_metadata.sampling_params
  546. and seq_group_metadata.sampling_params.prompt_logprobs else 1)
  547. def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
  548. seq_group_metadata: SequenceGroupMetadata):
  549. """If multi-modal data is given, add it to the input."""
  550. mm_data = seq_group_metadata.multi_modal_data
  551. if not mm_data:
  552. return
  553. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  554. inter_data.multi_modal_inputs = mm_kwargs
  555. def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
  556. """Add a sequence group to the builder."""
  557. seq_ids = seq_group_metadata.seq_data.keys()
  558. n_seqs = len(seq_ids)
  559. is_prompt = seq_group_metadata.is_prompt
  560. if is_prompt:
  561. assert n_seqs == 1
  562. self.decode_only = False
  563. inter_data = self.init_cached_inter_data(
  564. request_id=seq_group_metadata.request_id,
  565. seq_ids=seq_ids,
  566. is_prompt=is_prompt,
  567. block_tables=seq_group_metadata.block_tables,
  568. computed_block_nums=seq_group_metadata.computed_block_nums,
  569. reinit=True,
  570. reinit_use_defaults=True)
  571. self.inter_data_list.append(inter_data)
  572. for seq_idx in range(n_seqs):
  573. for per_seq_fn in self.per_seq_compute_fns:
  574. per_seq_fn(inter_data, seq_idx, seq_group_metadata)
  575. for per_seq_group_fn in self.per_seq_group_compute_fns:
  576. per_seq_group_fn(inter_data, seq_group_metadata)
  577. def _use_captured_graph(self, batch_size: int,
  578. max_decode_seq_len: int) -> bool:
  579. return (self.decode_only and not self.runner.model_config.enforce_eager
  580. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  581. and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
  582. def build(self) -> ModelInputForGPU:
  583. """Finalize the builder intermediate data and
  584. create on-device tensors.
  585. """
  586. # Combine and flatten intermediate data.
  587. input_tokens = []
  588. for inter_data in self.inter_data_list:
  589. for cur_input_tokens in inter_data.input_tokens:
  590. input_tokens.extend(cur_input_tokens)
  591. if not input_tokens:
  592. # This may happen when all prefill requests hit
  593. # prefix caching and there is no decode request.
  594. return self.model_input_cls()
  595. input_positions = []
  596. for inter_data in self.inter_data_list:
  597. for cur_input_positions in inter_data.input_positions:
  598. input_positions.extend(cur_input_positions)
  599. seq_lens = []
  600. max_decode_seq_len = 0
  601. for inter_data in self.inter_data_list:
  602. seq_lens.extend(inter_data.seq_lens)
  603. if not inter_data.is_prompt:
  604. max_decode_seq_len = max(max_decode_seq_len,
  605. max(inter_data.seq_lens))
  606. query_lens = []
  607. for inter_data in self.inter_data_list:
  608. query_lens.extend(inter_data.query_lens)
  609. # Mapping from request IDs to sequence IDs. Used for Jamba models
  610. # that manages the cache by itself.
  611. request_ids_to_seq_ids = {
  612. data.request_id: data.seq_ids
  613. for data in self.inter_data_list
  614. }
  615. batch_size = len(input_tokens)
  616. use_captured_graph = self._use_captured_graph(batch_size,
  617. max_decode_seq_len)
  618. # If cuda graph can be used, pad tensors accordingly.
  619. # See `capture_model` API for more details.
  620. # Aphrodite uses cuda graph only for decoding requests.
  621. cuda_graph_pad_size = -1
  622. if use_captured_graph:
  623. graph_batch_size = _get_graph_batch_size(batch_size)
  624. assert graph_batch_size >= batch_size
  625. cuda_graph_pad_size = graph_batch_size - batch_size
  626. batch_size = graph_batch_size
  627. # Tokens and positions.
  628. if cuda_graph_pad_size:
  629. input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
  630. input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
  631. assert self.runner.device is not None
  632. input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
  633. self.runner.device,
  634. self.runner.pin_memory)
  635. input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
  636. self.runner.device,
  637. self.runner.pin_memory)
  638. # Sequence and query lengths.
  639. if cuda_graph_pad_size:
  640. seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
  641. # Attention metadata.
  642. attn_metadata = self.attn_metadata_builder.build(
  643. seq_lens, query_lens, cuda_graph_pad_size, batch_size)
  644. # LoRA data.
  645. lora_requests = set()
  646. lora_mapping = None
  647. if self.enable_lora:
  648. lora_requests = set(r for data in self.inter_data_list
  649. for r in data.lora_requests)
  650. lora_index_mapping = flatten_2d_lists([
  651. flatten_2d_lists(inter_data.lora_index_mapping)
  652. for inter_data in self.inter_data_list
  653. ])
  654. if cuda_graph_pad_size:
  655. lora_index_mapping.extend(
  656. itertools.repeat(0, cuda_graph_pad_size))
  657. lora_prompt_mapping = flatten_2d_lists([
  658. flatten_2d_lists(inter_data.lora_prompt_mapping)
  659. for inter_data in self.inter_data_list
  660. ])
  661. lora_mapping = LoRAMapping(
  662. **dict(index_mapping=lora_index_mapping,
  663. prompt_mapping=lora_prompt_mapping,
  664. is_prefill=not self.decode_only))
  665. # Prompt adapter data.
  666. prompt_adapter_requests: Set[PromptAdapterRequest] = set()
  667. prompt_adapter_mapping = None
  668. if self.enable_prompt_adapter:
  669. prompt_adapter_requests = set(
  670. data.prompt_adapter_request for data in self.inter_data_list
  671. if data.prompt_adapter_request is not None)
  672. prompt_adapter_index_mapping = flatten_2d_lists([
  673. inter_data.prompt_adapter_index_mapping
  674. for inter_data in self.inter_data_list
  675. ])
  676. if cuda_graph_pad_size:
  677. prompt_adapter_index_mapping.extend(
  678. itertools.repeat(0, cuda_graph_pad_size))
  679. prompt_adapter_prompt_mapping = flatten_2d_lists([
  680. inter_data.prompt_adapter_prompt_mapping
  681. for inter_data in self.inter_data_list
  682. ])
  683. prompt_adapter_mapping = PromptAdapterMapping(
  684. prompt_adapter_index_mapping,
  685. prompt_adapter_prompt_mapping,
  686. )
  687. # Multi-modal data.
  688. multi_modal_inputs_list = [
  689. data.multi_modal_inputs for data in self.inter_data_list
  690. if data.multi_modal_inputs is not None
  691. ]
  692. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  693. return self.model_input_cls(
  694. input_tokens=input_tokens_tensor,
  695. input_positions=input_positions_tensor,
  696. attn_metadata=attn_metadata,
  697. seq_lens=seq_lens,
  698. query_lens=query_lens,
  699. lora_mapping=lora_mapping,
  700. lora_requests=lora_requests,
  701. multi_modal_kwargs=multi_modal_kwargs,
  702. request_ids_to_seq_ids=request_ids_to_seq_ids,
  703. finished_requests_ids=self.finished_requests_ids,
  704. prompt_adapter_mapping=prompt_adapter_mapping,
  705. prompt_adapter_requests=prompt_adapter_requests)
  706. class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
  707. """
  708. Helper class for shared methods between GPU model runners.
  709. """
  710. _model_input_cls: Type[TModelInputForGPU]
  711. _builder_cls: Type[ModelInputForGPUBuilder]
  712. def __init__(
  713. self,
  714. model_config: ModelConfig,
  715. parallel_config: ParallelConfig,
  716. scheduler_config: SchedulerConfig,
  717. device_config: DeviceConfig,
  718. cache_config: CacheConfig,
  719. load_config: LoadConfig,
  720. lora_config: Optional[LoRAConfig],
  721. kv_cache_dtype: Optional[str] = "auto",
  722. is_driver_worker: bool = False,
  723. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  724. return_hidden_states: bool = False,
  725. input_registry: InputRegistry = INPUT_REGISTRY,
  726. mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
  727. tp_rank: int = 0,
  728. ):
  729. self.model_config = model_config
  730. self.parallel_config = parallel_config
  731. self.scheduler_config = scheduler_config
  732. self.device_config = device_config
  733. self.cache_config = cache_config
  734. self.lora_config = lora_config
  735. self.load_config = load_config
  736. self.is_driver_worker = is_driver_worker
  737. self.prompt_adapter_config = prompt_adapter_config
  738. self.return_hidden_states = return_hidden_states
  739. self.device = self.device_config.device
  740. self.pin_memory = is_pin_memory_available()
  741. self.tp_rank = tp_rank
  742. self.kv_cache_dtype = kv_cache_dtype
  743. self.sliding_window = model_config.get_sliding_window()
  744. self.block_size = cache_config.block_size
  745. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  746. self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
  747. {} for _ in range(self.parallel_config.pipeline_parallel_size)
  748. ]
  749. self.graph_memory_pool: Optional[Tuple[
  750. int, int]] = None # Set during graph capture.
  751. self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
  752. parallel_config)
  753. # When using CUDA graph, the input block tables must be padded to
  754. # max_seq_len_to_capture. However, creating the block table in
  755. # Python can be expensive. To optimize this, we cache the block table
  756. # in numpy and only copy the actual input content at every iteration.
  757. # The shape of the cached block table will be
  758. # (max batch size to capture, max context len to capture / block size).
  759. self.graph_block_tables = np.zeros(
  760. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  761. dtype=np.int32)
  762. self.attn_backend = get_attn_backend(
  763. self.model_config.get_head_size(),
  764. self.model_config.get_sliding_window(),
  765. self.model_config.dtype,
  766. self.kv_cache_dtype,
  767. self.block_size,
  768. self.model_config.is_attention_free(),
  769. )
  770. if self.attn_backend:
  771. self.attn_state = self.attn_backend.get_state_cls()(
  772. weakref.proxy(self))
  773. else:
  774. self.attn_state = CommonAttentionState(weakref.proxy(self))
  775. # Multi-modal data support
  776. self.input_registry = input_registry
  777. self.mm_registry = mm_registry
  778. self.multi_modal_input_mapper = mm_registry \
  779. .create_input_mapper(model_config)
  780. self.mm_registry.init_mm_limits_per_prompt(self.model_config)
  781. # Lazy initialization
  782. self.model: nn.Module # Set after load_model
  783. # Set after load_model.
  784. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  785. self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
  786. set_cpu_offload_max_bytes(
  787. int(self.cache_config.cpu_offload_gb * 1024**3))
  788. # Used to cache python objects
  789. self.inter_data_cache: Dict[int, PyObjectCache] = {}
  790. self.sampling_metadata_cache: SamplingMetadataCache = \
  791. SamplingMetadataCache()
  792. def load_model(self) -> None:
  793. tp = get_tensor_model_parallel_world_size()
  794. rank = get_tensor_model_parallel_rank()
  795. if rank == 0:
  796. logger.info(f"Loading model {self.model_config.model}...")
  797. with CudaMemoryProfiler() as m:
  798. # measure the time it takes to load the model
  799. start_time = time.time()
  800. self.model = get_model(model_config=self.model_config,
  801. device_config=self.device_config,
  802. load_config=self.load_config,
  803. lora_config=self.lora_config,
  804. parallel_config=self.parallel_config,
  805. scheduler_config=self.scheduler_config,
  806. cache_config=self.cache_config)
  807. end_time = time.time()
  808. self.model_memory_usage = m.consumed_memory
  809. total_time = end_time - start_time
  810. if tp > 1:
  811. if rank == 0:
  812. logger.info(f"Model loaded in {total_time:.2f} seconds.")
  813. logger.info(
  814. "Total model weights memory usage: "
  815. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  816. else:
  817. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  818. logger.info(
  819. "Total model weights memory usage: "
  820. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  821. if self.lora_config:
  822. assert supports_lora(self.model), "Model does not support LoRA"
  823. assert not supports_multimodal(
  824. self.model
  825. ), "To be tested: multimodal language model with LoRA settings."
  826. self.lora_manager = LRUCacheWorkerLoRAManager(
  827. self.scheduler_config.max_num_seqs,
  828. self.scheduler_config.max_num_batched_tokens,
  829. self.vocab_size,
  830. self.lora_config,
  831. self.device,
  832. self.model.embedding_modules,
  833. self.model.embedding_padding_modules,
  834. max_position_embeddings=self.model.config.
  835. max_position_embeddings,
  836. )
  837. self.model = self.lora_manager.create_lora_manager(self.model)
  838. if self.prompt_adapter_config:
  839. self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
  840. self.scheduler_config.max_num_seqs,
  841. self.scheduler_config.max_num_batched_tokens, self.device,
  842. self.prompt_adapter_config)
  843. self.model = (
  844. self.prompt_adapter_manager.create_prompt_adapter_manager(
  845. self.model))
  846. if self.kv_cache_dtype == "fp8" and is_hip():
  847. # Currently only ROCm accepts kv-cache scaling factors
  848. # via quantization_param_path and this will be deprecated
  849. # in the future.
  850. if self.model_config.quantization_param_path is not None:
  851. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  852. warnings.warn(
  853. "Loading kv cache scaling factor from JSON is "
  854. "deprecated and will be removed. Please include "
  855. "kv cache scaling factors in the model checkpoint.",
  856. FutureWarning,
  857. stacklevel=2)
  858. self.model.load_kv_cache_scales(
  859. self.model_config.quantization_param_path)
  860. logger.info(
  861. "Loaded KV cache scaling factors from ",
  862. f"{self.model_config.quantization_param_path}")
  863. else:
  864. raise RuntimeError(
  865. "Using FP8 KV cache and scaling factors provided but "
  866. f"model {self.model.__class__} does not support loading"
  867. " scaling factors.", )
  868. else:
  869. logger.warning(
  870. "Using FP8 KV cache but no scaling factors "
  871. "provided. Defaulting to scaling factors of 1.0. "
  872. "This may lead to less accurate results!")
  873. if APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE:
  874. logger.info("Compiling the model using torch.compile...")
  875. start_time = time.time()
  876. self.model = torch.compile(self.model,
  877. fullgraph=True,
  878. backend="eager")
  879. end_time = time.time()
  880. logger.info(
  881. f"Model compiled in {end_time - start_time:.2f} seconds.")
  882. def get_model_memory_usage(self):
  883. return self.model_memory_usage
  884. def save_sharded_state(
  885. self,
  886. path: str,
  887. pattern: Optional[str] = None,
  888. max_size: Optional[int] = None,
  889. ) -> None:
  890. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  891. ShardedStateLoader.save_model(
  892. self.model,
  893. path,
  894. pattern=pattern,
  895. max_size=max_size,
  896. )
  897. def save_tensorized_model(
  898. self,
  899. tensorizer_config: TensorizerConfig,
  900. ) -> None:
  901. from aphrodite.modeling.model_loader.loader import TensorizerLoader
  902. TensorizerLoader.save_model(
  903. self.model,
  904. tensorizer_config=tensorizer_config,
  905. )
  906. def get_max_block_per_batch(self) -> int:
  907. block_size = self.block_size
  908. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  909. def _prepare_model_input_tensors(
  910. self,
  911. seq_group_metadata_list: List[SequenceGroupMetadata],
  912. finished_requests_ids: Optional[List[str]] = None
  913. ) -> TModelInputForGPU:
  914. """Helper method to prepare the model input based on a given sequence
  915. group. Prepares metadata needed for the base model forward pass but not
  916. metadata for possible additional steps, e.g., sampling.
  917. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  918. The result tensors and data structure also batches input in prefill
  919. -> decode order. For example,
  920. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  921. - input_tokens[num_prefill_tokens:] contains decode tokens.
  922. If cuda graph is required, this API automatically pads inputs.
  923. """
  924. builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
  925. for seq_group_metadata in seq_group_metadata_list:
  926. builder.add_seq_group(seq_group_metadata)
  927. builder.reset_cached_inter_data()
  928. return builder.build() # type: ignore
  929. @torch.inference_mode()
  930. def profile_run(self) -> None:
  931. rank = get_tensor_model_parallel_rank()
  932. if rank == 0:
  933. logger.info("Profiling peak memory usage...")
  934. # Enable top-k sampling to reflect the accurate memory usage.
  935. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  936. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  937. max_num_seqs = self.scheduler_config.max_num_seqs
  938. # This represents the maximum number of different requests
  939. # that will have unique loras, an therefore the max amount of memory
  940. # consumption create dummy lora request copies from the lora request
  941. # passed in, which contains a lora from the lora warmup path.
  942. dummy_lora_requests: List[LoRARequest] = []
  943. dummy_lora_requests_per_seq: List[LoRARequest] = []
  944. if self.lora_config:
  945. assert self.lora_manager is not None
  946. with self.lora_manager.dummy_lora_cache():
  947. for idx in range(self.lora_config.max_loras):
  948. lora_id = idx + 1
  949. dummy_lora_request = LoRARequest(
  950. lora_name=f"warmup_{lora_id}",
  951. lora_int_id=lora_id,
  952. lora_local_path="/not/a/real/path",
  953. )
  954. self.lora_manager.add_dummy_lora(dummy_lora_request,
  955. rank=LORA_WARMUP_RANK)
  956. dummy_lora_requests.append(dummy_lora_request)
  957. dummy_lora_requests_per_seq = [
  958. dummy_lora_requests[idx % len(dummy_lora_requests)]
  959. for idx in range(max_num_seqs)
  960. ]
  961. # Profile memory usage with max_num_sequences sequences and the total
  962. # number of tokens equal to max_num_batched_tokens.
  963. seqs: List[SequenceGroupMetadata] = []
  964. # Additional GPU memory may be needed for multi-modal encoding, which
  965. # needs to be accounted for when calculating the GPU blocks for
  966. # Aphrodite blocker manager.
  967. # To exercise the worst scenario for GPU memory consumption,
  968. # the number of seqs (batch_size) is chosen to maximize the number
  969. # of images processed.
  970. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
  971. self.model_config)
  972. if max_mm_tokens > 0:
  973. max_num_seqs_orig = max_num_seqs
  974. max_num_seqs = min(max_num_seqs,
  975. max_num_batched_tokens // max_mm_tokens)
  976. if max_num_seqs < 1:
  977. expr = (f"min({max_num_seqs_orig}, "
  978. f"{max_num_batched_tokens} // {max_mm_tokens})")
  979. logger.warning(
  980. f"Computed max_num_seqs ({expr}) to be less than 1. "
  981. "Setting it to the minimum value of 1.")
  982. max_num_seqs = 1
  983. batch_size = 0
  984. for group_id in range(max_num_seqs):
  985. seq_len = (max_num_batched_tokens // max_num_seqs +
  986. (group_id < max_num_batched_tokens % max_num_seqs))
  987. batch_size += seq_len
  988. seq_data, dummy_multi_modal_data = self.input_registry \
  989. .dummy_data_for_profiling(self.model_config,
  990. seq_len,
  991. self.mm_registry)
  992. seq = SequenceGroupMetadata(
  993. request_id=str(group_id),
  994. is_prompt=True,
  995. seq_data={group_id: seq_data},
  996. sampling_params=sampling_params,
  997. block_tables=None,
  998. lora_request=dummy_lora_requests_per_seq[group_id]
  999. if dummy_lora_requests_per_seq else None,
  1000. multi_modal_data=dummy_multi_modal_data,
  1001. )
  1002. seqs.append(seq)
  1003. # Run the model with the dummy inputs.
  1004. num_layers = self.model_config.get_num_layers(self.parallel_config)
  1005. kv_caches = [None] * num_layers
  1006. finished_requests_ids = [seq.request_id for seq in seqs]
  1007. model_input = self.prepare_model_input(
  1008. seqs, finished_requests_ids=finished_requests_ids)
  1009. intermediate_tensors = None
  1010. if not get_pp_group().is_first_rank:
  1011. intermediate_tensors = self.model.make_empty_intermediate_tensors(
  1012. batch_size=batch_size,
  1013. dtype=self.model_config.dtype,
  1014. device=self.device)
  1015. self.execute_model(model_input, kv_caches, intermediate_tensors)
  1016. torch.cuda.synchronize()
  1017. return
  1018. def remove_all_loras(self):
  1019. if not self.lora_manager:
  1020. raise RuntimeError("LoRA is not enabled.")
  1021. self.lora_manager.remove_all_adapters()
  1022. def set_active_loras(self, lora_requests: Set[LoRARequest],
  1023. lora_mapping: LoRAMapping) -> None:
  1024. if not self.lora_manager:
  1025. raise RuntimeError("LoRA is not enabled.")
  1026. self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
  1027. def add_lora(self, lora_request: LoRARequest) -> bool:
  1028. if not self.lora_manager:
  1029. raise RuntimeError("LoRA is not enabled.")
  1030. return self.lora_manager.add_adapter(lora_request)
  1031. def remove_lora(self, lora_id: int) -> bool:
  1032. if not self.lora_manager:
  1033. raise RuntimeError("LoRA is not enabled.")
  1034. return self.lora_manager.remove_adapter(lora_id)
  1035. def pin_lora(self, lora_id: int) -> bool:
  1036. if not self.lora_manager:
  1037. raise RuntimeError("LoRA is not enabled.")
  1038. return self.lora_manager.pin_adapter(lora_id)
  1039. def list_loras(self) -> Set[int]:
  1040. if not self.lora_manager:
  1041. raise RuntimeError("LoRA is not enabled.")
  1042. return self.lora_manager.list_adapters()
  1043. def remove_all_prompt_adapters(self):
  1044. if not self.prompt_adapter_manager:
  1045. raise RuntimeError("PromptAdapter is not enabled.")
  1046. self.prompt_adapter_manager.remove_all_adapters()
  1047. def set_active_prompt_adapters(
  1048. self, prompt_adapter_requests: Set[PromptAdapterRequest],
  1049. prompt_adapter_mapping: PromptAdapterMapping) -> None:
  1050. if not self.prompt_adapter_manager:
  1051. raise RuntimeError("PromptAdapter is not enabled.")
  1052. self.prompt_adapter_manager.set_active_adapters(
  1053. prompt_adapter_requests, prompt_adapter_mapping)
  1054. def add_prompt_adapter(
  1055. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  1056. if not self.prompt_adapter_manager:
  1057. raise RuntimeError("PromptAdapter is not enabled.")
  1058. return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
  1059. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1060. if not self.prompt_adapter_manager:
  1061. raise RuntimeError("PromptAdapter is not enabled.")
  1062. return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
  1063. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1064. if not self.prompt_adapter_manager:
  1065. raise RuntimeError("PromptAdapter is not enabled.")
  1066. return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
  1067. def list_prompt_adapters(self) -> Set[int]:
  1068. if not self.prompt_adapter_manager:
  1069. raise RuntimeError("PromptAdapter is not enabled.")
  1070. return self.prompt_adapter_manager.list_adapters()
  1071. @torch.inference_mode()
  1072. def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
  1073. """Cuda graph capture a model.
  1074. Note that CUDA graph's performance gain is negligible if number
  1075. of batched tokens are larger than 200. And since CUDA graph
  1076. requires fixed sized tensors, supporting large/variable batch
  1077. size requires high GPU memory overhead. Thus, Aphrodite only captures
  1078. decoding requests. Mixed batch (chunked prefill + decoding) or
  1079. prefill requests are not captured.
  1080. Since it is used for decoding-only, it assumes there's only 1 token
  1081. per sequence in the batch.
  1082. """
  1083. tp_rank = get_tensor_model_parallel_rank()
  1084. assert not self.model_config.enforce_eager
  1085. if tp_rank == 0:
  1086. logger.info(
  1087. "Capturing the model for CUDA graphs. This may lead to "
  1088. "unexpected consequences if the model is not static. To "
  1089. "run the model in eager mode, set 'enforce_eager=True' or "
  1090. "use '--enforce-eager' in the CLI.")
  1091. logger.info(
  1092. "CUDA graphs can take additional 1~3 GiB memory per GPU. "
  1093. "If you are running out of memory, consider decreasing "
  1094. "`gpu_memory_utilization` or enforcing eager mode. "
  1095. "You can also reduce the `max_num_seqs` as needed "
  1096. "to decrease memory usage.")
  1097. start_time = time.perf_counter()
  1098. # Prepare dummy inputs. These will be reused for all batch sizes.
  1099. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  1100. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  1101. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  1102. intermediate_inputs = None
  1103. if not get_pp_group().is_first_rank:
  1104. intermediate_inputs = self.model.make_empty_intermediate_tensors(
  1105. batch_size=max_batch_size,
  1106. dtype=self.model_config.dtype,
  1107. device=self.device)
  1108. # Prepare buffer for outputs. These will be reused for all batch sizes.
  1109. # It will be filled after the first graph capture.
  1110. hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
  1111. None
  1112. ] * self.parallel_config.pipeline_parallel_size
  1113. graph_batch_size = _get_graph_batch_size(
  1114. self.scheduler_config.max_num_seqs)
  1115. batch_size_capture_list = [
  1116. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  1117. ]
  1118. with self.attn_state.graph_capture(
  1119. max_batch_size), graph_capture() as graph_capture_context:
  1120. # NOTE: Capturing the largest batch size first may help reduce the
  1121. # memory usage of CUDA graph.
  1122. for virtual_engine in range(
  1123. self.parallel_config.pipeline_parallel_size):
  1124. for batch_size in reversed(batch_size_capture_list):
  1125. attn_metadata = (
  1126. self.attn_state.graph_capture_get_metadata_for_batch(
  1127. batch_size))
  1128. if self.lora_config:
  1129. lora_mapping = LoRAMapping(
  1130. **dict(index_mapping=[0] * batch_size,
  1131. prompt_mapping=[0] * batch_size,
  1132. is_prefill=False))
  1133. self.set_active_loras(set(), lora_mapping)
  1134. if self.prompt_adapter_config:
  1135. prompt_adapter_mapping = PromptAdapterMapping(
  1136. [-1] * batch_size,
  1137. [-1] * batch_size,
  1138. )
  1139. self.set_active_prompt_adapters(
  1140. set(), prompt_adapter_mapping)
  1141. graph_runner = CUDAGraphRunner(
  1142. self.model, self.attn_backend.get_name(),
  1143. self.attn_state.graph_clone(batch_size))
  1144. capture_inputs = {
  1145. "input_ids":
  1146. input_tokens[:batch_size],
  1147. "positions":
  1148. input_positions[:batch_size],
  1149. "hidden_or_intermediate_states":
  1150. hidden_or_intermediate_states[
  1151. virtual_engine] # type: ignore
  1152. [:batch_size]
  1153. if hidden_or_intermediate_states[virtual_engine]
  1154. is not None else None,
  1155. "intermediate_inputs":
  1156. intermediate_inputs[:batch_size]
  1157. if intermediate_inputs is not None else None,
  1158. "kv_caches":
  1159. kv_caches[virtual_engine],
  1160. "attn_metadata":
  1161. attn_metadata,
  1162. "memory_pool":
  1163. self.graph_memory_pool,
  1164. "stream":
  1165. graph_capture_context.stream
  1166. }
  1167. if self.has_seqlen_agnostic:
  1168. # Only used by Mamba-based models CUDA graph atm (Jamba)
  1169. capture_inputs.update({
  1170. "seqlen_agnostic_capture_inputs":
  1171. self.model.get_seqlen_agnostic_capture_inputs(
  1172. batch_size)
  1173. })
  1174. graph_runner.capture(**capture_inputs)
  1175. self.graph_memory_pool = graph_runner.graph.pool()
  1176. self.graph_runners[virtual_engine][batch_size] = (
  1177. graph_runner)
  1178. end_time = time.perf_counter()
  1179. elapsed_time = end_time - start_time
  1180. # This usually takes < 10 seconds.
  1181. if tp_rank == 0:
  1182. logger.info(f"Graph capturing finished in {elapsed_time:.2f} secs")
  1183. @property
  1184. def vocab_size(self) -> int:
  1185. return self.model_config.get_vocab_size()
  1186. class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
  1187. """
  1188. GPU model runner with sampling step.
  1189. """
  1190. _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
  1191. ModelInputForGPUWithSamplingMetadata)
  1192. _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
  1193. def make_model_input_from_broadcasted_tensor_dict(
  1194. self,
  1195. tensor_dict: Dict[str, Any],
  1196. ) -> ModelInputForGPUWithSamplingMetadata:
  1197. model_input = \
  1198. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  1199. tensor_dict,
  1200. attn_backend=self.attn_backend,
  1201. )
  1202. return model_input
  1203. def prepare_model_input(
  1204. self,
  1205. seq_group_metadata_list: List[SequenceGroupMetadata],
  1206. virtual_engine: int = 0,
  1207. finished_requests_ids: Optional[List[str]] = None
  1208. ) -> ModelInputForGPUWithSamplingMetadata:
  1209. """Prepare the model input based on a given sequence group, including
  1210. metadata for the sampling step.
  1211. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  1212. The result tensors and data structure also batches input in prefill
  1213. -> decode order. For example,
  1214. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  1215. - input_tokens[num_prefill_tokens:] contains decode tokens.
  1216. If cuda graph is required, this API automatically pads inputs.
  1217. """
  1218. model_input = self._prepare_model_input_tensors(
  1219. seq_group_metadata_list, finished_requests_ids)
  1220. if get_pp_group().is_last_rank:
  1221. # Sampling metadata is only required for the final pp group
  1222. generators = self.get_generators(finished_requests_ids)
  1223. sampling_metadata = SamplingMetadata.prepare(
  1224. seq_group_metadata_list, model_input.seq_lens,
  1225. model_input.query_lens, self.device, self.pin_memory,
  1226. generators, self.sampling_metadata_cache)
  1227. else:
  1228. sampling_metadata = None
  1229. is_prompt = (seq_group_metadata_list[0].is_prompt
  1230. if seq_group_metadata_list else None)
  1231. return dataclasses.replace(model_input,
  1232. sampling_metadata=sampling_metadata,
  1233. is_prompt=is_prompt,
  1234. virtual_engine=virtual_engine)
  1235. @torch.inference_mode()
  1236. def execute_model(
  1237. self,
  1238. model_input: ModelInputForGPUWithSamplingMetadata,
  1239. kv_caches: List[torch.Tensor],
  1240. intermediate_tensors: Optional[IntermediateTensors] = None,
  1241. num_steps: int = 1,
  1242. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  1243. if num_steps > 1:
  1244. raise ValueError("num_steps > 1 is not supported in ModelRunner")
  1245. if self.lora_config:
  1246. assert model_input.lora_requests is not None
  1247. assert model_input.lora_mapping is not None
  1248. self.set_active_loras(model_input.lora_requests,
  1249. model_input.lora_mapping)
  1250. if self.prompt_adapter_config:
  1251. assert model_input.prompt_adapter_requests is not None
  1252. assert model_input.prompt_adapter_mapping is not None
  1253. self.set_active_prompt_adapters(
  1254. model_input.prompt_adapter_requests,
  1255. model_input.prompt_adapter_mapping)
  1256. self.attn_state.begin_forward(model_input)
  1257. # Currently cuda graph is only supported by the decode phase.
  1258. assert model_input.attn_metadata is not None
  1259. prefill_meta = model_input.attn_metadata.prefill_metadata
  1260. decode_meta = model_input.attn_metadata.decode_metadata
  1261. # TODO: We can remove this once all
  1262. # virtual engines share the same kv cache.
  1263. virtual_engine = model_input.virtual_engine
  1264. if prefill_meta is None and decode_meta.use_cuda_graph:
  1265. assert model_input.input_tokens is not None
  1266. graph_batch_size = model_input.input_tokens.shape[0]
  1267. model_executable = self.graph_runners[virtual_engine][
  1268. graph_batch_size]
  1269. else:
  1270. model_executable = self.model
  1271. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  1272. seqlen_agnostic_kwargs = {
  1273. "finished_requests_ids": model_input.finished_requests_ids,
  1274. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  1275. } if self.has_seqlen_agnostic else {}
  1276. hidden_or_intermediate_states = model_executable(
  1277. input_ids=model_input.input_tokens,
  1278. positions=model_input.input_positions,
  1279. kv_caches=kv_caches,
  1280. attn_metadata=model_input.attn_metadata,
  1281. intermediate_tensors=intermediate_tensors,
  1282. **MultiModalInputs.as_kwargs(multi_modal_kwargs,
  1283. device=self.device),
  1284. **seqlen_agnostic_kwargs,
  1285. )
  1286. # Compute the logits in the last pipeline stage.
  1287. if not get_pp_group().is_last_rank:
  1288. return hidden_or_intermediate_states
  1289. logits = self.model.compute_logits(hidden_or_intermediate_states,
  1290. model_input.sampling_metadata)
  1291. if not self.is_driver_worker:
  1292. return []
  1293. # Sample the next token.
  1294. output: SamplerOutput = self.model.sample(
  1295. logits=logits,
  1296. sampling_metadata=model_input.sampling_metadata,
  1297. )
  1298. if self.return_hidden_states:
  1299. # we only need to pass hidden states of most recent token
  1300. assert model_input.sampling_metadata is not None
  1301. indices = model_input.sampling_metadata.selected_token_indices
  1302. if model_input.is_prompt:
  1303. hidden_states = hidden_or_intermediate_states.index_select(
  1304. 0, indices)
  1305. elif decode_meta.use_cuda_graph:
  1306. hidden_states = hidden_or_intermediate_states[:len(indices)]
  1307. else:
  1308. hidden_states = hidden_or_intermediate_states
  1309. output.hidden_states = hidden_states
  1310. return [output]
  1311. class CUDAGraphRunner:
  1312. def __init__(self, model: nn.Module, backend_name: str,
  1313. attn_state: AttentionState):
  1314. self.model = model
  1315. self.backend_name = backend_name
  1316. self.attn_state = attn_state
  1317. self.input_buffers: Dict[str, torch.Tensor] = {}
  1318. self.output_buffers: Dict[str, torch.Tensor] = {}
  1319. self._graph: Optional[torch.cuda.CUDAGraph] = None
  1320. @property
  1321. def graph(self):
  1322. assert self._graph is not None
  1323. return self._graph
  1324. def capture(
  1325. self,
  1326. input_ids: torch.Tensor,
  1327. positions: torch.Tensor,
  1328. hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
  1329. torch.Tensor]],
  1330. intermediate_inputs: Optional[IntermediateTensors],
  1331. kv_caches: List[torch.Tensor],
  1332. attn_metadata: AttentionMetadata,
  1333. memory_pool: Optional[Tuple[int, int]],
  1334. stream: torch.cuda.Stream,
  1335. **kwargs,
  1336. ) -> Union[torch.Tensor, IntermediateTensors]:
  1337. assert self._graph is None
  1338. # Run the model a few times without capturing the graph.
  1339. # This is to make sure that the captured graph does not include the
  1340. # kernel launches for initial benchmarking (e.g., Triton autotune).
  1341. # Note one iteration is not enough for torch.jit.script
  1342. for _ in range(_NUM_WARMUP_ITERS):
  1343. self.model(
  1344. input_ids,
  1345. positions,
  1346. kv_caches,
  1347. attn_metadata,
  1348. intermediate_inputs,
  1349. **kwargs,
  1350. )
  1351. torch.cuda.synchronize()
  1352. # Capture the graph.
  1353. self._graph = torch.cuda.CUDAGraph()
  1354. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  1355. output_hidden_or_intermediate_states = self.model(
  1356. input_ids,
  1357. positions,
  1358. kv_caches,
  1359. attn_metadata,
  1360. intermediate_inputs,
  1361. **kwargs,
  1362. )
  1363. if hidden_or_intermediate_states is not None:
  1364. if get_pp_group().is_last_rank:
  1365. hidden_or_intermediate_states.copy_(
  1366. output_hidden_or_intermediate_states)
  1367. else:
  1368. for key in hidden_or_intermediate_states.tensors:
  1369. hidden_or_intermediate_states[key].copy_(
  1370. output_hidden_or_intermediate_states[key])
  1371. else:
  1372. hidden_or_intermediate_states = (
  1373. output_hidden_or_intermediate_states)
  1374. del output_hidden_or_intermediate_states
  1375. # make sure `output_hidden_states` is deleted
  1376. # in the graph's memory pool
  1377. gc.collect()
  1378. torch.cuda.synchronize()
  1379. # Save the input and output buffers.
  1380. self.input_buffers = {
  1381. "input_ids": input_ids,
  1382. "positions": positions,
  1383. "kv_caches": kv_caches,
  1384. **self.attn_state.get_graph_input_buffers(attn_metadata),
  1385. **kwargs,
  1386. }
  1387. if intermediate_inputs is not None:
  1388. self.input_buffers.update(intermediate_inputs.tensors)
  1389. if get_pp_group().is_last_rank:
  1390. self.output_buffers = {
  1391. "hidden_states": hidden_or_intermediate_states
  1392. }
  1393. else:
  1394. self.output_buffers = hidden_or_intermediate_states
  1395. return hidden_or_intermediate_states
  1396. def forward(
  1397. self,
  1398. input_ids: torch.Tensor,
  1399. positions: torch.Tensor,
  1400. kv_caches: List[torch.Tensor],
  1401. attn_metadata: AttentionMetadata,
  1402. intermediate_tensors: Optional[IntermediateTensors],
  1403. **kwargs,
  1404. ) -> torch.Tensor:
  1405. # KV caches are fixed tensors, so we don't need to copy them.
  1406. del kv_caches
  1407. # Copy the input tensors to the input buffers.
  1408. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1409. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1410. if self.backend_name != "No attention":
  1411. self.input_buffers["slot_mapping"].copy_(
  1412. attn_metadata.slot_mapping, non_blocking=True)
  1413. self.attn_state.prepare_graph_input_buffers(self.input_buffers,
  1414. attn_metadata)
  1415. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1416. self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
  1417. **kwargs)
  1418. if intermediate_tensors is not None:
  1419. for key in intermediate_tensors.tensors:
  1420. self.input_buffers[key].copy_(intermediate_tensors[key],
  1421. non_blocking=True)
  1422. # Run the graph.
  1423. self.graph.replay()
  1424. # Return the output tensor.
  1425. if get_pp_group().is_last_rank:
  1426. return self.output_buffers["hidden_states"]
  1427. return self.output_buffers
  1428. def __call__(self, *args, **kwargs):
  1429. return self.forward(*args, **kwargs)
  1430. def _get_graph_batch_size(batch_size: int) -> int:
  1431. """Returns the padded batch size given actual batch size.
  1432. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1433. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1434. """
  1435. if batch_size <= 2:
  1436. return batch_size
  1437. elif batch_size <= 4:
  1438. return 4
  1439. else:
  1440. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1441. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)