diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index 59ec9139627cc8cf20c01602e25bba900918158f..58de1e456feb2651967f634ad8b96627df0cd1ee 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -393,13 +393,13 @@ class ArkindexExtractor: indent=4, ) ) - (self.output / "language_corpus.txt").write_text( + (self.output / "language_model" / "corpus.txt").write_text( "\n".join(self.language_corpus) ) - (self.output / "language_tokens.txt").write_text( + (self.output / "language_model" / "tokens.txt").write_text( "\n".join(self.language_tokens) ) - (self.output / "language_lexicon.txt").write_text( + (self.output / "language_model" / "lexicon.txt").write_text( "\n".join(self.language_lexicon) ) (self.output / "charset.pkl").write_bytes( diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py index 347747fc46d1e8ec8fb5ee6315194102c442bce5..22badd0f1dee2aefc71df772e824350bf05d89fa 100644 --- a/dan/ocr/predict/prediction.py +++ b/dan/ocr/predict/prediction.py @@ -92,12 +92,12 @@ class DAN: self.decoder = decoder self.lm_decoder = None - if use_language_model and parameters["lm_decoder"]["language_model_weight"] > 0: + if use_language_model and parameters["language_model"]["weight"] > 0: self.lm_decoder = CTCLanguageDecoder( - language_model_path=parameters["lm_decoder"]["language_model_path"], - lexicon_path=parameters["lm_decoder"]["lexicon_path"], - tokens_path=parameters["lm_decoder"]["tokens_path"], - language_model_weight=parameters["lm_decoder"]["language_model_weight"], + language_model_path=parameters["language_model"]["model"], + lexicon_path=parameters["language_model"]["lexicon"], + tokens_path=parameters["language_model"]["tokens"], + language_model_weight=parameters["language_model"]["weight"], ) self.mean, self.std = ( diff --git a/docs/get_started/training.md b/docs/get_started/training.md index 1cb5a0cffb55d6aaf9c9e2fee9fc65e433add1af..03773c50db6d9ccff6a0abdfa3167d47f201f373 100644 --- a/docs/get_started/training.md +++ b/docs/get_started/training.md @@ -21,9 +21,9 @@ output/ │ ├── val │ └── test ├── language_model -│ ├── corpus.txt -│ ├── lexicon.txt -│ └── tokens.txt +│ ├── language_corpus.txt +│ ├── language_lexicon.txt +│ └── language_tokens.txt ``` ## 2. Train diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml index 9f41f5cfc056a2211c21df4fdd47a2697afe32cd..88a43d9c171ee483e734f01b0d99250f10c5d5ec 100644 --- a/tests/data/prediction/parameters.yml +++ b/tests/data/prediction/parameters.yml @@ -22,8 +22,8 @@ parameters: - type: "max_resize" max_height: 1500 max_width: 1500 - lm_decoder: - language_model_path: tests/data/prediction/language_model.arpa - lexicon_path: tests/data/prediction/language_lexicon.txt - tokens_path: tests/data/prediction/language_tokens.txt - language_model_weight: 1.0 + language_model: + model: tests/data/prediction/language_model.arpa + lexicon: tests/data/prediction/language_lexicon.txt + tokens: tests/data/prediction/language_tokens.txt + weight: 1.0 diff --git a/tests/test_extract.py b/tests/test_extract.py index fb9fb2521172da3f2a6732be60cae2e023b12a8d..b3186b71b9a1adfa169f1d84d83ffbeaba26a093 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -395,9 +395,10 @@ def test_extract( VAL_DIR / "val-page_1-line_2.jpg", VAL_DIR / "val-page_1-line_3.jpg", output / "labels.json", - output / "language_corpus.txt", - output / "language_lexicon.txt", - output / "language_tokens.txt", + # Language resources + output / "language_model" / "corpus.txt", + output / "language_model" / "lexicon.txt", + output / "language_model" / "tokens.txt", ] assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths @@ -486,20 +487,22 @@ def test_extract( "⎵", expected_language_corpus ) - assert (output / "language_corpus.txt").read_text() == expected_language_corpus + assert ( + output / "language_model" / "corpus.txt" + ).read_text() == expected_language_corpus # Check "language_tokens.txt" expected_language_tokens = [ t if t != " " else "⎵" for t in sorted(list(expected_charset)) ] expected_language_tokens.append("◌") - assert (output / "language_tokens.txt").read_text() == "\n".join( + assert (output / "language_model" / "tokens.txt").read_text() == "\n".join( expected_language_tokens ) # Check "language_lexicon.txt" expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens] - assert (output / "language_lexicon.txt").read_text() == "\n".join( + assert (output / "language_model" / "lexicon.txt").read_text() == "\n".join( expected_language_lexicon ) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index b65c749416d1676b9fd5cd20a7cbcaf872e376bc..994e27b1e808092f205031032fea01e86510702e 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -700,11 +700,15 @@ def test_run_prediction_language_model( # Update language_model_weight in parameters.yml params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml") +<<<<<<< HEAD <<<<<<< HEAD params["parameters"]["language_model"]["weight"] = language_model_weight ======= params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight >>>>>>> c80c413 (Write tests for LM decoding) +======= + params["parameters"]["language_model"]["weight"] = language_model_weight +>>>>>>> 57684ef (Simplify and document data extraction) yaml.dump(params, (tmp_path / "parameters.yml").open("w")) run_prediction(