Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
......@@ -135,19 +135,6 @@ def add_predict_parser(subcommands) -> None:
type=int,
default=None,
)
parser.add_argument(
"--threshold-method",
help="Thresholding method.",
choices=["otsu", "simple"],
type=str,
default="otsu",
)
parser.add_argument(
"--threshold-value",
help="Thresholding value.",
type=int,
default=0,
)
parser.add_argument(
"--gpu-device",
help="Use a specific GPU if available.",
......
......@@ -220,8 +220,6 @@ def get_predicted_polygons_with_confidence(
level: Level,
height: int,
width: int,
threshold_method: str = "otsu",
threshold_value: int = 0,
max_object_height: int = 50,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
......@@ -235,8 +233,6 @@ def get_predicted_polygons_with_confidence(
:param level: Level to display (must be in [char, word, line, ner])
:param height: Original image height
:param width: Original image width
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"]
:param threshold_value: Thresholding value for the "simple" method.
:param max_object_height: Maximum height of predicted objects.
:param word_separators: List of word separators
:param line_separators: List of line separators
......@@ -256,8 +252,6 @@ def get_predicted_polygons_with_confidence(
max_value,
start_index,
weights,
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
size=(width, height),
)
......@@ -347,35 +341,21 @@ def polygon_to_bbx(polygon: np.ndarray) -> List[Tuple[int, int]]:
return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]]
def threshold(
mask: np.ndarray, threshold_method: str = "otsu", threshold_value: int = 0
) -> np.ndarray:
def threshold(mask: np.ndarray) -> np.ndarray:
"""
Threshold a grayscale mask.
:param mask: a grayscale image (np.array)
:param threshold_method: method to be used for thresholding. Should be in ["otsu", "simple"].
:param threshold_value: the threshold value used for binarization (used for the "simple" method).
"""
min_kernel = 1
max_kernel = mask.shape[1] // 100
if threshold_method == "simple":
bin_mask = np.array(np.where(mask > threshold_value, 255, 0), dtype=np.uint8)
return np.asarray(bin_mask, dtype=np.uint8)
elif threshold_method == "otsu":
# Blur and apply Otsu thresholding
blur = cv2.GaussianBlur(mask, (15, 15), 0)
_, bin_mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Apply dilation
kernel_width = cv2.getStructuringElement(
cv2.MORPH_CROSS, (max_kernel, min_kernel)
)
dilated = cv2.dilate(bin_mask, kernel_width, iterations=3)
return np.asarray(dilated, dtype=np.uint8)
else:
raise NotImplementedError(f"Method {threshold_method} is not implemented.")
# Blur and apply Otsu thresholding
blur = cv2.GaussianBlur(mask, (15, 15), 0)
_, bin_mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Apply dilation
kernel_width = cv2.getStructuringElement(cv2.MORPH_CROSS, (max_kernel, min_kernel))
dilated = cv2.dilate(bin_mask, kernel_width, iterations=3)
return np.asarray(dilated, dtype=np.uint8)
def get_polygon(
......@@ -383,8 +363,6 @@ def get_polygon(
max_value: np.float32,
offset: int,
weights: np.ndarray,
threshold_method: str,
threshold_value: int,
size: Tuple[int, int] = None,
max_object_height: int = 50,
) -> Tuple[dict, np.ndarray]:
......@@ -394,19 +372,13 @@ def get_polygon(
:param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization
:param offset: Offset value to get the relevant part of text piece
:param size: Target size (width, height) to resize the coverage vector
:param threshold_method: Binarization method to use (should be in ["simple", "otsu"])
:param max_object_height: Maximum height of predicted objects.
:param threshold_value: Threshold value used for the "simple" binarization method
"""
# Compute coverage vector
coverage_vector = compute_coverage(text, max_value, offset, weights, size=size)
# Generate a binary image for the current channel.
bin_mask = threshold(
coverage_vector,
threshold_method=threshold_method,
threshold_value=threshold_value,
)
bin_mask = threshold(coverage_vector)
coord, confidence = (
get_grid_search_contour(coverage_vector, bin_mask, height=max_object_height)
......@@ -475,8 +447,6 @@ def plot_attention(
level: Level,
scale: float,
outname: str,
threshold_method: str = "otsu",
threshold_value: int = 0,
max_object_height: int = 50,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
......@@ -527,8 +497,6 @@ def plot_attention(
max_value,
tot_len,
weights,
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
size=(image.width, image.height),
)
......
......@@ -138,8 +138,6 @@ class DAN:
line_separators: re.Pattern = parse_delimiters(["\n"]),
tokens: Dict[str, EntityType] = {},
start_token: str = None,
threshold_method: str = "otsu",
threshold_value: int = 0,
max_object_height: int = 50,
use_language_model: bool = False,
) -> dict:
......@@ -151,8 +149,6 @@ class DAN:
:param attentions: Return characters attention weights.
:param attention_level: Level of text pieces (must be in [char, word, line, ner])
:param extract_objects: Whether to extract polygons' coordinates.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
:param max_object_height: Maximum height of predicted objects.
"""
input_tensor = input_tensor.to(self.device)
......@@ -284,8 +280,6 @@ class DAN:
attention_level,
input_sizes[i][0],
input_sizes[i][1],
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
word_separators=word_separators,
line_separators=line_separators,
......@@ -309,8 +303,6 @@ def process_batch(
word_separators: List[str],
line_separators: List[str],
predict_objects: bool,
threshold_method: str,
threshold_value: int,
max_object_height: int,
tokens: Dict[str, EntityType],
start_token: str,
......@@ -346,8 +338,6 @@ def process_batch(
word_separators=word_separators,
line_separators=line_separators,
tokens=tokens,
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
start_token=start_token,
use_language_model=use_language_model,
......@@ -406,8 +396,6 @@ def process_batch(
line_separators=line_separators,
tokens=tokens,
display_polygons=predict_objects,
threshold_method=threshold_method,
threshold_value=threshold_value,
max_object_height=max_object_height,
outname=gif_filename,
)
......@@ -434,8 +422,6 @@ def run(
line_separators: List[str],
temperature: float,
predict_objects: bool,
threshold_method: str,
threshold_value: int,
max_object_height: int,
image_extension: str,
gpu_device: int,
......@@ -459,8 +445,6 @@ def run(
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
:param max_object_height: Maximum height of predicted objects.
:param gpu_device: Use a specific GPU if available.
:param batch_size: Size of the batches for prediction.
......@@ -498,8 +482,6 @@ def run(
word_separators,
line_separators,
predict_objects,
threshold_method,
threshold_value,
max_object_height,
tokens,
start_token,
......
......@@ -36,6 +36,18 @@ The library already has all the documents needed to run the [training command](.
teklia-dan train --config configs/tests.json
```
The library already has all the documents needed to run the [predict command](../usage/predict/index.md) with a minimalist model. In the `tests/data/prediction` directory, you can run the following command and add any extra parameters you need:
```shell
teklia-dan predict \
--image-dir images/ \
--image-extension png \
--model popp_line_model.pt \
--parameters parameters.yml \
--charset charset.pkl \
--output /tmp/dan-predict
```
## Documentation
This documentation uses [Sphinx](http://www.sphinx-doc.org/) and was generated using [MkDocs](https://mkdocs.org/) and [mkdocstrings](https://mkdocstrings.github.io/).
......
......@@ -24,8 +24,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
| `--max-object-height` | Maximum height for predicted objects. If set, grid search segmentation will be applied and width will be normalized to element width. | `int` | |
| `--word-separators` | List of word separators. | `list` | `[" ", "\n"]` |
| `--line-separators` | List of line separators. | `list` | `["\n"]` |
| `--threshold-method` | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`. | `str` | `"otsu"` |
| `--threshold-value ` | Threshold to use for the "simple" thresholding method. | `int` | `0` |
| `--gpu-device` | Use a specific GPU if available. | `int` | |
| `--batch-size` | Size of the batches for prediction. | `int` | `1` |
| `--start-token` | Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. | `str` | |
......@@ -130,8 +128,7 @@ teklia-dan predict \
--charset charset.pkl \
--output predict/ \
--attention-map \
--predict-objects \
--threshold-method otsu
--predict-objects
```
It will create the following JSON file named `predict/example.json` and a GIF showing a line-level attention map with extracted polygons `predict/example_line.gif`
......
......@@ -314,8 +314,6 @@ def test_run_prediction(
line_separators=["\n"],
temperature=temperature,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
max_object_height=None,
image_extension=None,
gpu_device=None,
......@@ -512,8 +510,6 @@ def test_run_prediction_batch(
line_separators=["\n"],
temperature=temperature,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
max_object_height=None,
image_extension=".png",
gpu_device=None,
......@@ -664,8 +660,6 @@ def test_run_prediction_language_model(
line_separators=["\n"],
temperature=1.0,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
max_object_height=None,
image_extension=".png",
gpu_device=None,
......