* Add training code from `l1mb_nnx` * Add support distributed training via sharding * Add support for multiple samples per sequence as in `lm1b` by modifying the attention layer. * Replace `tfds` with `grain`.