Skip to content
Snippets Groups Projects
conftest.py 5.88 KiB
# -*- coding: utf-8 -*-
import os
from pathlib import Path

import pytest
from arkindex_export import open_database
from torch.optim import Adam

from dan.decoder import GlobalHTADecoder
from dan.encoder import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import Preprocessing

FIXTURES = Path(__file__).resolve().parent / "data"


@pytest.fixture(autouse=True)
def setup_environment(responses):
    """Setup needed environment variables"""

    # Allow accessing remote API schemas
    # defaulting to the prod environment
    schema_url = os.environ.get(
        "ARKINDEX_API_SCHEMA_URL",
        "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json",
    )
    responses.add_passthru(schema_url)

    # Set schema url in environment
    os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url


@pytest.fixture
def database_path():
    return FIXTURES / "export.sqlite"


@pytest.fixture(autouse=True)
def demo_db(database_path):
    """
    Open connection towards a known demo database
    """
    open_database(database_path)


@pytest.fixture
def training_config():
    return {
        "dataset_params": {
            "datasets": {
                "training": "./tests/data/training/training_dataset",
            },
            "train": {
                "name": "training-train",
                "datasets": [
                    ("training", "train"),
                ],
            },
            "val": {
                "training-val": [
                    ("training", "val"),
                ],
            },
            "test": {
                "training-test": [
                    ("training", "test"),
                ],
            },
            "config": {
                "load_in_memory": True,  # Load all images in CPU memory
                "preprocessings": [
                    {
                        "type": Preprocessing.MaxResize,
                        "max_width": 2000,
                        "max_height": 2000,
                    },
                ],
                "augmentation": True,
            },
        },
        "model_params": {
            "models": {
                "encoder": FCN_Encoder,
                "decoder": GlobalHTADecoder,
            },
            "transfer_learning": None,
            "transfered_charset": True,  # Transfer learning of the decision layer based on charset of the line HTR model
            "additional_tokens": 1,  # for decision layer = [<eot>, ], only for transferred charset
            "input_channels": 3,  # number of channels of input image
            "dropout": 0.5,  # dropout rate for encoder
            "enc_dim": 256,  # dimension of extracted features
            "nb_layers": 5,  # encoder
            "h_max": 500,  # maximum height for encoder output (for 2D positional embedding)
            "w_max": 1000,  # maximum width for encoder output (for 2D positional embedding)
            "l_max": 15000,  # max predicted sequence (for 1D positional embedding)
            "dec_num_layers": 8,  # number of transformer decoder layers
            "dec_num_heads": 4,  # number of heads in transformer decoder layers
            "dec_res_dropout": 0.1,  # dropout in transformer decoder layers
            "dec_pred_dropout": 0.1,  # dropout rate before decision layer
            "dec_att_dropout": 0.1,  # dropout rate in multi head attention
            "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
            "attention_win": 100,  # length of attention window
            # Curriculum dropout
            "dropout_scheduler": {
                "function": exponential_dropout_scheduler,
                "T": 5e4,
            },
        },
        "training_params": {
            "output_folder": "dan_trained_model",  # folder name for checkpoint and results
            "max_nb_epochs": 4,  # maximum number of epochs before to stop
            "max_training_time": 1200,  # maximum time before to stop (in seconds)
            "load_epoch": "last",  # ["best", "last"]: last to continue training, best to evaluate
            "interval_save_weights": None,  # None: keep best and last only
            "batch_size": 2,  # mini-batch size for training
            "use_ddp": False,  # Use DistributedDataParallel
            "nb_gpu": 0,
            "optimizers": {
                "all": {
                    "class": Adam,
                    "args": {
                        "lr": 0.0001,
                        "amsgrad": False,
                    },
                },
            },
            "lr_schedulers": None,  # Learning rate schedulers
            "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
            "eval_on_valid_interval": 2,  # Interval (in epochs) to evaluate during training
            "focus_metric": "cer",  # Metrics to focus on to determine best epoch
            "expected_metric_value": "low",  # ["high", "low"] What is best for the focus metric value
            "set_name_focus_metric": "training-val",  # Which dataset to focus on to select best weights
            "train_metrics": [
                "loss_ce",
                "cer",
                "wer",
                "wer_no_punct",
            ],  # Metrics name for training
            "eval_metrics": [
                "cer",
                "wer",
                "wer_no_punct",
            ],  # Metrics name for evaluation on validation set during training
            "force_cpu": True,  # True for debug purposes
            "max_char_prediction": 30,  # max number of token prediction
            # Keep teacher forcing rate to 20% during whole training
            "label_noise_scheduler": {
                "min_error_rate": 0.2,
                "max_error_rate": 0.2,
                "total_num_steps": 5e4,
            },
        },
    }


@pytest.fixture
def prediction_data_path():
    return FIXTURES / "prediction"