Add new optimization from pytorch 2.0
We should upgrade both training and inference code to use the new compilation system provided by PyTorch 2.0. There are some examples in their documentation. It should basically be about adding
model = torch.compile(model)
in the most interesting places. What's more tricky is the DDP support during training.