diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py index 4709b5f941bd2c4668cf0dd8d92d075f74e64395..8f3ca71583b59b4c280ae20315677fa22ce42c7f 100644 --- a/kaldi_data_generator/main.py +++ b/kaldi_data_generator/main.py @@ -307,6 +307,7 @@ class HTRDataGenerator: trans_data = TranscriptionData( element_id=res["element"]["id"], + element_name=res["element"]["name"], polygon=polygon, text=text, trans_id=res["id"], @@ -380,18 +381,31 @@ class HTRDataGenerator: } def _save_line_image( - self, page_id, i, line_img, manifest_fp=None, trans: TranscriptionData = None + self, page_id, line_img, manifest_fp=None, trans: TranscriptionData = None ): + # Get line id + line_id = trans.element_id + + # Get line number from its name + line_number = trans.element_name.split("_")[-1] + if self.should_rotate: if trans.rotation_class: rotate_angle = ROTATION_CLASSES_TO_ANGLES[trans.rotation_class] line_img = rotate_and_trim(line_img, rotate_angle, WHITE) if self.format == "kraken": - - save_img(f"{self.out_line_dir}/{page_id}_{i}.png", line_img) - manifest_fp.write(f"{page_id}_{i}.png\n") + # Save image using the template {page_id}_{line_number}_{line_id} + # TODO: check if (0>3) is enough (pad line_number to 3 digits) + save_img( + f"{self.out_line_dir}/{page_id}_{line_number:0>3}_{line_id}.png", + line_img, + ) + manifest_fp.write(f"{page_id}_{line_number:0>3}_{line_id}.png\n") else: - save_img(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img) + save_img( + f"{self.out_line_img_dir}/{page_id}_{line_number:0>3}_{line_id}.jpg", + line_img, + ) def extract_lines(self, page_id: str, image_data: dict): if self.should_filter_by_class or self.should_filter_by_style: @@ -441,7 +455,7 @@ class HTRDataGenerator: # not needed for kaldi manifest_fp = None - for i, trans in enumerate(sorted_lines): + for trans in sorted_lines: extracted_img = extract( img=img, polygon=trans.polygon, @@ -452,16 +466,25 @@ class HTRDataGenerator: grayscale=self.grayscale, ) - self._save_line_image(page_id, i, extracted_img, manifest_fp, trans) + # don't enumerate, read the line number from the elements's name (e.g. line_xx) so that it matches with Arkindex + self._save_line_image(page_id, extracted_img, manifest_fp, trans) if self.format == "kraken": manifest_fp.close() - for i, trans in enumerate(sorted_lines): + for trans in sorted_lines: + line_number = trans.element_name.split("_")[-1] + line_id = trans.element_id if self.format == "kraken": - write_file(f"{self.out_line_dir}/{page_id}_{i}.gt.txt", trans.text) + write_file( + f"{self.out_line_dir}/{page_id}_{line_number:0>3}_{line_id}.gt.txt", + trans.text, + ) else: - write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", trans.text) + write_file( + f"{self.out_line_text_dir}/{page_id}_{line_number:0>3}_{line_id}.txt", + trans.text, + ) def run_selection(self): selected_elems = [e for e in self.api_client.paginate("ListSelection")] diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py index 281d56ceff6392d93cda308e3de9fd5489e81f13..0aed87a53d6edc45e24cdd9acfb2405f4e185161 100644 --- a/kaldi_data_generator/utils.py +++ b/kaldi_data_generator/utils.py @@ -26,6 +26,7 @@ class TranscriptionData: def __init__( self, element_id, + element_name, polygon, text, trans_id, @@ -33,6 +34,7 @@ class TranscriptionData: rotation_class=None, ): self.element_id = element_id + self.element_name = element_name self.polygon = np.asarray(polygon).clip(0) self.text = text self.trans_id = trans_id @@ -60,6 +62,7 @@ class TranscriptionData: """ return TranscriptionData( trans.element_id, + trans.element_name, new_polygon, trans.text, trans.trans_id,