Skip to content
Snippets Groups Projects
extract.py 18.88 KiB
# -*- coding: utf-8 -*-

import json
import logging
import pickle
import random
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from uuid import UUID

import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm

from arkindex_export import open_database
from dan.datasets.extract.db import (
    Element,
    get_elements,
    get_transcription_entities,
    get_transcriptions,
)
from dan.datasets.extract.exceptions import (
    ImageDownloadError,
    NoEndTokenError,
    NoTranscriptionError,
    ProcessingError,
    UnknownTokenInText,
)
from dan.datasets.extract.utils import (
    Tokenizer,
    download_image,
    get_bbox,
    insert_token,
    normalize_linebreaks,
    normalize_spaces,
)
from dan.utils import EntityType, LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
    BoundingBox,
    Extraction,
    polygon_to_bbox,
)

IMAGES_DIR = "images"  # Subpath to the images directory.
LANGUAGE_DIR = "language_model"  # Subpath to the language model directory.

TRAIN_NAME = "train"
SPLIT_NAMES = [TRAIN_NAME, "val", "test"]
IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg"
# IIIF 2.0 uses `full`
IIIF_FULL_SIZE = "full"

logger = logging.getLogger(__name__)


class ArkindexExtractor:
    """
    Extract data from Arkindex
    """

    def __init__(
        self,
        folders: list = [],
        element_type: List[str] = [],
        parent_element_type: str = None,
        output: Path = None,
        entity_separators: List[str] = ["\n", " "],
        unknown_token: str = "",
        tokens: Path = None,
        transcription_worker_version: Optional[Union[str, bool]] = None,
        entity_worker_version: Optional[Union[str, bool]] = None,
        max_width: Optional[int] = None,
        max_height: Optional[int] = None,
        keep_spaces: bool = False,
        image_extension: str = "",
        allow_empty: bool = False,
        subword_vocab_size: int = 1000,
    ) -> None:
        self.folders = folders
        self.element_type = element_type
        self.parent_element_type = parent_element_type
        self.output = output
        self.entity_separators = entity_separators
        self.unknown_token = unknown_token
        self.tokens = parse_tokens(tokens) if tokens else None
        self.transcription_worker_version = transcription_worker_version
        self.entity_worker_version = entity_worker_version
        self.max_width = max_width
        self.max_height = max_height
        self.image_extension = image_extension
        self.allow_empty = allow_empty
        self.mapping = LMTokenMapping()
        self.keep_spaces = keep_spaces
        self.subword_vocab_size = subword_vocab_size

        self.data: Dict = defaultdict(dict)
        self.charset = set()
        self.language_corpus = defaultdict(list)
        self.language_tokens = []
        self.language_lexicon = defaultdict(list)

        # Image download tasks to process
        self.tasks: List[Dict[str, str]] = []

    def get_iiif_size_arg(self, width: int, height: int) -> str:
        if (self.max_width is None or width <= self.max_width) and (
            self.max_height is None or height <= self.max_height
        ):
            return IIIF_FULL_SIZE

        bigger_width = self.max_width and width >= self.max_width
        bigger_height = self.max_height and height >= self.max_height

        if bigger_width and bigger_height:
            # Resize to the biggest dim to keep aspect ratio
            # Only resize width is bigger than max size
            # This ratio tells which dim needs the biggest shrinking
            ratio = width * self.max_height / (height * self.max_width)
            return f"{self.max_width}," if ratio > 1 else f",{self.max_height}"
        elif bigger_width:
            return f"{self.max_width},"
        # Only resize height is bigger than max size
        elif bigger_height:
            return f",{self.max_height}"

    def build_iiif_url(self, polygon, image_url) -> Tuple[BoundingBox, str]:
        bbox = polygon_to_bbox(json.loads(str(polygon)))
        size = self.get_iiif_size_arg(width=bbox.width, height=bbox.height)
        # Rotations are done using the lib
        return bbox, IIIF_URL.format(
            image_url=image_url, bbox=get_bbox(polygon), size=size
        )

    def _keep_char(self, char: str) -> bool:
        # Keep all text by default if no separator was given
        return not self.entity_separators or char in self.entity_separators

    def reconstruct_text(self, full_text: str, entities) -> str:
        """
        Insert tokens delimiting the start/end of each entity on the transcription.
        """
        text, text_offset = "", 0
        for entity in entities:
            # Text before entity
            text += "".join(
                filter(self._keep_char, full_text[text_offset : entity.offset])
            )

            entity_type: EntityType = self.tokens.get(entity.type)
            if not entity_type:
                logger.warning(
                    f"Label `{entity.type}` is missing in the NER configuration."
                )
            # We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends
            elif not entity_type.end and not self.entity_separators:
                raise NoEndTokenError(entity.type)

            # Entity text:
            # - with tokens if there is an entity_type
            # - without tokens if there is no entity_type but we want to keep the whole text
            if entity_type or not self.entity_separators:
                text += insert_token(
                    full_text,
                    entity_type,
                    offset=entity.offset,
                    length=entity.length,
                )
            text_offset = entity.offset + entity.length

        # Remaining text after the last entity
        text += "".join(filter(self._keep_char, full_text[text_offset:]))

        if not self.entity_separators or self.keep_spaces:
            return text

        # Add some clean up to avoid several separators between entities
        text, full_text = "", text
        for char in full_text:
            last_char = text[-1] if len(text) else ""

            # Keep the current character if there are no two consecutive separators
            if (
                char not in self.entity_separators
                or last_char not in self.entity_separators
            ):
                text += char
            # If several separators follow each other, keep only one according to the given order
            elif self.entity_separators.index(char) < self.entity_separators.index(
                last_char
            ):
                text = text[:-1] + char

        # Remove separators at the beginning and end of text
        return text.strip("".join(self.entity_separators))

    def extract_transcription(self, element: Element):
        """
        Extract the element's transcription.
        If the entities are needed, they are added to the transcription using tokens.
        """
        transcriptions = get_transcriptions(
            element.id, self.transcription_worker_version
        )
        if len(transcriptions) == 0:
            if self.allow_empty:
                return ""
            raise NoTranscriptionError(element.id)

        transcription = random.choice(transcriptions)

        if not self.tokens:
            return transcription.text.strip()

        entities = get_transcription_entities(
            transcription.id, self.entity_worker_version
        )
        return self.reconstruct_text(transcription.text, entities)

    def get_image(
        self,
        split: str,
        polygon: List[List[int]],
        image_url: str,
        destination: Path,
    ) -> None:
        """Save the element's image to the given path and applies any image operations needed.

        :param split: Dataset split this image belongs to.
        :param polygon: Polygon of the processed element.
        :param image_url: Base IIIF URL of the image.
        :param destination: Where the image should be saved.
        """
        bbox, download_url = self.build_iiif_url(polygon=polygon, image_url=image_url)
        try:
            img: Image.Image = download_image(download_url)

            # The polygon's coordinate are in the referential of the full image
            # We need to remove the offset of the bounding rectangle
            polygon = [(x - bbox.x, y - bbox.y) for x, y in polygon]

            # Normalize bbox
            bbox = BoundingBox(x=0, y=0, width=bbox.width, height=bbox.height)

            image = extract(
                img=cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR),
                polygon=np.asarray(polygon).clip(0),
                bbox=bbox,
                extraction_mode=Extraction.boundingRect,
                max_deskew_angle=45,
            )

            destination.parent.mkdir(parents=True, exist_ok=True)
            cv2.imwrite(str(destination), image)

        except Exception as e:
            raise ImageDownloadError(
                split=split, path=str(destination), url=download_url, exc=e
            )

    def format_text(self, text: str, charset: Optional[set] = None):
        if not self.keep_spaces:
            text = normalize_spaces(text)
            text = normalize_linebreaks(text)

        # Replace unknown characters by the unknown token
        if charset is not None:
            unknown_charset = set(text) - charset
            text = text.translate(
                {
                    ord(unknown_char): self.unknown_token
                    for unknown_char in unknown_charset
                }
            )
        return text.strip()

    def process_element(
        self,
        element: Element,
        split: str,
    ):
        """
        Extract an element's data and save it to disk.
        The output path is directly related to the split of the element.
        """
        text = self.extract_transcription(element)

        if self.unknown_token in text:
            raise UnknownTokenInText(element_id=element.id)

        image_path = Path(self.output, IMAGES_DIR, split, element.id).with_suffix(
            self.image_extension
        )

        # Create task for multithreading pool if image does not exist yet
        if not image_path.exists():
            self.tasks.append(
                {
                    "split": split,
                    "polygon": json.loads(str(element.polygon)),
                    "image_url": element.image.url,
                    "destination": image_path,
                }
            )

        text = self.format_text(
            text,
            # Do not replace unknown characters in train split
            charset=self.charset if split != TRAIN_NAME else None,
        )

        self.data[split][str(image_path)] = text
        self.charset = self.charset.union(set(text))

    def process_parent(
        self,
        pbar,
        parent: Element,
        split: str,
    ):
        """
        Extract data from a parent element.
        """
        base_description = (
            f"Extracting data from {parent.type} ({parent.id}) for split ({split})"
        )
        pbar.set_description(desc=base_description)
        if self.element_type == [parent.type]:
            try:
                self.process_element(parent, split)
            except ProcessingError as e:
                logger.warning(f"Skipping {parent.id}: {str(e)}")
        # Extract children elements
        else:
            children = get_elements(
                parent.id,
                self.element_type,
            )

            nb_children = children.count()
            for idx, element in enumerate(children, start=1):
                # Update description to update the children processing progress
                pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
                try:
                    self.process_element(element, split)
                except ProcessingError as e:
                    logger.warning(f"Skipping {element.id}: {str(e)}")

    def format_lm_files(self) -> None:
        """
        Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
        """
        logger.info("Preparing language resources")

        # Build LM tokens
        for token in sorted(list(self.charset)):
            assert (
                token not in self.mapping.encode.values()
            ), f"Special token {token} is reserved for language modeling."
            self.language_tokens.append(
                self.mapping.encode[token]
            ) if token in self.mapping.encode else self.language_tokens.append(token)

        self.language_tokens.append(self.mapping.ctc.encoded)
        assert all(
            [len(token) == 1 for token in self.language_lexicon]
        ), "Tokens should be single characters."

        # Build LM corpus
        train_corpus = [text.replace("\n", " ") for text in self.data["train"].values()]
        tokenizer = Tokenizer(
            train_corpus,
            outdir=self.output / "language_model",
            mapping=self.mapping,
            tokens=self.tokens,
            subword_vocab_size=self.subword_vocab_size,
        )
        self.language_corpus["characters"] = [
            tokenizer.char_tokenize(doc) for doc in train_corpus
        ]
        self.language_corpus["words"] = [
            tokenizer.word_tokenize(doc) for doc in train_corpus
        ]
        self.language_corpus["subwords"] = [
            tokenizer.subword_tokenize(doc) for doc in train_corpus
        ]

        # Build vocabulary
        word_vocabulary = set(
            [
                word
                for doc in self.language_corpus["words"]
                for word in doc.split()
                if word != ""
            ]
        )
        subword_vocabulary = set(
            [
                subword
                for doc in self.language_corpus["subwords"]
                for subword in doc.split()
                if subword != ""
            ]
        )

        # Build LM lexicon
        self.language_lexicon["characters"] = [
            f"{token} {token}" for token in self.language_tokens
        ]
        self.language_lexicon["words"] = [
            f"{word} {tokenizer.char_tokenize(word)}"
            for word in sorted(word_vocabulary)
            if word != ""
        ]
        self.language_lexicon["subwords"] = [
            f"{subword} {tokenizer.char_tokenize(subword)}"
            for subword in sorted(subword_vocabulary)
        ]

    def export(self):
        (self.output / "labels.json").write_text(
            json.dumps(
                self.data,
                sort_keys=True,
                indent=4,
            )
        )
        for level in ["characters", "words", "subwords"]:
            (self.output / "language_model" / f"corpus_{level}.txt").write_text(
                "\n".join(self.language_corpus[level])
            )
            (self.output / "language_model" / f"lexicon_{level}.txt").write_text(
                "\n".join(self.language_lexicon[level])
            )
        (self.output / "language_model" / "tokens.txt").write_text(
            "\n".join(self.language_tokens)
        )
        (self.output / "charset.pkl").write_bytes(
            pickle.dumps(sorted(list(self.charset)))
        )

    def download_images(self):
        failed_downloads = []
        with tqdm(
            desc="Downloading images", total=len(self.tasks)
        ) as pbar, ThreadPoolExecutor() as executor:

            def process_future(future: Future):
                """
                Callback function called at the end of the thread
                """
                # Update the progress bar count
                pbar.update(1)

                exc = future.exception()
                if exc is None:
                    # No error
                    return
                # If failed, tag for removal
                assert isinstance(exc, ImageDownloadError)
                # Remove transcription from labels dict
                del self.data[exc.split][exc.path]
                # Save tried URL
                failed_downloads.append((exc.url, exc.message))

            # Submit all tasks
            for task in self.tasks:
                executor.submit(self.get_image, **task).add_done_callback(
                    process_future
                )

        if failed_downloads:
            logger.error(f"Failed to download {len(failed_downloads)} image(s).")
            print(*list(map(": ".join, failed_downloads)), sep="\n")

    def run(self):
        # Iterate over the subsets to find the page images and labels.
        for folder_id, split in zip(self.folders, SPLIT_NAMES):
            with tqdm(
                get_elements(
                    folder_id,
                    [self.parent_element_type],
                ),
                desc=f"Extracting data from ({folder_id}) for split ({split})",
            ) as pbar:
                # Iterate over the pages to create splits at page level.
                for parent in pbar:
                    self.process_parent(
                        pbar=pbar,
                        parent=parent,
                        split=split,
                    )
                    # Progress bar updates
                    pbar.update()
                    pbar.refresh()

        self.download_images()
        self.format_lm_files()
        self.export()


def run(
    database: Path,
    element_type: List[str],
    parent_element_type: str,
    output: Path,
    entity_separators: List[str],
    unknown_token: str,
    tokens: Path,
    train_folder: UUID,
    val_folder: UUID,
    test_folder: UUID,
    transcription_worker_version: Optional[Union[str, bool]],
    entity_worker_version: Optional[Union[str, bool]],
    max_width: Optional[int],
    max_height: Optional[int],
    image_format: str,
    keep_spaces: bool,
    allow_empty: bool,
    subword_vocab_size: int,
):
    assert database.exists(), f"No file found @ {database}"
    open_database(path=database)

    folders = [str(train_folder), str(val_folder), str(test_folder)]

    # Create directories
    for split in SPLIT_NAMES:
        Path(output, IMAGES_DIR, split).mkdir(parents=True, exist_ok=True)
    Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True)

    ArkindexExtractor(
        folders=folders,
        element_type=element_type,
        parent_element_type=parent_element_type,
        output=output,
        entity_separators=entity_separators,
        unknown_token=unknown_token,
        tokens=tokens,
        transcription_worker_version=transcription_worker_version,
        entity_worker_version=entity_worker_version,
        max_width=max_width,
        max_height=max_height,
        keep_spaces=keep_spaces,
        image_extension=image_format,
        allow_empty=allow_empty,
        subword_vocab_size=subword_vocab_size,
    ).run()