Skip to content
Snippets Groups Projects
Commit c832f4de authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Yoann Schneider
Browse files

Iterate over the selected sets to list the elements

parent 2f0cf175
No related branches found
No related tags found
1 merge request!505Iterate over the selected sets to list the elements
Pipeline #162799 passed
......@@ -51,7 +51,7 @@ class DatasetMixin:
return map(
lambda result: Dataset(**result["dataset"], selected_sets=result["sets"]),
list(results),
results,
)
def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
......@@ -65,14 +65,20 @@ class DatasetMixin:
dataset, Dataset
), "dataset shouldn't be null and should be a Dataset"
results = self.api_client.paginate("ListDatasetElements", id=dataset.id)
if dataset.sets == dataset.selected_sets:
results = self.api_client.paginate("ListDatasetElements", id=dataset.id)
else:
results = iter(
element
for selected_set in dataset.selected_sets
for element in self.api_client.paginate(
"ListDatasetElements", id=dataset.id, set=selected_set
)
)
def format_result(result):
if result["set"] not in dataset.selected_sets:
return
return (result["set"], Element(**result["element"]))
return filter(None, map(format_result, list(results)))
return map(
lambda result: (result["set"], Element(**result["element"])), results
)
@unsupported_cache
def update_dataset_state(self, dataset: Dataset, state: DatasetState) -> Dataset:
......
......@@ -195,7 +195,7 @@ def test_list_dataset_elements_per_split_api_error(
):
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
status=500,
)
......@@ -211,23 +211,23 @@ def test_list_dataset_elements_per_split_api_error(
# The API call is retried 5 times
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
),
]
......@@ -235,110 +235,60 @@ def test_list_dataset_elements_per_split_api_error(
def test_list_dataset_elements_per_split(
responses, mock_dataset_worker, default_dataset
):
expected_results = [
{
"set": "set_1",
"element": {
"id": "0000",
"type": "page",
"name": "Test",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
{
"set": "set_1",
"element": {
"id": "1111",
"type": "page",
"name": "Test 2",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
{
"set": "set_2",
"element": {
"id": "2222",
"type": "page",
"name": "Test 3",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
{
"set": "set_3",
"element": {
"id": "3333",
"type": "page",
"name": "Test 4",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
# `set_4` is not in `default_dataset.selected_sets`
{
"set": "set_4",
"element": {
"id": "4444",
"type": "page",
"name": "Test 5",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
expected_results = []
for selected_set in default_dataset.selected_sets:
index = selected_set[-1]
expected_results.append(
{
"set": selected_set,
"element": {
"id": str(index) * 4,
"type": "page",
"name": f"Test {index}",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
}
)
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true",
status=200,
json={
"count": 1,
"next": None,
"results": [expected_results[-1]],
},
},
]
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
status=200,
json={
"count": 4,
"next": None,
"results": expected_results,
},
)
)
assert list(
mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
) == [
("set_1", [expected_results[0]["element"], expected_results[1]["element"]]),
("set_2", [expected_results[2]["element"]]),
("set_3", [expected_results[3]["element"]]),
("set_1", [expected_results[0]["element"]]),
("set_2", [expected_results[1]["element"]]),
("set_3", [expected_results[2]["element"]]),
]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert len(responses.calls) == len(BASE_API_CALLS) + 3
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_2&with_count=true",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_3&with_count=true",
),
]
......@@ -360,7 +310,7 @@ def test_list_datasets_api_error(responses, mock_dataset_worker):
with pytest.raises(
Exception, match="Stopping pagination as data will be incomplete"
):
mock_dataset_worker.list_datasets()
next(mock_dataset_worker.list_datasets())
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
......
......@@ -139,12 +139,26 @@ def test_list_dataset_elements_wrong_param_dataset(mock_dataset_worker, payload,
mock_dataset_worker.list_dataset_elements(**payload)
@pytest.mark.parametrize(
"sets",
[
["set_1"],
["set_1", "set_2", "set_3"],
["set_1", "set_2", "set_3", "set_4"],
],
)
def test_list_dataset_elements_api_error(
responses, mock_dataset_worker, default_dataset
responses, mock_dataset_worker, sets, default_dataset
):
default_dataset.selected_sets = sets
query_params = (
"?with_count=true"
if sets == default_dataset.sets
else "?set=set_1&with_count=true"
)
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
status=500,
)
......@@ -160,122 +174,107 @@ def test_list_dataset_elements_api_error(
# The API call is retried 5 times
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
),
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
),
]
@pytest.mark.parametrize(
"sets",
[
["set_1"],
["set_1", "set_2", "set_3"],
["set_1", "set_2", "set_3", "set_4"],
],
)
def test_list_dataset_elements(
responses,
mock_dataset_worker,
sets,
default_dataset,
):
expected_results = [
{
"set": "set_1",
"element": {
"id": "0000",
"type": "page",
"name": "Test",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
{
"set": "set_1",
"element": {
"id": "1111",
"type": "page",
"name": "Test 2",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
{
"set": "set_2",
"element": {
"id": "2222",
"type": "page",
"name": "Test 3",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
{
"set": "set_3",
"element": {
"id": "3333",
"type": "page",
"name": "Test 4",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
default_dataset.selected_sets = sets
dataset_elements = []
for split in default_dataset.sets:
index = split[-1]
dataset_elements.append(
{
"set": split,
"element": {
"id": str(index) * 4,
"type": "page",
"name": f"Test {index}",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
}
)
if split == "set_1":
dataset_elements.append({**dataset_elements[-1]})
dataset_elements[-1]["element"]["name"] = f"Test {index} (bis)"
# All sets are selected, we call the unfiltered endpoint once
if default_dataset.sets == default_dataset.selected_sets:
expected_results = dataset_elements
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
status=200,
json={
"count": len(expected_results),
"next": None,
"results": expected_results,
},
},
]
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
status=200,
json={
"count": 4,
"next": None,
"results": expected_results
# `set_4` is not in `default_dataset.selected_sets`
+ [
{
"set": "set_4",
"element": {
"id": "4444",
"type": "page",
"name": "Test 5",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
}
],
},
)
)
expected_calls = [
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true"
]
# Not all sets are selected, we call the filtered endpoint multiple times, once per set
else:
expected_results, expected_calls = [], []
for selected_set in default_dataset.selected_sets:
partial_results = [
element
for element in dataset_elements
if element["set"] == selected_set
]
expected_results += partial_results
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true",
status=200,
json={
"count": len(partial_results),
"next": None,
"results": partial_results,
},
)
expected_calls += [
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true"
]
for idx, element in enumerate(
mock_dataset_worker.list_dataset_elements(dataset=default_dataset)
......@@ -285,15 +284,10 @@ def test_list_dataset_elements(
expected_results[idx]["element"],
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert len(responses.calls) == len(BASE_API_CALLS) + len(expected_calls)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
),
]
] == BASE_API_CALLS + [("GET", expected_call) for expected_call in expected_calls]
@pytest.mark.parametrize(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment