# Training configuration
All hyperparameters are specified and editable in the training scripts (meaning are in comments). This page introduces some useful keys and their description.

## Dataset parameters

| Parameter                               | Description                                                                            | Type         | Default                                        |
| --------------------------------------- | -------------------------------------------------------------------------------------- | ------------ | ---------------------------------------------- |
| `dataset_name`                          | Name of the dataset.                                                                   | `str`        |                                                |
| `dataset_level`                         | Level of the dataset. Should be named after the element type.                          | `str`        |                                                |
| `dataset_variant`                       | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str`        |                                                |
| `dataset_path`                          | Path to the dataset.                                                                   | `str`        |                                                |
| `dataset_params.config.dataset_manager` | Dataset manager class.                                                                 | custom class | `OCRDatasetManager`                            |
| `dataset_params.config.dataset_class`   | Dataset class.                                                                         | custom class | `OCRDataset`                                   |
| `dataset_params.config.datasets`        | Dataset dictionary with the dataset name as key and dataset path as value.             | `dict`       |                                                |
| `dataset_params.config.load_in_memory`  | Load all images in CPU memory.                                                         | `str`        | `True`                                         |
| `dataset_params.config.worker_per_gpu`  | Number of parallel processes per gpu for data loading.                                 | `int`        | `4`                                            |
| `dataset_params.config.preprocessings`  | List of pre-processing functions to apply to input images.                             | `list`       | (see [dedicated section](#data-preprocessing)) |
| `dataset_params.config.augmentation`    | Configuration for data augmentation.                                                   | `dict`       | (see [dedicated section](#data-augmentation))  |


### Data preprocessing

Preprocessing is applied before training the network (see `dan/manager/dataset.py`).
The following transformations are implemented:

* Convert to grayscale
```py
{
    "type": "to_grayscaled"
}
```
* Convert to RGB
```py
{
    "type": "to_RGB"
}
```
* Resize to a fixed height
```py
{
    "type": "fixed_height",
    "fixed_height": 1000,
}
```
* Resize to a maximum size
```py
{
    "type": "resize",
    "keep_ratio": True,
    "max_height": 1000,
    "max_width": None,
}
```

Multiple transformations can be combined. For example, to resize an image to a fixed height of 1000 pixels and convert images to RGB, use the following configuration in `dataset_params.config.preprocessings`:

```py
[
    {
        "type": "fixed_height",
        "fixed_height": 1000
    },
    {
        "type": "to_RGB"
    }
]
```

### Data augmentation

Augmentation transformations are applied on-the-fly during training to artificially increase data variability.

The following transformations are implemented in `dan/transforms.py`:
* Color inversion
* Dilation and erosion
* Elastic distortion
* Reducing interline spacing
* Gaussian blur
* Gaussian noise

DAN also takes advantage of [transforms from torchvision](https://pytorch.org/vision/stable/transforms.html):
* ColorJitter
* GaussianBlur
* RandomCrop
* RandomPerspective

The following configuration is used by default when using the `teklia-dan train document` command. Data augmentation is applied with a probability of 0.9, and each transformation has a 0.1 probability to be used.

```py
{
        "order": "random",
        "proba": 0.9,
        "augmentations": [
            {
                "type": "perspective",
                "proba": 0.1,
                "min_factor": 0,
                "max_factor": 0.4,
            },
            {
                "type": "elastic_distortion",
                "proba": 0.1,
                "min_alpha": 0.5,
                "max_alpha": 1,
                "min_sigma": 1,
                "max_sigma": 10,
                "min_kernel_size": 3,
                "max_kernel_size": 9,
            },
            {
                "type": "dilation_erosion",
                "proba": 0.1,
                "min_kernel": 1,
                "max_kernel": 3,
                "iterations": 1,
            },
            {
                "type": "color_jittering",
                "proba": 0.1,
                "factor_hue": 0.2,
                "factor_brightness": 0.4,
                "factor_contrast": 0.4,
                "factor_saturation": 0.4,
            },
            {
                "type": "gaussian_blur",
                "proba": 0.1,
                "min_kernel": 3,
                "max_kernel": 5,
                "min_sigma": 3,
                "max_sigma": 5,
            },
            {
                "type": "gaussian_noise",
                "proba": 0.1,
                "std": 0.5,
            },
            {
                "type": "sharpen",
                "proba": 0.1,
                "min_alpha": 0,
                "max_alpha": 1,
                "min_strength": 0,
                "max_strength": 1,
            },
        ],
    }
```

## Model parameters

| Name                                      | Description                                                                          | Type          | Default                                                           |
| ----------------------------------------- | ------------------------------------------------------------------------------------ | ------------- | ----------------------------------------------------------------- |
| `model_params.models.encoder`             | Encoder class.                                                                       | custom class  | `FCN_encoder`                                                     |
| `model_params.models.decoder`             | Decoder class.                                                                       | custom class  | `GlobalHTADecoder`                                                |
| `model_params.transfer_learning.encoder`  | Model to load for the encoder [state_dict_name, checkpoint_path, learnable, strict]. | `list`        | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]`  |
| `model_params.transfer_learning.decoder`  | Model to load for the decoder [state_dict_name, checkpoint_path, learnable, strict]. | `list`        | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` |
| `model_params.transfered_charset`         | Transfer learning of the decision layer based on charset of the model to transfer.   | `bool`        | `True`                                                            |
| `model_params.additional_tokens`          | For decision layer = [<eot>, ], only for transferred charset.                        | `int`         | `1`                                                               |
| `model_params.input_channels`             | Number of channels of input image.                                                   | `int`         | `3`                                                               |
| `model_params.dropout`                    | Dropout probability in the encoder.                                                  | `float`       | `0.5`                                                             |
| `model_params.enc_dim`                    | Dimension of features extracted by the encoder.                                      | `int`         | `256`                                                             |
| `model_params.nb_layers`                  | Number of layers in the encoder.                                                     | `int`         | `5`                                                               |
| `model_params.h_max`                      | Maximum height for encoder output (for 2D positional embedding).                     | `int`         | `500`                                                             |
| `model_params.w_max`                      | Maximum width for encoder output (for 2D positional embedding).                      | `int`         | `1000`                                                            |
| `model_params.l_max`                      | Maximum predicted sequence length (for 1D positional embedding).                     | `int`         | `15000`                                                           |
| `model_params.dec_num_layers`             | Number of transformer decoder layers.                                                | `int`         | `8`                                                               |
| `model_params.dec_num_heads`              | Number of heads in transformer decoder layers.                                       | `int`         | `4`                                                               |
| `model_params.dec_res_dropout`            | Dropout probability in transformer decoder layers.                                   | `int`         | `0.1`                                                             |
| `model_params.dec_pred_dropout`           | Dropout rate before decision layer.                                                  | `float`       | `0.1`                                                             |
| `model_params.dec_att_dropout`            | Dropout rate in multi head attention.                                                | `float`       | `0.1`                                                             |
| `model_params.dec_dim_feedforward`        | Number of dimensions for feedforward layer in transformer decoder layers.            | `int`         | `256`                                                             |
| `model_params.use_2d_pe`                  | Whether to use 2D positional embedding.                                              | `bool`        | `True`                                                            |
| `model_params.use_1d_pe`                  | Whether to use 1D positional embedding.                                              | `bool`        | `True`                                                            |
| `model_params.use_lstm`                   | Whether to use a LSTM layer in the decoder.                                          | `bool`        | `False`                                                           |
| `model_params.attention_win`              | Length of attention window.                                                          | `int`         | `100`                                                             |
| `model_params.dropout_scheduler.function` | Curriculum dropout scheduler.                                                        | custom class. | `100`                                                             |
| `model_params.dropout_scheduler.T`        | Exponential factor.                                                                  | `float`       | `5e4`                                                             |


## Training parameters

| Name                                                        | Description                                                                 | Type         | Default                                     |
| ----------------------------------------------------------- | --------------------------------------------------------------------------- | ------------ | ------------------------------------------- |
| `training_params.output_folder`                             | Directory for checkpoint and results.                                       | `str`        |                                             |
| `training_params.max_nb_epochs`                             | Maximum number of epochs before stopping training.                          | `int`        | `800`                                       |
| `training_params.max_training_time`                         | Maximum time (in seconds) before stopping training.                         | `int`        | `350000`                                    |
| `training_params.load_epoch`                                | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str`        | `"last"`                                    |
| `training_params.interval_save_weights`                     | Step to save weights. Set to `None` to keep only best and last epochs.      | `int`        | `None`                                      |
| `training_params.batch_size`                                | Mini-batch size for the training loop.                                      | `int`        | `2`                                         |
| `training_params.valid_batch_size`                          | Mini-batch size for the valdiation loop.                                    | `int`        | `4`                                         |
| `training_params.use_ddp`                                   | Whether to use DistributedDataParallel.                                     | `bool`       | `False`                                     |
| `training_params.ddp_port`                                  | DDP port.                                                                   | `int`        | `20027`                                     |
| `training_params.use_amp`                                   | Whether to enable automatic mix-precision.                                  | `int`        | `torch.cuda.device_count()`                 |
| `training_params.nb_gpu`                                    | Number of GPUs to train DAN.                                                | `str`        |                                             |
| `training_params.optimizers.all.class`                      | Optimizer class.                                                            | custom class | `Adam`                                      |
| `training_params.optimizers.all.args.lr`                    | Learning rate for the optimizer.                                            | `float`      | `0.0001`                                    |
| `training_params.optimizers.all.args.amsgrad`               | Whether to use AMSGrad optimization.                                        | custom class | `False`                                     |
| `training_params.lr_schedulers`                             | Learning rate schedulers.                                                   | custom class | `None`                                      |
| `training_params.eval_on_valid`                             | Whether to evaluate and log metrics on the validation set during training.  | `bool`       | `True`                                      |
| `training_params.eval_on_valid_interval`                    | Interval (in epochs) to evaluate during training.                           | `int`        | `5`                                         |
| `training_params.focus_metric`                              | Metrics to focus on to determine best epoch.                                | `str`        | `cer`                                       |
| `training_params.expected_metric_value`                     | Best value for the focus metric. Should be either `"high"` or `"low"`.      | `low`        | `cer`                                       |
| `training_params.set_name_focus_metric`                     | Dataset to focus on to select best weights.                                 | `str`        |                                             |
| `training_params.train_metrics`                             | List of metrics to compute during training.                                 | `list`       | `["loss_ce", "cer", "wer", "wer_no_punct"]` |
| `training_params.train_metrics`                             | List of metrics to compute during validation.                               | `list`       | `["cer", "wer", "wer_no_punct"]`            |
| `training_params.force_cpu`                                 | Whether to train on CPU (for debugging).                                    | `bool`       | `False`                                     |
| `training_params.max_char_prediction`                       | Maximum number of characters to predict.                                    | `int`        | `1000`                                      |
| `training_params.label_noise_scheduler.min_error_rate`  | Minimum ratio of teacher forcing.                                           | `float`      | `0.2`                                       |
| `training_params.label_noise_scheduler.max_error_rate`  | Maximum ratio of teacher forcing.                                           | `float`      | `0.2`                                       |
| `training_params.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing.                            | `float`      | `5e4`                                       |


## MLFlow logging

To log your experiment on MLFlow, you need to:
- install the extra requirements via

    ```shell
    $ pip install .[mlflow]
    ```

- update the following arguments:

| Name                           | Description                          | Type  | Default |
| ------------------------------ | ------------------------------------ | ----- | ------- |
| `mlflow.run_id`                | Name of the current run in MLflow.   | `str` |         |
| `mlflow.s3_endpoint_url`       | URL of S3 endpoint.                  | `str` |         |
| `mlflow.aws_access_key_id`     | Access key id to the AWS server.     | `str` |         |
| `mlflow.aws_secret_access_key` | Secret access key to the AWS server. | `str` |         |