diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py index 66e30710a84b347c8051e239e50983b7eff2945f..07091be5301ae8a9f8189ef701e369afdc76b068 100644 --- a/dan/ocr/predict/__init__.py +++ b/dan/ocr/predict/__init__.py @@ -141,4 +141,16 @@ def add_predict_parser(subcommands) -> None: action="store_true", required=False, ) + parser.add_argument( + "--compile-model", + help="Whether to compile the model. Recommended to speed up inference.", + action="store_true", + required=False, + ) + parser.add_argument( + "--dynamic-mode", + help="Whether to use the dynamic mode during model compilation. Recommended for prediction on images of variable size.", + action="store_true", + required=False, + ) parser.set_defaults(func=run) diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index adf9efafe39464ee86dc80242f0a7874091091b2..7b6ab65ec0c8a4058e0754ad909fdd67a2ecd73c 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -52,12 +52,16 @@ class DAN: path: Path, mode: str = "eval", use_language_model: bool = False, + compile_model: bool = False, + dynamic_mode: bool = False, ) -> None: """ Load a trained model. :param path: Path to the directory containing the model, the YAML parameters file and the charset file. :param mode: The mode to load the model (train or eval). :param use_language_model: Whether to use an explicit language model to rescore text hypotheses. + :param compile_model: Whether to compile the model. + :param dynamic_mode: Whether to use the dynamic mode during model compilation. """ model_path = path / "model.pt" assert model_path.is_file(), f"File {model_path} not found" @@ -84,6 +88,15 @@ class DAN: logger.debug(f"Loaded model {model_path}") + if compile_model: + torch.compiler.cudagraph_mark_step_begin() + encoder = torch.compile(encoder, dynamic=True if dynamic_mode else None) + + torch.compiler.cudagraph_mark_step_begin() + decoder = torch.compile(decoder, dynamic=True if dynamic_mode else None) + + logger.info("Encoder and decoder have been compiled") + if mode == "train": encoder.train() decoder.train() @@ -445,6 +458,8 @@ def run( tokens: Dict[str, EntityType], start_token: str, use_language_model: bool, + compile_model: bool, + dynamic_mode: bool, ) -> None: """ Predict a single image save the output @@ -464,6 +479,8 @@ def run( :param tokens: NER tokens used. :param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. :param use_language_model: Whether to use an explicit language model to rescore text hypotheses. + :param compile_model: Whether to compile the model. + :param dynamic_mode: Whether to use the dynamic mode during model compilation. """ # Create output directory if necessary if not output.exists(): @@ -473,7 +490,13 @@ def run( cuda_device = f":{gpu_device}" if gpu_device is not None else "" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) - dan_model.load(model, mode="eval", use_language_model=use_language_model) + dan_model.load( + model, + mode="eval", + use_language_model=use_language_model, + compile_model=compile_model, + dynamic_mode=dynamic_mode, + ) images = image_dir.rglob(f"*{image_extension}") for image_batch in list_to_batches(images, n=batch_size): diff --git a/docs/usage/predict/index.md b/docs/usage/predict/index.md index 656011b1966cefc03de25ebbe99448ded57feeab..dfb8870c75fab461d96bc343b48105e91d56aba4 100644 --- a/docs/usage/predict/index.md +++ b/docs/usage/predict/index.md @@ -25,6 +25,8 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image. | `--batch-size` | Size of the batches for prediction. | `int` | `1` | | `--start-token` | Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. | `str` | | | `--use-language-model` | Whether to use an explicit language model to rescore text hypotheses. | `bool` | `False` | +| `--compile-model` | Whether to compile the model. Recommended to speed up inference. | `bool` | `False` | +| `--dynamic-mode` | Whether to use the dynamic mode during model compilation. Recommended for prediction on images of variable size. | `bool` | `False` | The `--model` argument expects a directory with the following files: @@ -277,3 +279,28 @@ It will create the following JSON file named after the image in the `predict_wor } } ``` + +### Speed up prediction with model compilation + +To speed up prediction, it is recommended to [compile models using `torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). + +Run this command to use this option: + +```shell +teklia-dan predict \ + --image-dir images/ \ + --model models \ + --output predict/ \ + --compile-model +``` + +When predicting on images of variable size, it is recommended to enable the `dynamic` mode: + +```shell +teklia-dan predict \ + --image-dir images/ \ + --model models \ + --output predict/ \ + --compile-model \ + --dynamic-mode +``` diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 1fc5423e2bbaf53c231dcb2f1ef57cb2439a5151..7634fa0a8aba7fd5fdd9e1c7b939fce3d17d4e4f 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -341,6 +341,8 @@ def test_run_prediction( tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"), start_token=None, use_language_model=False, + compile_model=False, + dynamic_mode=False, ) prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text()) @@ -534,6 +536,8 @@ def test_run_prediction_batch( tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"), start_token=None, use_language_model=False, + compile_model=False, + dynamic_mode=False, ) for image_name, expected_prediction in zip(image_names, expected_predictions): @@ -693,6 +697,8 @@ def test_run_prediction_language_model( tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"), start_token=None, use_language_model=True, + compile_model=False, + dynamic_mode=False, ) for image_name, expected_prediction in zip(image_names, expected_predictions):