diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 336da0e422bdef987a889062afd7f124d63047a0..ebd6f2234b02e1bfe52910fdb1f535de53198307 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -477,7 +477,7 @@ class CorpusDataset(CorpusACLMixin, ListCreateAPIView): """ Update a dataset. - Requires a **contributor** access to the dataset's corpus. A dataset's state can only be updated by a Ponos task. + Requires a **contributor** access to the dataset's corpus. """ ), ), @@ -486,7 +486,7 @@ class CorpusDataset(CorpusACLMixin, ListCreateAPIView): """ Partially update a dataset. - Requires a **contributor** access to the dataset's corpus. A dataset's state can only be updated by a Ponos task. + Requires a **contributor** access to the dataset's corpus. """ ) ), diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index faf60898dfe6853559cc3375c513d0cca203f6eb..4eab7266f049a4ad0dddd5eebbd39846b483b283 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -419,7 +419,20 @@ class DatasetLightSerializer(serializers.ModelSerializer): class DatasetSerializer(DatasetLightSerializer): - state = EnumField(DatasetState, required=False) + state = EnumField( + DatasetState, + required=False, + help_text=dedent(""" + State can only be updated through Ponos task authentication from a process containing this dataset + and is limited to these transitions: + + Open → Building → Complete + ↕ + Error + + When updated to Complete, the dataset is automatically linked to the task used for authentication. + """), + ) corpus = serializers.HiddenField(default=_corpus_from_context) creator = serializers.CharField( source='creator.display_name', @@ -427,16 +440,52 @@ class DatasetSerializer(DatasetLightSerializer): help_text='Display name of the user who created the dataset.', ) - class Meta(DatasetLightSerializer.Meta): - fields = ('id', 'name', 'description', 'sets', 'state', 'corpus_id', 'corpus', 'creator', 'task_id', 'created', 'updated') - read_only_fields = ('id', 'corpus_id', 'corpus', 'creator', 'task_id', 'created', 'updated') + def validate_state(self, state): + """ + Dataset's state update is limited to these transitions: + Open → Building → Complete + ↕ + Error + """ + if not isinstance(self.instance, Dataset) or state == self.instance.state: + return state + + transitions = { + DatasetState.Open: (DatasetState.Building,), + DatasetState.Building: (DatasetState.Complete, DatasetState.Error), + DatasetState.Error: (DatasetState.Building,) + } + if state not in transitions.get(self.instance.state, ()): + raise ValidationError(f'Transition from {self.instance.state} to {state} is not allowed.') + return state def validate(self, data): + data = super().validate(data) + if not isinstance(self.instance, Dataset): + return data + + state = data.get('state') + if state is None or state == self.instance.state: + return data + + # Dataset's state update requires a Ponos task authentication request = self.context.get('request') - # Only Ponos tasks can update a dataset's state - if request and not isinstance(request.auth, Task) and data.get('state'): - del data['state'] - return super().validate(data) + if request and not isinstance(request.auth, Task): + raise ValidationError({ + 'state': ['Ponos task authentication is required to update the state of a Dataset.'] + }) + # Dataset's state update is only allowed on tasks of Dataset processes, that have this dataset included + if not request.auth.process.datasets.filter(id=self.instance.id).exists(): + raise ValidationError({'state': ['A task can only update the state of one of the datasets of its process.']}) + # Link a completed dataset to the current task which generated its artifacts + if state == DatasetState.Complete: + data['task_id'] = request.auth.id + + return data + + class Meta(DatasetLightSerializer.Meta): + fields = ('id', 'name', 'description', 'sets', 'state', 'corpus_id', 'corpus', 'creator', 'task_id', 'created', 'updated') + read_only_fields = ('id', 'corpus_id', 'corpus', 'creator', 'task_id', 'created', 'updated') class DatasetElementSerializer(serializers.ModelSerializer): diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 799d2ef956c87a91df17789970a0a25ab8113212..ed7276ed603c2a20a024754c2a5a57e222e5a51d 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -30,9 +30,8 @@ class TestDatasetsAPI(FixtureAPITestCase): cls.write_user = User.objects.get(email='user2@user.fr') cls.dataset = Dataset.objects.get(name='First Dataset') cls.dataset2 = Dataset.objects.get(name='Second Dataset') + cls.process.datasets.set((cls.dataset, cls.dataset2)) cls.private_dataset = Dataset.objects.create(name="Private Dataset", description="Dead Sea Scrolls", corpus=cls.private_corpus, creator=cls.dataset_creator) - cls.dataset.task = cls.task - cls.dataset.save() cls.vol = cls.corpus.elements.get(name='Volume 1') cls.page1 = cls.corpus.elements.get(name='Volume 1, page 1r') cls.page2 = cls.corpus.elements.get(name='Volume 1, page 1v') @@ -452,6 +451,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_update(self): self.client.force_login(self.user) + self.assertIsNone(self.dataset.task_id) with self.assertNumQueries(8): response = self.client.put( reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}), @@ -468,6 +468,7 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(self.dataset.name, 'Shin Seiki Evangelion') self.assertEqual(self.dataset.description, 'Omedeto!') self.assertListEqual(self.dataset.sets, ['unit-01', 'unit-00', 'unit-02']) + self.assertIsNone(self.dataset.task_id) def test_update_sets_length(self): self.client.force_login(self.user) @@ -552,9 +553,11 @@ class TestDatasetsAPI(FixtureAPITestCase): 'sets': ['Set names must be unique.'] }) - def test_update_not_ponos_state_ignored(self): + def test_update_state_requires_ponos_auth(self): self.client.force_login(self.user) - with self.assertNumQueries(8): + self.dataset.state = DatasetState.Building + self.dataset.save() + with self.assertNumQueries(7): response = self.client.put( reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}), data={ @@ -565,12 +568,15 @@ class TestDatasetsAPI(FixtureAPITestCase): }, format='json' ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.dataset.refresh_from_db() - self.assertEqual(self.dataset.state, DatasetState.Open) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + 'state': ['Ponos task authentication is required to update the state of a Dataset.'] + }) def test_update_ponos_task_state_update(self): - with self.assertNumQueries(7): + self.dataset.state = DatasetState.Building + self.dataset.save() + with self.assertNumQueries(8): response = self.client.put( reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}), HTTP_AUTHORIZATION=f"Ponos {self.task.token}", @@ -585,6 +591,72 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.dataset.refresh_from_db() self.assertEqual(self.dataset.state, DatasetState.Complete) + self.assertEqual(self.dataset.task_id, self.task.id) + + def test_update_ponos_task_state_forbidden(self): + """Dataset's state update is limited to specific transitions""" + op, build, complete, error = [DatasetState[state] for state in ('Open', 'Building', 'Complete', 'Error')] + states = { + (op, op): True, + (op, build) : True, + (op, complete) : False, + (op, error): False, + (build, op): False, + (build, build) : True, + (build, complete) : True, + (build, error): True, + (complete, op): False, + (complete, build) : False, + (complete, complete) : True, + (complete, error): False, + (error, op): False, + (error, build) : True, + (error, complete) : False, + (error, error): True + } + for ((state_from, state_to), expected) in states.items(): + with self.subTest(state_from=state_from, state_to=state_to): + self.dataset.state = state_from + self.dataset.save() + response = self.client.put( + reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}), + HTTP_AUTHORIZATION=f"Ponos {self.task.token}", + data={ + 'name': self.dataset.name, + 'description': self.dataset.description, + 'state': state_to.value, + }, + format='json', + ) + if expected: + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.dataset.refresh_from_db() + self.assertEqual(self.dataset.state, state_to) + else: + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.json(), + {'state': [f'Transition from {state_from} to {state_to} is not allowed.']} + ) + + def test_update_ponos_task_state_requires_dataset_in_process(self): + self.process.process_datasets.all().delete() + with self.assertNumQueries(7): + response = self.client.put( + reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}), + HTTP_AUTHORIZATION=f"Ponos {self.task.token}", + data={ + 'name': 'Shin Seiki Evangelion', + 'description': 'Omedeto!', + 'sets': ['unit-01', 'unit-00', 'unit-02'], + 'state': 'building' + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'state': ['A task can only update the state of one of the datasets of its process.'] + }) def test_update_ponos_task_bad_state(self): with self.assertNumQueries(5): @@ -697,33 +769,50 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {'name': ['This field may not be blank.'], 'description': ['This field may not be blank.']}) - def test_partial_update_not_ponos_state_ignored(self): + def test_partial_update_requires_ponos_auth(self): self.client.force_login(self.user) - with self.assertNumQueries(8): + with self.assertNumQueries(7): response = self.client.patch( reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}), data={ - 'state': 'complete' + 'state': 'building' }, format='json' ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.dataset.refresh_from_db() - self.assertEqual(self.dataset.state, DatasetState.Open) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + 'state': ['Ponos task authentication is required to update the state of a Dataset.'] + }) def test_partial_update_ponos_task_state_update(self): - with self.assertNumQueries(7): + with self.assertNumQueries(8): response = self.client.patch( reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}), HTTP_AUTHORIZATION=f"Ponos {self.task.token}", data={ - 'state': 'complete' + 'state': 'building' }, format='json', ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.dataset.refresh_from_db() - self.assertEqual(self.dataset.state, DatasetState.Complete) + self.assertEqual(self.dataset.state, DatasetState.Building) + + def test_partial_update_ponos_task_state_requires_dataset_in_process(self): + self.process.process_datasets.all().delete() + with self.assertNumQueries(7): + response = self.client.patch( + reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}), + HTTP_AUTHORIZATION=f"Ponos {self.task.token}", + data={ + 'state': 'building' + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'state': ['A task can only update the state of one of the datasets of its process.'] + }) def test_partial_update_ponos_task_bad_state(self): with self.assertNumQueries(5): @@ -757,6 +846,52 @@ class TestDatasetsAPI(FixtureAPITestCase): } }) + def test_partial_update_ponos_task_state_forbidden(self): + """Dataset's state update is limited to specific transitions""" + op, build, complete, error = [DatasetState[state] for state in ('Open', 'Building', 'Complete', 'Error')] + states = { + (op, op): True, + (op, build) : True, + (op, complete) : False, + (op, error): False, + (build, op): False, + (build, build) : True, + (build, complete) : True, + (build, error): True, + (complete, op): False, + (complete, build) : False, + (complete, complete) : True, + (complete, error): False, + (error, op): False, + (error, build) : True, + (error, complete) : False, + (error, error): True + } + for ((state_from, state_to), expected) in states.items(): + with self.subTest(state_from=state_from, state_to=state_to): + self.dataset.state = state_from + self.dataset.save() + response = self.client.patch( + reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}), + HTTP_AUTHORIZATION=f"Ponos {self.task.token}", + data={ + 'name': self.dataset.name, + 'description': self.dataset.description, + 'state': state_to.value, + }, + format='json', + ) + if expected: + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.dataset.refresh_from_db() + self.assertEqual(self.dataset.state, state_to) + else: + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.json(), + {'state': [f'Transition from {state_from} to {state_to} is not allowed.']} + ) + # RetrieveDataset def test_retrieve_requires_login(self): @@ -809,12 +944,23 @@ class TestDatasetsAPI(FixtureAPITestCase): "state": "open", "sets": ["training", "test", "validation"], "creator": "Test user", - "task_id": str(self.task.id), + "task_id": None, "corpus_id": str(self.corpus.id), "created": FAKE_CREATED, "updated": FAKE_CREATED }) + def test_retrieve_task_id(self): + self.client.force_login(self.user) + self.dataset.task = self.task + self.dataset.save() + with self.assertNumQueries(5): + response = self.client.get( + reverse('api:dataset-update', kwargs={'pk': self.dataset.pk}) + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()["task_id"], str(self.task.id)) + # DestroyDataset def test_delete_requires_login(self):