diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index abe6a03b62fec1913521af177d6bb2e1517349aa..c4b6cd7cf6d9b1884f3ab69eef606d9ef2609844 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -471,7 +471,7 @@ class StartProcess(CorpusACLMixin, APIView): qs = DataImport.objects \ .select_related('corpus') \ .filter(corpus_id__isnull=False) \ - .prefetch_related('versions') + .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related('version'))) process = get_object_or_404(qs, pk=self.kwargs['pk']) @@ -488,10 +488,18 @@ class StartProcess(CorpusACLMixin, APIView): data = serializer.validated_data errors = defaultdict(list) - # Use process.versions.all() to access the (prefetched) versions to avoid new SQL queries - if len(list(process.versions.all())) > 0: - if data.get('use_gpu') and (not any(item.gpu_usage != WorkerVersionGPUUsage.Disabled for item in process.versions.all())): + # Use process.worker_runs.all() to access the (prefetched) worker_runs to avoid new SQL queries + # The related version have also been prefetched + if len(list(process.worker_runs.all())) > 0: + if data.get('use_gpu') and (not any(item.version.gpu_usage != WorkerVersionGPUUsage.Disabled for item in process.worker_runs.all())): errors['use_gpu'] = 'The process is configured to use GPU, but does not include any workers that support GPU usage.' + # Check if a worker run has no model version but version.model_usage = True + missing_model_versions = [] + for worker_run in process.worker_runs.all(): + if worker_run.version.model_usage and worker_run.model_version_id is None: + missing_model_versions.append(worker_run.version.worker.name) + if len(missing_model_versions) > 0: + errors['model_version'] = f"The following workers require a model version and none was set : {missing_model_versions}" else: if data.get('worker_activity'): errors['worker_activity'] = 'The process must have workers attached to handle their activity.' @@ -634,7 +642,7 @@ class RepositoryList(RepositoryACLMixin, ListAPIView): return self.readable_repositories \ .annotate(authorized_users=Count('memberships')) \ .select_related('credentials') \ - .prefetch_related('corpora', 'workers') \ + .prefetch_related('corpora', 'workers__type') \ .order_by('url') @@ -1218,6 +1226,11 @@ class WorkerRunDetails(CorpusACLMixin, RetrieveUpdateDestroyAPIView): permission_classes = (IsVerified, ) serializer_class = WorkerRunEditSerializer + def get_object(self): + if not hasattr(self, '_worker_run'): + self._worker_run = super().get_object() + return self._worker_run + def get_queryset(self): # Use default DB to avoid a race condition checking process workflow return WorkerRun.objects \ @@ -1225,6 +1238,12 @@ class WorkerRunDetails(CorpusACLMixin, RetrieveUpdateDestroyAPIView): .using('default') \ .select_related('version__worker__type', 'dataimport__workflow', 'dataimport__corpus') + def get_serializer_context(self): + context = super().get_serializer_context() + if 'model_version_id' in self.request.data: + context['model_usage'] = self.get_object().version.model_usage + return context + def check_object_permissions(self, request, worker_run): if not self.has_admin_access(worker_run.dataimport.corpus): raise PermissionDenied(detail='You do not have an admin access to the process project.') diff --git a/arkindex/dataimport/migrations/0048_workerrun_model_version.py b/arkindex/dataimport/migrations/0048_workerrun_model_version.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5e0f7db2eadbc6c6f7f6555f086cf9d2605ea6 --- /dev/null +++ b/arkindex/dataimport/migrations/0048_workerrun_model_version.py @@ -0,0 +1,20 @@ +# Generated by Django 4.0.2 on 2022-05-03 16:14 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('training', '0004_modelversion_archive_hash'), + ('dataimport', '0047_workerversion_model_usage'), + ] + + operations = [ + migrations.AddField( + model_name='workerrun', + name='model_version', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='worker_runs', to='training.modelversion'), + ), + ] diff --git a/arkindex/dataimport/models.py b/arkindex/dataimport/models.py index 85aefc4b8f7584c7c787e54bb5de24c892f41a5b..b7e156d82b84891b3b028a439b3d51bf29e3f1f6 100644 --- a/arkindex/dataimport/models.py +++ b/arkindex/dataimport/models.py @@ -652,6 +652,7 @@ class WorkerRun(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) dataimport = models.ForeignKey('dataimport.DataImport', on_delete=models.CASCADE, related_name='worker_runs') version = models.ForeignKey('dataimport.WorkerVersion', on_delete=models.CASCADE, related_name='worker_runs') + model_version = models.ForeignKey('training.ModelVersion', on_delete=models.CASCADE, related_name='worker_runs', null=True) parents = ArrayField(models.UUIDField()) configuration = models.ForeignKey( WorkerConfiguration, diff --git a/arkindex/dataimport/serializers/imports.py b/arkindex/dataimport/serializers/imports.py index 9f887e65c38ec86057c9579a03b52480707295c9..38e7560dde1372504be28138ceba822de6e773b7 100644 --- a/arkindex/dataimport/serializers/imports.py +++ b/arkindex/dataimport/serializers/imports.py @@ -16,6 +16,7 @@ from arkindex.documents.models import Corpus, Element, ElementType from arkindex.documents.serializers.elements import ElementSlimSerializer from arkindex.project.mixins import ProcessACLMixin from arkindex.project.serializer_fields import EnumField, LinearRingField +from arkindex.training.models import Model, ModelVersion from arkindex.users.models import Role from arkindex.users.utils import get_max_level from ponos.models import Farm, State @@ -393,11 +394,12 @@ class WorkerRunSerializer(serializers.ModelSerializer): class Meta: model = WorkerRun - read_only_fields = ('id', 'worker', 'dataimport_id') + read_only_fields = ('id', 'worker', 'dataimport_id', 'model_version_id') fields = ( 'id', 'parents', 'worker_version_id', + 'model_version_id', 'dataimport_id', 'worker', 'configuration_id', @@ -409,6 +411,18 @@ class WorkerRunEditSerializer(WorkerRunSerializer): Serialize a worker run with only parents as editable field """ worker_version_id = serializers.UUIDField(read_only=True, source='version_id') + model_version_id = serializers.PrimaryKeyRelatedField(queryset=ModelVersion.objects.all(), required=False, allow_null=True) + + def validate_model_version_id(self, model_version): + model_usage = self.context.get('model_usage') + if not model_usage: + raise ValidationError("This worker version does not support model usage.") + model = Model.objects.get(id=model_version.model_id) + # Check access rights on model version + access_level = get_max_level(self.context["request"].user, model) + if not access_level or access_level < Role.Contributor.value: + raise ValidationError('You do not have access to this model version.') + return model_version class ImportTranskribusSerializer(serializers.Serializer): diff --git a/arkindex/dataimport/tests/test_imports.py b/arkindex/dataimport/tests/test_imports.py index 9972b5bdc52bd85a2bfac7be0c342a50fa6b6f6f..fb64ea8f608d47fb003ee5dd282dca0f16661ecc 100644 --- a/arkindex/dataimport/tests/test_imports.py +++ b/arkindex/dataimport/tests/test_imports.py @@ -22,6 +22,7 @@ from arkindex.dataimport.models import ( from arkindex.dataimport.utils import get_default_farm_id from arkindex.documents.models import Corpus, ElementType from arkindex.project.tests import FixtureAPITestCase +from arkindex.training.models import Model, ModelVersion, ModelVersionState from arkindex.users.models import Role, User from ponos.models import Farm, State, Task, Workflow @@ -66,6 +67,11 @@ class TestImports(FixtureAPITestCase): cls.recognizer = WorkerVersion.objects.get(worker__slug='reco') cls.version_gpu = WorkerVersion.objects.get(worker__slug='worker-gpu') cls.workers_process = cls.corpus.imports.get(mode=DataImportMode.Workers) + cls.version_with_model = WorkerVersion.objects.get(worker__slug='generic') + + # Create a model and a model version + cls.model_1 = Model.objects.create(name='My model') + cls.model_version_1 = ModelVersion.objects.create(model=cls.model_1, state=ModelVersionState.Available, size=8, hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb') # Create multiple processes the user can access cls.user2 = User.objects.create_user('user2@test.test', display_name='Process creator') @@ -1155,6 +1161,36 @@ class TestImports(FixtureAPITestCase): {'__all__': ['Only a DataImport with Workers mode and not already launched can be started later on']} ) + def test_start_process_without_required_model(self): + dataimport2 = self.corpus.imports.create(creator=self.user, mode=DataImportMode.Workers) + dataimport2.worker_runs.create(version=self.version_with_model, parents=[], configuration=None) + dataimport2.save() + + self.client.force_login(self.user) + response = self.client.post( + reverse('api:process-start', kwargs={'pk': str(dataimport2.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual( + response.json(), + {"model_version": f"The following workers require a model version and none was set : {[self.version_with_model.worker.name]}"} + ) + + def test_start_process_with_required_model(self): + dataimport2 = self.corpus.imports.create(creator=self.user, mode=DataImportMode.Workers) + run = dataimport2.worker_runs.create(version=self.version_with_model, parents=[], configuration=None) + run.model_version = self.model_version_1 + run.save() + self.assertIsNone(dataimport2.workflow) + + self.client.force_login(self.user) + with self.assertNumQueries(27): + response = self.client.post( + reverse('api:process-start', kwargs={'pk': str(dataimport2.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()['id'], str(dataimport2.id)) + def test_start_process_empty(self): dataimport2 = self.corpus.imports.create(creator=self.user, mode=DataImportMode.Workers) self.assertIsNone(dataimport2.workflow) @@ -1307,7 +1343,7 @@ class TestImports(FixtureAPITestCase): element_type=self.corpus.types.get(slug='page') ) self.client.force_login(self.user) - with self.assertNumQueries(7): + with self.assertNumQueries(6): response = self.client.post( reverse('api:process-start', kwargs={'pk': str(process.id)}), {'use_cache': 'true', 'worker_activity': 'true', 'use_gpu': 'true'} @@ -1338,8 +1374,12 @@ class TestImports(FixtureAPITestCase): self.assertFalse(process.use_cache) self.assertEqual(process.activity_state, ActivityState.Disabled) - versions_mock.all.return_value = [self.version_gpu] - versions_mock.exists.return_value = True + # Add a worker run to this process + run_mock = MagicMock() + run_mock.version = self.version_gpu + run_mock.build_task_recipe.return_value = {'image': ''} + # Cheat to mock a query on process worker runs + worker_runs_mock.all.return_value = [run_mock] self.client.force_login(self.user) with self.assertNumQueries(20): diff --git a/arkindex/dataimport/tests/test_repos.py b/arkindex/dataimport/tests/test_repos.py index 1e09361614d15597329bccb5f0b00f50c83d7120..cdd08897a67775b3350733e02a1793f48719c1b2 100644 --- a/arkindex/dataimport/tests/test_repos.py +++ b/arkindex/dataimport/tests/test_repos.py @@ -115,7 +115,7 @@ class TestRepositories(FixtureTestCase): self.iiif_repo.memberships.create(user=self.user, level=Role.Admin.value) self.worker_repo.memberships.create(user=self.user, level=Role.Guest.value) self.client.force_login(self.user) - with self.assertNumQueries(10): + with self.assertNumQueries(8): response = self.client.get(reverse('api:repository-list')) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() @@ -128,7 +128,7 @@ class TestRepositories(FixtureTestCase): Multiple repository serialization should not include the git_clone_url field. """ self.client.force_login(self.internal_user) - with self.assertNumQueries(10): + with self.assertNumQueries(8): response = self.client.get(reverse('api:repository-list')) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() @@ -141,7 +141,7 @@ class TestRepositories(FixtureTestCase): """ self.iiif_repo.corpora.create() self.client.force_login(self.internal_user) - with self.assertNumQueries(10): + with self.assertNumQueries(8): response = self.client.get(reverse('api:repository-list')) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() diff --git a/arkindex/dataimport/tests/test_workerruns.py b/arkindex/dataimport/tests/test_workerruns.py index 2046d3f8fcdd471407a346d9aec499f4bde8808b..d240212608aee59d9fd3afa6fed28c65a07f5fe8 100644 --- a/arkindex/dataimport/tests/test_workerruns.py +++ b/arkindex/dataimport/tests/test_workerruns.py @@ -6,6 +6,7 @@ from rest_framework import status from arkindex.dataimport.models import DataImportMode, WorkerRun, WorkerVersion from arkindex.dataimport.utils import get_default_farm_id from arkindex.project.tests import FixtureAPITestCase +from arkindex.training.models import Model, ModelVersion, ModelVersionState from arkindex.users.models import Role from ponos.models import State, Workflow @@ -43,6 +44,11 @@ class TestWorkerRuns(FixtureAPITestCase): # Add an execution access right on the worker cls.worker_1.memberships.create(user=cls.user, level=Role.Contributor.value) + # Model and Model version setup + cls.model_1 = Model.objects.create(name='My model') + cls.model_1.memberships.create(user=cls.user, level=Role.Contributor.value) + cls.model_version_1 = ModelVersion.objects.create(model=cls.model_1, state=ModelVersionState.Available, size=8, hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb') + def test_runs_list_requires_login(self): response = self.client.get(reverse('api:worker-run-list', kwargs={'pk': str(self.dataimport_1.id)})) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -70,6 +76,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker_version_id': str(self.version_1.id), 'dataimport_id': str(self.dataimport_1.id), 'parents': [], + 'model_version_id': None, 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, @@ -183,6 +190,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker_version_id': str(self.version_1.id), 'dataimport_id': str(self.dataimport_2.id), 'parents': [], + 'model_version_id': None, 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, @@ -238,6 +246,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker_version_id': str(self.version_1.id), 'dataimport_id': str(self.dataimport_2.id), 'parents': [], + 'model_version_id': None, 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, @@ -289,6 +298,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'id': str(self.run_1.id), 'worker_version_id': str(self.version_1.id), 'dataimport_id': str(self.dataimport_1.id), + 'model_version_id': None, 'parents': [], 'worker': { 'id': str(self.worker_1.id), @@ -413,6 +423,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'dataimport_id': str(self.dataimport_1.id), 'parents': [], 'worker_version_id': str(self.version_1.id), + 'model_version_id': None, 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, @@ -469,6 +480,151 @@ class TestWorkerRuns(FixtureAPITestCase): '__all__': ['Cannot update a WorkerRun on a DataImport that has already started'] }) + def test_update_run_model_version_not_allowed(self): + """ + The model_version UUID is not allowed when the related version doesn't allow model_usage + """ + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version_no_model = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=False + ) + run_2 = self.dataimport_1.worker_runs.create( + version=version_no_model, + parents=[], + ) + with self.assertNumQueries(7): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), + data={ + 'model_version_id': str(self.model_version_1.id), + }, format='json' + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + 'model_version_id': ['This worker version does not support model usage.'] + }) + + def test_update_run_unknown_model_version(self): + """ + Cannot use a model version id that doesn't exist + """ + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version_no_model = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=True + ) + run_2 = self.dataimport_1.worker_runs.create( + version=version_no_model, + parents=[], + ) + random_model_version_uuid = str(uuid.uuid4()) + with self.assertNumQueries(7): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), + data={ + 'model_version_id': random_model_version_uuid, + }, format='json' + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + 'model_version_id': [f'Invalid pk "{random_model_version_uuid}" - object does not exist.'] + }) + + def test_update_run_model_version_no_access(self): + """ + Cannot update a worker run with a model_version UUID, when you don't have access to the model version + """ + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version_no_model = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=True + ) + run_2 = self.dataimport_1.worker_runs.create( + version=version_no_model, + parents=[], + ) + + # Create a model version, the user has no access to + model_no_access = Model.objects.create(name='Secret model') + model_version_no_access = ModelVersion.objects.create(model=model_no_access, state=ModelVersionState.Available, size=8, hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb') + + with self.assertNumQueries(10): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), + data={ + 'model_version_id': str(model_version_no_access.id), + }, format='json' + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + 'model_version_id': ['You do not have access to this model version.'] + }) + + def test_update_run_model_version(self): + """ + Update the worker run by adding a model_version with a worker version that supports it + """ + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version_with_model = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=True + ) + run_2 = self.dataimport_1.worker_runs.create( + version=version_with_model, + parents=[], + ) + with self.assertNumQueries(12): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), + data={ + 'model_version_id': str(self.model_version_1.id), + }, format='json' + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + 'id': str(run_2.id), + 'worker_version_id': str(version_with_model.id), + 'dataimport_id': str(self.dataimport_1.id), + 'model_version_id': str(self.model_version_1.id), + 'parents': [], + 'worker': { + 'id': str(self.worker_1.id), + 'name': self.worker_1.name, + 'type': self.worker_1.type.slug, + 'slug': self.worker_1.slug, + }, + 'configuration_id': None, + }) + def test_update_run(self): rev_2 = self.repo.revisions.create( hash='2', @@ -498,6 +654,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'id': str(self.run_1.id), 'worker_version_id': str(self.version_1.id), 'dataimport_id': str(self.dataimport_1.id), + 'model_version_id': None, 'parents': [str(run_2.id)], 'worker': { 'id': str(self.worker_1.id), diff --git a/arkindex/dataimport/tests/test_workers.py b/arkindex/dataimport/tests/test_workers.py index 2cda8dc6d90effddeeaf23932b0990ef83449238..18e769cb242f3daf191d648f5dc22db17a383d0f 100644 --- a/arkindex/dataimport/tests/test_workers.py +++ b/arkindex/dataimport/tests/test_workers.py @@ -47,6 +47,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): cls.worker_1 = Worker.objects.get(slug='reco') cls.worker_2 = Worker.objects.get(slug='dla') cls.worker_3 = Worker.objects.get(slug='worker-gpu') + cls.worker_4 = Worker.objects.get(slug='generic') cls.rev2 = Revision.objects.create( hash='1234', message='commit message', @@ -76,7 +77,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.get(reverse('api:workers-list')) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { - 'count': 3, + 'count': 4, 'next': None, 'number': 1, 'previous': None, @@ -87,7 +88,15 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'name': 'Document layout analyser', 'slug': 'dla', 'type': 'dla', - }, { + }, + { + 'id': str(self.worker_4.id), + 'repository_id': str(self.repo.id), + 'name': 'Generic worker with a Model', + 'slug': 'generic', + 'type': 'recognizer', + }, + { 'id': str(self.worker_1.id), 'repository_id': str(self.repo.id), 'name': 'Recognizer', diff --git a/arkindex/documents/fixtures/data.json b/arkindex/documents/fixtures/data.json index 98e90c6c18bc39382ded8e7c1abd25d407db8c05..1a390c897b8e9ce1cadd9f8f7a3eb899321be9de 100644 --- a/arkindex/documents/fixtures/data.json +++ b/arkindex/documents/fixtures/data.json @@ -71,79 +71,6 @@ "author": "me" } }, -{ - "model": "dataimport.workertype", - "pk": "c852529d-1852-444f-b3c5-f8677bc0069b", - "fields": { - "display_name": "Document layout analyser", - "slug": "dla", - "created": "2022-04-07T01:23:45.678Z", - "updated": "2022-04-07T01:23:45.678Z" - } -}, -{ - "model": "dataimport.workertype", - "pk": "433795b1-c1a7-4fb1-ad88-76c62065510d", - "fields": { - "display_name": "Worker requiring a GPU", - "slug": "worker", - "created": "2022-04-07T01:23:45.678Z", - "updated": "2022-04-07T01:23:45.678Z" - } -}, -{ - "model": "dataimport.workertype", - "pk": "25d7116d-d377-4e13-be6c-b7718f5c73de", - "fields": { - "display_name": "Recognizer", - "slug": "recognizer", - "created": "2022-04-07T01:23:45.678Z", - "updated": "2022-04-07T01:23:45.678Z" - } -}, -{ - "model": "dataimport.workertype", - "pk": "a18e08c0-6b10-4a47-bc54-d08fdc9ed22d", - "fields": { - "display_name": "Classifier", - "slug": "classifier", - "created": "2022-04-07T01:23:45.678Z", - "updated": "2022-04-07T01:23:45.678Z" - } -}, -{ - "model": "dataimport.worker", - "pk": "1fa33954-2e36-4e09-a193-39cdde9efd29", - "fields": { - "name": "Document layout analyser", - "slug": "dla", - "type": "c852529d-1852-444f-b3c5-f8677bc0069b", - "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", - "public": false - } -}, -{ - "model": "dataimport.worker", - "pk": "b12a1147-90f4-490a-8b3c-60b497d10888", - "fields": { - "name": "Worker requiring a GPU", - "slug": "worker-gpu", - "type": "433795b1-c1a7-4fb1-ad88-76c62065510d", - "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", - "public": false - } -}, -{ - "model": "dataimport.worker", - "pk": "d9551cc9-d997-4e1f-9a1c-921a8f8d4e77", - "fields": { - "name": "Recognizer", - "slug": "reco", - "type": "25d7116d-d377-4e13-be6c-b7718f5c73de", - "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", - "public": false - } -}, { "model": "dataimport.workerrun", "pk": "13564c24-e7f1-4452-aa59-a3f6bac9a219", @@ -237,6 +164,17 @@ "public": false } }, +{ + "model": "dataimport.worker", + "pk": "f88dccd1-bc9e-4b71-a628-a03572c0148b", + "fields": { + "name": "Generic worker with a Model", + "slug": "generic", + "type": "25d7116d-d377-4e13-be6c-b7718f5c73de", + "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", + "public": false + } +}, { "model": "dataimport.workerversion", "pk": "45c63f12-63f2-46bf-bd65-960ff6b2896b", @@ -248,6 +186,7 @@ }, "state": "available", "gpu_usage": "disabled", + "model_usage": false, "docker_image": "17ea0ee4-1b5c-4f06-a89f-15ffb01bcffc", "docker_image_iid": null } @@ -263,6 +202,7 @@ }, "state": "available", "gpu_usage": "disabled", + "model_usage": false, "docker_image": "17ea0ee4-1b5c-4f06-a89f-15ffb01bcffc", "docker_image_iid": null } @@ -278,6 +218,23 @@ }, "state": "available", "gpu_usage": "required", + "model_usage": false, + "docker_image": "17ea0ee4-1b5c-4f06-a89f-15ffb01bcffc", + "docker_image_iid": null + } +}, +{ + "model": "dataimport.workerversion", + "pk": "54cb1997-07cc-44a5-bf10-19ade68ad878", + "fields": { + "worker": "f88dccd1-bc9e-4b71-a628-a03572c0148b", + "revision": "840e0937-0496-4eba-8e64-2aa89081ffab", + "configuration": { + "test": 42 + }, + "state": "available", + "gpu_usage": "disabled", + "model_usage": true, "docker_image": "17ea0ee4-1b5c-4f06-a89f-15ffb01bcffc", "docker_image_iid": null } diff --git a/arkindex/documents/management/commands/build_fixtures.py b/arkindex/documents/management/commands/build_fixtures.py index 2d37df3944c7ba0e5d26ea80f05ae68e348150e4..c70fc9826ca60f0084d18db44b13ec5abc885c12 100644 --- a/arkindex/documents/management/commands/build_fixtures.py +++ b/arkindex/documents/management/commands/build_fixtures.py @@ -125,6 +125,7 @@ class Command(BaseCommand): revision=revision, configuration={'test': 42}, state=WorkerVersionState.Available, + model_usage=False, docker_image=docker_image ) dla_worker = WorkerVersion.objects.create( @@ -136,6 +137,7 @@ class Command(BaseCommand): revision=revision, configuration={'test': 42}, state=WorkerVersionState.Available, + model_usage=False, docker_image=docker_image ) @@ -148,10 +150,26 @@ class Command(BaseCommand): revision=revision, configuration={'test': 42}, state=WorkerVersionState.Available, + model_usage=False, docker_image=docker_image, gpu_usage=WorkerVersionGPUUsage.Required ) + # Create a generic worker and its version that uses a ML Model + WorkerVersion.objects.create( + worker=worker_repo.workers.create( + name='Generic worker with a Model', + slug='generic', + type=recognizer_worker_type, + ), + revision=revision, + configuration={'test': 42}, + state=WorkerVersionState.Available, + gpu_usage=False, + model_usage=True, + docker_image=docker_image + ) + # Create a IIIF repository repo = creds.repos.create( url='http://gitlab/repo', diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index fd0c1733bf6f387a3a217a0bebdbbda066aa15ee..2a214da657ced216c1e799123c943276a0f554fa 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -713,9 +713,10 @@ class TestModelAPI(FixtureAPITestCase): def test_destroy_model_versions(self): """To destroy a model version, you need admin rights on the model. + This also deletes every worker run that used this model version """ self.client.force_login(self.user1) - with self.assertNumQueries(9): + with self.assertNumQueries(10): response = self.client.delete(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)})) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) @@ -753,5 +754,4 @@ class TestModelAPI(FixtureAPITestCase): with self.assertNumQueries(7): response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version4.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.maxDiff = None self.assertDictEqual(response.json(), _deserialize_model_version(self.model_version4))