Browse Source

fix: kv cache size calculation on TPUs

AlpinDale 7 tháng trước cách đây
mục cha
commit
af1286f9fa
1 tập tin đã thay đổi với 7 bổ sung6 xóa
  1. 7 6
      aphrodite/task_handler/tpu_worker.py

+ 7 - 6
aphrodite/task_handler/tpu_worker.py

@@ -117,14 +117,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
         xm.wait_device_ops()
 
         m = xm.get_memory_info(self.device)
-        program_size = 1024 * 1024 * 1024  # 1GB
-        free_bytes = max(m["bytes_limit"] - m["bytes_used"] - program_size, 0)
-        kv_cache_bytes = int(free_bytes *
-                             self.cache_config.gpu_memory_utilization)
-        kv_cache_dtype_btyes = get_dtype_size(self.cache_dtype)
+        total_memory_size = m["bytes_limit"]
+        usable_memory_size = int(total_memory_size *
+                                 self.cache_config.gpu_memory_utilization)
+        profiled = m["bytes_used"]  # Weights + intermediate activations.
+        kv_cache_bytes = max(usable_memory_size - profiled, 0)
+        dtype_btyes = get_dtype_size(self.cache_dtype)
         block_size = self.cache_config.block_size
         num_tpu_blocks = (kv_cache_bytes //
-                          (kv_cache_dtype_btyes * block_size * num_layers * 2 *
+                          (dtype_btyes * block_size * num_layers * 2 *
                            head_size * num_kv_heads))
         num_tpu_blocks = (num_tpu_blocks // 8) * 8  # Round down to 8.
         return num_tpu_blocks, 0