layers.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from dataclasses import dataclass
  2. from typing import Optional, Dict
  3. import torch
  4. from torch import nn
  5. @dataclass
  6. class ControlVectorMapping:
  7. layer_mapping: Dict[int, torch.Tensor]
  8. class BaseLayerWithControlVector(nn.Module):
  9. pass
  10. class MLPWithControlVector(BaseLayerWithControlVector):
  11. def __init__(self, base_layer) -> None:
  12. super().__init__()
  13. self.base_layer = base_layer
  14. self.control_vectors = {}
  15. self.normalize = False
  16. self.active_vector: torch.Tensor = None
  17. def set_normalization(self, normalize: bool) -> None:
  18. self.normalize = normalize
  19. def set_layer_id(self, layer_id: int) -> None:
  20. self.layer_id = layer_id
  21. def set_control_vector(self, index: int, cv_vector: torch.Tensor):
  22. """Set a control vector at a specific index."""
  23. self.control_vectors[index] = cv_vector
  24. def get_control_vector(self, index: int) -> Optional[torch.Tensor]:
  25. """Get a control vector by index."""
  26. return self.control_vectors.get(index)
  27. def reset_control_vector(self, index: int):
  28. """Reset a control vector to zero at a specific index."""
  29. if index in self.control_vectors:
  30. self.control_vectors[index] = 0
  31. def set_active_tensor(self, index: int):
  32. if index is not None and index in self.control_vectors:
  33. self.active_vector = self.control_vectors[index]
  34. else:
  35. self.active_vector = None
  36. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  37. """Forward pass with optional application of control vectors."""
  38. hidden_states = self.base_layer(hidden_states)
  39. norm_pre = torch.norm(hidden_states, dim=-1, keepdim=True)
  40. cv = self.active_vector
  41. if cv is not None and cv.numel() > 0:
  42. hidden_states += cv
  43. if self.normalize:
  44. print("HERE")
  45. hidden_states = hidden_states * norm_pre / torch.norm(
  46. hidden_states, dim=-1, keepdim=True)
  47. return hidden_states