From de0eb743e2a25d4760837025cd06f88c4ed544be Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Mon, 30 Oct 2023 12:46:22 +0100 Subject: [PATCH] Only support thresholding method `otsu` --- dan/ocr/predict/__init__.py | 13 ---------- dan/ocr/predict/attention.py | 50 +++++++----------------------------- dan/ocr/predict/inference.py | 18 ------------- docs/usage/predict/index.md | 5 +--- tests/test_prediction.py | 6 ----- 5 files changed, 10 insertions(+), 82 deletions(-) diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py index e5f7b2bc..fd535905 100644 --- a/dan/ocr/predict/__init__.py +++ b/dan/ocr/predict/__init__.py @@ -135,19 +135,6 @@ def add_predict_parser(subcommands) -> None: type=int, default=None, ) - parser.add_argument( - "--threshold-method", - help="Thresholding method.", - choices=["otsu", "simple"], - type=str, - default="otsu", - ) - parser.add_argument( - "--threshold-value", - help="Thresholding value.", - type=int, - default=0, - ) parser.add_argument( "--gpu-device", help="Use a specific GPU if available.", diff --git a/dan/ocr/predict/attention.py b/dan/ocr/predict/attention.py index 3ddbae9e..8e1b07e7 100644 --- a/dan/ocr/predict/attention.py +++ b/dan/ocr/predict/attention.py @@ -220,8 +220,6 @@ def get_predicted_polygons_with_confidence( level: Level, height: int, width: int, - threshold_method: str = "otsu", - threshold_value: int = 0, max_object_height: int = 50, word_separators: re.Pattern = parse_delimiters(["\n", " "]), line_separators: re.Pattern = parse_delimiters(["\n"]), @@ -235,8 +233,6 @@ def get_predicted_polygons_with_confidence( :param level: Level to display (must be in [char, word, line, ner]) :param height: Original image height :param width: Original image width - :param threshold_method: Thresholding method. Should be in ["otsu", "simple"] - :param threshold_value: Thresholding value for the "simple" method. :param max_object_height: Maximum height of predicted objects. :param word_separators: List of word separators :param line_separators: List of line separators @@ -256,8 +252,6 @@ def get_predicted_polygons_with_confidence( max_value, start_index, weights, - threshold_method=threshold_method, - threshold_value=threshold_value, max_object_height=max_object_height, size=(width, height), ) @@ -347,35 +341,21 @@ def polygon_to_bbx(polygon: np.ndarray) -> List[Tuple[int, int]]: return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]] -def threshold( - mask: np.ndarray, threshold_method: str = "otsu", threshold_value: int = 0 -) -> np.ndarray: +def threshold(mask: np.ndarray) -> np.ndarray: """ Threshold a grayscale mask. :param mask: a grayscale image (np.array) - :param threshold_method: method to be used for thresholding. Should be in ["otsu", "simple"]. - :param threshold_value: the threshold value used for binarization (used for the "simple" method). """ min_kernel = 1 max_kernel = mask.shape[1] // 100 - if threshold_method == "simple": - bin_mask = np.array(np.where(mask > threshold_value, 255, 0), dtype=np.uint8) - return np.asarray(bin_mask, dtype=np.uint8) - - elif threshold_method == "otsu": - # Blur and apply Otsu thresholding - blur = cv2.GaussianBlur(mask, (15, 15), 0) - _, bin_mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) - # Apply dilation - kernel_width = cv2.getStructuringElement( - cv2.MORPH_CROSS, (max_kernel, min_kernel) - ) - dilated = cv2.dilate(bin_mask, kernel_width, iterations=3) - return np.asarray(dilated, dtype=np.uint8) - - else: - raise NotImplementedError(f"Method {threshold_method} is not implemented.") + # Blur and apply Otsu thresholding + blur = cv2.GaussianBlur(mask, (15, 15), 0) + _, bin_mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + # Apply dilation + kernel_width = cv2.getStructuringElement(cv2.MORPH_CROSS, (max_kernel, min_kernel)) + dilated = cv2.dilate(bin_mask, kernel_width, iterations=3) + return np.asarray(dilated, dtype=np.uint8) def get_polygon( @@ -383,8 +363,6 @@ def get_polygon( max_value: np.float32, offset: int, weights: np.ndarray, - threshold_method: str, - threshold_value: int, size: Tuple[int, int] = None, max_object_height: int = 50, ) -> Tuple[dict, np.ndarray]: @@ -394,19 +372,13 @@ def get_polygon( :param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization :param offset: Offset value to get the relevant part of text piece :param size: Target size (width, height) to resize the coverage vector - :param threshold_method: Binarization method to use (should be in ["simple", "otsu"]) :param max_object_height: Maximum height of predicted objects. - :param threshold_value: Threshold value used for the "simple" binarization method """ # Compute coverage vector coverage_vector = compute_coverage(text, max_value, offset, weights, size=size) # Generate a binary image for the current channel. - bin_mask = threshold( - coverage_vector, - threshold_method=threshold_method, - threshold_value=threshold_value, - ) + bin_mask = threshold(coverage_vector) coord, confidence = ( get_grid_search_contour(coverage_vector, bin_mask, height=max_object_height) @@ -475,8 +447,6 @@ def plot_attention( level: Level, scale: float, outname: str, - threshold_method: str = "otsu", - threshold_value: int = 0, max_object_height: int = 50, word_separators: re.Pattern = parse_delimiters(["\n", " "]), line_separators: re.Pattern = parse_delimiters(["\n"]), @@ -527,8 +497,6 @@ def plot_attention( max_value, tot_len, weights, - threshold_method=threshold_method, - threshold_value=threshold_value, max_object_height=max_object_height, size=(image.width, image.height), ) diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index 0d67983a..e012a5bf 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -138,8 +138,6 @@ class DAN: line_separators: re.Pattern = parse_delimiters(["\n"]), tokens: Dict[str, EntityType] = {}, start_token: str = None, - threshold_method: str = "otsu", - threshold_value: int = 0, max_object_height: int = 50, use_language_model: bool = False, ) -> dict: @@ -151,8 +149,6 @@ class DAN: :param attentions: Return characters attention weights. :param attention_level: Level of text pieces (must be in [char, word, line, ner]) :param extract_objects: Whether to extract polygons' coordinates. - :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. - :param threshold_value: Thresholding value to use for the "simple" thresholding method. :param max_object_height: Maximum height of predicted objects. """ input_tensor = input_tensor.to(self.device) @@ -284,8 +280,6 @@ class DAN: attention_level, input_sizes[i][0], input_sizes[i][1], - threshold_method=threshold_method, - threshold_value=threshold_value, max_object_height=max_object_height, word_separators=word_separators, line_separators=line_separators, @@ -309,8 +303,6 @@ def process_batch( word_separators: List[str], line_separators: List[str], predict_objects: bool, - threshold_method: str, - threshold_value: int, max_object_height: int, tokens: Dict[str, EntityType], start_token: str, @@ -346,8 +338,6 @@ def process_batch( word_separators=word_separators, line_separators=line_separators, tokens=tokens, - threshold_method=threshold_method, - threshold_value=threshold_value, max_object_height=max_object_height, start_token=start_token, use_language_model=use_language_model, @@ -406,8 +396,6 @@ def process_batch( line_separators=line_separators, tokens=tokens, display_polygons=predict_objects, - threshold_method=threshold_method, - threshold_value=threshold_value, max_object_height=max_object_height, outname=gif_filename, ) @@ -434,8 +422,6 @@ def run( line_separators: List[str], temperature: float, predict_objects: bool, - threshold_method: str, - threshold_value: int, max_object_height: int, image_extension: str, gpu_device: int, @@ -459,8 +445,6 @@ def run( :param word_separators: List of word separators. :param line_separators: List of line separators. :param predict_objects: Whether to extract objects. - :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. - :param threshold_value: Thresholding value to use for the "simple" thresholding method. :param max_object_height: Maximum height of predicted objects. :param gpu_device: Use a specific GPU if available. :param batch_size: Size of the batches for prediction. @@ -498,8 +482,6 @@ def run( word_separators, line_separators, predict_objects, - threshold_method, - threshold_value, max_object_height, tokens, start_token, diff --git a/docs/usage/predict/index.md b/docs/usage/predict/index.md index bb313b2b..51f99d4a 100644 --- a/docs/usage/predict/index.md +++ b/docs/usage/predict/index.md @@ -24,8 +24,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image. | `--max-object-height` | Maximum height for predicted objects. If set, grid search segmentation will be applied and width will be normalized to element width. | `int` | | | `--word-separators` | List of word separators. | `list` | `[" ", "\n"]` | | `--line-separators` | List of line separators. | `list` | `["\n"]` | -| `--threshold-method` | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`. | `str` | `"otsu"` | -| `--threshold-value ` | Threshold to use for the "simple" thresholding method. | `int` | `0` | | `--gpu-device` | Use a specific GPU if available. | `int` | | | `--batch-size` | Size of the batches for prediction. | `int` | `1` | | `--start-token` | Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. | `str` | | @@ -130,8 +128,7 @@ teklia-dan predict \ --charset charset.pkl \ --output predict/ \ --attention-map \ - --predict-objects \ - --threshold-method otsu + --predict-objects ``` It will create the following JSON file named `predict/example.json` and a GIF showing a line-level attention map with extracted polygons `predict/example_line.gif` diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 30576b1b..6affeeb9 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -314,8 +314,6 @@ def test_run_prediction( line_separators=["\n"], temperature=temperature, predict_objects=False, - threshold_method="otsu", - threshold_value=0, max_object_height=None, image_extension=None, gpu_device=None, @@ -512,8 +510,6 @@ def test_run_prediction_batch( line_separators=["\n"], temperature=temperature, predict_objects=False, - threshold_method="otsu", - threshold_value=0, max_object_height=None, image_extension=".png", gpu_device=None, @@ -664,8 +660,6 @@ def test_run_prediction_language_model( line_separators=["\n"], temperature=1.0, predict_objects=False, - threshold_method="otsu", - threshold_value=0, max_object_height=None, image_extension=".png", gpu_device=None, -- GitLab