1234567891011121314151617181920212223 |
- # @package _global_
- defaults:
- - /experiment/owt/base.yaml
- - override /model: gpt2-hf
- - override /model/gpt2model: gpt2-small
- - override /callbacks: [default, norm-monitor, flop-count]
- datamodule:
- batch_size: 8
- train:
- # Use the standard torch.nn.CrossEntropyLoss
- loss_fn: null
- callbacks:
- flop_count:
- input_size:
- - ${datamodule.max_length}
- input_dtype:
- # It's surprisingly hard to get hydra to return torch.long since it's not a callable
- _target_: torch.__getattribute__
- _args_:
- - long
|