# @package _global_ defaults: - /experiment/owt/base.yaml - override /model: gpt2-hf - override /model/gpt2model: gpt2-small - override /callbacks: [default, norm-monitor, flop-count] datamodule: batch_size: 8 train: # Use the standard torch.nn.CrossEntropyLoss loss_fn: null callbacks: flop_count: input_size: - ${datamodule.max_length} input_dtype: # It's surprisingly hard to get hydra to return torch.long since it's not a callable _target_: torch.__getattribute__ _args_: - long