From f203b58794fb37f68b395401f6b75ed15520c40a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Tue, 26 Sep 2023 10:32:08 +0200
Subject: [PATCH] Simplify and document data extraction

---
 dan/datasets/extract/extract.py      |  6 +++---
 dan/ocr/predict/prediction.py        | 10 +++++-----
 docs/get_started/training.md         |  6 +++---
 tests/data/prediction/parameters.yml | 10 +++++-----
 tests/test_extract.py                | 15 +++++++++------
 tests/test_prediction.py             |  4 ++++
 6 files changed, 29 insertions(+), 22 deletions(-)

diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index 59ec9139..58de1e45 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -393,13 +393,13 @@ class ArkindexExtractor:
                 indent=4,
             )
         )
-        (self.output / "language_corpus.txt").write_text(
+        (self.output / "language_model" / "corpus.txt").write_text(
             "\n".join(self.language_corpus)
         )
-        (self.output / "language_tokens.txt").write_text(
+        (self.output / "language_model" / "tokens.txt").write_text(
             "\n".join(self.language_tokens)
         )
-        (self.output / "language_lexicon.txt").write_text(
+        (self.output / "language_model" / "lexicon.txt").write_text(
             "\n".join(self.language_lexicon)
         )
         (self.output / "charset.pkl").write_bytes(
diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py
index 347747fc..22badd0f 100644
--- a/dan/ocr/predict/prediction.py
+++ b/dan/ocr/predict/prediction.py
@@ -92,12 +92,12 @@ class DAN:
         self.decoder = decoder
         self.lm_decoder = None
 
-        if use_language_model and parameters["lm_decoder"]["language_model_weight"] > 0:
+        if use_language_model and parameters["language_model"]["weight"] > 0:
             self.lm_decoder = CTCLanguageDecoder(
-                language_model_path=parameters["lm_decoder"]["language_model_path"],
-                lexicon_path=parameters["lm_decoder"]["lexicon_path"],
-                tokens_path=parameters["lm_decoder"]["tokens_path"],
-                language_model_weight=parameters["lm_decoder"]["language_model_weight"],
+                language_model_path=parameters["language_model"]["model"],
+                lexicon_path=parameters["language_model"]["lexicon"],
+                tokens_path=parameters["language_model"]["tokens"],
+                language_model_weight=parameters["language_model"]["weight"],
             )
 
         self.mean, self.std = (
diff --git a/docs/get_started/training.md b/docs/get_started/training.md
index 1cb5a0cf..03773c50 100644
--- a/docs/get_started/training.md
+++ b/docs/get_started/training.md
@@ -21,9 +21,9 @@ output/
 │   ├── val
 │   └── test
 ├── language_model
-│   ├── corpus.txt
-│   ├── lexicon.txt
-│   └── tokens.txt
+│   ├── language_corpus.txt
+│   ├── language_lexicon.txt
+│   └── language_tokens.txt
 ```
 
 ## 2. Train
diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml
index 9f41f5cf..88a43d9c 100644
--- a/tests/data/prediction/parameters.yml
+++ b/tests/data/prediction/parameters.yml
@@ -22,8 +22,8 @@ parameters:
     - type: "max_resize"
       max_height: 1500
       max_width: 1500
-  lm_decoder:
-    language_model_path: tests/data/prediction/language_model.arpa
-    lexicon_path: tests/data/prediction/language_lexicon.txt
-    tokens_path: tests/data/prediction/language_tokens.txt
-    language_model_weight: 1.0
+  language_model:
+    model: tests/data/prediction/language_model.arpa
+    lexicon: tests/data/prediction/language_lexicon.txt
+    tokens: tests/data/prediction/language_tokens.txt
+    weight: 1.0
diff --git a/tests/test_extract.py b/tests/test_extract.py
index fb9fb252..b3186b71 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -395,9 +395,10 @@ def test_extract(
         VAL_DIR / "val-page_1-line_2.jpg",
         VAL_DIR / "val-page_1-line_3.jpg",
         output / "labels.json",
-        output / "language_corpus.txt",
-        output / "language_lexicon.txt",
-        output / "language_tokens.txt",
+        # Language resources
+        output / "language_model" / "corpus.txt",
+        output / "language_model" / "lexicon.txt",
+        output / "language_model" / "tokens.txt",
     ]
     assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
 
@@ -486,20 +487,22 @@ def test_extract(
             "⎵", expected_language_corpus
         )
 
-    assert (output / "language_corpus.txt").read_text() == expected_language_corpus
+    assert (
+        output / "language_model" / "corpus.txt"
+    ).read_text() == expected_language_corpus
 
     # Check "language_tokens.txt"
     expected_language_tokens = [
         t if t != " " else "⎵" for t in sorted(list(expected_charset))
     ]
     expected_language_tokens.append("◌")
-    assert (output / "language_tokens.txt").read_text() == "\n".join(
+    assert (output / "language_model" / "tokens.txt").read_text() == "\n".join(
         expected_language_tokens
     )
 
     # Check "language_lexicon.txt"
     expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens]
-    assert (output / "language_lexicon.txt").read_text() == "\n".join(
+    assert (output / "language_model" / "lexicon.txt").read_text() == "\n".join(
         expected_language_lexicon
     )
 
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index b65c7494..994e27b1 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -700,11 +700,15 @@ def test_run_prediction_language_model(
 
     # Update language_model_weight in parameters.yml
     params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
+<<<<<<< HEAD
 <<<<<<< HEAD
     params["parameters"]["language_model"]["weight"] = language_model_weight
 =======
     params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight
 >>>>>>> c80c413 (Write tests for LM decoding)
+=======
+    params["parameters"]["language_model"]["weight"] = language_model_weight
+>>>>>>> 57684ef (Simplify and document data extraction)
     yaml.dump(params, (tmp_path / "parameters.yml").open("w"))
 
     run_prediction(
-- 
GitLab