From 69f278a9ba5687e97b9b9f817216b9cd8e41b54d Mon Sep 17 00:00:00 2001 From: Tristan Faine <tfaine@teklia.com> Date: Wed, 12 Apr 2023 07:05:03 +0000 Subject: [PATCH] Add predicted objects to predict command --- dan/predict/__init__.py | 18 ++ dan/predict/attention.py | 268 +++++++++++++++++++++++++-- dan/predict/prediction.py | 160 +++++++++++----- docs/assets/example_line_polygon.gif | 3 + docs/usage/predict.md | 59 +++++- 5 files changed, 441 insertions(+), 67 deletions(-) create mode 100644 docs/assets/example_line_polygon.gif diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py index 1528d35d..0c17ffd2 100644 --- a/dan/predict/__init__.py +++ b/dan/predict/__init__.py @@ -106,4 +106,22 @@ def add_predict_parser(subcommands) -> None: help="String separators used to split text into lines.", required=False, ) + parser.add_argument( + "--predict-objects", + action="store_true", + help="Whether to output objects when plotting attention maps.", + required=False, + ) + parser.add_argument( + "--threshold-method", + help="Thresholding method.", + choices=["otsu", "simple"], + default="otsu", + ) + parser.add_argument( + "--threshold-value", + help="Thresholding value.", + type=int, + default=0, + ) parser.set_defaults(func=run) diff --git a/dan/predict/attention.py b/dan/predict/attention.py index 3d32b5c8..160dd52b 100644 --- a/dan/predict/attention.py +++ b/dan/predict/attention.py @@ -6,18 +6,43 @@ import numpy as np from PIL import Image from dan import logger +from dan.utils import round_floats -def split_text(text, level, word_separators, line_separators): +def parse_delimiters(delimiters): + return re.compile(r"|".join(delimiters)) + + +def compute_prob_by_separator(characters, probabilities, separator): + """ + Split text and confidences using separators and return a list of average confidence scores. + :param characters: list of characters. + :param probabilities: list of character probabilities. + :param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences. + Returns a list confidence scores. + """ + # match anything except separators, get start and end index + pattern = re.compile(f"[^{separator.pattern}]+") + matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)] + + # Iterate over text pieces and compute mean confidence + probs = [np.mean(probabilities[start:end]) for (start, end) in matches] + texts = [characters[start:end] for (start, end) in matches] + return texts, probs + + +def split_text(text: str, level: str, word_separators, line_separators): """ Split text into a list of characters, word, or lines. :param text: Text prediction from DAN - :param level: Level to visualize (char, word, line) + :param level: Level to visualize from [char, word, line] + :param word_separators: List of word separators + :param line_separators: List of line separators """ - # split into characters if level == "char": text_split = list(text) offset = 0 + # split into words elif level == "word": text_split = re.split(word_separators, text) @@ -31,13 +56,89 @@ def split_text(text, level, word_separators, line_separators): return text_split, offset -def compute_coverage(text: str, max_value: float, offset: int, attentions): +def split_text_and_confidences( + text, confidences, level, word_separators, line_separators +): + """ + Split text into a list of characters, words or lines with corresponding confidences scores + :param text: Text prediction from DAN + :param confidences: Character confidences + :param level: Level to visualize from [char, word, line] + :param word_separators: List of word separators + :param line_separators: List of line separators + """ + if level == "char": + texts = list(text) + offset = 0 + elif level == "word": + texts, probs = compute_prob_by_separator(text, confidences, word_separators) + offset = 1 + elif level == "line": + texts, probs = compute_prob_by_separator(text, confidences, line_separators) + offset = 1 + else: + logger.error("Level should be either 'char', 'word', or 'line'") + return texts, round_floats(probs), offset + + +def get_predicted_polygons_with_confidence( + text, + weights, + confidences, + level, + height, + width, + threshold_method="otsu", + threshold_value=0, + word_separators=["\n", " "], + line_separators=["\n"], +): + """ + Returns the polygons of each object of the current prediction + :param text: Text predicted by DAN + :param weights: Attention weights of size (n_char, feature_height, feature_width) + :param confidences: Character confidences + :param level: Level to display (must be in [char, word, line]) + :param height: Original image height + :param width: Original image width + :param threshold_method: Thresholding method. Should be in ["otsu", "simple"] + :param threshold_value: Thresholding value for the "simple" method. + :param word_separators: List of word separators + :param line_separators: List of line separators + """ + # Split text into characters, words or lines + text_list, confidence_list, offset = split_text_and_confidences( + text, confidences, level, word_separators, line_separators + ) + + max_value = weights.sum(0).max() + polygons = [] + start_index = 0 + for text_piece, confidence in zip(text_list, confidence_list): + start_index += len(text_piece) + offset + polygon, _ = get_polygon( + text_piece, + max_value, + offset, + weights, + threshold_method=threshold_method, + threshold_value=threshold_value, + size=(width, height), + ) + polygon["text"] = text_piece + polygon["text_confidence"] = confidence + polygons.append(polygon) + return polygons + + +def compute_coverage(text: str, max_value: float, offset: int, attentions, size: tuple): """ Aggregates attention maps for the current text piece (char, word, line) :param text: Text piece selected with offset after splitting DAN prediction :param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization :param offset: Offset value to get the relevant part of text piece :param attentions: Attention weights of size (n_char, feature_height, feature_width) + :param size: Target size (width, height) to resize the coverage vector """ _, height, width = attentions.shape @@ -49,9 +150,130 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions): # Normalize coverage vector coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8) + + # Resize it + if size: + coverage_vector = cv2.resize(coverage_vector, size) + return coverage_vector +def blend_coverage(coverage_vector, image, mask, scale): + """ + Blends current coverage_vector over original image, used to make an attention map. + :param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width) + :param image: Input image in PIL format + :param mask: Mask of the image (of any color) + :param scale: Scaling factor for the output gif image + """ + height, width = coverage_vector.shape + + # Blend coverage vector with original image + blank_array = np.zeros((height, width)).astype(np.uint8) + coverage_vector = Image.fromarray( + np.stack([coverage_vector, blank_array, blank_array], axis=2), "RGB" + ) + blend = Image.composite(image, coverage_vector, mask) + + # Resize to save time + blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS) + return blend + + +def compute_contour_metrics(coverage_vector, contour): + """ + Compute the contours's area and the mean value inside it. + :param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width) + :param contour: Contour of the current attention blob + """ + # draw the contour zone + mask = np.zeros(coverage_vector.shape, dtype=np.uint8) + cv2.drawContours(mask, [contour], -1, (255), -1) + + max_value = np.where(mask > 0, coverage_vector, 0).max() / 255 + area = cv2.contourArea(contour) + return max_value, max_value * area + + +def polygon_to_bbx(polygon): + x, y, w, h = cv2.boundingRect(polygon) + return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]] + + +def threshold(mask, threshold_method="otsu", threshold_value=0): + """ + Threshold a grayscale mask. + :param mask: a grayscale image (np.array) + :param threshold_method: method to be used for thresholding. Should be in ["otsu", "simple"]. + :param threshold_value: the threshold value used for binarization (used for the "simple" method). + """ + min_kernel = 1 + max_kernel = mask.shape[1] // 100 + + if threshold_method == "simple": + bin_mask = np.array(np.where(mask > threshold_value, 255, 0), dtype=np.uint8) + return np.asarray(bin_mask, dtype=np.uint8) + + elif threshold_method == "otsu": + # Blur and apply Otsu thresholding + blur = cv2.GaussianBlur(mask, (15, 15), 0) + _, bin_mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + # Apply dilation + kernel_width = cv2.getStructuringElement( + cv2.MORPH_CROSS, (max_kernel, min_kernel) + ) + dilated = cv2.dilate(bin_mask, kernel_width, iterations=3) + return np.asarray(dilated, dtype=np.uint8) + + else: + raise NotImplementedError(f"Method {threshold_method} is not implemented.") + + +def get_polygon( + text, max_value, offset, weights, threshold_method, threshold_value, size=None +): + """ + Gets polygon associated with element of current text_piece, indexed by offset + :param text: Text piece selected with offset after splitting DAN prediction + :param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization + :param offset: Offset value to get the relevant part of text piece + :param size: Target size (width, height) to resize the coverage vector + :param threshold_method: Binarization method to use (should be in ["simple", "otsu"]) + :param threshold_value: Threshold value used for the "simple" binarization method + """ + # Compute coverage vector + coverage_vector = compute_coverage(text, max_value, offset, weights, size=size) + + # Generate a binary image for the current channel. + bin_mask = threshold( + coverage_vector, + threshold_method=threshold_method, + threshold_value=threshold_value, + ) + + # Detect the objects contours + contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + if not contours: + return {}, None + + # Select best contour + metrics = [compute_contour_metrics(coverage_vector, cnt) for cnt in contours] + confidences, scores = map(list, zip(*metrics)) + best_contour = contours[np.argmax(scores)] + confidence = round(confidences[np.argmax(scores)] / max_value, 2) + + # Format for JSON + coord = polygon_to_bbx(np.squeeze(best_contour)) + polygon = { + "confidence": confidence, + "polygon": coord, + } + simplified_contour = np.expand_dims(np.array(coord, dtype=np.int32), axis=1) + + return polygon, simplified_contour + + def plot_attention( image, text, @@ -59,8 +281,11 @@ def plot_attention( level, scale, outname, + threshold_method="otsu", + threshold_value=0, word_separators=["\n", " "], line_separators=["\n"], + display_polygons=False, ): """ Create a gif by blending attention maps to the image for each text piece (char, word or line) @@ -70,6 +295,9 @@ def plot_attention( :param level: Level to display (must be in [char, word, line]) :param scale: Scaling factor for the output gif image :param outname: Name of the gif image + :param word_separators: List of word separators + :param line_separators: List of line separators + :param display_polygons: Whether to plot extracted polygons """ height, width, _ = image.shape @@ -84,27 +312,35 @@ def plot_attention( # Iterate on characters, words or lines tot_len = 0 - max_value = weights.sum(0).max() for text_piece in text_list: # Accumulate weights for the current word/line and resize to original image size - coverage_vector = compute_coverage(text_piece, max_value, tot_len, weights) - coverage_vector = cv2.resize(coverage_vector, (width, height)) + coverage_vector = compute_coverage( + text_piece, max_value, tot_len, weights, (width, height) + ) + + # Get polygons if flag is set: + if display_polygons: + # draw the contour + _, contour = get_polygon( + text_piece, + max_value, + tot_len, + weights, + threshold_method=threshold_method, + threshold_value=threshold_value, + size=(width, height), + ) + + if contour is not None: + cv2.drawContours(coverage_vector, [contour], 0, (255), 5) # Keep track of text length tot_len += len(text_piece) + offset # Blend coverage vector with original image - blank_array = np.zeros((height, width)).astype(np.uint8) - coverage_vector = Image.fromarray( - np.stack([coverage_vector, blank_array, blank_array], axis=2), "RGB" - ) - blend = Image.composite(image, coverage_vector, mask) - - # Resize to save time - blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS) - attention_map.append(blend) + attention_map.append(blend_coverage(coverage_vector, image, mask, scale)) attention_map[0].save( outname, diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 55bf817c..9c8de085 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -2,7 +2,7 @@ import os import pickle -import re +from pathlib import Path import cv2 import numpy as np @@ -14,8 +14,13 @@ from dan.datasets.extract.utils import save_json from dan.decoder import GlobalHTADecoder from dan.models import FCN_Encoder from dan.ocr.utils import LM_ind_to_str -from dan.predict.attention import plot_attention -from dan.utils import read_image, round_floats +from dan.predict.attention import ( + get_predicted_polygons_with_confidence, + parse_delimiters, + plot_attention, + split_text_and_confidences, +) +from dan.utils import read_image class DAN: @@ -92,7 +97,13 @@ class DAN: input_sizes, confidences=False, attentions=False, + attention_level=False, + extract_objects=False, + word_separators=["\n", " "], + line_separators=["\n"], start_token=None, + threshold_method="otsu", + threshold_value=0, ): """ Run prediction on an input image. @@ -100,6 +111,10 @@ class DAN: :param input_sizes: The original images sizes. :param confidences: Return the characters probabilities. :param attentions: Return characters attention weights. + :param attention_level: Level of text pieces (must be in [char, word, line]) + :param extract_objects: Whether to extract polygons' coordinates. + :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. + :param threshold_value: Thresholding value to use for the "simple" thresholding method. """ input_tensor = input_tensor.to(self.device) @@ -110,13 +125,20 @@ class DAN: # Run the prediction. with torch.no_grad(): - b = input_tensor.size(0) - reached_end = torch.zeros((b,), dtype=torch.bool, device=self.device) - prediction_len = torch.zeros((b,), dtype=torch.int, device=self.device) + batch_size = input_tensor.size(0) + reached_end = torch.zeros( + (batch_size,), dtype=torch.bool, device=self.device + ) + prediction_len = torch.zeros( + (batch_size,), dtype=torch.int, device=self.device + ) predicted_tokens = ( - torch.ones((b, 1), dtype=torch.long, device=self.device) * start_token + torch.ones((batch_size, 1), dtype=torch.long, device=self.device) + * start_token + ) + predicted_tokens_len = torch.ones( + (batch_size,), dtype=torch.int, device=self.device ) - predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device) whole_output = list() confidence_scores = list() @@ -185,10 +207,11 @@ class DAN: predicted_tokens = predicted_tokens[:, 1:] prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1 predicted_tokens = [ - predicted_tokens[i, : prediction_len[i]] for i in range(b) + predicted_tokens[i, : prediction_len[i]] for i in range(batch_size) ] confidence_scores = [ - confidence_scores[i, : prediction_len[i]].tolist() for i in range(b) + confidence_scores[i, : prediction_len[i]].tolist() + for i in range(batch_size) ] # Transform tokens to characters @@ -198,34 +221,32 @@ class DAN: logger.info("Images processed") - out = {"text": predicted_text} + out = {} + + out["text"] = predicted_text if confidences: out["confidences"] = confidence_scores if attentions: out["attentions"] = attention_maps + if extract_objects: + out["objects"] = [ + get_predicted_polygons_with_confidence( + predicted_text[i], + attention_maps[i], + confidence_scores[i], + attention_level, + input_sizes[i][0], + input_sizes[i][1], + threshold_method=threshold_method, + threshold_value=threshold_value, + word_separators=word_separators, + line_separators=line_separators, + ) + for i in range(batch_size) + ] return out -def parse_delimiters(delimiters): - return re.compile(r"|".join(delimiters)) - - -def compute_prob_by_separator(characters, probabilities, separator): - """ - Split text and confidences using separators and return a list of average confidence scores. - :param characters: list of characters. - :param probabilities: list of probabilities. - :param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences. - Returns a list confidence scores. - """ - # match anything except separators, get start and end index - pattern = re.compile(f"[^{separator.pattern}]+") - matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)] - - # Iterate over text pieces and compute mean confidence - return [np.mean(probabilities[start:end]) for (start, end) in matches] - - def run( image, model, @@ -240,7 +261,28 @@ def run( attention_map_scale, word_separators, line_separators, + predict_objects, + threshold_method, + threshold_value, ): + """ + Predict a single image save the output + :param image: Path to the image to predict. + :param model: Path to the model to use for prediction. + :param parameters: Path to the YAML parameters file. + :param charset: Path to the charset. + :param output: Path to the output folder where the results will be saved. + :param scale: Scaling factor to resize the image. + :param confidence_score: Whether to compute confidence score. + :param attention_map: Whether to plot the attention map. + :param attention_map_level: Level of objects to extract. + :param attention_map_scale: Scaling factor for the attention map. + :param word_separators: List of word separators. + :param line_separators: List of line separators. + :param predict_objects: Whether to extract objects. + :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. + :param threshold_value: Thresholding value to use for the "simple" thresholding method. + """ # Create output directory if necessary if not os.path.exists(output): os.mkdir(output) @@ -261,41 +303,56 @@ def run( input_tensor = input_tensor.to(device) input_sizes = [im.shape[:2]] + # Parse delimiters to regex + word_separators = parse_delimiters(word_separators) + line_separators = parse_delimiters(line_separators) + # Predict prediction = dan_model.predict( input_tensor, input_sizes, confidences=confidence_score, attentions=attention_map, + attention_level=attention_map_level, + extract_objects=predict_objects, + word_separators=word_separators, + line_separators=line_separators, + threshold_method=threshold_method, + threshold_value=threshold_value, ) - text = prediction["text"][0] - result = {"text": text} - # Parse delimiters to regex - word_separators = parse_delimiters(word_separators) - line_separators = parse_delimiters(line_separators) + result = {} + result["text"] = prediction["text"][0] + + # Return extracted objects (coordinates, text, confidence) + if predict_objects: + result["objects"] = prediction["objects"][0] - # Average character-based confidence scores + # Return mean confidence score if confidence_score: + result["confidences"] = {} + char_confidences = prediction["confidences"][0] - result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)} - if "word" in confidence_score_levels: - word_probs = compute_prob_by_separator( - text, char_confidences, word_separators - ) - result["confidences"].update({"word": round_floats(word_probs)}) - if "line" in confidence_score_levels: - line_probs = compute_prob_by_separator( - text, char_confidences, line_separators + result["confidences"]["total"] = np.around(np.mean(char_confidences), 2) + + for level in confidence_score_levels: + result["confidences"][level] = [] + texts, confidences, _ = split_text_and_confidences( + prediction["text"][0], + char_confidences, + level, + word_separators, + line_separators, ) - result["confidences"].update({"line": round_floats(line_probs)}) - if "char" in confidence_score_levels: - result["confidences"].update({"char": round_floats(char_confidences)}) + + for text, conf in zip(texts, confidences): + result["confidences"][level].append({"text": text, "confidence": conf}) # Save gif with attention map if attention_map: gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif" logger.info(f"Creating attention GIF in {gif_filename}") + # this returns polygons but unused for now. plot_attention( image=im, text=prediction["text"][0], @@ -304,10 +361,13 @@ def run( scale=attention_map_scale, word_separators=word_separators, line_separators=line_separators, + display_polygons=predict_objects, + threshold_method=threshold_method, + threshold_value=threshold_value, outname=gif_filename, ) result["attention_gif"] = gif_filename json_filename = f"{output}/{image.stem}.json" logger.info(f"Saving JSON prediction in {json_filename}") - save_json(json_filename, result) + save_json(Path(json_filename), result) diff --git a/docs/assets/example_line_polygon.gif b/docs/assets/example_line_polygon.gif new file mode 100644 index 00000000..e92c8096 --- /dev/null +++ b/docs/assets/example_line_polygon.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37aad226af685c2d44da6742b817fd80dc3f22f9114acb2b6ac1e66554992c90 +size 27778846 diff --git a/docs/usage/predict.md b/docs/usage/predict.md index bb63b575..ea8ddd93 100644 --- a/docs/usage/predict.md +++ b/docs/usage/predict.md @@ -13,12 +13,15 @@ Use the `teklia-dan predict` command to predict a trained DAN model on an image. | `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | | | `--scale` | Image scaling factor before feeding it to DAN. | `float` | `1.0` | | `--confidence-score` | Whether to return confidence scores. | `bool` | `False` | -| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | | +| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | | | `--attention-map` | Whether to plot attention maps. | `bool` | `False` | | `--attention-map-scale` | Image scaling factor before creating the GIF. | `float` | `0.5` | | `--attention-map-level` | Level to plot the attention maps. Should be in `["line", "word", "char"]`. | `str` | `"line"` | +| `--predict-objects` | Whether to return polygons coordinates. | `bool` | `False` | | `--word-separators` | List of word separators. | `list` | `[" ", "\n"]` | | `--line-separators` | List of line separators. | `list` | `["\n"]` | +| `--threshold-method` | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`. | `str` | `"otsu"` | +| `--threshold-value ` | Threshold to use for the "simple" thresholding method. | `int` | `0` | ## Examples @@ -100,3 +103,57 @@ It will create the following JSON file named `dan_humu_page/predict/example.json } ``` <img src="../../assets/example_word.gif" > + + +### Predict with line-level attention maps and extract polygons + +To run a prediction, plot line-level attention maps, and extract polygons, run this command: + +```shell +teklia-dan predict \ + --image dan_humu_page/example.jpg \ + --model dan_humu_page/model.pt \ + --parameters dan_humu_page/parameters.yml \ + --charset dan_humu_page/charset.pkl \ + --output dan_humu_page/predict/ \ + --scale 0.5 \ + --attention-map \ + --predict-objects \ + --threshold-method otsu +``` + +It will create the following JSON file named `dan_humu_page/predict/example.json` and a GIF showing a line-level attention map with extracted polygons `dan_humu_page/predict/example_line.gif` + +```json +{ + "text": "Oslo\n39 \nOresden den 24te Rasser!\nH\u00f8jst\u00e6redesherr Hartvig - assert!\nUllereder fra den f\u00f8rste tide da\njeg havder den tilfredsstillelser at vide den ar-\ndistiske ledelser af Kristiania theater i Deres\nhronder, har jeg g\u00e5t hernede med et stille\nh\u00e5b om fra Dem at modtage et forelag, sig -\nsende tils at lade \"K\u00e6rlighedens \u00abKomedie\u00bb\nopf\u00f8re fore det norske purblikum.\nEt s\u00e5dant forslag er imidlertid, imod\nforventning; ikke fremkommet, og jeg n\u00f8des der-\nfor tils self at grivbe initiativet, hvilket hervede\nsker, idet jeg\nbeder\nbet\nragte stigkket some ved denne\nskrivelse officielde indleveret til theatret. No-\nget exemplar af bogen vedlagger jeg ikke da\ndenne (i 2den udgave) med Lethed kan er -\nholdet deroppe.\nDe bet\u00e6nkeligheder, jeg i sin tid n\u00e6-\nrede mod stykkets opf\u00f8relse, er for l\u00e6nge si -\ndem forsvundne. Af mange begn er jeg kom-\nmen til den overbevisning at almenlreden\naru har f\u00e5tt sine \u00f8gne opladte for den sand -\nMed at dette arbejde i sin indersten id\u00e9 hviler\np\u00e5 et ubedinget meralsk grundlag, og brad\nstykkets hele kunstneriske struktuve ang\u00e5r,", + "objects": [ + { + "confidence": 0.68, + "polygon": [ + [ + 264, + 118 + ], + [ + 410, + 118 + ], + [ + 410, + 185 + ], + [ + 264, + 185 + ] + ], + "text": "Oslo", + "text_confidence": 0.8 + }, + ... + "attention_gif": "dan_humu_page/predict/example_line.gif" +} +``` + +<img src="../../assets/example_line_polygon.gif" > -- GitLab