Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
Showing
with 251 additions and 32 deletions
......@@ -44,6 +44,7 @@ repos:
rev: 0.7.16
hooks:
- id: mdformat
exclude: tests/data/analyze
# Optionally add plugins
additional_dependencies:
- mdformat-mkdocs[recommended]
......@@ -76,7 +76,7 @@ To apply DAN to an image, one needs to first add a few imports and to load an im
```python
import cv2
from dan.predict import DAN
from dan.ocr.predict import DAN
image = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB)
```
......
......@@ -3,8 +3,7 @@ import argparse
import errno
from dan.datasets import add_dataset_parser
from dan.ocr import add_train_parser
from dan.predict import add_predict_parser
from dan.ocr import add_predict_parser, add_train_parser
def get_parser():
......
......@@ -3,6 +3,7 @@
Preprocess datasets for training.
"""
from dan.datasets.analyze import add_analyze_parser
from dan.datasets.extract import add_extract_parser
from dan.datasets.format import add_format_parser
......@@ -17,3 +18,4 @@ def add_dataset_parser(subcommands) -> None:
add_extract_parser(subcommands)
add_format_parser(subcommands)
add_analyze_parser(subcommands)
# -*- coding: utf-8 -*-
"""
Analyze dataset and display statistics in markdown format.
"""
import json
from pathlib import Path
from typing import Dict
import yaml
from dan.datasets.analyze.statistics import run
def read_yaml(yaml_path: str) -> Dict:
"""
Read YAML tokens file
"""
filename = Path(yaml_path)
assert filename.exists()
return yaml.safe_load(filename.read_text())
def read_json(json_path: str) -> Dict:
"""
Read labels JSON file
"""
filename = Path(json_path)
assert filename.exists()
return json.loads(filename.read_text())
def add_analyze_parser(subcommands) -> None:
parser = subcommands.add_parser(
"analyze",
description=__doc__,
help=__doc__,
)
parser.add_argument(
"--labels",
type=read_json,
help="Path to the formatted labels in JSON format.",
required=True,
)
parser.add_argument(
"--tokens",
type=read_yaml,
help="Path to the tokens YAML file.",
required=False,
)
parser.add_argument(
"--output-file",
dest="output",
type=Path,
help="The statistics will be saved to this file in Markdown format.",
required=True,
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
from collections import Counter, defaultdict
from operator import itemgetter
from pathlib import Path
from typing import Dict, List, Optional
import imagesize
import numpy as np
from mdutils.mdutils import MdUtils
from prettytable import MARKDOWN, PrettyTable
from dan import logger
METRIC_COLUMN = "Metric"
def create_table(
data: Dict,
count: bool = False,
total: bool = True,
):
"""
Each keys will be made into a column
We compute min, max, mean, median, total by default.
Total can be disabled. Count (length) computation can be enabled.
"""
statistics = PrettyTable(field_names=[METRIC_COLUMN, *data.keys()])
statistics.align.update({METRIC_COLUMN: "l"})
statistics.set_style(MARKDOWN)
operations = []
if count:
operations.append(("Count", len))
operations.extend(
[
("Min", np.min),
("Max", np.max),
("Mean", np.mean),
("Median", np.median),
]
)
if total:
operations.append(("Total", np.sum))
statistics.add_rows(
[
[col_name, *list(map(operator, data.values()))]
for col_name, operator in operations
]
)
return statistics
class Statistics:
HEADERS = {
"Images": "Images statistics",
"Labels": "Labels statistics",
"Chars": "Characters statistics",
"Tokens": "NER tokens statistics",
}
def __init__(self, filename: str) -> None:
self.document = MdUtils(file_name=filename, title="Statistics")
def _write_section(self, table: PrettyTable, title: str, level: int = 2):
"""
Write the new section in the file.
<title with appropriate level>
<table>
"""
self.document.new_header(level=level, title=title, add_table_of_contents="n")
self.document.write("\n")
logger.info(f"{title}\n\n{table}\n")
self.document.write(table.get_string())
self.document.write("\n")
def create_image_statistics(self, images: List[str]):
"""
Compute statistics on image sizes and write them to file.
"""
shapes = list(map(imagesize.get, images))
widths, heights = zip(*shapes)
self._write_section(
table=create_table(
data={"Width": widths, "Height": heights}, count=True, total=False
),
title=Statistics.HEADERS["Images"],
)
def create_label_statistics(self, labels: List[str]):
"""
Compute statistics on text labels and write them to file.
"""
char_counter = Counter()
data = defaultdict(list)
for text in labels:
char_counter.update(text)
data["Chars"].append(len(text))
data["Words"].append(len(text.split()))
data["Lines"].append(len(text.split("\n")))
self._write_section(
table=create_table(data=data),
title=Statistics.HEADERS["Labels"],
)
self.create_character_occurrences_statistics(char_counter)
def create_character_occurrences_statistics(self, char_counter: Counter):
"""
Compute statistics on the character distribution and write them to file.
"""
char_occurrences = PrettyTable(
field_names=["Character", "Occurrence"],
)
char_occurrences.align.update({"Character": "l", "Occurrence": "r"})
char_occurrences.set_style(MARKDOWN)
char_occurrences.add_rows(list(char_counter.most_common()))
self._write_section(
table=char_occurrences, title=Statistics.HEADERS["Chars"], level=3
)
def create_ner_statistics(self, labels: List[str], ner_tokens: Dict) -> str:
"""
Compute statistics on ner tokens presence.
"""
entity_counter = defaultdict(list)
for text in labels:
for ner_label, token in ner_tokens.items():
entity_counter[ner_label].append(text.count(token["start"]))
self._write_section(
table=create_table(data=entity_counter),
title=Statistics.HEADERS["Tokens"],
level=3,
)
def run(self, labels: Dict, tokens: Optional[Dict]):
# Iterate over each split
for split_name, split_data in labels.items():
self.document.new_header(level=1, title=split_name.capitalize())
# Image statistics
# Path to the images are the key of the dict
self.create_image_statistics(images=split_data.keys())
# The text is actually under the "text" key of the values
labels = list(map(itemgetter("text"), split_data.values()))
# Text statistics
self.create_label_statistics(labels=labels)
if tokens is not None:
self.create_ner_statistics(labels=labels, ner_tokens=tokens)
self.document.create_md_file()
def run(labels: Dict, tokens: Optional[Dict], output: Path) -> None:
"""
Compute and save a dataset statistics.
"""
Statistics(filename=str(output)).run(labels=labels, tokens=tokens)
......@@ -3,7 +3,8 @@
Train a new DAN model.
"""
from dan.ocr.document import add_document_parser
from dan.ocr.predict import add_predict_parser # noqa
from dan.ocr.train import run
def add_train_parser(subcommands) -> None:
......@@ -12,6 +13,5 @@ def add_train_parser(subcommands) -> None:
description=__doc__,
help=__doc__,
)
subcommands = parser.add_subparsers(metavar="subcommand")
add_document_parser(subcommands)
parser.set_defaults(func=run)
File moved
# -*- coding: utf-8 -*-
"""
Train a DAN model at document level.
"""
from dan.ocr.document.train import run
def add_document_parser(subcommands) -> None:
parser = subcommands.add_parser(
"document",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
File moved
File moved
File moved
File moved
......@@ -8,8 +8,8 @@ import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from dan.manager.dataset import OCRDataset
from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
from dan.ocr.manager.dataset import OCRDataset
from dan.ocr.transforms import get_augmentation_transforms, get_preprocessing_transforms
from dan.utils import pad_images, pad_sequences_1D
......
......@@ -17,10 +17,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.manager.metrics import MetricManager
from dan.manager.ocr import OCRDatasetManager
from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.schedulers import DropoutScheduler
from dan.ocr.manager.metrics import MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.ocr.schedulers import DropoutScheduler
from dan.utils import fix_ddp_layers_names, ind_to_token
if MLFLOW_AVAILABLE:
......
File moved
......@@ -5,7 +5,7 @@ Predict on an image using a trained DAN model.
import pathlib
from dan.predict.prediction import run
from dan.ocr.predict.prediction import run
def add_predict_parser(subcommands) -> None:
......
File moved
......@@ -12,15 +12,15 @@ import yaml
from dan import logger
from dan.datasets.extract.utils import parse_tokens
from dan.decoder import GlobalHTADecoder
from dan.encoder import FCN_Encoder
from dan.predict.attention import (
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.predict.attention import (
get_predicted_polygons_with_confidence,
parse_delimiters,
plot_attention,
split_text_and_confidences,
)
from dan.transforms import get_preprocessing_transforms
from dan.ocr.transforms import get_preprocessing_transforms
from dan.utils import ind_to_token, list_to_batches, pad_images, read_image
......
File moved