12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from dataclasses import dataclass
- from typing import Optional, Dict
- import torch
- from torch import nn
- @dataclass
- class ControlVectorMapping:
- layer_mapping: Dict[int, torch.Tensor]
- class BaseLayerWithControlVector(nn.Module):
- pass
- class MLPWithControlVector(BaseLayerWithControlVector):
- def __init__(self, base_layer) -> None:
- super().__init__()
- self.base_layer = base_layer
- self.control_vectors = {}
- self.normalize = False
- self.active_vector: torch.Tensor = None
- def set_normalization(self, normalize: bool) -> None:
- self.normalize = normalize
- def set_layer_id(self, layer_id: int) -> None:
- self.layer_id = layer_id
- def set_control_vector(self, index: int, cv_vector: torch.Tensor):
- """Set a control vector at a specific index."""
- self.control_vectors[index] = cv_vector
- def get_control_vector(self, index: int) -> Optional[torch.Tensor]:
- """Get a control vector by index."""
- return self.control_vectors.get(index)
- def reset_control_vector(self, index: int):
- """Reset a control vector to zero at a specific index."""
- if index in self.control_vectors:
- self.control_vectors[index] = 0
- def set_active_tensor(self, index: int):
- if index is not None and index in self.control_vectors:
- self.active_vector = self.control_vectors[index]
- else:
- self.active_vector = None
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Forward pass with optional application of control vectors."""
- hidden_states = self.base_layer(hidden_states)
- norm_pre = torch.norm(hidden_states, dim=-1, keepdim=True)
- cv = self.active_vector
- if cv is not None and cv.numel() > 0:
- hidden_states += cv
- if self.normalize:
- print("HERE")
- hidden_states = hidden_states * norm_pre / torch.norm(
- hidden_states, dim=-1, keepdim=True)
- return hidden_states
|