gpt2s-hf.yaml 543 B

1234567891011121314151617181920212223
  1. # @package _global_
  2. defaults:
  3. - /experiment/owt/base.yaml
  4. - override /model: gpt2-hf
  5. - override /model/gpt2model: gpt2-small
  6. - override /callbacks: [default, norm-monitor, flop-count]
  7. datamodule:
  8. batch_size: 8
  9. train:
  10. # Use the standard torch.nn.CrossEntropyLoss
  11. loss_fn: null
  12. callbacks:
  13. flop_count:
  14. input_size:
  15. - ${datamodule.max_length}
  16. input_dtype:
  17. # It's surprisingly hard to get hydra to return torch.long since it's not a callable
  18. _target_: torch.__getattribute__
  19. _args_:
  20. - long