From 4fe27870afbe20162036606a98d1ba0b2f00d70c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Thu, 3 Aug 2023 08:31:40 +0200
Subject: [PATCH] Apply f0a6e38c

---
 dan/ocr/document/train_popp.py | 317 +++++++++++++++++++++++++++++++++
 1 file changed, 317 insertions(+)
 create mode 100644 dan/ocr/document/train_popp.py

diff --git a/dan/ocr/document/train_popp.py b/dan/ocr/document/train_popp.py
new file mode 100644
index 00000000..6237748a
--- /dev/null
+++ b/dan/ocr/document/train_popp.py
@@ -0,0 +1,317 @@
+# -*- coding: utf-8 -*-
+import json
+import logging
+import random
+from copy import deepcopy
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.multiprocessing as mp
+from torch.optim import Adam
+
+from dan.decoder import GlobalHTADecoder
+from dan.encoder import FCN_Encoder
+from dan.manager.training import Manager
+from dan.mlflow import MLFLOW_AVAILABLE
+from dan.schedulers import exponential_dropout_scheduler
+from dan.transforms import Preprocessing
+from dan.utils import MLflowNotInstalled
+
+if MLFLOW_AVAILABLE:
+    import mlflow
+
+    from dan.mlflow import make_mlflow_request, start_mlflow_run
+
+
+logger = logging.getLogger(__name__)
+
+
+def train_and_test(rank, params, mlflow_logging=False):
+    torch.manual_seed(0)
+    torch.cuda.manual_seed(0)
+    np.random.seed(0)
+    random.seed(0)
+    torch.backends.cudnn.benchmark = False
+    torch.backends.cudnn.deterministic = True
+
+    params["training_params"]["ddp_rank"] = rank
+    model = Manager(params)
+    model.load_model()
+
+    if mlflow_logging:
+        logger.info("MLflow logging enabled")
+
+    model.train(mlflow_logging=mlflow_logging)
+
+    # load weights giving best CER on valid set
+    model.params["training_params"]["load_epoch"] = "best"
+    model.load_model()
+
+    metrics = ["cer", "wer", "wer_no_punct", "time"]
+    for dataset_name in params["dataset_params"]["datasets"].keys():
+        for set_name in ["test", "val", "train"]:
+            model.predict(
+                "{}-{}".format(dataset_name, set_name),
+                [
+                    (dataset_name, set_name),
+                ],
+                metrics,
+                output=True,
+                mlflow_logging=mlflow_logging,
+            )
+
+
+def get_config():
+    """
+    Retrieve model configuration
+    """
+    dataset_name = "data/popp"
+    dataset_level = "page"
+    dataset_variant = ""
+    dataset_path = "."
+    params = {
+        # "mlflow": {
+        #     "run_name": "Test log DAN",
+        #     "run_id": None,
+        #     "s3_endpoint_url": "",
+        #     "tracking_uri": "",
+        #     "experiment_id": "0",
+        #     "aws_access_key_id": "",
+        #     "aws_secret_access_key": "",
+        # },
+        "dataset_params": {
+            "datasets": {
+                dataset_name: "{}/{}_{}{}".format(
+                    dataset_path, dataset_name, dataset_level, dataset_variant
+                ),
+            },
+            "train": {
+                "name": "{}-train".format(dataset_name),
+                "datasets": [
+                    (dataset_name, "train"),
+                ],
+            },
+            "val": {
+                "{}-val".format(dataset_name): [
+                    (dataset_name, "val"),
+                ],
+            },
+            "test": {
+                "{}-test".format(dataset_name): [
+                    (dataset_name, "test"),
+                ],
+            },
+            "config": {
+                "load_in_memory": True,  # Load all images in CPU memory
+                "worker_per_gpu": 4,  # Num of parallel processes per gpu for data loading
+                "preprocessings": [
+                    {
+                        "type": Preprocessing.MaxResize,
+                        "max_width": 2000,
+                        "max_height": 2000,
+                    }
+                ],
+                "augmentation": True,
+            },
+        },
+        "model_params": {
+            "models": {
+                "encoder": FCN_Encoder,
+                "decoder": GlobalHTADecoder,
+            },
+            # "transfer_learning": None,
+            "transfer_learning": {
+                # model_name: [state_dict_name, checkpoint_path, learnable, strict]
+                "encoder": [
+                    "encoder",
+                    "pretrained-models/popp_sp.pt",
+                    True,
+                    True,
+                ],
+                "decoder": [
+                    "decoder",
+                    "pretrained-models/popp_sp.pt",
+                    True,
+                    False,
+                ],
+            },
+            "transfered_charset": True,  # Transfer learning of the decision layer based on charset of the line HTR model
+            "additional_tokens": 1,  # for decision layer = [<eot>, ], only for transferred charset
+            "input_channels": 3,  # number of channels of input image
+            "dropout": 0.5,  # dropout rate for encoder
+            "enc_dim": 256,  # dimension of extracted features
+            "nb_layers": 5,  # encoder
+            "h_max": 500,  # maximum height for encoder output (for 2D positional embedding)
+            "w_max": 1000,  # maximum width for encoder output (for 2D positional embedding)
+            "l_max": 15000,  # max predicted sequence (for 1D positional embedding)
+            "dec_num_layers": 8,  # number of transformer decoder layers
+            "dec_num_heads": 4,  # number of heads in transformer decoder layers
+            "dec_res_dropout": 0.1,  # dropout in transformer decoder layers
+            "dec_pred_dropout": 0.1,  # dropout rate before decision layer
+            "dec_att_dropout": 0.1,  # dropout rate in multi head attention
+            "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
+            "use_2d_pe": True,  # use 2D positional embedding
+            "use_1d_pe": True,  # use 1D positional embedding
+            "use_lstm": False,
+            "attention_win": 100,  # length of attention window
+            # Curriculum dropout
+            "dropout_scheduler": {
+                "function": exponential_dropout_scheduler,
+                "T": 5e4,
+            },
+        },
+        "training_params": {
+            "output_folder": "outputs/dan_esposalles_record",  # folder name for checkpoint and results
+            "max_nb_epochs": 2,  # maximum number of epochs before to stop
+            "max_training_time": 3600
+            * 24
+            * 1.9,  # maximum time before to stop (in seconds)
+            "load_epoch": "last",  # ["best", "last"]: last to continue training, best to evaluate
+            "interval_save_weights": None,  # None: keep best and last only
+            "batch_size": 1,  # mini-batch size for training
+            "valid_batch_size": 1,  # mini-batch size for valdiation
+            "use_ddp": False,  # Use DistributedDataParallel
+            "ddp_port": "20027",
+            "use_amp": True,  # Enable automatic mix-precision
+            "nb_gpu": torch.cuda.device_count(),
+            "optimizers": {
+                "all": {
+                    "class": Adam,
+                    "args": {
+                        "lr": 0.0001,
+                        "amsgrad": False,
+                    },
+                },
+            },
+            "lr_schedulers": None,  # Learning rate schedulers
+            "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
+            "eval_on_valid_interval": 5,  # Interval (in epochs) to evaluate during training
+            "focus_metric": "cer",  # Metrics to focus on to determine best epoch
+            "expected_metric_value": "low",  # ["high", "low"] What is best for the focus metric value
+            "set_name_focus_metric": "{}-val".format(
+                dataset_name
+            ),  # Which dataset to focus on to select best weights
+            "train_metrics": [
+                "loss_ce",
+                "cer",
+                "wer",
+                "wer_no_punct",
+            ],  # Metrics name for training
+            "eval_metrics": [
+                "cer",
+                "wer",
+                "wer_no_punct",
+            ],  # Metrics name for evaluation on validation set during training
+            "force_cpu": True,  # True for debug purposes
+            "max_char_prediction": 10,  # max number of token prediction
+            # Keep teacher forcing rate to 20% during whole training
+            "label_noise_scheduler": {
+                "min_error_rate": 0.2,
+                "max_error_rate": 0.2,
+                "total_num_steps": 5e4,
+            },
+        },
+    }
+
+    return params, dataset_name
+
+
+def serialize_config(config):
+    """
+    Make every field of the configuration JSON-Serializable and remove sensitive information.
+
+    - Classes are transformed using their name attribute
+    - Functions are casted to strings
+    """
+    # Create a copy of the original config without erase it
+    serialized_config = deepcopy(config)
+
+    # Remove credentials to the config
+    serialized_config["mlflow"]["s3_endpoint_url"] = ""
+    serialized_config["mlflow"]["tracking_uri"] = ""
+    serialized_config["mlflow"]["aws_access_key_id"] = ""
+    serialized_config["mlflow"]["aws_secret_access_key"] = ""
+
+    # Get the name of the class
+    serialized_config["model_params"]["models"]["encoder"] = serialized_config[
+        "model_params"
+    ]["models"]["encoder"].__name__
+    serialized_config["model_params"]["models"]["decoder"] = serialized_config[
+        "model_params"
+    ]["models"]["decoder"].__name__
+    serialized_config["training_params"]["optimizers"]["all"][
+        "class"
+    ] = serialized_config["training_params"]["optimizers"]["all"]["class"].__name__
+
+    # Cast the functions to str
+    serialized_config["dataset_params"]["config"]["augmentation"] = str(
+        serialized_config["dataset_params"]["config"]["augmentation"]
+    )
+    serialized_config["model_params"]["dropout_scheduler"]["function"] = str(
+        serialized_config["model_params"]["dropout_scheduler"]["function"]
+    )
+    serialized_config["training_params"]["nb_gpu"] = str(
+        serialized_config["training_params"]["nb_gpu"]
+    )
+
+    return serialized_config
+
+
+def start_training(config, mlflow_logging: bool) -> None:
+    if (
+        config["training_params"]["use_ddp"]
+        and not config["training_params"]["force_cpu"]
+    ):
+        mp.spawn(
+            train_and_test,
+            args=(config, mlflow_logging),
+            nprocs=config["training_params"]["nb_gpu"],
+        )
+    else:
+        train_and_test(0, config, mlflow_logging)
+
+
+def run():
+    """
+    Main program, training a new model, using a valid configuration
+    """
+
+    config, dataset_name = get_config()
+
+    if "mlflow" in config and not MLFLOW_AVAILABLE:
+        logger.error(
+            "Cannot log to MLflow. Please install the `mlflow` extra requirements."
+        )
+        raise MLflowNotInstalled()
+
+    if "mlflow" not in config:
+        start_training(config, mlflow_logging=False)
+    else:
+        labels_path = (
+            Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
+        )
+        with start_mlflow_run(config["mlflow"]) as (run, created):
+            if created:
+                logger.info(f"Started MLflow run with ID ({run.info.run_id})")
+            else:
+                logger.info(f"Resumed MLflow run with ID ({run.info.run_id})")
+
+            make_mlflow_request(
+                mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name}
+            )
+            # Get the labels json file
+            with open(labels_path) as json_file:
+                labels_artifact = json.load(json_file)
+
+            # Log MLflow artifacts
+            for artifact, filename in [
+                (serialize_config(config), "config.json"),
+                (labels_artifact, "labels.json"),
+            ]:
+                make_mlflow_request(
+                    mlflow_method=mlflow.log_dict,
+                    dictionary=artifact,
+                    artifact_file=filename,
+                )
+            start_training(config, mlflow_logging=True)
-- 
GitLab