diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index 218117b098da5e62d27d952808678ccaa68a2beb..d9f30cf625131f5e4da3ae436ecdc24df6f01f39 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -92,7 +92,7 @@ class DAN: self.decoder = decoder self.lm_decoder = None - if use_language_model: + if use_language_model and parameters["lm_decoder"]["language_model_weight"] > 0: self.lm_decoder = CTCLanguageDecoder( language_model_path=parameters["lm_decoder"]["language_model_path"], lexicon_path=parameters["lm_decoder"]["lexicon_path"], @@ -478,6 +478,9 @@ def run( ) batch_size = 1 if use_language_model else batch_size + # Do not use LM with invalid LM weight + use_language_model = dan_model.lm_decoder is not None + images = image_dir.rglob(f"*{image_extension}") if not image else [image] for image_batch in list_to_batches(images, n=batch_size): process_batch( diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml index 88a43d9c171ee483e734f01b0d99250f10c5d5ec..9f41f5cfc056a2211c21df4fdd47a2697afe32cd 100644 --- a/tests/data/prediction/parameters.yml +++ b/tests/data/prediction/parameters.yml @@ -22,8 +22,8 @@ parameters: - type: "max_resize" max_height: 1500 max_width: 1500 - language_model: - model: tests/data/prediction/language_model.arpa - lexicon: tests/data/prediction/language_lexicon.txt - tokens: tests/data/prediction/language_tokens.txt - weight: 1.0 + lm_decoder: + language_model_path: tests/data/prediction/language_model.arpa + lexicon_path: tests/data/prediction/language_lexicon.txt + tokens_path: tests/data/prediction/language_tokens.txt + language_model_weight: 1.0 diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 30576b1b5136e7459dc5843bbaab172a9173162f..2474473ec683005ab261bd30fd25ab09733ce972 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -545,19 +545,28 @@ def test_run_prediction_batch( { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "language_model": { +<<<<<<< HEAD "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "confidence": 0.92, +======= + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241" +>>>>>>> c80c413 (Write tests for LM decoding) }, }, { "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "language_model": { +<<<<<<< HEAD "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "confidence": 0.88, +======= + "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376" +>>>>>>> c80c413 (Write tests for LM decoding) }, }, { "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", +<<<<<<< HEAD "language_model": { "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", "confidence": 0.86, @@ -569,6 +578,13 @@ def test_run_prediction_batch( "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", "confidence": 0.89, }, +======= + "language_model": {"text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…1431"}, + }, + { + "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", + "language_model": {"text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère"}, +>>>>>>> c80c413 (Write tests for LM decoding) }, ], ), @@ -584,19 +600,28 @@ def test_run_prediction_batch( { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "language_model": { +<<<<<<< HEAD "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "confidence": 0.90, +======= + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241" +>>>>>>> c80c413 (Write tests for LM decoding) }, }, { "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "language_model": { +<<<<<<< HEAD "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "confidence": 0.84, +======= + "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376" +>>>>>>> c80c413 (Write tests for LM decoding) }, }, { "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", +<<<<<<< HEAD "language_model": { "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14331", "confidence": 0.83, @@ -608,6 +633,13 @@ def test_run_prediction_batch( "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", "confidence": 0.86, }, +======= + "language_model": {"text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14331"}, + }, + { + "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", + "language_model": {"text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère"}, +>>>>>>> c80c413 (Write tests for LM decoding) }, ], ), @@ -628,10 +660,18 @@ def test_run_prediction_batch( ), ), ) +<<<<<<< HEAD +======= +@pytest.mark.parametrize("batch_size", [1, 2]) +>>>>>>> c80c413 (Write tests for LM decoding) def test_run_prediction_language_model( image_names, language_model_weight, expected_predictions, +<<<<<<< HEAD +======= + batch_size, +>>>>>>> c80c413 (Write tests for LM decoding) tmp_path, ): # Make tmpdir and copy needed images inside @@ -645,7 +685,11 @@ def test_run_prediction_language_model( # Update language_model_weight in parameters.yml params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml") +<<<<<<< HEAD params["parameters"]["language_model"]["weight"] = language_model_weight +======= + params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight +>>>>>>> c80c413 (Write tests for LM decoding) yaml.dump(params, (tmp_path / "parameters.yml").open("w")) run_prediction( @@ -669,7 +713,11 @@ def test_run_prediction_language_model( max_object_height=None, image_extension=".png", gpu_device=None, +<<<<<<< HEAD batch_size=1, +======= + batch_size=batch_size, +>>>>>>> c80c413 (Write tests for LM decoding) tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"), start_token=None, use_language_model=True, @@ -686,7 +734,10 @@ def test_run_prediction_language_model( prediction["language_model"]["text"] == expected_prediction["language_model"]["text"] ) +<<<<<<< HEAD assert np.isclose( prediction["language_model"]["confidence"], expected_prediction["language_model"]["confidence"], ) +======= +>>>>>>> c80c413 (Write tests for LM decoding)