Skip to content
Snippets Groups Projects
simara_formatter.py 3.94 KiB
from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
import os
import numpy as np
from Datasets.dataset_formatters.utils_dataset import natural_sort
from PIL import Image
import xml.etree.ElementTree as ET
import re
from tqdm import tqdm

# Layout string to token
SEM_MATCHING_TOKENS_STR = {
    "INTITULE": "",
    "DATE": "",
    "COTE_SERIE": "",
    "ANALYSE_COMPL": "",
    "PRECISIONS_SUR_COTE": "",
    "COTE_ARTICLE": ""
}

# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {
    "": "",
    "": "",
    "": "",
    "": "",
    "": "",
    "": ""
}

class SimaraDatasetFormatter(OCRDatasetFormatter):
    def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=True):
        super(SimaraDatasetFormatter, self).__init__("simara", level, "_sem" if sem_token else "", set_names)

        self.dpi = dpi
        self.sem_token = sem_token
        self.map_datasets_files.update({
            "simara": {
                # (1,050 for train, 100 for validation and 100 for test)
                "page": {
                    "format_function": self.format_simara_page,
                },
            }
        })
        self.matching_tokens_str = SEM_MATCHING_TOKENS_STR
        self.matching_tokens = SEM_MATCHING_TOKENS

    def preformat_simara_page(self):
        """
        Extract all information from dataset and correct some annotations
        """
        dataset = {
            "train": list(),
            "valid": list(),
            "test": list()
        }
        img_folder_path = os.path.join("Datasets", "raw", "simara", "images")
        labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels")
        sem_labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels_sem")
        train_files = [
            os.path.join(labels_folder_path, 'train', name)
            for name in os.listdir(os.path.join(sem_labels_folder_path, 'train'))]
        valid_files = [
            os.path.join(labels_folder_path, 'valid', name)
            for name in os.listdir(os.path.join(sem_labels_folder_path, 'valid'))]
        test_files = [
            os.path.join(labels_folder_path, 'test', name)
            for name in os.listdir(os.path.join(sem_labels_folder_path, 'test'))]
        for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]):
            for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)):
                with open(label_file, 'r') as f:
                    text = f.read()
                with open(label_file.replace('labels', 'labels_sem'), 'r') as f:
                    sem_text = f.read()
                dataset[set_name].append({
                    "img_path": os.path.join(
                        img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')),
                    "label": text,
                    "sem_label": sem_text,
                })
        print(dataset['test'], len(dataset['test']))
        return dataset

    def format_simara_page(self):
        """
        Format simara page dataset
        """
        dataset = self.preformat_simara_page()
        for set_name in self.set_names:
            fold = os.path.join(self.target_fold_path, set_name)
            for sample in tqdm(dataset[set_name], desc='Formatting '+set_name):
                new_name = sample['img_path'].split('/')[-1]
                new_img_path = os.path.join(fold, new_name)
                self.load_resize_save(sample["img_path"], new_img_path)#, 300, self.dpi)
                page = {
                    "text": sample["label"] if not self.sem_token else sample["sem_label"],
                }
                self.charset = self.charset.union(set(page["text"]))
                self.gt[set_name][new_name] = page


if __name__ == "__main__":

    SimaraDatasetFormatter("page", sem_token=True).format()
    #SimaraDatasetFormatter("page", sem_token=False).format()