Skip to content
Snippets Groups Projects

Training configuration

All hyperparameters are specified and editable in the training scripts dan/ocr/document/train.py::get_config (descriptions are in comments). This page introduces some useful keys and theirs descriptions.

Dataset parameters

Parameter Description Type Default
dataset_name Name of the dataset. str
dataset_level Level of the dataset. Should be named after the element type. str
dataset_variant Variant of the dataset. Usually empty for HTR datasets, "_sem" for HTR+NER datasets. str
dataset_path Path to the dataset. str
dataset_params.config.load_in_memory Load all images in CPU memory. bool True
dataset_params.config.worker_per_gpu Number of parallel processes per gpu for data loading. int 4
dataset_params.config.preprocessings List of pre-processing functions to apply to input images. list (see dedicated section)
dataset_params.config.augmentation Whether to use data augmentation on the training set. bool True (see dedicated section)

!!! warning The variables dataset_name, dataset_level, dataset_variant and dataset_path must have values such that the data is located in {dataset_path}/{dataset_name}_{dataset_level}{dataset_variant}.

Data preprocessing

Preprocessing is applied before training the network (see dan/manager/dataset.py). The list of accepted transforms is defined in dan/transforms.py:

class Preprocessing(Enum):
    # If the image is bigger than the given size, resize it while keeping the original ratio
    MaxResize = "max_resize"
    # Resize the height to a fixed value while keeping the original ratio
    FixedHeightResize = "fixed_height_resize"
    # Resize the width to a fixed value while keeping the original ratio
    FixedWidthResize = "fixed_width_resize"

Usage:

  • Resize to a fixed height
[
    {
        "type": Preprocessing.FixedHeightResize,
        "fixed_height": 1500,
    }
]
  • Resize to a fixed width
[
    {
        "type": Preprocessing.FixedWidthResize,
        "fixed_width": 1500,
    }
]
  • Resize to a maximum size (only if the image is bigger than the given size)
[
    {
        "type": Preprocessing.MaxResize,
        "max_height": 2000,
        "max_width": 2000,
    }
]
  • Combine these pre-processings
[
    {
        "type": Preprocessing.FixedHeightResize,
        "fixed_height": 2000,
    },
    {
        "type": Preprocessing.FixedWidthResize,
        "fixed_width": 2000,
    }
]

Data augmentation

Augmentation transformations are applied on-the-fly during training to artificially increase data variability.

DAN takes advantage of transforms from albumentations. The following configuration is used by default when using the teklia-dan train document command. Data augmentation is applied with a probability of 0.9. In this case, two transformations are randomly selected to be applied.

transforms = SomeOf(
    [
        Perspective(scale=(0.05, 0.09), fit_output=True),
        GaussianBlur(sigma_limit=2.5),
        GaussNoise(var_limit=50**2),
        ColorJitter(contrast=0.2, brightness=0.2, saturation=0.2, hue=0.2),
        ElasticTransform(alpha=20.0, sigma=5.0, alpha_affine=1.0, border_mode=0),
        Sharpen(alpha=(0.0, 1.0)),
        ErosionDilation(min_kernel=1, max_kernel=4, iterations=1),
        Affine(shear={"x": (-20, 20), "y": (0, 0)}),
        CoarseDropout(),
        Downscale(scale_min=0.5, scale_max=0.9, interpolation=INTER_NEAREST),
        ToGray(),
    ],
    n=2,
    p=0.9,
)

For a detailed description of all augmentation transforms, see the dedicated page.

Model parameters

Name Description Type Default
model_params.models.encoder Encoder class. custom class FCN_encoder
model_params.models.decoder Decoder class. custom class GlobalHTADecoder
model_params.transfer_learning.encoder Model to load for the encoder [state_dict_name, checkpoint_path, learnable, strict]. list ["encoder", "pretrained_models/dan_rimes_page.pt", True, True]
model_params.transfer_learning.decoder Model to load for the decoder [state_dict_name, checkpoint_path, learnable, strict]. list ["encoder", "pretrained_models/dan_rimes_page.pt", True, False]
model_params.transfered_charset Transfer learning of the decision layer based on charset of the model to transfer. bool True
model_params.additional_tokens For decision layer = [, ], only for transferred charset. int 1
model_params.dropout Dropout probability in the encoder. float 0.5
model_params.enc_dim Dimension of features extracted by the encoder. int 256
model_params.nb_layers Number of layers in the encoder. int 5
model_params.h_max Maximum height for encoder output (for 2D positional embedding). int 500
model_params.w_max Maximum width for encoder output (for 2D positional embedding). int 1000
model_params.l_max Maximum predicted sequence length (for 1D positional embedding). int 15000
model_params.dec_num_layers Number of transformer decoder layers. int 8
model_params.dec_num_heads Number of heads in transformer decoder layers. int 4
model_params.dec_res_dropout Dropout probability in transformer decoder layers. int 0.1
model_params.dec_pred_dropout Dropout rate before decision layer. float 0.1
model_params.dec_att_dropout Dropout rate in multi head attention. float 0.1
model_params.dec_dim_feedforward Number of dimensions for feedforward layer in transformer decoder layers. int 256
model_params.attention_win Length of attention window. int 100
model_params.dropout_scheduler.function Curriculum dropout scheduler. custom class exponential_dropout_scheduler
model_params.dropout_scheduler.T Exponential factor. float 5e4

Training parameters

Name Description Type Default
training_params.output_folder Directory for checkpoint and results. str
training_params.max_nb_epochs Maximum number of epochs before stopping training. int 800
training_params.load_epoch Model to load. Should be either "best" (evaluation) or last (training). str "last"
training_params.batch_size Mini-batch size for the training loop. int 2
training_params.use_ddp Whether to use DistributedDataParallel. bool False
training_params.ddp_port DDP port. int 20027
training_params.use_amp Whether to enable automatic mix-precision. bool True
training_params.nb_gpu Number of GPUs to train DAN. int torch.cuda.device_count()
training_params.optimizers.all.class Optimizer class. custom class Adam
training_params.optimizers.all.args.lr Learning rate for the optimizer. float 0.0001
training_params.optimizers.all.args.amsgrad Whether to use AMSGrad optimization. bool False
training_params.lr_schedulers Learning rate schedulers. custom class None
training_params.eval_on_valid Whether to evaluate and log metrics on the validation set during training. bool True
training_params.eval_on_valid_interval Interval (in epochs) to evaluate during training. int 5
training_params.set_name_focus_metric Dataset to focus on to select best weights. str
training_params.train_metrics List of metrics to compute during training. list ["loss_ce", "cer", "wer", "wer_no_punct"]
training_params.eval_metrics List of metrics to compute during validation. list ["cer", "wer", "wer_no_punct"]
training_params.force_cpu Whether to train on CPU (for debugging). bool False
training_params.max_char_prediction Maximum number of characters to predict. int 1000
training_params.label_noise_scheduler.min_error_rate Minimum ratio of teacher forcing. float 0.2
training_params.label_noise_scheduler.max_error_rate Maximum ratio of teacher forcing. float 0.2
training_params.label_noise_scheduler.total_num_steps Number of steps before stopping teacher forcing. float 5e4

During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations.

MLFlow logging

To log your experiment on MLFlow, you need to:

  • install the extra requirements via
$ pip install .[mlflow]
  • update the following arguments:
Name Description Type Default
mlflow.run_id Name of the current run in MLflow. str
mlflow.s3_endpoint_url URL of S3 endpoint. str
mlflow.aws_access_key_id Access key id to the AWS server. str
mlflow.aws_secret_access_key Secret access key to the AWS server. str