Skip to content
Snippets Groups Projects
Commit 1cbc0e3d authored by Manon Blanco's avatar Manon Blanco Committed by Yoann Schneider
Browse files

Support new dataset API

parent f0d051f4
Branches
Tags
1 merge request!498Support new dataset API
Pipeline #160731 passed
......@@ -49,7 +49,10 @@ class DatasetMixin:
"ListProcessDatasets", id=self.process_information["id"]
)
return map(Dataset, list(results))
return map(
lambda result: Dataset(**result["dataset"], selected_sets=result["sets"]),
list(results),
)
def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
"""
......@@ -65,9 +68,11 @@ class DatasetMixin:
results = self.api_client.paginate("ListDatasetElements", id=dataset.id)
def format_result(result):
if result["set"] not in dataset.selected_sets:
return
return (result["set"], Element(**result["element"]))
return map(format_result, list(results))
return filter(None, map(format_result, list(results)))
@unsupported_cache
def update_dataset_state(self, dataset: Dataset, state: DatasetState) -> Dataset:
......
......@@ -601,14 +601,15 @@ def default_dataset():
"id": "dataset_id",
"name": "My dataset",
"description": "A super dataset built by me",
"sets": ["set_1", "set_2", "set_3"],
"sets": ["set_1", "set_2", "set_3", "set_4"],
"state": DatasetState.Open.value,
"corpus_id": "corpus_id",
"creator": "creator@teklia.com",
"task_id": "11111111-1111-1111-1111-111111111111",
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
}
},
selected_sets=["set_1", "set_2", "set_3"],
)
......
......@@ -296,6 +296,22 @@ def test_list_dataset_elements_per_split(
"worker_run_id": None,
},
},
# `set_4` is not in `default_dataset.selected_sets`
{
"set": "set_4",
"element": {
"id": "4444",
"type": "page",
"name": "Test 5",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
]
responses.add(
responses.GET,
......@@ -362,6 +378,8 @@ def test_list_datasets_api_error(responses, mock_dataset_worker):
def test_list_datasets(responses, mock_dataset_worker):
expected_results = [
{
"id": "process_dataset_1",
"dataset": {
"id": "dataset_1",
"name": "Dataset 1",
"description": "My first great dataset",
......@@ -371,7 +389,11 @@ def test_list_datasets(responses, mock_dataset_worker):
"creator": "test@teklia.com",
"task_id": "task_id_1",
},
"sets": ["test"],
},
{
"id": "process_dataset_2",
"dataset": {
"id": "dataset_2",
"name": "Dataset 2",
"description": "My second great dataset",
......@@ -381,7 +403,11 @@ def test_list_datasets(responses, mock_dataset_worker):
"creator": "test@teklia.com",
"task_id": "task_id_2",
},
"sets": ["train", "val"],
},
{
"id": "process_dataset_3",
"dataset": {
"id": "dataset_3",
"name": "Dataset 3 (TRASHME)",
"description": "My third dataset, in error",
......@@ -391,6 +417,8 @@ def test_list_datasets(responses, mock_dataset_worker):
"creator": "test@teklia.com",
"task_id": "task_id_3",
},
"sets": ["random set"],
},
]
responses.add(
responses.GET,
......@@ -403,7 +431,11 @@ def test_list_datasets(responses, mock_dataset_worker):
},
)
assert list(mock_dataset_worker.list_datasets()) == expected_results
for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
assert dataset == {
**expected_results[idx]["dataset"],
"selected_sets": expected_results[idx]["sets"],
}
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
......
......@@ -52,6 +52,8 @@ def test_list_process_datasets(
):
expected_results = [
{
"id": "process_dataset_1",
"dataset": {
"id": "dataset_1",
"name": "Dataset 1",
"description": "My first great dataset",
......@@ -61,7 +63,11 @@ def test_list_process_datasets(
"creator": "test@teklia.com",
"task_id": "task_id_1",
},
"sets": ["test"],
},
{
"id": "process_dataset_2",
"dataset": {
"id": "dataset_2",
"name": "Dataset 2",
"description": "My second great dataset",
......@@ -71,7 +77,11 @@ def test_list_process_datasets(
"creator": "test@teklia.com",
"task_id": "task_id_2",
},
"sets": ["train", "val"],
},
{
"id": "process_dataset_3",
"dataset": {
"id": "dataset_3",
"name": "Dataset 3 (TRASHME)",
"description": "My third dataset, in error",
......@@ -81,6 +91,8 @@ def test_list_process_datasets(
"creator": "test@teklia.com",
"task_id": "task_id_3",
},
"sets": ["random set"],
},
]
responses.add(
responses.GET,
......@@ -95,7 +107,10 @@ def test_list_process_datasets(
for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
assert isinstance(dataset, Dataset)
assert dataset == expected_results[idx]
assert dataset == {
**expected_results[idx]["dataset"],
"selected_sets": expected_results[idx]["sets"],
}
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
......@@ -240,7 +255,25 @@ def test_list_dataset_elements(
json={
"count": 4,
"next": None,
"results": expected_results,
"results": expected_results
# `set_4` is not in `default_dataset.selected_sets`
+ [
{
"set": "set_4",
"element": {
"id": "4444",
"type": "page",
"name": "Test 5",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
}
],
},
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment