layers.py 368 B

1234567891011121314
  1. from dataclasses import dataclass
  2. from typing import Tuple
  3. @dataclass
  4. class AdapterMapping:
  5. # Per every token in input_ids:
  6. index_mapping: Tuple[int, ...]
  7. # Per sampled token:
  8. prompt_mapping: Tuple[int, ...]
  9. def __post_init__(self):
  10. self.index_mapping = tuple(self.index_mapping)
  11. self.prompt_mapping = tuple(self.prompt_mapping)