Skip to content
Snippets Groups Projects

Training configuration

All hyperparameters are specified and editable in the training scripts (meaning are in comments). This page introduces some useful keys and their description.

Dataset parameters

Parameter Description Type Default
dataset_name Name of the dataset. str
dataset_level Level of the dataset. Should be named after the element type. str
dataset_variant Variant of the dataset. Usually empty for HTR datasets, "_sem" for HTR+NER datasets. str
dataset_path Path to the dataset. str
dataset_params.config.dataset_manager Dataset manager class. custom class OCRDatasetManager
dataset_params.config.dataset_class Dataset class. custom class OCRDataset
dataset_params.config.datasets Dataset dictionary with the dataset name as key and dataset path as value. dict
dataset_params.config.load_in_memory Load all images in CPU memory. str True
dataset_params.config.worker_per_gpu Number of parallel processes per gpu for data loading. int 4
dataset_params.config.height_divisor Factor to reduce the width of the feature vector before feeding the decoder. int 8
dataset_params.config.width_divisor Factor to reduce the height of the feature vector before feeding the decoder. int 32
dataset_params.config.padding_value Image padding value. int 0
dataset_params.config.padding_token Transcription padding value. int None
dataset_params.config.constraints Whether to add end-of-transcription and start-of-transcription tokens in labels. list []
dataset_params.config.preprocessings List of pre-processing functions to apply to input images. list (see dedicated section)
dataset_params.config.augmentation Configuration for data augmentation. dict (see dedicated section)

Data preprocessing

Preprocessing is applied before training the network (see dan/manager/dataset.py). The following transformations are implemented:

  • DPI adjustment
    {
        "type": "dpi",
        "source": 300,
        "target": 150,
    }
  • Convert to grayscale
{
    "type": "to_grayscaled"
}
  • Convert to RGB
{
    "type": "to_RGB"
}
  • Resize to a fixed height
{
    "type": "fixed_height",
    "fixed_height": 1000,
}
  • Resize to a maximum size
{
    "type": "resize",
    "keep_ratio": True,
    "max_height": 1000,
    "max_width": None,
}

Multiple transformations can be combined. For example, to resize an image to a fixed height of 1000 pixels and convert images to RGB, use the following configuration in dataset_params.config.preprocessings:

[
    {
        "type": "fixed_height",
        "fixed_height": 1000
    },
    {
        "type": "to_RGB"
    }
]

Data augmentation

Augmentation transformations are applied on-the-fly during training to artificially increase data variability.

The following transformations are implemented in dan/transforms.py:

  • Color inversion
  • DPI adjusting
  • Dilation and erosion
  • Elastic distortion
  • Reducing interline spacing
  • Gaussian blur
  • Gaussian noise

DAN also takes advantage of transforms from torchvision:

  • ColorJitter
  • GaussianBlur
  • RandomCrop
  • RandomPerspective

The following configuration is used by default when using the teklia-dan train document command. Data augmentation is applied with a probability of 0.9, and each transformation has a 0.1 probability to be used.

{
        "order": "random",
        "proba": 0.9,
        "augmentations": [
            {
                "type": "dpi",
                "proba": 0.1,
                "min_factor": 0.75,
                "max_factor": 1,
                "preserve_ratio": True,
            },
            {
                "type": "perspective",
                "proba": 0.1,
                "min_factor": 0,
                "max_factor": 0.4,
            },
            {
                "type": "elastic_distortion",
                "proba": 0.1,
                "min_alpha": 0.5,
                "max_alpha": 1,
                "min_sigma": 1,
                "max_sigma": 10,
                "min_kernel_size": 3,
                "max_kernel_size": 9,
            },
            {
                "type": "dilation_erosion",
                "proba": 0.1,
                "min_kernel": 1,
                "max_kernel": 3,
                "iterations": 1,
            },
            {
                "type": "color_jittering",
                "proba": 0.1,
                "factor_hue": 0.2,
                "factor_brightness": 0.4,
                "factor_contrast": 0.4,
                "factor_saturation": 0.4,
            },
            {
                "type": "gaussian_blur",
                "proba": 0.1,
                "min_kernel": 3,
                "max_kernel": 5,
                "min_sigma": 3,
                "max_sigma": 5,
            },
            {
                "type": "gaussian_noise",
                "proba": 0.1,
                "std": 0.5,
            },
            {
                "type": "sharpen",
                "proba": 0.1,
                "min_alpha": 0,
                "max_alpha": 1,
                "min_strength": 0,
                "max_strength": 1,
            },
        ],
    }

Model parameters

Name Description Type Default
model_params.models.encoder Encoder class. custom class FCN_encoder
model_params.models.decoder Decoder class. custom class GlobalHTADecoder
model_params.transfer_learning.encoder Model to load for the encoder [state_dict_name, checkpoint_path, learnable, strict]. list ["encoder", "pretrained_models/dan_rimes_page.pt", True, True]
model_params.transfer_learning.decoder Model to load for the decoder [state_dict_name, checkpoint_path, learnable, strict]. list ["encoder", "pretrained_models/dan_rimes_page.pt", True, False]
model_params.transfered_charset Transfer learning of the decision layer based on charset of the model to transfer. bool True
model_params.additional_tokens For decision layer = [, ], only for transferred charset. int 1
model_params.input_channels Number of channels of input image. int 3
model_params.dropout Dropout probability in the encoder. float 0.5
model_params.enc_dim Dimension of features extracted by the encoder. int 256
model_params.nb_layers Number of layers in the encoder. int 5
model_params.h_max Maximum height for encoder output (for 2D positional embedding). int 500
model_params.w_max Maximum width for encoder output (for 2D positional embedding). int 1000
model_params.l_max Maximum predicted sequence length (for 1D positional embedding). int 15000
model_params.dec_num_layers Number of transformer decoder layers. int 8
model_params.dec_num_heads Number of heads in transformer decoder layers. int 4
model_params.dec_res_dropout Dropout probability in transformer decoder layers. int 0.1
model_params.dec_pred_dropout Dropout rate before decision layer. float 0.1
model_params.dec_att_dropout Dropout rate in multi head attention. float 0.1
model_params.dec_dim_feedforward Number of dimensions for feedforward layer in transformer decoder layers. int 256
model_params.use_2d_pe Whether to use 2D positional embedding. bool True
model_params.use_1d_pe Whether to use 1D positional embedding. bool True
model_params.use_lstm Whether to use a LSTM layer in the decoder. bool False
model_params.attention_win Length of attention window. int 100
model_params.dropout_scheduler.function Curriculum dropout scheduler. custom class. 100
model_params.dropout_scheduler.T Exponential factor. float 5e4

Training parameters

Name Description Type Default
training_params.output_folder Directory for checkpoint and results. str
training_params.max_nb_epochs Maximum number of epochs before stopping training. int 800
training_params.max_training_time Maximum time (in seconds) before stopping training. int 350000
training_params.load_epoch Model to load. Should be either "best" (evaluation) or last (training). str "last"
training_params.interval_save_weights Step to save weights. Set to None to keep only best and last epochs. int None
training_params.batch_size Mini-batch size for the training loop. int 2
training_params.valid_batch_size Mini-batch size for the valdiation loop. int 4
training_params.use_ddp Whether to use DistributedDataParallel. bool False
training_params.ddp_port DDP port. int 20027
training_params.use_amp Whether to enable automatic mix-precision. int torch.cuda.device_count()
training_params.nb_gpu Number of GPUs to train DAN. str
training_params.optimizers.all.class Optimizer class. custom class Adam
training_params.optimizers.all.args.lr Learning rate for the optimizer. float 0.0001
training_params.optimizers.all.args.amsgrad Whether to use AMSGrad optimization. custom class False
training_params.lr_schedulers Learning rate schedulers. custom class None
training_params.eval_on_valid Whether to evaluate and log metrics on the validation set during training. bool True
training_params.eval_on_valid_interval Interval (in epochs) to evaluate during training. int 5
training_params.focus_metric Metrics to focus on to determine best epoch. str cer
training_params.expected_metric_value Best value for the focus metric. Should be either "high" or "low". low cer
training_params.set_name_focus_metric Dataset to focus on to select best weights. str
training_params.train_metrics List of metrics to compute during training. list ["loss_ce", "cer", "wer", "wer_no_punct"]
training_params.train_metrics List of metrics to compute during validation. list ["cer", "wer", "wer_no_punct"]
training_params.force_cpu Whether to train on CPU (for debugging). bool False
training_params.max_char_prediction Maximum number of characters to predict. int 1000
training_params.label_noise_scheduler.min_error_rate Minimum ratio of teacher forcing. float 0.2
training_params.label_noise_scheduler.max_error_rate Maximum ratio of teacher forcing. float 0.2
training_params.label_noise_scheduler.total_num_steps Number of steps before stopping teacher forcing. float 5e4

MLFlow logging

To log your experiment on MLFlow, you need to:

  • install the extra requirements via

    $ pip install .[mlflow]
  • update the following arguments:

Name Description Type Default
mlflow.run_id Name of the current run in MLflow. str
mlflow.s3_endpoint_url URL of S3 endpoint. str
mlflow.aws_access_key_id Access key id to the AWS server. str
mlflow.aws_secret_access_key Secret access key to the AWS server. str