diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index 01251d788aae229bd08e155a240d1380cfe7bab6..a946ca9677eea48b9b760b96d640ede7e39071c5 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -1232,7 +1232,7 @@ class WorkerRunDetails(CorpusACLMixin, RetrieveUpdateDestroyAPIView): return WorkerRun.objects \ .filter(dataimport__corpus_id__isnull=False) \ .using('default') \ - .select_related('version__worker__type', 'dataimport__workflow', 'dataimport__corpus') + .select_related('version__worker__type', 'dataimport__workflow', 'dataimport__corpus', 'version__revision__repo') def get_serializer_context(self): context = super().get_serializer_context() diff --git a/arkindex/dataimport/serializers/imports.py b/arkindex/dataimport/serializers/imports.py index 38e7560dde1372504be28138ceba822de6e773b7..1a1aa13f20b827817db1b53023a6c459da61660d 100644 --- a/arkindex/dataimport/serializers/imports.py +++ b/arkindex/dataimport/serializers/imports.py @@ -11,7 +11,11 @@ from arkindex.dataimport.models import ( WorkerRun, ) from arkindex.dataimport.serializers.git import RevisionSerializer -from arkindex.dataimport.serializers.workers import WorkerLightSerializer +from arkindex.dataimport.serializers.workers import ( + WorkerConfigurationSerializer, + WorkerLightSerializer, + WorkerVersionSerializer, +) from arkindex.documents.models import Corpus, Element, ElementType from arkindex.documents.serializers.elements import ElementSlimSerializer from arkindex.project.mixins import ProcessACLMixin @@ -413,6 +417,17 @@ class WorkerRunEditSerializer(WorkerRunSerializer): worker_version_id = serializers.UUIDField(read_only=True, source='version_id') model_version_id = serializers.PrimaryKeyRelatedField(queryset=ModelVersion.objects.all(), required=False, allow_null=True) + worker_version = WorkerVersionSerializer(read_only=True, source='version') + configuration = WorkerConfigurationSerializer(read_only=True) + process = DataImportLightSerializer(read_only=True, source='dataimport') + + class Meta(WorkerRunSerializer.Meta): + fields = WorkerRunSerializer.Meta.fields + ( + 'worker_version', + 'configuration', + 'process', + ) + def validate_model_version_id(self, model_version): model_usage = self.context.get('model_usage') if not model_usage: @@ -424,6 +439,11 @@ class WorkerRunEditSerializer(WorkerRunSerializer): raise ValidationError('You do not have access to this model version.') return model_version + def validate(self, data): + # Store configuration if provided + data['configuration'] = data.get('configuration_id') + return data + class ImportTranskribusSerializer(serializers.Serializer): """ diff --git a/arkindex/dataimport/tests/test_workerruns.py b/arkindex/dataimport/tests/test_workerruns.py index d240212608aee59d9fd3afa6fed28c65a07f5fe8..670cabe5044866d4bea94e02741e116c0b9e292d 100644 --- a/arkindex/dataimport/tests/test_workerruns.py +++ b/arkindex/dataimport/tests/test_workerruns.py @@ -2,6 +2,7 @@ import uuid from django.urls import reverse from rest_framework import status +from rest_framework.serializers import DateTimeField from arkindex.dataimport.models import DataImportMode, WorkerRun, WorkerVersion from arkindex.dataimport.utils import get_default_farm_id @@ -24,6 +25,77 @@ tasks: ARTIFACT_ID = uuid.uuid4() +def _deserialize_worker_configuration(worker_config): + return { + 'id': str(worker_config.id), + 'name': worker_config.name, + 'configuration': worker_config.configuration + } + + +def _deserialize_worker(worker): + return { + 'id': str(worker.id), + 'name': worker.name, + 'type': str(worker.type), + 'slug': worker.slug + } + + +def _deserialize_revision(revision): + return { + 'id': str(revision.id), + 'commit_url': revision.commit_url, + 'author': revision.author, + 'created': DateTimeField().to_representation(value=revision.created), + 'hash': revision.hash, + 'message': revision.message, + 'refs': [], + } + + +def _deserialize_worker_version(worker_version): + return { + 'id': str(worker_version.id), + 'revision': _deserialize_revision(worker_version.revision), + 'docker_image': str(worker_version.docker_image.id) if worker_version.docker_image else None, + 'docker_image_iid': str(worker_version.docker_image_iid) if worker_version.docker_image_iid else None, + 'docker_image_name': worker_version.docker_image_name, + 'configuration': worker_version.configuration, + 'state': worker_version.state.value, + 'gpu_usage': worker_version.gpu_usage.value, + 'model_usage': worker_version.model_usage, + 'worker': _deserialize_worker(worker_version.worker) + } + + +def _deserialize_worker_process(process): + return { + 'id': str(process.id), + 'name': process.name, + 'state': process.state.value, + 'mode': process.mode.value, + 'corpus': str(process.corpus.id), + 'workflow': (str(process.workflow) if process.workflow else None), + 'activity_state': process.activity_state.value, + } + + +def _deserialize_worker_run(worker_run): + return { + 'id': str(worker_run.id), + 'worker_version_id': str(worker_run.version.id), + 'dataimport_id': str(worker_run.dataimport.id), + 'parents': [str(parent_id) for parent_id in worker_run.parents], + 'model_version_id': str(worker_run.model_version_id) if worker_run.model_version_id else None, + 'worker': _deserialize_worker(worker_run.version.worker), + 'configuration_id': str(worker_run.configuration.id) if worker_run.configuration else None, + 'configuration': (_deserialize_worker_configuration(worker_run.configuration) if worker_run.configuration else None), + 'worker_version': _deserialize_worker_version(worker_run.version), + 'process': _deserialize_worker_process(worker_run.dataimport), + } + + class TestWorkerRuns(FixtureAPITestCase): """ Test worker runs endpoints and methods @@ -275,11 +347,12 @@ class TestWorkerRuns(FixtureAPITestCase): """ self.worker_1.memberships.update(level=Role.Guest.value) self.client.force_login(self.user) - with self.assertNumQueries(6): + with self.assertNumQueries(8): response = self.client.get( reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}) ) self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), _deserialize_worker_run(self.run_1)) def test_retrieve_run_invalid_id(self): self.client.force_login(self.user) @@ -290,24 +363,10 @@ class TestWorkerRuns(FixtureAPITestCase): def test_retrieve_run(self): self.client.force_login(self.user) - with self.assertNumQueries(6): + with self.assertNumQueries(8): response = self.client.get(reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) - - self.assertDictEqual(response.json(), { - 'id': str(self.run_1.id), - 'worker_version_id': str(self.version_1.id), - 'dataimport_id': str(self.dataimport_1.id), - 'model_version_id': None, - 'parents': [], - 'worker': { - 'id': str(self.worker_1.id), - 'name': self.worker_1.name, - 'type': self.worker_1.type.slug, - 'slug': self.worker_1.slug, - }, - 'configuration_id': None, - }) + self.assertDictEqual(response.json(), _deserialize_worker_run(self.run_1)) def test_update_run_requires_login(self): rev_2 = self.repo.revisions.create( @@ -325,12 +384,13 @@ class TestWorkerRuns(FixtureAPITestCase): parents=[], ) - response = self.client.patch( - reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}), - data={ - 'parents': [str(run_2.id)], - }, format='json' - ) + with self.assertNumQueries(0): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}), + data={ + 'parents': [str(run_2.id)], + }, format='json' + ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_update_run_no_project_admin_right(self): @@ -339,9 +399,10 @@ class TestWorkerRuns(FixtureAPITestCase): """ self.corpus.memberships.filter(user=self.user).update(level=Role.Guest.value) self.client.force_login(self.user) - response = self.client.patch( - reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}) - ) + with self.assertNumQueries(5): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}) + ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.json(), {'detail': 'You do not have an admin access to the process project.'}) @@ -362,22 +423,24 @@ class TestWorkerRuns(FixtureAPITestCase): ) self.client.force_login(self.user) - response = self.client.patch( - reverse('api:worker-run-details', kwargs={'pk': '12341234-1234-1234-1234-123412341234'}), - data={ - 'parents': [str(run_2.id)], - }, format='json' - ) + with self.assertNumQueries(3): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': '12341234-1234-1234-1234-123412341234'}), + data={ + 'parents': [str(run_2.id)], + }, format='json' + ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_update_run_inexistant_parent(self): self.client.force_login(self.user) - response = self.client.patch( - reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}), - data={ - 'parents': ['12341234-1234-1234-1234-123412341234'], - }, format='json' - ) + with self.assertNumQueries(7): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}), + data={ + 'parents': ['12341234-1234-1234-1234-123412341234'], + }, format='json' + ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), [ f"Can't add or update WorkerRun {self.run_1.id} because parents field isn't properly defined. It can be either because" @@ -390,7 +453,7 @@ class TestWorkerRuns(FixtureAPITestCase): Dataimport field cannot be updated """ self.client.force_login(self.user) - with self.assertNumQueries(8): + with self.assertNumQueries(10): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}), data={ @@ -399,7 +462,8 @@ class TestWorkerRuns(FixtureAPITestCase): }, format='json' ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.json()['dataimport_id'], str(self.dataimport_1.id)) + + self.assertEqual(response.json(), _deserialize_worker_run(self.run_1)) self.run_1.refresh_from_db() self.assertEqual(self.run_1.dataimport.id, self.dataimport_1.id) @@ -409,7 +473,7 @@ class TestWorkerRuns(FixtureAPITestCase): """ self.client.force_login(self.user) dla_version = WorkerVersion.objects.get(worker__slug='dla') - with self.assertNumQueries(8): + with self.assertNumQueries(10): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}), data={ @@ -418,27 +482,14 @@ class TestWorkerRuns(FixtureAPITestCase): }, format='json' ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), { - 'id': str(self.run_1.id), - 'dataimport_id': str(self.dataimport_1.id), - 'parents': [], - 'worker_version_id': str(self.version_1.id), - 'model_version_id': None, - 'worker': { - 'id': str(self.worker_1.id), - 'name': self.worker_1.name, - 'type': self.worker_1.type.slug, - 'slug': self.worker_1.slug, - }, - 'configuration_id': None, - }) + self.assertEqual(response.json(), _deserialize_worker_run(self.run_1)) self.run_1.refresh_from_db() self.assertNotEqual(self.run_1.version_id, dla_version.id) def test_update_run_configuration(self): self.client.force_login(self.user) self.assertEqual(self.run_1.configuration, None) - with self.assertNumQueries(9): + with self.assertNumQueries(11): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': self.run_1.id}), data={ @@ -448,16 +499,18 @@ class TestWorkerRuns(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.run_1.refresh_from_db() + self.assertEqual(response.json(), _deserialize_worker_run(self.run_1)) self.assertEqual(self.run_1.configuration.id, self.configuration_1.id) def test_update_run_invalid_configuration(self): self.client.force_login(self.user) self.assertEqual(self.run_1.configuration, None) - response = self.client.patch( - reverse('api:worker-run-details', kwargs={'pk': self.run_1.id}), - data={'configuration_id': str(self.configuration_2.id)}, - format='json' - ) + with self.assertNumQueries(7): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': self.run_1.id}), + data={'configuration_id': str(self.configuration_2.id)}, + format='json' + ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {'configuration_id': ['The configuration must be part of the same worker.']}) @@ -602,7 +655,8 @@ class TestWorkerRuns(FixtureAPITestCase): version=version_with_model, parents=[], ) - with self.assertNumQueries(12): + self.assertEqual(run_2.model_version_id, None) + with self.assertNumQueries(14): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), data={ @@ -610,20 +664,9 @@ class TestWorkerRuns(FixtureAPITestCase): }, format='json' ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), { - 'id': str(run_2.id), - 'worker_version_id': str(version_with_model.id), - 'dataimport_id': str(self.dataimport_1.id), - 'model_version_id': str(self.model_version_1.id), - 'parents': [], - 'worker': { - 'id': str(self.worker_1.id), - 'name': self.worker_1.name, - 'type': self.worker_1.type.slug, - 'slug': self.worker_1.slug, - }, - 'configuration_id': None, - }) + run_2.refresh_from_db() + self.assertEqual(response.json(), _deserialize_worker_run(run_2)) + self.assertEqual(run_2.model_version_id, self.model_version_1.id) def test_update_run(self): rev_2 = self.repo.revisions.create( @@ -640,9 +683,8 @@ class TestWorkerRuns(FixtureAPITestCase): version=version_2, parents=[], ) - self.client.force_login(self.user) - with self.assertNumQueries(9): + with self.assertNumQueries(11): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}), data={ @@ -650,20 +692,8 @@ class TestWorkerRuns(FixtureAPITestCase): }, format='json' ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), { - 'id': str(self.run_1.id), - 'worker_version_id': str(self.version_1.id), - 'dataimport_id': str(self.dataimport_1.id), - 'model_version_id': None, - 'parents': [str(run_2.id)], - 'worker': { - 'id': str(self.worker_1.id), - 'name': self.worker_1.name, - 'type': self.worker_1.type.slug, - 'slug': self.worker_1.slug, - }, - 'configuration_id': None, - }) + self.run_1.refresh_from_db() + self.assertDictEqual(response.json(), _deserialize_worker_run(self.run_1)) def test_delete_run_requires_login(self): response = self.client.delete(