diff --git a/bio_parser/parse/document.py b/bio_parser/parse/document.py index 6be99e9a6fcbccd914034cc4ac5b85a2604b2b8e..f06e24c6da0aea2c7cbe8dcea19cfa7c0af0cc91 100644 --- a/bio_parser/parse/document.py +++ b/bio_parser/parse/document.py @@ -247,6 +247,8 @@ class Span: class Document: """Representation of a BIO document.""" + filename: str + """Document filename""" bio_repr: str """Full BIO representation of the Document""" tokens: list[Token] = field(default_factory=list) @@ -396,4 +398,4 @@ class Document: Returns: Document: Parsed document """ - return Document(filepath.read_text()) + return Document(filepath.stem, filepath.read_text()) diff --git a/bio_parser/utils.py b/bio_parser/utils.py index a5f485cb51097d800e7555cb8e2e9670b413f256..ba12461dbf513ee18f5b76d22f66a7b06013b6ac 100644 --- a/bio_parser/utils.py +++ b/bio_parser/utils.py @@ -84,7 +84,7 @@ def load_dataset( prediction_dir (Path): Path to prediction directory. Returns: - list[tuple[Document, Document]]: A list of tuple containing the label and corresponding prediction Documents. + A list of tuple containing the label and corresponding prediction Documents. """ sorted_labels = sorted(label_dir.glob("*.bio"), key=attrgetter("name")) sorted_predictions = sorted(prediction_dir.glob("*.bio"), key=attrgetter("name")) diff --git a/tests/parse/test_document.py b/tests/parse/test_document.py index 1329516ff2525cd1068db2eb12ae27dd8d7c91da..565f2a1d5dd5489ce78a66f1bf62057e90cf35de 100644 --- a/tests/parse/test_document.py +++ b/tests/parse/test_document.py @@ -193,7 +193,7 @@ def test_consecutive_entities(): # BIO FILE # dog B-Animal # cat B-Animal - document = Document("dog B-Animal\ncat B-Animal") + document = Document("test", "dog B-Animal\ncat B-Animal") assert document.chars == ["d", "o", "g", " ", "c", "a", "t"] diff --git a/tests/parse/test_validate.py b/tests/parse/test_validate.py index 160ea7b253e4f4987d1e2f9f725ab0a4425916db..b4e40765158dc18452d2d44f8ac21f0a4fc73421 100644 --- a/tests/parse/test_validate.py +++ b/tests/parse/test_validate.py @@ -16,6 +16,7 @@ def test_valid(): # Check content of JSON assert json.loads(output.read_text()) == { "bio_repr": "San B-GPE\nFrancisco I-GPE\nconsiders O\nbanning B-VERB\nsidewalk O\ndelivery O\nrobots O", + "filename": "valid", "tokens": [ {"idx": 0, "text": "San B-GPE"}, {"idx": 1, "text": "Francisco I-GPE"}, diff --git a/tests/test_utils.py b/tests/test_utils.py index 1269b25d0f9fe1cb919b31bd9f4b453679e58e02..c56d97220c90b5d1bc0f3b69b0ebedeabdf638d6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest from bio_parser.parse.document import Document from bio_parser.utils import check_complete, check_valid_bio, load_dataset @@ -129,16 +131,25 @@ def test_check_complete_raise(labels, predictions, message): check_complete(labels, predictions) -def test_load_dataset(): - label_dir = DATA / "labels" - prediction_dir = DATA / "predictions" +@pytest.mark.parametrize( + ("label_dir", "prediction_dir", "expected_names"), + [ + ( + DATA / "labels", + DATA / "predictions", + [Path("example_0.bio"), Path("example_1.bio"), Path("example_2.bio")], + ), + ], +) +def test_load_dataset(label_dir, prediction_dir, expected_names): documents = load_dataset(label_dir, prediction_dir) - for i in range(3): - filename = f"example_{i}.bio" + for i, expected_name in enumerate(expected_names): assert documents[i] == ( - Document.from_file(label_dir / filename), - Document.from_file(prediction_dir / filename), + Document.from_file(label_dir / expected_name), + Document.from_file(prediction_dir / expected_name), ) + assert documents[i][0].filename == expected_name.stem + assert documents[i][1].filename == expected_name.stem @pytest.mark.parametrize(