From f203b58794fb37f68b395401f6b75ed15520c40a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Tue, 26 Sep 2023 10:32:08 +0200 Subject: [PATCH] Simplify and document data extraction --- dan/datasets/extract/extract.py | 6 +++--- dan/ocr/predict/prediction.py | 10 +++++----- docs/get_started/training.md | 6 +++--- tests/data/prediction/parameters.yml | 10 +++++----- tests/test_extract.py | 15 +++++++++------ tests/test_prediction.py | 4 ++++ 6 files changed, 29 insertions(+), 22 deletions(-) diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index 59ec9139..58de1e45 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -393,13 +393,13 @@ class ArkindexExtractor: indent=4, ) ) - (self.output / "language_corpus.txt").write_text( + (self.output / "language_model" / "corpus.txt").write_text( "\n".join(self.language_corpus) ) - (self.output / "language_tokens.txt").write_text( + (self.output / "language_model" / "tokens.txt").write_text( "\n".join(self.language_tokens) ) - (self.output / "language_lexicon.txt").write_text( + (self.output / "language_model" / "lexicon.txt").write_text( "\n".join(self.language_lexicon) ) (self.output / "charset.pkl").write_bytes( diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py index 347747fc..22badd0f 100644 --- a/dan/ocr/predict/prediction.py +++ b/dan/ocr/predict/prediction.py @@ -92,12 +92,12 @@ class DAN: self.decoder = decoder self.lm_decoder = None - if use_language_model and parameters["lm_decoder"]["language_model_weight"] > 0: + if use_language_model and parameters["language_model"]["weight"] > 0: self.lm_decoder = CTCLanguageDecoder( - language_model_path=parameters["lm_decoder"]["language_model_path"], - lexicon_path=parameters["lm_decoder"]["lexicon_path"], - tokens_path=parameters["lm_decoder"]["tokens_path"], - language_model_weight=parameters["lm_decoder"]["language_model_weight"], + language_model_path=parameters["language_model"]["model"], + lexicon_path=parameters["language_model"]["lexicon"], + tokens_path=parameters["language_model"]["tokens"], + language_model_weight=parameters["language_model"]["weight"], ) self.mean, self.std = ( diff --git a/docs/get_started/training.md b/docs/get_started/training.md index 1cb5a0cf..03773c50 100644 --- a/docs/get_started/training.md +++ b/docs/get_started/training.md @@ -21,9 +21,9 @@ output/ │ ├── val │ └── test ├── language_model -│ ├── corpus.txt -│ ├── lexicon.txt -│ └── tokens.txt +│ ├── language_corpus.txt +│ ├── language_lexicon.txt +│ └── language_tokens.txt ``` ## 2. Train diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml index 9f41f5cf..88a43d9c 100644 --- a/tests/data/prediction/parameters.yml +++ b/tests/data/prediction/parameters.yml @@ -22,8 +22,8 @@ parameters: - type: "max_resize" max_height: 1500 max_width: 1500 - lm_decoder: - language_model_path: tests/data/prediction/language_model.arpa - lexicon_path: tests/data/prediction/language_lexicon.txt - tokens_path: tests/data/prediction/language_tokens.txt - language_model_weight: 1.0 + language_model: + model: tests/data/prediction/language_model.arpa + lexicon: tests/data/prediction/language_lexicon.txt + tokens: tests/data/prediction/language_tokens.txt + weight: 1.0 diff --git a/tests/test_extract.py b/tests/test_extract.py index fb9fb252..b3186b71 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -395,9 +395,10 @@ def test_extract( VAL_DIR / "val-page_1-line_2.jpg", VAL_DIR / "val-page_1-line_3.jpg", output / "labels.json", - output / "language_corpus.txt", - output / "language_lexicon.txt", - output / "language_tokens.txt", + # Language resources + output / "language_model" / "corpus.txt", + output / "language_model" / "lexicon.txt", + output / "language_model" / "tokens.txt", ] assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths @@ -486,20 +487,22 @@ def test_extract( "⎵", expected_language_corpus ) - assert (output / "language_corpus.txt").read_text() == expected_language_corpus + assert ( + output / "language_model" / "corpus.txt" + ).read_text() == expected_language_corpus # Check "language_tokens.txt" expected_language_tokens = [ t if t != " " else "⎵" for t in sorted(list(expected_charset)) ] expected_language_tokens.append("◌") - assert (output / "language_tokens.txt").read_text() == "\n".join( + assert (output / "language_model" / "tokens.txt").read_text() == "\n".join( expected_language_tokens ) # Check "language_lexicon.txt" expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens] - assert (output / "language_lexicon.txt").read_text() == "\n".join( + assert (output / "language_model" / "lexicon.txt").read_text() == "\n".join( expected_language_lexicon ) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index b65c7494..994e27b1 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -700,11 +700,15 @@ def test_run_prediction_language_model( # Update language_model_weight in parameters.yml params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml") +<<<<<<< HEAD <<<<<<< HEAD params["parameters"]["language_model"]["weight"] = language_model_weight ======= params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight >>>>>>> c80c413 (Write tests for LM decoding) +======= + params["parameters"]["language_model"]["weight"] = language_model_weight +>>>>>>> 57684ef (Simplify and document data extraction) yaml.dump(params, (tmp_path / "parameters.yml").open("w")) run_prediction( -- GitLab