diff --git a/arkindex/dataimport/admin.py b/arkindex/dataimport/admin.py index 877cdae7f01f6c2279389ca210bb87cfb4f43271..d2cc69cee0886bc1010803cfac29438bcbf63afb 100644 --- a/arkindex/dataimport/admin.py +++ b/arkindex/dataimport/admin.py @@ -9,6 +9,7 @@ from arkindex.dataimport.models import ( Revision, Worker, WorkerConfiguration, + WorkerType, WorkerVersion, ) from arkindex.users.admin import GroupMembershipInline, UserMembershipInline @@ -100,15 +101,21 @@ class WorkerConfigurationInline(admin.StackedInline): class WorkerAdmin(admin.ModelAdmin): list_display = ('id', 'name', 'slug', 'type', 'repository') - field = ('id', 'name', 'slug', 'type', 'repository') + fields = ('id', 'name', 'slug', 'type', 'repository', 'public') readonly_fields = ('id', ) inlines = [WorkerVersionInline, UserMembershipInline, GroupMembershipInline, WorkerConfigurationInline] +class WorkerTypeAdmin(admin.ModelAdmin): + list_display = ('id', 'slug', 'display_name', 'created') + fields = ('id', 'slug', 'display_name') + readonly_fields = ('id', ) + + class WorkerVersionAdmin(admin.ModelAdmin): list_display = ('id', 'worker', 'revision') list_filter = ('worker', ) - field = ('id', 'worker', 'revision', 'configuration') + fields = ('id', 'worker', 'revision', 'configuration') readonly_fields = ('id', ) raw_id_fields = ('docker_image', 'revision') @@ -124,5 +131,6 @@ admin.site.register(DataFile, DataFileAdmin) admin.site.register(Revision, RevisionAdmin) admin.site.register(Repository, RepositoryAdmin) admin.site.register(Worker, WorkerAdmin) +admin.site.register(WorkerType, WorkerTypeAdmin) admin.site.register(WorkerVersion, WorkerVersionAdmin) admin.site.register(WorkerConfiguration, WorkerConfigurationAdmin) diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index 3b3a6b179edab400612d0d5ea7fd5f06874ae4d3..c3d82a3acd6086cc10a1d90b069018fd7a2b5874 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -48,6 +48,7 @@ from arkindex.dataimport.models import ( WorkerActivityState, WorkerConfiguration, WorkerRun, + WorkerType, WorkerVersion, WorkerVersionGPUUsage, ) @@ -815,6 +816,7 @@ class WorkerList(WorkerACLMixin, ListCreateAPIView): """ return self.executable_workers \ .order_by('name', 'id') \ + .select_related('type') \ .distinct() def create(self, request, *args, **kwargs): @@ -830,11 +832,17 @@ class WorkerList(WorkerACLMixin, ListCreateAPIView): elif access_right < Role.Admin.value: raise PermissionDenied(detail='You do not have admin access to this repository.') + # Get or create a WorkerType with given slug + worker_type, _ = WorkerType.objects.get_or_create( + slug=serializer.validated_data['type'], + defaults={'display_name': serializer.validated_data['type'].slug.capitalize()} + ) + worker, created = repo.workers.get_or_create( slug=serializer.validated_data['slug'], defaults={ 'name': serializer.validated_data['name'], - 'type': serializer.validated_data['type'], + 'type': worker_type, } ) @@ -1117,7 +1125,7 @@ class WorkerRunList(WorkerACLMixin, ListCreateAPIView): raise PermissionDenied(detail='You do not have an admin access to the corpus of this process.') return process.worker_runs \ - .select_related('version__worker') \ + .select_related('version__worker__type') \ .order_by('id') def perform_create(self, serializer): @@ -1186,7 +1194,7 @@ class WorkerRunDetails(CorpusACLMixin, RetrieveUpdateDestroyAPIView): return WorkerRun.objects \ .filter(dataimport__corpus_id__isnull=False) \ .using('default') \ - .select_related('version__worker', 'dataimport__workflow', 'dataimport__corpus') + .select_related('version__worker__type', 'dataimport__workflow', 'dataimport__corpus') def check_object_permissions(self, request, worker_run): if not self.has_admin_access(worker_run.dataimport.corpus): @@ -1615,7 +1623,7 @@ class CreateProcessTemplate(ProcessACLMixin, WorkerACLMixin, CreateAPIView): def get_queryset(self): return DataImport.objects \ - .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related('version__worker'))) + .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related('version__worker__type'))) def check_object_permissions(self, request, template): access_level = self.process_access_level(template) @@ -1663,7 +1671,7 @@ class ApplyProcessTemplate(ProcessACLMixin, WorkerACLMixin, CreateAPIView): def get_queryset(self): return DataImport.objects \ .filter(mode=DataImportMode.Template) \ - .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related('version__worker'))) + .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related('version__worker__type'))) def check_object_permissions(self, request, template): access_level = self.process_access_level(template) diff --git a/arkindex/dataimport/management/commands/fake_worker_version.py b/arkindex/dataimport/management/commands/fake_worker_version.py index c6f7dd5a0c147318b9df4ff954baeb53437a20c4..3295efbd6584f17e88f1a8c095a443f718d44e22 100644 --- a/arkindex/dataimport/management/commands/fake_worker_version.py +++ b/arkindex/dataimport/management/commands/fake_worker_version.py @@ -3,7 +3,7 @@ import uuid from django.core.management.base import BaseCommand, CommandError -from arkindex.dataimport.models import Repository, RepositoryType, Revision, Worker, WorkerVersion +from arkindex.dataimport.models import Repository, RepositoryType, Revision, Worker, WorkerType, WorkerVersion class Command(BaseCommand): @@ -41,10 +41,15 @@ class Command(BaseCommand): } ) + worker_type, _ = WorkerType.objects.get_or_create( + slug="classifier", + defaults={'display_name': "classifier"} + ) + worker, _ = Worker.objects.get_or_create( name=name, slug=slug, - type="classifier", + type=worker_type, repository=repo ) diff --git a/arkindex/dataimport/migrations/0046_workertype_alter_worker_type.py b/arkindex/dataimport/migrations/0046_workertype_alter_worker_type.py new file mode 100644 index 0000000000000000000000000000000000000000..25365bba557d23f8e00d4480b23933708ec32569 --- /dev/null +++ b/arkindex/dataimport/migrations/0046_workertype_alter_worker_type.py @@ -0,0 +1,79 @@ +# Generated by Django 4.0.2 on 2022-04-07 11:43 + +import uuid + +import django.db.models.deletion +from django.db import migrations, models + + +def update_worker_types(apps, schema_editor): + Worker = apps.get_model('dataimport', 'Worker') + WorkerType = apps.get_model('dataimport', 'WorkerType') + + # Get list of current worker types + current_types = Worker.objects.values('type').distinct() + created_types = WorkerType.objects.bulk_create( + [WorkerType(slug=type_slug['type'], display_name=type_slug['type'].capitalize()) for type_slug in current_types] + ) + for worker_type in created_types: + Worker.objects.filter(type=worker_type.slug).update(type_fk=worker_type.id) + + +def retrieve_worker_type_slugs(apps, schema_editor): + Worker = apps.get_model('dataimport', 'Worker') + for worker in Worker.objects.all(): + worker.type = worker.type_fk.slug + worker.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataimport', '0045_remove_dataimport_best_class'), + ] + + operations = [ + migrations.CreateModel( + name='WorkerType', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('slug', models.CharField(max_length=100, unique=True)), + ('display_name', models.CharField(max_length=100)), + ('created', models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now)), + ('updated', models.DateTimeField(auto_now=True)) + ], + ), + migrations.AddField( + model_name='worker', + name='type_fk', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='type', to='dataimport.workertype'), + ), + migrations.AlterField( + model_name="worker", + name="type", + field=models.CharField(max_length=50, null=True), + ), + migrations.RunPython( + update_worker_types, + reverse_code=retrieve_worker_type_slugs + ), + migrations.RemoveField( + model_name='worker', + name='type', + ), + migrations.RenameField( + model_name='worker', + old_name='type_fk', + new_name='type', + ), + migrations.AlterField( + model_name="worker", + name="type", + field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='type', to='dataimport.workertype'), + ), + migrations.AlterField( + model_name="workertype", + name="created", + field=models.DateTimeField(auto_now_add=True) + ) + ] diff --git a/arkindex/dataimport/models.py b/arkindex/dataimport/models.py index 457bc9aa2159e6b06afb61b28cedf4e6ea6a8acd..176a71215728701c57ecfb0773b57e62e9411a9c 100644 --- a/arkindex/dataimport/models.py +++ b/arkindex/dataimport/models.py @@ -530,7 +530,7 @@ class Worker(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) name = models.CharField(max_length=100) slug = models.CharField(max_length=100) - type = models.CharField(max_length=50) + type = models.ForeignKey('dataimport.WorkerType', on_delete=models.PROTECT, related_name='type') repository = models.ForeignKey('dataimport.Repository', on_delete=models.CASCADE, related_name='workers') memberships = GenericRelation('users.Right', 'content_id') @@ -544,6 +544,18 @@ class Worker(models.Model): return f'{self.name}' +class WorkerType(IndexableModel): + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + slug = models.CharField(max_length=100, unique=True) + display_name = models.CharField(max_length=100) + + def __str__(self): + return self.slug + + def __repr__(self): + return f'<WorkerType {self.slug}>' + + class WorkerVersionState(Enum): Created = 'created' Processing = 'processing' diff --git a/arkindex/dataimport/serializers/workers.py b/arkindex/dataimport/serializers/workers.py index 68a0c010ca26bb64addfb35608bd7f5658c18878..c1583d2b6ca18b3585a28c0fe2bc7c14245a628d 100644 --- a/arkindex/dataimport/serializers/workers.py +++ b/arkindex/dataimport/serializers/workers.py @@ -15,6 +15,7 @@ from arkindex.dataimport.models import ( WorkerActivity, WorkerActivityState, WorkerConfiguration, + WorkerType, WorkerVersion, WorkerVersionGPUUsage, WorkerVersionState, @@ -27,6 +28,8 @@ class WorkerLightSerializer(serializers.ModelSerializer): """ Serialize a simple repository worker """ + # We want to only have to specify the type as a slug (char) in stead of the actual workerType id + type = serializers.SlugRelatedField(queryset=WorkerType.objects.all(), slug_field='slug') class Meta: model = Worker diff --git a/arkindex/dataimport/tests/commands/test_fake_worker_version.py b/arkindex/dataimport/tests/commands/test_fake_worker_version.py index ef715ea106d41917cb7c6692dbb2a2cdcf77a44f..5eb1af226ce89a5d3bc6b8edf48042d83827d098 100644 --- a/arkindex/dataimport/tests/commands/test_fake_worker_version.py +++ b/arkindex/dataimport/tests/commands/test_fake_worker_version.py @@ -1,7 +1,15 @@ from django.core.management import call_command from django.core.management.base import CommandError -from arkindex.dataimport.models import Repository, RepositoryType, Revision, Worker, WorkerVersion, WorkerVersionState +from arkindex.dataimport.models import ( + Repository, + RepositoryType, + Revision, + Worker, + WorkerType, + WorkerVersion, + WorkerVersionState, +) from arkindex.project.tests import FixtureTestCase @@ -31,7 +39,7 @@ class TestFakeWorker(FixtureTestCase): self.assertEqual(worker.name, name) self.assertEqual(worker.slug, slug) - self.assertEqual(worker.type, "classifier") + self.assertEqual(worker.type.slug, "classifier") self.assertIsNotNone(revision.hash) self.assertEqual(revision.message, "Fake revision") @@ -98,10 +106,14 @@ class TestFakeWorker(FixtureTestCase): type=RepositoryType.Worker, ) + worker_type = WorkerType.objects.get( + slug="classifier" + ) + worker = Worker.objects.create( name=name, slug=slug, - type="classifier", + type=worker_type, repository=repo ) diff --git a/arkindex/dataimport/tests/test_managers.py b/arkindex/dataimport/tests/test_managers.py index 8f671dd915f87cc443f3937a7e92b14c1b1a85f9..35676c0a92a8050537b39501bbab20ca50591428 100644 --- a/arkindex/dataimport/tests/test_managers.py +++ b/arkindex/dataimport/tests/test_managers.py @@ -1,6 +1,6 @@ from uuid import uuid4 -from arkindex.dataimport.models import CorpusWorkerVersion, Repository, RepositoryType, WorkerVersionState +from arkindex.dataimport.models import CorpusWorkerVersion, Repository, RepositoryType, WorkerType, WorkerVersionState from arkindex.documents.models import Classification, Element, Entity, MetaData, Transcription, TranscriptionEntity from arkindex.project.tests import FixtureTestCase from ponos.models import Artifact @@ -17,9 +17,12 @@ class TestManagers(FixtureTestCase): # The fixtures have two worker versions, only one of them is used in existing elements cls.recognizer = cls.repo.workers.get(slug='reco').versions.get() + # Retrieve one workerType + cls.worker_type = WorkerType.objects.get(slug='recognizer') + def _make_worker_version(self): return self.revision.versions.create( - worker=self.repo.workers.create(slug=str(uuid4())), + worker=self.repo.workers.create(slug=str(uuid4()), type=self.worker_type), configuration={}, state=WorkerVersionState.Available, docker_image=Artifact.objects.first(), diff --git a/arkindex/dataimport/tests/test_repos.py b/arkindex/dataimport/tests/test_repos.py index 4b4cbd56ce04993b8b217c91c25ebb438ec04d8e..1e09361614d15597329bccb5f0b00f50c83d7120 100644 --- a/arkindex/dataimport/tests/test_repos.py +++ b/arkindex/dataimport/tests/test_repos.py @@ -57,7 +57,7 @@ class TestRepositories(FixtureTestCase): { 'id': str(w.id), 'name': w.name, - 'type': w.type, + 'type': w.type.slug, 'slug': w.slug, 'repository_id': str(self.worker_repo.id), } for w in self.worker_repo.workers.all() @@ -115,7 +115,7 @@ class TestRepositories(FixtureTestCase): self.iiif_repo.memberships.create(user=self.user, level=Role.Admin.value) self.worker_repo.memberships.create(user=self.user, level=Role.Guest.value) self.client.force_login(self.user) - with self.assertNumQueries(7): + with self.assertNumQueries(10): response = self.client.get(reverse('api:repository-list')) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() @@ -128,7 +128,7 @@ class TestRepositories(FixtureTestCase): Multiple repository serialization should not include the git_clone_url field. """ self.client.force_login(self.internal_user) - with self.assertNumQueries(7): + with self.assertNumQueries(10): response = self.client.get(reverse('api:repository-list')) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() @@ -141,7 +141,7 @@ class TestRepositories(FixtureTestCase): """ self.iiif_repo.corpora.create() self.client.force_login(self.internal_user) - with self.assertNumQueries(7): + with self.assertNumQueries(10): response = self.client.get(reverse('api:repository-list')) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() diff --git a/arkindex/dataimport/tests/test_signals.py b/arkindex/dataimport/tests/test_signals.py index 2a62986f68828be327b47aeb1ebd5c255997339b..2f047f6964fb81e3e7ea7267e0011e3f0b059558 100644 --- a/arkindex/dataimport/tests/test_signals.py +++ b/arkindex/dataimport/tests/test_signals.py @@ -1,6 +1,6 @@ from rest_framework.exceptions import ValidationError -from arkindex.dataimport.models import DataImportMode, RepositoryType, Worker, WorkerRun, WorkerVersion +from arkindex.dataimport.models import DataImportMode, RepositoryType, Worker, WorkerRun, WorkerType, WorkerVersion from arkindex.dataimport.signals import _list_ancestors from arkindex.project.tests import FixtureAPITestCase @@ -22,11 +22,13 @@ class TestSignals(FixtureAPITestCase): author='bob', ) + cls.worker_type_1 = WorkerType.objects.get(slug="recognizer") + cls.worker_1 = Worker.objects.create( repository=cls.repo, name='Worker 1', slug='worker_1', - type='classifier' + type=cls.worker_type_1 ) cls.version_1 = WorkerVersion.objects.create( worker=cls.worker_1, diff --git a/arkindex/dataimport/tests/test_workeractivity.py b/arkindex/dataimport/tests/test_workeractivity.py index 655a17bedf8b4b088c8b6f17362e3290941db6d3..aed2479545b1b5911c97a0e2daf83f49dc70410b 100644 --- a/arkindex/dataimport/tests/test_workeractivity.py +++ b/arkindex/dataimport/tests/test_workeractivity.py @@ -13,6 +13,7 @@ from arkindex.dataimport.models import ( WorkerActivity, WorkerActivityState, WorkerConfiguration, + WorkerType, WorkerVersion, WorkerVersionState, ) @@ -35,6 +36,7 @@ class TestWorkerActivity(FixtureTestCase): ) cls.configuration = WorkerConfiguration.objects.create(worker=cls.worker_version.worker, name='A config', configuration={'a': 'b'}) cls.process.worker_runs.create(version=cls.worker_version, parents=[], configuration=cls.configuration) + cls.worker_type = WorkerType.objects.get(slug='recognizer') def setUp(self): super().setUp() @@ -370,7 +372,7 @@ class TestWorkerActivity(FixtureTestCase): worker=Repository.objects.first().workers.create( name='New version', slug='new', - type='new', + type=self.worker_type, ), revision=self.worker_version.revision, configuration={}, diff --git a/arkindex/dataimport/tests/test_workerruns.py b/arkindex/dataimport/tests/test_workerruns.py index 860c549564374b2169505b6be7eb383a7f09f626..9fa88d7bc2b7808da51cc92e761d5f861f3db53b 100644 --- a/arkindex/dataimport/tests/test_workerruns.py +++ b/arkindex/dataimport/tests/test_workerruns.py @@ -73,7 +73,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, - 'type': self.worker_1.type, + 'type': self.worker_1.type.slug, 'slug': self.worker_1.slug, }, 'configuration_id': None, @@ -170,7 +170,7 @@ class TestWorkerRuns(FixtureAPITestCase): def test_runs_post_create_worker_run(self): self.client.force_login(self.user) - with self.assertNumQueries(15): + with self.assertNumQueries(16): response = self.client.post( reverse('api:worker-run-list', kwargs={'pk': str(self.dataimport_2.id)}), data={'worker_version_id': str(self.version_1.id), 'parents': []}, format='json' @@ -186,7 +186,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, - 'type': self.worker_1.type, + 'type': self.worker_1.type.slug, 'slug': self.worker_1.slug, }, 'configuration_id': None, @@ -196,7 +196,7 @@ class TestWorkerRuns(FixtureAPITestCase): self.worker_1.memberships.filter(user=self.user).update(level=Role.Guest.value) self.worker_1.repository.memberships.create(user=self.user, level=Role.Contributor.value) self.client.force_login(self.user) - with self.assertNumQueries(16): + with self.assertNumQueries(17): response = self.client.post( reverse('api:worker-run-list', kwargs={'pk': str(self.dataimport_2.id)}), data={'worker_version_id': str(self.version_1.id), 'parents': []}, format='json' @@ -219,7 +219,7 @@ class TestWorkerRuns(FixtureAPITestCase): def test_create_run_configuration(self): self.client.force_login(self.user) - with self.assertNumQueries(16): + with self.assertNumQueries(17): response = self.client.post( reverse('api:worker-run-list', kwargs={'pk': str(self.dataimport_2.id)}), data={ @@ -241,7 +241,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, - 'type': self.worker_1.type, + 'type': self.worker_1.type.slug, 'slug': self.worker_1.slug, }, 'configuration_id': str(self.configuration_1.id) @@ -293,7 +293,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, - 'type': self.worker_1.type, + 'type': self.worker_1.type.slug, 'slug': self.worker_1.slug, }, 'configuration_id': None, @@ -416,7 +416,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, - 'type': self.worker_1.type, + 'type': self.worker_1.type.slug, 'slug': self.worker_1.slug, }, 'configuration_id': None, @@ -502,7 +502,7 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, - 'type': self.worker_1.type, + 'type': self.worker_1.type.slug, 'slug': self.worker_1.slug, }, 'configuration_id': None, diff --git a/arkindex/dataimport/tests/test_workers.py b/arkindex/dataimport/tests/test_workers.py index 29898206046c45e62cf36d47fabede499295b2d4..c5e866dbf1a9ccfe8140c8b60dddab06707fa45c 100644 --- a/arkindex/dataimport/tests/test_workers.py +++ b/arkindex/dataimport/tests/test_workers.py @@ -9,6 +9,7 @@ from arkindex.dataimport.models import ( RepositoryType, Revision, Worker, + WorkerType, WorkerVersion, WorkerVersionGPUUsage, WorkerVersionState, @@ -53,6 +54,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): repo=cls.repo, ) + cls.worker_type_classifier = WorkerType.objects.get(slug="classifier") cls.worker_config = cls.worker_1.configurations.create(name='config time', configuration={'key': 'value'}) def setUp(self): @@ -179,7 +181,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): repository=repo2, name='Worker 2', slug='worker_2', - type='classifier' + type=self.worker_type_classifier ) worker_2.memberships.create(user=self.user, level=Role.Contributor.value) @@ -197,7 +199,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'repository_id': str(repo2.id), 'name': 'Worker 2', 'slug': 'worker_2', - 'type': 'classifier' + 'type': self.worker_type_classifier.slug } ] }) @@ -217,11 +219,11 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): credentials=self.creds, provider_name='GitLabProvider' ) - worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type='classifier') + worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type=self.worker_type_classifier) repo2.memberships.create(user=self.user, level=Role.Guest.value) self.client.force_login(self.user) - with self.assertNumQueries(8): + with self.assertNumQueries(9): response = self.client.get(reverse('api:worker-retrieve', kwargs={'pk': str(worker_2.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { @@ -229,7 +231,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'repository_id': str(repo2.id), 'name': worker_2.name, 'slug': worker_2.slug, - 'type': worker_2.type, + 'type': worker_2.type.slug, }) def test_worker_create_requires_login(self): @@ -253,7 +255,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'repository_id': 'repo', 'name': 'Worker post', 'slug': 'worker_post', - 'type': 'classifier' + 'type': self.worker_type_classifier.slug } ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -267,7 +269,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'repository_id': uuid.uuid4(), 'name': 'Worker post', 'slug': 'worker_post', - 'type': 'classifier' + 'type': self.worker_type_classifier.slug } ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @@ -295,7 +297,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'repository_id': str(self.repo.id), 'name': 'Worker post', 'slug': 'worker_post', - 'type': 'classifier' + 'type': self.worker_type_classifier.slug } ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -303,7 +305,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.assertNotEqual(data['id'], str(self.worker_1.id)) self.assertEqual(data['name'], 'Worker post') self.assertEqual(data['slug'], 'worker_post') - self.assertEqual(data['type'], 'classifier') + self.assertEqual(data['type'], self.worker_type_classifier.slug) def test_workers_create_internal_user(self): """ @@ -323,7 +325,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'repository_id': str(self.repo.id), 'name': 'Worker post', 'slug': 'worker_post', - 'type': 'classifier' + 'type': self.worker_type_classifier.slug } ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -337,7 +339,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'repository_id': str(self.repo.id), 'name': 'Worker Test', 'slug': 'reco', - 'type': 'classifier' + 'type': self.worker_type_classifier.slug } ) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -362,7 +364,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): state=WorkerVersionState.Error, docker_image=self.version_1.docker_image ) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.get(reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() @@ -383,7 +385,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, - 'type': self.worker_1.type, + 'type': self.worker_1.type.slug, 'slug': self.worker_1.slug, } }) @@ -415,7 +417,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.worker_1.public = True self.worker_1.save() self.client.force_login(user) - with self.assertNumQueries(8): + with self.assertNumQueries(9): response = self.client.get(reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() @@ -428,7 +430,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): repository=self.repo, name='Worker 2', slug='worker_2', - type='classifier' + type=self.worker_type_classifier ) version_2 = WorkerVersion.objects.create( worker=worker_2, @@ -437,7 +439,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): gpu_usage=WorkerVersionGPUUsage.Disabled ) - with self.assertNumQueries(13): + with self.assertNumQueries(14): response = self.client.get(reverse('api:worker-versions', kwargs={'pk': str(worker_2.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() @@ -476,7 +478,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.client.force_login(self.internal_user) # Complete mode - with self.assertNumQueries(13): + with self.assertNumQueries(14): response = self.client.get( reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}), {'mode': 'complete'} @@ -485,7 +487,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.assertEqual(response.json()['count'], 6) # Simple mode - with self.assertNumQueries(10): + with self.assertNumQueries(11): response = self.client.get( reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}), {'mode': 'simple'} @@ -888,7 +890,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): Worker versions may be retrieved with no authentication in order to see the version on public resources """ - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.get(reverse('api:version-retrieve', kwargs={'pk': str(self.version_1.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() @@ -905,7 +907,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, - 'type': self.worker_1.type, + 'type': self.worker_1.type.slug, 'slug': self.worker_1.slug, } }) @@ -928,7 +930,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'worker': { 'id': str(self.worker_1.id), 'name': self.worker_1.name, - 'type': self.worker_1.type, + 'type': self.worker_1.type.slug, 'slug': self.worker_1.slug, } }) @@ -1168,7 +1170,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'worker': { 'id': str(version.worker.id), 'name': version.worker.name, - 'type': version.worker.type, + 'type': version.worker.type.slug, 'slug': version.worker.slug, } } @@ -1179,7 +1181,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): def test_corpus_worker_version_no_login(self): self.corpus.worker_versions.set([self.version_1]) - with self.assertNumQueries(8): + with self.assertNumQueries(9): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1198,7 +1200,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.client.force_login(self.user) self.corpus.worker_versions.set([self.version_1]) - with self.assertNumQueries(12): + with self.assertNumQueries(13): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1215,7 +1217,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.client.force_login(self.user) self.corpus.worker_versions.set([self.version_1]) - with self.assertNumQueries(12): + with self.assertNumQueries(13): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1266,7 +1268,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): credentials=self.creds, provider_name='GitLabProvider' ) - worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type='classifier') + worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type=self.worker_type_classifier) config_1 = worker_2.configurations.create(name='config_1', configuration={'key': 'value'}) config_2 = worker_2.configurations.create(name='config_2', configuration={'dulce': 'et decorum est'}) repo2.memberships.create(user=self.user, level=Role.Guest.value) @@ -1335,7 +1337,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): credentials=self.creds, provider_name='GitLabProvider' ) - worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type='classifier') + worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type=self.worker_type_classifier) worker_2.configurations.create(name=name, configuration={'key': 'value'}) repo2.memberships.create(user=self.user, level=Role.Admin.value) @@ -1362,7 +1364,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): credentials=self.creds, provider_name='GitLabProvider' ) - worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type='classifier') + worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type=self.worker_type_classifier) worker_2.configurations.create(name='config-name', configuration=config) repo2.memberships.create(user=self.user, level=Role.Admin.value) @@ -1393,7 +1395,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): credentials=self.creds, provider_name='GitLabProvider' ) - worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type='classifier') + worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type=self.worker_type_classifier) worker_2.configurations.create(name='config_1', configuration={'key': 'value'}) repo2.memberships.create(user=self.user, level=Role.Admin.value) @@ -1422,7 +1424,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): credentials=self.creds, provider_name='GitLabProvider' ) - worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type='classifier') + worker_2 = repo2.workers.create(name='Worker 2', slug='worker_2', type=self.worker_type_classifier) worker_2.configurations.create(name='config_1', configuration={'key': 'value'}) repo2.memberships.create(user=self.user, level=Role.Admin.value) diff --git a/arkindex/documents/export/worker_version.sql b/arkindex/documents/export/worker_version.sql index 509b9417c91c4dd49d916efa2afb344d07f95f6f..ef48cebd2d02bdecfe38b79df48810d9e095666f 100644 --- a/arkindex/documents/export/worker_version.sql +++ b/arkindex/documents/export/worker_version.sql @@ -3,9 +3,10 @@ -- fills up the RAM. Adding DISTINCT to all the SELECT queries of the UNION -- slows this query down by ~20%. Using multiple INs instead of a UNION makes -- this query twice as slow. -SELECT version.id, worker.name, worker.slug, worker.type, revision.hash, repository.url +SELECT version.id, worker.name, worker.slug, workertype.slug, revision.hash, repository.url FROM dataimport_workerversion version INNER JOIN dataimport_worker worker ON (version.worker_id = worker.id) +INNER JOIN dataimport_workertype workertype ON (worker.type_id = workertype.id) INNER JOIN dataimport_repository repository ON (worker.repository_id = repository.id) INNER JOIN dataimport_revision revision ON (version.revision_id = revision.id) WHERE version.id IN ( diff --git a/arkindex/documents/fixtures/data.json b/arkindex/documents/fixtures/data.json index 61ba25ced1cd909433a5bd0a6a470b1185069922..74d49569a55f8fd11bec20b2e0af4dcd43dc8729 100644 --- a/arkindex/documents/fixtures/data.json +++ b/arkindex/documents/fixtures/data.json @@ -45,13 +45,53 @@ "author": "me" } }, +{ + "model": "dataimport.workertype", + "pk": "c852529d-1852-444f-b3c5-f8677bc0069b", + "fields": { + "display_name": "Document layout analyser", + "slug": "dla", + "created": "2022-04-07T01:23:45.678Z", + "updated": "2022-04-07T01:23:45.678Z" + } +}, +{ + "model": "dataimport.workertype", + "pk": "433795b1-c1a7-4fb1-ad88-76c62065510d", + "fields": { + "display_name": "Worker requiring a GPU", + "slug": "worker", + "created": "2022-04-07T01:23:45.678Z", + "updated": "2022-04-07T01:23:45.678Z" + } +}, +{ + "model": "dataimport.workertype", + "pk": "25d7116d-d377-4e13-be6c-b7718f5c73de", + "fields": { + "display_name": "Recognizer", + "slug": "recognizer", + "created": "2022-04-07T01:23:45.678Z", + "updated": "2022-04-07T01:23:45.678Z" + } +}, +{ + "model": "dataimport.workertype", + "pk": "a18e08c0-6b10-4a47-bc54-d08fdc9ed22d", + "fields": { + "display_name": "Classifier", + "slug": "classifier", + "created": "2022-04-07T01:23:45.678Z", + "updated": "2022-04-07T01:23:45.678Z" + } +}, { "model": "dataimport.worker", "pk": "1fa33954-2e36-4e09-a193-39cdde9efd29", "fields": { "name": "Document layout analyser", "slug": "dla", - "type": "dla", + "type": "c852529d-1852-444f-b3c5-f8677bc0069b", "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", "public": false } @@ -62,7 +102,80 @@ "fields": { "name": "Worker requiring a GPU", "slug": "worker-gpu", - "type": "worker", + "type": "433795b1-c1a7-4fb1-ad88-76c62065510d", + "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", + "public": false + } +}, +{ + "model": "dataimport.worker", + "pk": "d9551cc9-d997-4e1f-9a1c-921a8f8d4e77", + "fields": { + "name": "Recognizer", + "slug": "reco", + "type": "25d7116d-d377-4e13-be6c-b7718f5c73de", + "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", + "public": false + } +}, +{ + "model": "dataimport.workertype", + "pk": "c852529d-1852-444f-b3c5-f8677bc0069b", + "fields": { + "display_name": "Document layout analyser", + "slug": "dla", + "created": "2022-04-07T01:23:45.678Z", + "updated": "2022-04-07T01:23:45.678Z" + } +}, +{ + "model": "dataimport.workertype", + "pk": "433795b1-c1a7-4fb1-ad88-76c62065510d", + "fields": { + "display_name": "Worker requiring a GPU", + "slug": "worker", + "created": "2022-04-07T01:23:45.678Z", + "updated": "2022-04-07T01:23:45.678Z" + } +}, +{ + "model": "dataimport.workertype", + "pk": "25d7116d-d377-4e13-be6c-b7718f5c73de", + "fields": { + "display_name": "Recognizer", + "slug": "recognizer", + "created": "2022-04-07T01:23:45.678Z", + "updated": "2022-04-07T01:23:45.678Z" + } +}, +{ + "model": "dataimport.workertype", + "pk": "a18e08c0-6b10-4a47-bc54-d08fdc9ed22d", + "fields": { + "display_name": "Classifier", + "slug": "classifier", + "created": "2022-04-07T01:23:45.678Z", + "updated": "2022-04-07T01:23:45.678Z" + } +}, +{ + "model": "dataimport.worker", + "pk": "1fa33954-2e36-4e09-a193-39cdde9efd29", + "fields": { + "name": "Document layout analyser", + "slug": "dla", + "type": "c852529d-1852-444f-b3c5-f8677bc0069b", + "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", + "public": false + } +}, +{ + "model": "dataimport.worker", + "pk": "b12a1147-90f4-490a-8b3c-60b497d10888", + "fields": { + "name": "Worker requiring a GPU", + "slug": "worker-gpu", + "type": "433795b1-c1a7-4fb1-ad88-76c62065510d", "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", "public": false } @@ -73,7 +186,7 @@ "fields": { "name": "Recognizer", "slug": "reco", - "type": "recognizer", + "type": "25d7116d-d377-4e13-be6c-b7718f5c73de", "repository": "453d0ed8-81c0-4c39-81ea-1083315e92e2", "public": false } diff --git a/arkindex/documents/management/commands/build_fixtures.py b/arkindex/documents/management/commands/build_fixtures.py index 02240cea731288d9197ffda9b944cbb0b0ba863d..60b560a3bab1317c942f0934494b2bf6874677a6 100644 --- a/arkindex/documents/management/commands/build_fixtures.py +++ b/arkindex/documents/management/commands/build_fixtures.py @@ -8,6 +8,7 @@ from django.utils import timezone as DjangoTimeZone from arkindex.dataimport.models import ( RepositoryType, + WorkerType, WorkerVersion, WorkerVersionGPUUsage, WorkerVersionState, @@ -101,6 +102,11 @@ class Command(BaseCommand): author="Test user" ) + # Create worker types + dla_worker_type = WorkerType.objects.create(slug="dla") + recognizer_worker_type = WorkerType.objects.create(slug="recognizer") + gpu_worker_type = WorkerType.objects.create(slug="worker") + # Create a fake docker build with a docker image task farm = Farm.objects.create(name="Wheat farm") workflow = Workflow.objects.create(farm=farm, recipe='tasks:\n docker_build:\n image: reco') @@ -112,7 +118,7 @@ class Command(BaseCommand): worker=worker_repo.workers.create( name='Recognizer', slug='reco', - type='recognizer', + type=recognizer_worker_type, ), revision=revision, configuration={'test': 42}, @@ -123,7 +129,7 @@ class Command(BaseCommand): worker=worker_repo.workers.create( name='Document layout analyser', slug='dla', - type='dla', + type=dla_worker_type, ), revision=revision, configuration={'test': 42}, @@ -135,7 +141,7 @@ class Command(BaseCommand): worker=worker_repo.workers.create( name='Worker requiring a GPU', slug='worker-gpu', - type='worker', + type=gpu_worker_type, ), revision=revision, configuration={'test': 42}, diff --git a/arkindex/documents/tests/tasks/test_export.py b/arkindex/documents/tests/tasks/test_export.py index b7004ff70b31935d3970fc909ae76eae1e5660ac..bcf292571f929c4af3d18af570e964528b39815c 100644 --- a/arkindex/documents/tests/tasks/test_export.py +++ b/arkindex/documents/tests/tasks/test_export.py @@ -9,7 +9,7 @@ from uuid import uuid4 from django.core import mail from django.test import override_settings -from arkindex.dataimport.models import Repository, RepositoryType, WorkerVersion, WorkerVersionState +from arkindex.dataimport.models import Repository, RepositoryType, WorkerType, WorkerVersion, WorkerVersionState from arkindex.documents.export import export_corpus from arkindex.documents.models import ( Classification, @@ -53,9 +53,12 @@ class TestExport(FixtureTestCase): # The fixtures have two worker versions, only one of them is used in existing elements cls.recognizer = cls.repo.workers.get(slug='reco').versions.get() + # Retrieve a workerType for created workers + cls.worker_type = WorkerType.objects.get(slug='dla') + def _make_worker_version(self): return self.revision.versions.create( - worker=self.repo.workers.create(slug=str(uuid4())), + worker=self.repo.workers.create(slug=str(uuid4()), type=self.worker_type), configuration={}, state=WorkerVersionState.Available, docker_image=Artifact.objects.first(), @@ -190,7 +193,7 @@ class TestExport(FixtureTestCase): str(version.id), version.worker.name, version.worker.slug, - version.worker.type, + version.worker.type.slug, version.revision.hash, version.worker.repository.url ), @@ -198,7 +201,7 @@ class TestExport(FixtureTestCase): str(metadata_version.id), metadata_version.worker.name, metadata_version.worker.slug, - metadata_version.worker.type, + metadata_version.worker.type.slug, metadata_version.revision.hash, metadata_version.worker.repository.url ) diff --git a/arkindex/project/tests/test_acl_mixin.py b/arkindex/project/tests/test_acl_mixin.py index 9aed1a000638da88567f657927a7fd80ff47608c..5e41b9c45476c8f609f1c1d740b46497525f6e35 100644 --- a/arkindex/project/tests/test_acl_mixin.py +++ b/arkindex/project/tests/test_acl_mixin.py @@ -3,7 +3,7 @@ import uuid from django.contrib.auth.models import AnonymousUser from django.contrib.contenttypes.models import ContentType -from arkindex.dataimport.models import DataImport, DataImportMode, Repository, RepositoryType, Revision +from arkindex.dataimport.models import DataImport, DataImportMode, Repository, RepositoryType, Revision, WorkerType from arkindex.documents.models import Corpus from arkindex.project.mixins import ( ACLMixin, @@ -50,7 +50,8 @@ class TestACLMixin(FixtureTestCase): cls.group3 = Group.objects.create(name='Group3') cls.repo1 = Repository.objects.create(type=RepositoryType.Worker, url='http://repo1') - cls.worker = cls.repo1.workers.create(name='repo1 worker', slug='worker') + cls.worker_type = WorkerType.objects.create(display_name='ner', slug='ner') + cls.worker = cls.repo1.workers.create(name='repo1 worker', slug='worker', type=cls.worker_type) cls.corpus1 = Corpus.objects.create(name="Corpus1", id=uuid.UUID('bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb')) cls.corpus2 = Corpus.objects.create(name="Corpus2", id=uuid.UUID('aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa')) diff --git a/arkindex/sql_validation/indexer_prefetch.sql b/arkindex/sql_validation/indexer_prefetch.sql index 9aba9063ff974d2de0883fad4b91b5504e137686..a3615d5d6b1a89eee429bb7507386236c942e76a 100644 --- a/arkindex/sql_validation/indexer_prefetch.sql +++ b/arkindex/sql_validation/indexer_prefetch.sql @@ -60,7 +60,7 @@ WHERE "dataimport_workerversion"."id" IN ('{worker_version_id}'::uuid); SELECT "dataimport_worker"."id", "dataimport_worker"."name", "dataimport_worker"."slug", - "dataimport_worker"."type", + "dataimport_worker"."type_id", "dataimport_worker"."repository_id", "dataimport_worker"."public" FROM "dataimport_worker" @@ -126,7 +126,7 @@ WHERE "dataimport_workerversion"."id" IN ('{worker_version_id}'::uuid); SELECT "dataimport_worker"."id", "dataimport_worker"."name", "dataimport_worker"."slug", - "dataimport_worker"."type", + "dataimport_worker"."type_id", "dataimport_worker"."repository_id", "dataimport_worker"."public" FROM "dataimport_worker"