From d6cf517da5fd4fc155a183d4803a03ceacd3b4cc Mon Sep 17 00:00:00 2001
From: Martin Maarand <maarand@teklia.com>
Date: Fri, 29 Oct 2021 15:17:17 +0000
Subject: [PATCH] Raise an exception  if multiple transcriptions from the same
 text_line

---
 kaldi_data_generator/image_utils.py |  4 +---
 kaldi_data_generator/main.py        | 26 ++++++++++++++++++++++++++
 kaldi_data_generator/utils.py       | 26 +++++++++++++++++++++++++-
 3 files changed, 52 insertions(+), 4 deletions(-)

diff --git a/kaldi_data_generator/image_utils.py b/kaldi_data_generator/image_utils.py
index f7550a8..73d5eef 100644
--- a/kaldi_data_generator/image_utils.py
+++ b/kaldi_data_generator/image_utils.py
@@ -169,6 +169,4 @@ def resize_transcription_data(
         orig_polygon, max_width, max_height, scale_x, scale_y_top, scale_y_bottom
     )
 
-    return TranscriptionData(
-        trans.element_id, resized_polygon, trans.text, trans.rotation_class
-    )
+    return TranscriptionData.copy_replace_polygon(trans, resized_polygon)
diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py
index 0914fc9..2a61077 100644
--- a/kaldi_data_generator/main.py
+++ b/kaldi_data_generator/main.py
@@ -4,8 +4,10 @@
 import argparse
 import os
 import random
+from collections import Counter
 from enum import Enum
 from pathlib import Path
+from typing import List
 
 import cv2
 import numpy as np
@@ -188,6 +190,26 @@ class HTRDataGenerator:
             )
             raise e
 
+    def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]):
+        if not lines:
+            return
+
+        line_elem_counter = Counter([trans.element_id for trans in lines])
+        most_common = line_elem_counter.most_common(10)
+        if most_common[0][-1] > 1:
+            logger.error("Line elements have multiple transcriptions! Showing top 10:")
+            logger.error(f"{most_common}")
+            raise ValueError(f"Multiple transcriptions: {most_common[0]}")
+
+        worker_version_counter = Counter([trans.worker_version_id for trans in lines])
+        if len(worker_version_counter) > 1:
+            logger.warning(
+                f"There are transcriptions from multiple worker versions on this page: {page_id}:"
+            )
+            logger.warning(
+                f"Top 10 worker versions: {worker_version_counter.most_common(10)}"
+            )
+
     def get_transcriptions(self, page_id: str, accepted_zones):
         lines = []
         try:
@@ -222,10 +244,14 @@ class HTRDataGenerator:
                     element_id=res["element"]["id"],
                     polygon=polygon,
                     text=text,
+                    trans_id=res["id"],
+                    worker_version_id=res["worker_version_id"],
                 )
 
                 lines.append(trans_data)
 
+            self._validate_transcriptions(page_id, lines)
+
             if self.should_rotate:
                 classes_by_elem = self.get_children_classes(page_id)
 
diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py
index 671912c..59a4761 100644
--- a/kaldi_data_generator/utils.py
+++ b/kaldi_data_generator/utils.py
@@ -17,10 +17,20 @@ BoundingBox = NamedTuple(
 
 
 class TranscriptionData:
-    def __init__(self, element_id, polygon, text, rotation_class=None):
+    def __init__(
+        self,
+        element_id,
+        polygon,
+        text,
+        trans_id,
+        worker_version_id,
+        rotation_class=None,
+    ):
         self.element_id = element_id
         self.polygon = np.asarray(polygon).clip(0)
         self.text = text
+        self.trans_id = trans_id
+        self.worker_version_id = worker_version_id
         self.rotation_class = rotation_class
 
         self.rect = BoundingBox._make(cv2.boundingRect(self.polygon))
@@ -37,6 +47,20 @@ class TranscriptionData:
     def __repr__(self):
         return str(vars(self))
 
+    @classmethod
+    def copy_replace_polygon(cls, trans: "TranscriptionData", new_polygon):
+        """
+        Class method to keep the change logic inside the class - less likely to forget to update.
+        """
+        return TranscriptionData(
+            trans.element_id,
+            new_polygon,
+            trans.text,
+            trans.trans_id,
+            trans.worker_version_id,
+            trans.rotation_class,
+        )
+
 
 def write_file(file_name, content):
     with open(file_name, "w") as f:
-- 
GitLab