From d6cf517da5fd4fc155a183d4803a03ceacd3b4cc Mon Sep 17 00:00:00 2001 From: Martin Maarand <maarand@teklia.com> Date: Fri, 29 Oct 2021 15:17:17 +0000 Subject: [PATCH] Raise an exception if multiple transcriptions from the same text_line --- kaldi_data_generator/image_utils.py | 4 +--- kaldi_data_generator/main.py | 26 ++++++++++++++++++++++++++ kaldi_data_generator/utils.py | 26 +++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/kaldi_data_generator/image_utils.py b/kaldi_data_generator/image_utils.py index f7550a8..73d5eef 100644 --- a/kaldi_data_generator/image_utils.py +++ b/kaldi_data_generator/image_utils.py @@ -169,6 +169,4 @@ def resize_transcription_data( orig_polygon, max_width, max_height, scale_x, scale_y_top, scale_y_bottom ) - return TranscriptionData( - trans.element_id, resized_polygon, trans.text, trans.rotation_class - ) + return TranscriptionData.copy_replace_polygon(trans, resized_polygon) diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py index 0914fc9..2a61077 100644 --- a/kaldi_data_generator/main.py +++ b/kaldi_data_generator/main.py @@ -4,8 +4,10 @@ import argparse import os import random +from collections import Counter from enum import Enum from pathlib import Path +from typing import List import cv2 import numpy as np @@ -188,6 +190,26 @@ class HTRDataGenerator: ) raise e + def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]): + if not lines: + return + + line_elem_counter = Counter([trans.element_id for trans in lines]) + most_common = line_elem_counter.most_common(10) + if most_common[0][-1] > 1: + logger.error("Line elements have multiple transcriptions! Showing top 10:") + logger.error(f"{most_common}") + raise ValueError(f"Multiple transcriptions: {most_common[0]}") + + worker_version_counter = Counter([trans.worker_version_id for trans in lines]) + if len(worker_version_counter) > 1: + logger.warning( + f"There are transcriptions from multiple worker versions on this page: {page_id}:" + ) + logger.warning( + f"Top 10 worker versions: {worker_version_counter.most_common(10)}" + ) + def get_transcriptions(self, page_id: str, accepted_zones): lines = [] try: @@ -222,10 +244,14 @@ class HTRDataGenerator: element_id=res["element"]["id"], polygon=polygon, text=text, + trans_id=res["id"], + worker_version_id=res["worker_version_id"], ) lines.append(trans_data) + self._validate_transcriptions(page_id, lines) + if self.should_rotate: classes_by_elem = self.get_children_classes(page_id) diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py index 671912c..59a4761 100644 --- a/kaldi_data_generator/utils.py +++ b/kaldi_data_generator/utils.py @@ -17,10 +17,20 @@ BoundingBox = NamedTuple( class TranscriptionData: - def __init__(self, element_id, polygon, text, rotation_class=None): + def __init__( + self, + element_id, + polygon, + text, + trans_id, + worker_version_id, + rotation_class=None, + ): self.element_id = element_id self.polygon = np.asarray(polygon).clip(0) self.text = text + self.trans_id = trans_id + self.worker_version_id = worker_version_id self.rotation_class = rotation_class self.rect = BoundingBox._make(cv2.boundingRect(self.polygon)) @@ -37,6 +47,20 @@ class TranscriptionData: def __repr__(self): return str(vars(self)) + @classmethod + def copy_replace_polygon(cls, trans: "TranscriptionData", new_polygon): + """ + Class method to keep the change logic inside the class - less likely to forget to update. + """ + return TranscriptionData( + trans.element_id, + new_polygon, + trans.text, + trans.trans_id, + trans.worker_version_id, + trans.rotation_class, + ) + def write_file(file_name, content): with open(file_name, "w") as f: -- GitLab