Skip to content
Snippets Groups Projects
Verified Commit cc85e2b8 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Specify sets to evaluate

parent ff6cea59
No related branches found
No related tags found
1 merge request!480Evaluate specific sets only
This commit is part of merge request !480. Comments created here will be created in the context of that merge request.
......@@ -8,3 +8,8 @@ logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
)
TRAIN_NAME = "train"
VAL_NAME = "val"
TEST_NAME = "test"
SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME]
......@@ -18,12 +18,12 @@ import numpy as np
from PIL import Image
from tqdm import tqdm
from dan import TRAIN_NAME
from dan.datasets.download.exceptions import ImageDownloadError
from dan.datasets.download.utils import (
download_image,
get_bbox,
)
from dan.datasets.extract.arkindex import TRAIN_NAME
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
......
......@@ -14,6 +14,7 @@ from uuid import UUID
from tqdm import tqdm
from arkindex_export import Dataset, DatasetElement, Element, open_database
from dan import SPLIT_NAMES, TEST_NAME, TRAIN_NAME, VAL_NAME
from dan.datasets.extract.db import (
get_dataset_elements,
get_elements,
......@@ -32,11 +33,6 @@ from dan.datasets.extract.utils import (
)
from dan.utils import parse_tokens
TRAIN_NAME = "train"
VAL_NAME = "val"
TEST_NAME = "test"
SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME]
logger = logging.getLogger(__name__)
......
......@@ -25,6 +25,7 @@ from nerval.utils import TABLE_HEADER as NERVAL_TABLE_HEADER
from nerval.utils import print_results
from prettytable import MARKDOWN, PrettyTable
from dan import SPLIT_NAMES
from dan.bio import convert
from dan.ocr import wandb
from dan.ocr.manager.metrics import Inference
......@@ -83,6 +84,14 @@ def add_evaluate_parser(subcommands) -> None:
type=Path,
)
parser.add_argument(
"--sets",
dest="set_names",
help="Where to save evaluation results in JSON format.",
default=SPLIT_NAMES,
nargs="+",
)
parser.set_defaults(func=run)
......@@ -196,6 +205,7 @@ def eval(
nerval_threshold: float,
output_json: Path | None,
mlflow_logging: bool,
set_names: list[str],
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
......@@ -227,7 +237,7 @@ def eval(
metrics_values, all_inferences = [], {}
for dataset_name in config["dataset"]["datasets"]:
for set_name in ["train", "val", "test"]:
for set_name in set_names:
logger.info(f"Evaluating on set `{set_name}`")
metrics, inferences = model.evaluate(
"{}-{}".format(dataset_name, set_name),
......@@ -272,7 +282,12 @@ def eval(
wandb.log_artifact(artifact, local_path=output_json, name=output_json.name)
def run(config: dict, nerval_threshold: float, output_json: Path | None):
def run(
config: dict,
nerval_threshold: float,
output_json: Path | None,
set_names: list[str] = SPLIT_NAMES,
):
update_config(config)
# Start "Weights & Biases" as soon as possible
......@@ -294,8 +309,8 @@ def run(config: dict, nerval_threshold: float, output_json: Path | None):
):
mp.spawn(
eval,
args=(config, nerval_threshold, output_json, mlflow_logging),
args=(config, nerval_threshold, output_json, mlflow_logging, set_names),
nprocs=config["training"]["device"]["nb_gpu"],
)
else:
eval(0, config, nerval_threshold, output_json, mlflow_logging)
eval(0, config, nerval_threshold, output_json, mlflow_logging, set_names)
......@@ -10,6 +10,7 @@ import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm
from dan import TRAIN_NAME
from dan.datasets.utils import natural_sort
from dan.utils import check_valid_size, read_image, token_to_ind
......@@ -171,7 +172,7 @@ class OCRDataset(Dataset):
"""
image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int)
if self.set_name == "train":
if self.set_name == TRAIN_NAME:
image_reduced_shape = [max(1, t) for t in image_reduced_shape]
image_position = [
......
......@@ -10,6 +10,7 @@ import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from dan import TRAIN_NAME
from dan.ocr.manager.dataset import OCRDataset
from dan.ocr.transforms import get_augmentation_transforms, get_preprocessing_transforms
from dan.utils import pad_images, pad_sequences_1D
......@@ -66,8 +67,8 @@ class OCRDatasetManager:
Load training and validation datasets
"""
self.train_dataset = OCRDataset(
set_name="train",
paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
set_name=TRAIN_NAME,
paths_and_sets=self.get_paths_and_sets(self.params[TRAIN_NAME]["datasets"]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
......
......@@ -25,6 +25,7 @@ from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose, PILToTensor
from tqdm import tqdm
from dan import TRAIN_NAME
from dan.ocr import wandb
from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
......@@ -645,7 +646,7 @@ class GenericTrainingManager:
# init variables
nb_epochs = self.params["training"]["max_nb_epochs"]
metric_names = self.params["training"]["metrics"]["train"]
metric_names = self.params["training"]["metrics"][TRAIN_NAME]
display_values = None
# perform epochs
......@@ -669,7 +670,7 @@ class GenericTrainingManager:
self.latest_epoch
)
# init epoch metrics values
self.metric_manager["train"] = MetricManager(
self.metric_manager[TRAIN_NAME] = MetricManager(
metric_names=metric_names,
dataset_name=self.dataset_name,
tokens=self.tokens,
......@@ -684,7 +685,7 @@ class GenericTrainingManager:
# train on batch data and compute metrics
batch_values = self.train_batch(batch_data, metric_names)
batch_metrics = self.metric_manager["train"].compute_metrics(
batch_metrics = self.metric_manager[TRAIN_NAME].compute_metrics(
batch_values, metric_names
)
batch_metrics["names"] = batch_data["names"]
......@@ -716,14 +717,20 @@ class GenericTrainingManager:
self.dropout_scheduler.update_dropout_rate()
# Add batch metrics values to epoch metrics values
self.metric_manager["train"].update_metrics(batch_metrics)
display_values = self.metric_manager["train"].get_display_values()
self.metric_manager[TRAIN_NAME].update_metrics(batch_metrics)
display_values = self.metric_manager[
TRAIN_NAME
].get_display_values()
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]) * self.nb_workers)
# Log MLflow metrics
logging_metrics(
display_values, "train", num_epoch, mlflow_logging, self.is_master
display_values,
TRAIN_NAME,
num_epoch,
mlflow_logging,
self.is_master,
)
if self.is_master:
......@@ -731,7 +738,7 @@ class GenericTrainingManager:
for key in display_values:
self.writer.add_scalar(
"train/{}_{}".format(
self.params["dataset"]["train"]["name"], key
self.params["dataset"][TRAIN_NAME]["name"], key
),
display_values[key],
num_epoch,
......
......@@ -24,7 +24,7 @@ from arkindex_export import (
WorkerVersion,
database,
)
from dan.datasets.extract.arkindex import TEST_NAME, TRAIN_NAME, VAL_NAME
from dan import TEST_NAME, TRAIN_NAME, VAL_NAME
from tests import FIXTURES
......
......@@ -8,7 +8,7 @@ from operator import itemgetter
import pytest
from arkindex_export import Dataset, DatasetElement, Element
from dan.datasets.extract.arkindex import TRAIN_NAME
from dan import TRAIN_NAME
from dan.datasets.extract.db import (
get_dataset_elements,
get_elements,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment