diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py index 3091e8e4929d0e562a6c5a35a680ac2973228d37..347747fc46d1e8ec8fb5ee6315194102c442bce5 100644 --- a/dan/ocr/predict/prediction.py +++ b/dan/ocr/predict/prediction.py @@ -92,7 +92,7 @@ class DAN: self.decoder = decoder self.lm_decoder = None - if use_language_model: + if use_language_model and parameters["lm_decoder"]["language_model_weight"] > 0: self.lm_decoder = CTCLanguageDecoder( language_model_path=parameters["lm_decoder"]["language_model_path"], lexicon_path=parameters["lm_decoder"]["lexicon_path"], @@ -479,6 +479,9 @@ def run( ) batch_size = 1 if use_language_model else batch_size + # Do not use LM with invalid LM weight + use_language_model = dan_model.lm_decoder is not None + images = image_dir.rglob(f"*{image_extension}") if not image else [image] for image_batch in list_to_batches(images, n=batch_size): process_batch( diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml index 88a43d9c171ee483e734f01b0d99250f10c5d5ec..9f41f5cfc056a2211c21df4fdd47a2697afe32cd 100644 --- a/tests/data/prediction/parameters.yml +++ b/tests/data/prediction/parameters.yml @@ -22,8 +22,8 @@ parameters: - type: "max_resize" max_height: 1500 max_width: 1500 - language_model: - model: tests/data/prediction/language_model.arpa - lexicon: tests/data/prediction/language_lexicon.txt - tokens: tests/data/prediction/language_tokens.txt - weight: 1.0 + lm_decoder: + language_model_path: tests/data/prediction/language_model.arpa + lexicon_path: tests/data/prediction/language_lexicon.txt + tokens_path: tests/data/prediction/language_tokens.txt + language_model_weight: 1.0 diff --git a/tests/test_prediction.py b/tests/test_prediction.py index cc61d2ecf43a271740516f87a0c5bb3e719f74a1..96c2f778ee93264fa67501dd40e38282acf8b77f 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -506,19 +506,28 @@ def test_run_prediction_batch( { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "language_model": { +<<<<<<< HEAD "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "confidence": 0.92, +======= + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241" +>>>>>>> c80c413 (Write tests for LM decoding) }, }, { "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "language_model": { +<<<<<<< HEAD "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "confidence": 0.88, +======= + "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376" +>>>>>>> c80c413 (Write tests for LM decoding) }, }, { "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", +<<<<<<< HEAD "language_model": { "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", "confidence": 0.86, @@ -530,6 +539,13 @@ def test_run_prediction_batch( "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", "confidence": 0.89, }, +======= + "language_model": {"text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…1431"}, + }, + { + "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", + "language_model": {"text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère"}, +>>>>>>> c80c413 (Write tests for LM decoding) }, ], ), @@ -545,19 +561,28 @@ def test_run_prediction_batch( { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "language_model": { +<<<<<<< HEAD "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "confidence": 0.90, +======= + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241" +>>>>>>> c80c413 (Write tests for LM decoding) }, }, { "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "language_model": { +<<<<<<< HEAD "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "confidence": 0.84, +======= + "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376" +>>>>>>> c80c413 (Write tests for LM decoding) }, }, { "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", +<<<<<<< HEAD "language_model": { "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14331", "confidence": 0.83, @@ -569,6 +594,13 @@ def test_run_prediction_batch( "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", "confidence": 0.86, }, +======= + "language_model": {"text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14331"}, + }, + { + "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", + "language_model": {"text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère"}, +>>>>>>> c80c413 (Write tests for LM decoding) }, ], ), @@ -589,10 +621,18 @@ def test_run_prediction_batch( ), ), ) +<<<<<<< HEAD +======= +@pytest.mark.parametrize("batch_size", [1, 2]) +>>>>>>> c80c413 (Write tests for LM decoding) def test_run_prediction_language_model( image_names, language_model_weight, expected_predictions, +<<<<<<< HEAD +======= + batch_size, +>>>>>>> c80c413 (Write tests for LM decoding) tmp_path, ): # Make tmpdir and copy needed images inside @@ -606,7 +646,11 @@ def test_run_prediction_language_model( # Update language_model_weight in parameters.yml params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml") +<<<<<<< HEAD params["parameters"]["language_model"]["weight"] = language_model_weight +======= + params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight +>>>>>>> c80c413 (Write tests for LM decoding) yaml.dump(params, (tmp_path / "parameters.yml").open("w")) run_prediction( @@ -630,7 +674,11 @@ def test_run_prediction_language_model( max_object_height=None, image_extension=".png", gpu_device=None, +<<<<<<< HEAD batch_size=1, +======= + batch_size=batch_size, +>>>>>>> c80c413 (Write tests for LM decoding) tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"), start_token=None, use_language_model=True, @@ -647,7 +695,10 @@ def test_run_prediction_language_model( prediction["language_model"]["text"] == expected_prediction["language_model"]["text"] ) +<<<<<<< HEAD assert np.isclose( prediction["language_model"]["confidence"], expected_prediction["language_model"]["confidence"], ) +======= +>>>>>>> c80c413 (Write tests for LM decoding)