utils.py 3.93 KiB
"""Utils functions."""
import logging
from operator import attrgetter
from pathlib import Path
from bio_parser.parse.document import Document
from bio_parser.parse.exceptions import ForbiddenEntityName, InvalidFile
logger = logging.getLogger(__name__)
def check_complete(labels: list[Path], predictions: list[Path]):
"""Check that each label BIO file has a corresponding prediction BIO file and each prediction BIO file has a corresponding label BIO file. Otherwise raise an error.
Args:
labels: List of sorted label BIO files.
predictions: List of sorted prediction BIO files.
"""
# List filenames in prediction and label directories.
label_filenames = {label.name for label in labels}
prediction_filenames = {prediction.name for prediction in predictions}
# Raise an error if there are any missing files.
if label_filenames != prediction_filenames:
messages = []
missing_label_files = prediction_filenames.difference(label_filenames)
missing_pred_files = label_filenames.difference(prediction_filenames)
if len(missing_pred_files) > 0:
messages.append(f"Missing prediction files: {missing_pred_files}.")
if len(missing_label_files) > 0:
messages.append(f"Missing label files: {missing_label_files}.")
raise FileNotFoundError("\n".join(messages))
def check_valid_bio(
bio_files: list[Path], global_stat_name: str | None = None
) -> list[Document]:
"""Check that BIO files exists and are valid.
Args:
bio_files (list[Path]): List of BIO files to check
global_stat_name (str | None, optional): Forbid an entity name. Defaults to None.
Raises:
FileNotFoundError: A file could not be found.
FileNotFoundError:
Exception: Forbidden entity name used in a file.
Returns:
list[Document]: _description_
"""
parsed = []
for filename in bio_files:
# Raise an error if the document does not exist
if not filename.exists():
raise FileNotFoundError(
f"BIO file {filename} does not exist.",
)
# Raise an error if the document is not valid
try:
document = Document.from_file(filename)
except Exception as e:
raise InvalidFile(filename) from e
# Raise an error if an entity is named global_stat_name
if global_stat_name and global_stat_name in {
entity[0] for entity in document.entities
}:
raise ForbiddenEntityName(filename=filename, entity_name=global_stat_name)
parsed.append(document)
return parsed
def load_dataset(
label_dir: Path,
prediction_dir: Path,
) -> list[tuple[Document, Document]]:
"""Load BIO files for a given dataset.
Args:
label_dir (Path): Path to the label directory.
prediction_dir (Path): Path to prediction directory.
Returns:
A list of tuple containing the label and corresponding prediction Documents.
"""
sorted_labels = sorted(label_dir.glob("*.bio"), key=attrgetter("name"))
sorted_predictions = sorted(prediction_dir.glob("*.bio"), key=attrgetter("name"))
# Check if a directory is empty
if not (sorted_labels and sorted_predictions):
messages = []
if not sorted_labels:
messages.append(f"Empty label directory: {label_dir}.")
if not sorted_predictions:
messages.append(f"Empty prediction directory: {prediction_dir}.")
raise FileNotFoundError("\n".join(messages))
# Check that the dataset is complete and valid
check_complete(sorted_labels, sorted_predictions)
logger.info("Loading labels...")
labels = check_valid_bio(sorted_labels)
logger.info("Loading prediction...")
predictions = check_valid_bio(sorted_predictions)
logger.info("The dataset is complete and valid.")
# Return each label and prediction Document couple
return list(zip(labels, predictions))