|
@@ -111,7 +111,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|
|
raise NotImplementedError("TPU version must be 4 or higher.")
|
|
|
|
|
|
self.megacore_mode = None
|
|
|
- tpu_type = torch_xla.tpu.get_tp_groupu_env()["TYPE"].lower()
|
|
|
+ tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
|
|
|
if not tpu_type.endswith("lite"):
|
|
|
if self.num_kv_heads % 2 == 0:
|
|
|
self.megacore_mode = "kv_head"
|