From 8489982e72833cedfb2621bc866ad99a29c86f08 Mon Sep 17 00:00:00 2001
From: Manon Blanco <blanco@teklia.com>
Date: Fri, 22 Dec 2023 13:37:17 +0000
Subject: [PATCH] Update eval data for tests

---
 configs/eval.json                         |  8 +--
 tests/data/evaluate/checkpoints/best_0.pt |  4 +-
 tests/data/evaluate/metrics_table.md      | 10 ++--
 tests/data/prediction/labels.json         | 12 ++++
 tests/test_evaluate.py                    | 72 ++++++++++++-----------
 5 files changed, 62 insertions(+), 44 deletions(-)
 create mode 100644 tests/data/prediction/labels.json

diff --git a/configs/eval.json b/configs/eval.json
index 01072cef..d93920aa 100644
--- a/configs/eval.json
+++ b/configs/eval.json
@@ -1,7 +1,7 @@
 {
     "dataset": {
         "datasets": {
-            "training": "tests/data/training/training_dataset"
+            "training": "tests/data/prediction"
         },
         "train": {
             "name": "training-train",
@@ -19,8 +19,8 @@
                 ["training", "test"]
             ]
         },
-        "max_char_prediction": 30,
-        "tokens": null
+        "max_char_prediction": 200,
+        "tokens": "tests/data/prediction/tokens.yml"
     },
     "model": {
         "transfered_charset": true,
@@ -45,7 +45,7 @@
     },
     "training": {
         "data": {
-            "batch_size": 2,
+            "batch_size": 1,
             "load_in_memory": true,
             "worker_per_gpu": 4,
             "preprocessings": [
diff --git a/tests/data/evaluate/checkpoints/best_0.pt b/tests/data/evaluate/checkpoints/best_0.pt
index 79bcb28a..67b9565a 100644
--- a/tests/data/evaluate/checkpoints/best_0.pt
+++ b/tests/data/evaluate/checkpoints/best_0.pt
@@ -1,3 +1,3 @@
 version https://git-lfs.github.com/spec/v1
-oid sha256:428ceb4d08363c05b6e60e87e5e1ae65560d345756926c23f13e6d191dc33d69
-size 84773087
+oid sha256:072302b3c54aa6e9a3afb06cf45c2a8e97e20d300854bfacd585cba61282e252
+size 84723223
diff --git a/tests/data/evaluate/metrics_table.md b/tests/data/evaluate/metrics_table.md
index 0ff41bfb..d67456d8 100644
--- a/tests/data/evaluate/metrics_table.md
+++ b/tests/data/evaluate/metrics_table.md
@@ -1,5 +1,5 @@
-| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) |
-|:-----:|:-------------:|:---------:|:-------------:|:---------:|:------------------:|
-| train |     130.23    |   130.23  |     100.0     |   100.0   |       100.0        |
-|  val  |     126.83    |   126.83  |     100.0     |   100.0   |       100.0        |
-|  test |     112.24    |   112.24  |     100.0     |   100.0   |       100.0        |
+| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | NER  |
+|:-----:|:-------------:|:---------:|:-------------:|:---------:|:------------------:|:----:|
+| train |     18.89     |   21.05   |     26.67     |   26.67   |       26.67        | 7.14 |
+|  val  |      8.82     |   11.54   |      50.0     |    50.0   |        50.0        | 0.0  |
+|  test |      2.78     |    3.33   |     14.29     |   14.29   |       14.29        | 0.0  |
diff --git a/tests/data/prediction/labels.json b/tests/data/prediction/labels.json
new file mode 100644
index 00000000..6efebc38
--- /dev/null
+++ b/tests/data/prediction/labels.json
@@ -0,0 +1,12 @@
+{
+    "train": {
+        "tests/data/prediction/images/0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241",
+        "tests/data/prediction/images/0dfe8bcd-ed0b-453e-bf19-cc697012296e.png": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle"
+    },
+    "val": {
+        "tests/data/prediction/images/2c242f5c-e979-43c4-b6f2-a6d4815b651d.png": "ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331"
+    },
+    "test": {
+        "tests/data/prediction/images/ffdec445-7f14-4f5f-be44-68d0844d0df1.png": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère"
+    }
+}
diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py
index 06958d64..98fbd938 100644
--- a/tests/test_evaluate.py
+++ b/tests/test_evaluate.py
@@ -52,43 +52,49 @@ def test_add_metrics_table_row():
     (
         (
             {
-                "nb_chars": 43,
-                "cer": 1.3023,
-                "nb_chars_no_token": 43,
-                "cer_no_token": 1.3023,
-                "nb_words": 9,
-                "wer": 1.0,
-                "nb_words_no_punct": 9,
-                "wer_no_punct": 1.0,
-                "nb_words_no_token": 9,
-                "wer_no_token": 1.0,
+                "nb_chars": 90,
+                "cer": 0.1889,
+                "nb_chars_no_token": 76,
+                "cer_no_token": 0.2105,
+                "nb_words": 15,
+                "wer": 0.2667,
+                "nb_words_no_punct": 15,
+                "wer_no_punct": 0.2667,
+                "nb_words_no_token": 15,
+                "wer_no_token": 0.2667,
+                "nb_tokens": 14,
+                "ner": 0.0714,
                 "nb_samples": 2,
             },
             {
-                "nb_chars": 41,
-                "cer": 1.2683,
-                "nb_chars_no_token": 41,
-                "cer_no_token": 1.2683,
-                "nb_words": 9,
-                "wer": 1.0,
-                "nb_words_no_punct": 9,
-                "wer_no_punct": 1.0,
-                "nb_words_no_token": 9,
-                "wer_no_token": 1.0,
-                "nb_samples": 2,
+                "nb_chars": 34,
+                "cer": 0.0882,
+                "nb_chars_no_token": 26,
+                "cer_no_token": 0.1154,
+                "nb_words": 8,
+                "wer": 0.5,
+                "nb_words_no_punct": 8,
+                "wer_no_punct": 0.5,
+                "nb_words_no_token": 8,
+                "wer_no_token": 0.5,
+                "nb_tokens": 8,
+                "ner": 0.0,
+                "nb_samples": 1,
             },
             {
-                "nb_chars": 49,
-                "cer": 1.1224,
-                "nb_chars_no_token": 49,
-                "cer_no_token": 1.1224,
-                "nb_words": 9,
-                "wer": 1.0,
-                "nb_words_no_punct": 9,
-                "wer_no_punct": 1.0,
-                "nb_words_no_token": 9,
-                "wer_no_token": 1.0,
-                "nb_samples": 2,
+                "nb_chars": 36,
+                "cer": 0.0278,
+                "nb_chars_no_token": 30,
+                "cer_no_token": 0.0333,
+                "nb_words": 7,
+                "wer": 0.1429,
+                "nb_words_no_punct": 7,
+                "wer_no_punct": 0.1429,
+                "nb_words_no_token": 7,
+                "wer_no_token": 0.1429,
+                "nb_tokens": 6,
+                "ner": 0.0,
+                "nb_samples": 1,
             },
         ),
     ),
@@ -106,7 +112,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
         filename = (
             evaluate_config["training"]["output_folder"]
             / "results"
-            / f"predict_training-{split_name}_0.yaml"
+            / f"predict_training-{split_name}_1685.yaml"
         )
 
         with filename.open() as f:
-- 
GitLab