Skip to content
Snippets Groups Projects
Commit d6cf517d authored by Martin Maarand's avatar Martin Maarand
Browse files

Raise an exception if multiple transcriptions from the same text_line

parent 9ab6085f
No related branches found
No related tags found
1 merge request!20Raise an exception if multiple transcriptions from the same text_line
Pipeline #74320 passed
......@@ -169,6 +169,4 @@ def resize_transcription_data(
orig_polygon, max_width, max_height, scale_x, scale_y_top, scale_y_bottom
)
return TranscriptionData(
trans.element_id, resized_polygon, trans.text, trans.rotation_class
)
return TranscriptionData.copy_replace_polygon(trans, resized_polygon)
......@@ -4,8 +4,10 @@
import argparse
import os
import random
from collections import Counter
from enum import Enum
from pathlib import Path
from typing import List
import cv2
import numpy as np
......@@ -188,6 +190,26 @@ class HTRDataGenerator:
)
raise e
def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]):
if not lines:
return
line_elem_counter = Counter([trans.element_id for trans in lines])
most_common = line_elem_counter.most_common(10)
if most_common[0][-1] > 1:
logger.error("Line elements have multiple transcriptions! Showing top 10:")
logger.error(f"{most_common}")
raise ValueError(f"Multiple transcriptions: {most_common[0]}")
worker_version_counter = Counter([trans.worker_version_id for trans in lines])
if len(worker_version_counter) > 1:
logger.warning(
f"There are transcriptions from multiple worker versions on this page: {page_id}:"
)
logger.warning(
f"Top 10 worker versions: {worker_version_counter.most_common(10)}"
)
def get_transcriptions(self, page_id: str, accepted_zones):
lines = []
try:
......@@ -222,10 +244,14 @@ class HTRDataGenerator:
element_id=res["element"]["id"],
polygon=polygon,
text=text,
trans_id=res["id"],
worker_version_id=res["worker_version_id"],
)
lines.append(trans_data)
self._validate_transcriptions(page_id, lines)
if self.should_rotate:
classes_by_elem = self.get_children_classes(page_id)
......
......@@ -17,10 +17,20 @@ BoundingBox = NamedTuple(
class TranscriptionData:
def __init__(self, element_id, polygon, text, rotation_class=None):
def __init__(
self,
element_id,
polygon,
text,
trans_id,
worker_version_id,
rotation_class=None,
):
self.element_id = element_id
self.polygon = np.asarray(polygon).clip(0)
self.text = text
self.trans_id = trans_id
self.worker_version_id = worker_version_id
self.rotation_class = rotation_class
self.rect = BoundingBox._make(cv2.boundingRect(self.polygon))
......@@ -37,6 +47,20 @@ class TranscriptionData:
def __repr__(self):
return str(vars(self))
@classmethod
def copy_replace_polygon(cls, trans: "TranscriptionData", new_polygon):
"""
Class method to keep the change logic inside the class - less likely to forget to update.
"""
return TranscriptionData(
trans.element_id,
new_polygon,
trans.text,
trans.trans_id,
trans.worker_version_id,
trans.rotation_class,
)
def write_file(file_name, content):
with open(file_name, "w") as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment