diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 56ed03871950ed5711f867a052112c2f0b154360..461e1c2a3795fbe56fcfb491f43be16c841d69c2 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -82,6 +82,7 @@ from arkindex.process.serializers.ingest import BucketSerializer, S3ImportSerial from arkindex.process.serializers.training import ProcessDatasetSerializer, StartTrainingSerializer from arkindex.process.serializers.worker_runs import WorkerRunEditSerializer, WorkerRunSerializer from arkindex.process.serializers.workers import ( + CorpusWorkerVersionSerializer, DockerWorkerVersionSerializer, RepositorySerializer, WorkerActivitySerializer, @@ -1136,7 +1137,7 @@ class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView): Guest access is required on private corpora. No specific rights are required on the workers. """ permission_classes = (IsVerifiedOrReadOnly, ) - serializer_class = WorkerVersionSerializer + serializer_class = CorpusWorkerVersionSerializer # For OpenAPI type discovery queryset = WorkerVersion.objects.none() @@ -1146,20 +1147,20 @@ class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView): def get_queryset(self): return ( - self.corpus.worker_versions + self.corpus.worker_version_cache .select_related( - 'worker', - 'revision', + 'worker_version__worker__type', + 'worker_version__worker__repository', + 'worker_version__revision__repo', + 'model_version__model', + 'worker_configuration', ) .order_by( - 'worker__name', - 'revision__hash', + 'worker_version__worker__name', + 'worker_version__revision__hash', ) .prefetch_related( - 'revision__repo', - 'revision__refs', - 'worker__repository', - 'worker__type', + 'worker_version__revision__refs', ) ) diff --git a/arkindex/process/managers.py b/arkindex/process/managers.py index 58f5041d17e0f6857768fe4c3335d241a95d79ae..b3f42aeade3149bc5bb64e2ec928b8fc73230a8e 100644 --- a/arkindex/process/managers.py +++ b/arkindex/process/managers.py @@ -53,6 +53,7 @@ class CorpusWorkerVersionManager(models.Manager): def rebuild(self): """ Rebuild the corpus worker versions cache from all ML results. + Worker version and configuration attributes on the M2M are left blank. """ from arkindex.documents.models import ( Classification, diff --git a/arkindex/process/migrations/0012_corpusworkerversion_model_version_configuration.py b/arkindex/process/migrations/0012_corpusworkerversion_model_version_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..06abf7a6bc94da184d549411b26296b9162004bb --- /dev/null +++ b/arkindex/process/migrations/0012_corpusworkerversion_model_version_configuration.py @@ -0,0 +1,45 @@ +# Generated by Django 4.1.7 on 2023-07-27 16:02 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('training', '0003_index_cleanup'), + ('process', '0011_worker_run_blank_fields'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='corpusworkerversion', + unique_together=set(), + ), + migrations.AddField( + model_name='corpusworkerversion', + name='model_version', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='corpus_worker_versions', to='training.modelversion'), + ), + migrations.AddField( + model_name='corpusworkerversion', + name='worker_configuration', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='corpus_worker_versions', to='process.workerconfiguration'), + ), + migrations.AddConstraint( + model_name='corpusworkerversion', + constraint=models.UniqueConstraint(condition=models.Q(('model_version__isnull', True), ('worker_configuration__isnull', True)), fields=('corpus', 'worker_version'), name='corpus_workerversion_version_null_configuration_null'), + ), + migrations.AddConstraint( + model_name='corpusworkerversion', + constraint=models.UniqueConstraint(condition=models.Q(('model_version__isnull', False), ('worker_configuration__isnull', True)), fields=('corpus', 'worker_version', 'model_version'), name='corpus_workerversion_version_not_null_configuration_null'), + ), + migrations.AddConstraint( + model_name='corpusworkerversion', + constraint=models.UniqueConstraint(condition=models.Q(('model_version__isnull', True), ('worker_configuration__isnull', False)), fields=('corpus', 'worker_version', 'worker_configuration'), name='corpus_workerversion_version_null_configuration_not_null'), + ), + migrations.AddConstraint( + model_name='corpusworkerversion', + constraint=models.UniqueConstraint(condition=models.Q(('model_version__isnull', False), ('worker_configuration__isnull', False)), fields=('corpus', 'worker_version', 'model_version', 'worker_configuration'), name='corpus_workerversion_version_not_null_configuration_not_null'), + ), + ] diff --git a/arkindex/process/models.py b/arkindex/process/models.py index 7bad318541b4affe4c6a972b36b1c67ed2275fa8..7f9de2924a9fb0c9a146e3388fa2775fc830b25b 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -753,8 +753,17 @@ class Process(IndexableModel): if self.mode == ProcessMode.Workers: CorpusWorkerVersion.objects.bulk_create([ - CorpusWorkerVersion(corpus_id=self.corpus_id, worker_version_id=worker_version_id) - for worker_version_id in self.worker_runs.values_list('version_id', flat=True) + CorpusWorkerVersion( + corpus_id=self.corpus_id, + worker_version_id=worker_version_id, + model_version_id=model_version_id, + worker_configuration_id=conf_id, + ) + for worker_version_id, model_version_id, conf_id in self.worker_runs.values_list( + 'version_id', + 'model_version_id', + 'configuration_id', + ) ], ignore_conflicts=True) worker_version_id = None @@ -1373,10 +1382,44 @@ class CorpusWorkerVersion(models.Model): on_delete=models.CASCADE, related_name='corpus_cache', ) + model_version = models.ForeignKey( + 'training.ModelVersion', + related_name='corpus_worker_versions', + on_delete=models.SET_NULL, + null=True, + blank=True, + ) + worker_configuration = models.ForeignKey( + WorkerConfiguration, + related_name='corpus_worker_versions', + on_delete=models.SET_NULL, + null=True, + blank=True, + ) objects = CorpusWorkerVersionManager() class Meta: - unique_together = ( - ('corpus', 'worker_version') - ) + constraints = [ + # A worker version can be linked to a corpus via multiple model version or configuration + models.UniqueConstraint( + fields=['corpus', 'worker_version'], + condition=Q(model_version__isnull=True, worker_configuration__isnull=True), + name='corpus_workerversion_version_null_configuration_null', + ), + models.UniqueConstraint( + fields=['corpus', 'worker_version', 'model_version'], + condition=Q(model_version__isnull=False, worker_configuration__isnull=True), + name='corpus_workerversion_version_not_null_configuration_null', + ), + models.UniqueConstraint( + fields=['corpus', 'worker_version', 'worker_configuration'], + condition=Q(model_version__isnull=True, worker_configuration__isnull=False), + name='corpus_workerversion_version_null_configuration_not_null', + ), + models.UniqueConstraint( + fields=['corpus', 'worker_version', 'model_version', 'worker_configuration'], + condition=Q(model_version__isnull=False, worker_configuration__isnull=False), + name='corpus_workerversion_version_not_null_configuration_not_null', + ), + ] diff --git a/arkindex/process/serializers/workers.py b/arkindex/process/serializers/workers.py index ca036fa234228abd119383263221d21dc11b7025..bf1267a9cf55a3b19e78738133ee42c061c1beee 100644 --- a/arkindex/process/serializers/workers.py +++ b/arkindex/process/serializers/workers.py @@ -14,6 +14,7 @@ from rest_framework.exceptions import ValidationError from arkindex.ponos.models import Task from arkindex.ponos.utils import get_process_from_task_auth from arkindex.process.models import ( + CorpusWorkerVersion, GitRef, Process, Repository, @@ -30,6 +31,7 @@ from arkindex.process.models import ( from arkindex.process.serializers.git import GitRefSerializer, RevisionWithRefsSerializer from arkindex.process.utils import hash_object from arkindex.project.serializer_fields import EnumField +from arkindex.training.serializers import ModelVersionLightSerializer from arkindex.users.models import Role @@ -393,6 +395,16 @@ class WorkerConfigurationSerializer(WorkerConfigurationListSerializer): read_only_fields = ('id', 'configuration') +class CorpusWorkerVersionSerializer(serializers.ModelSerializer): + worker_version = WorkerVersionSerializer() + model_version = ModelVersionLightSerializer(read_only=True, allow_null=True) + worker_configuration = WorkerConfigurationSerializer(read_only=True, allow_null=True) + + class Meta: + model = CorpusWorkerVersion + fields = ('id', 'worker_version', 'model_version', 'worker_configuration') + + class WorkerConfigurationExistsErrorSerializer(serializers.Serializer): id = serializers.UUIDField(required=False, help_text="UUID of an existing worker configuration, if the error comes from a duplicate configuration.") configuration = serializers.ListField(child=serializers.CharField(), required=False, help_text="Configuration error message.") diff --git a/arkindex/process/tests/test_transkribus_import.py b/arkindex/process/tests/test_transkribus_import.py index 7c2a6f75d6f06d2ee5e9ac44279d6b2ce7c62c09..282bf27c2ee7c685c5aefbdf2ee7ee875ea62704 100644 --- a/arkindex/process/tests/test_transkribus_import.py +++ b/arkindex/process/tests/test_transkribus_import.py @@ -177,9 +177,9 @@ class TestTranskribusImport(FixtureAPITestCase): process = Process.objects.get(id=data["id"]) self.assertEqual(process.mode, ProcessMode.Transkribus) corpus = process.corpus - with self.assertNumQueries(10): + with self.assertNumQueries(7): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': corpus.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.json()['results']), 1) - self.assertEqual(response.json()['results'][0]['id'], str(self.transkribus_worker_version.id)) + self.assertEqual(response.json()['results'][0]['worker_version']['id'], str(self.transkribus_worker_version.id)) diff --git a/arkindex/process/tests/test_workers.py b/arkindex/process/tests/test_workers.py index 0465eb2b3128b5198dec81d07f220cb89061d29a..03da27277be7c66a5789d5eb2c989d20481d6afe 100644 --- a/arkindex/process/tests/test_workers.py +++ b/arkindex/process/tests/test_workers.py @@ -4,18 +4,21 @@ from django.urls import reverse from rest_framework import status from arkindex.process.models import ( + CorpusWorkerVersion, GitRefType, Process, ProcessMode, Repository, Revision, Worker, + WorkerConfiguration, WorkerType, WorkerVersion, WorkerVersionGPUUsage, WorkerVersionState, ) from arkindex.project.tests import FixtureAPITestCase +from arkindex.training.models import Model, ModelVersionState from arkindex.users.models import Right, Role, User @@ -48,6 +51,17 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): repo=cls.repo, ) + cls.model = Model.objects.create( + name='Generic model', + public=False, + ) + cls.model_version = cls.model.versions.create( + state=ModelVersionState.Available, + hash='42', + archive_hash='42', + size=1337, + ) + process = cls.rev.processes.create(mode=ProcessMode.Repository, creator=cls.user) process.start() cls.task = process.tasks.get() @@ -71,7 +85,6 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): with self.assertNumQueries(7): response = self.client.get(reverse('api:workers-list')) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.maxDiff = None self.assertDictEqual(response.json(), { 'count': 5, 'next': None, @@ -1537,9 +1550,10 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): def test_corpus_worker_version_no_login_public(self): self.assertTrue(self.corpus.public) # Use multiple versions to ensure we don't get duplicated queries - self.corpus.worker_versions.set([self.version_1, self.version_2]) + corpus_worker_version_1 = self.corpus.worker_version_cache.create(worker_version=self.version_1) + corpus_worker_version_2 = self.corpus.worker_version_cache.create(worker_version=self.version_2) - with self.assertNumQueries(7): + with self.assertNumQueries(4): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1548,8 +1562,11 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'number': 1, 'previous': None, 'next': None, - 'results': [ - { + 'results': [{ + 'id': str(corpus_worker_version_2.id), + 'worker_configuration': None, + 'model_version': None, + 'worker_version': { 'id': str(self.version_2.id), 'configuration': {'test': 42}, 'revision': { @@ -1574,7 +1591,11 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'slug': 'dla', } }, - { + }, { + 'id': str(corpus_worker_version_1.id), + 'worker_configuration': None, + 'model_version': None, + 'worker_version': { 'id': str(self.version_1.id), 'configuration': {'test': 42}, 'revision': { @@ -1599,7 +1620,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'slug': 'reco', } }, - ], + }] }) def test_corpus_worker_version_no_login_private(self): @@ -1618,9 +1639,10 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.user.save() self.client.force_login(self.user) # Use multiple versions to ensure we don't get duplicated queries - self.corpus.worker_versions.set([self.version_1, self.version_2]) + corpus_worker_version_1 = self.corpus.worker_version_cache.create(worker_version=self.version_1) + corpus_worker_version_2 = self.corpus.worker_version_cache.create(worker_version=self.version_2) - with self.assertNumQueries(11): + with self.assertNumQueries(8): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1629,8 +1651,11 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'number': 1, 'previous': None, 'next': None, - 'results': [ - { + 'results': [{ + 'id': str(corpus_worker_version_2.id), + 'worker_configuration': None, + 'model_version': None, + 'worker_version': { 'id': str(self.version_2.id), 'configuration': {'test': 42}, 'revision': { @@ -1655,7 +1680,11 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'slug': 'dla', } }, - { + }, { + 'id': str(corpus_worker_version_1.id), + 'worker_configuration': None, + 'model_version': None, + 'worker_version': { 'id': str(self.version_1.id), 'configuration': {'test': 42}, 'revision': { @@ -1680,15 +1709,28 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'slug': 'reco', } }, - ], + }] }) def test_corpus_worker_version_list(self): self.client.force_login(self.user) - # Use multiple versions to ensure we don't get duplicated queries - self.corpus.worker_versions.set([self.version_1, self.version_2]) + # Use multiple versions with a model and worker configuration to ensure we don't get duplicated queries + conf_1 = WorkerConfiguration.objects.create(name="test", configuration={"num": 1337}, worker=self.worker_reco) + corpus_worker_version_1 = CorpusWorkerVersion.objects.create( + corpus=self.corpus, + worker_version=self.version_1, + model_version=self.model_version, + worker_configuration=conf_1, + ) + conf_2 = WorkerConfiguration.objects.create(name="test 2", configuration={"value": "leet"}, worker=self.worker_dla) + corpus_worker_version_2 = CorpusWorkerVersion.objects.create( + corpus=self.corpus, + worker_version=self.version_2, + model_version=self.model_version, + worker_configuration=conf_2, + ) - with self.assertNumQueries(11): + with self.assertNumQueries(8): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1697,8 +1739,9 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'number': 1, 'previous': None, 'next': None, - 'results': [ - { + 'results': [{ + 'id': str(corpus_worker_version_2.id), + 'worker_version': { 'id': str(self.version_2.id), 'configuration': {'test': 42}, 'revision': { @@ -1723,7 +1766,26 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'slug': 'dla', } }, - { + 'worker_configuration': { + 'id': str(conf_2.id), + 'archived': False, + 'name': 'test 2', + 'configuration': {'value': 'leet'}, + }, + 'model_version': { + 'id': str(self.model_version.id), + 'configuration': {}, + 'model': { + 'id': str(self.model.id), + 'name': 'Generic model' + }, + 'size': 1337, + 'state': ModelVersionState.Available.value, + 'tag': None, + }, + }, { + 'id': str(corpus_worker_version_1.id), + 'worker_version': { 'id': str(self.version_1.id), 'configuration': {'test': 42}, 'revision': { @@ -1748,51 +1810,45 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'slug': 'reco', } }, - ], + 'worker_configuration': { + 'id': str(conf_1.id), + 'archived': False, + 'name': 'test', + 'configuration': {'num': 1337}, + }, + 'model_version': { + 'id': str(self.model_version.id), + 'configuration': {}, + 'model': { + 'id': str(self.model.id), + 'name': 'Generic model' + }, + 'size': 1337, + 'state': ModelVersionState.Available.value, + 'tag': None, + }, + }], }) def test_corpus_worker_version_no_right(self): self.user.rights.all().delete() self.client.force_login(self.user) - # Use multiple versions to ensure we don't get duplicated queries - self.corpus.worker_versions.set([self.version_1, self.version_2]) + corpus_worker_version = self.corpus.worker_version_cache.create(worker_version=self.version_1) - with self.assertNumQueries(11): + with self.assertNumQueries(8): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { - 'count': 2, + 'count': 1, 'number': 1, 'previous': None, 'next': None, - 'results': [ - { - 'id': str(self.version_2.id), - 'configuration': {'test': 42}, - 'revision': { - 'id': str(self.version_2.revision_id), - 'hash': '1337', - 'author': 'Test user', - 'message': 'My w0rk3r', - 'created': '2020-02-02T01:23:45.678000Z', - 'commit_url': 'http://my_repo.fake/workers/worker/commit/1337', - 'refs': [] - }, - 'gpu_usage': 'disabled', - 'model_usage': False, - 'docker_image': str(self.version_2.docker_image_id), - 'docker_image_iid': None, - 'docker_image_name': f'my_repo.fake/workers/worker/dla:{self.version_2.id}', - 'state': 'available', - 'worker': { - 'id': str(self.worker_dla.id), - 'name': 'Document layout analyser', - 'type': 'dla', - 'slug': 'dla', - } - }, - { + 'results': [{ + 'id': str(corpus_worker_version.id), + 'worker_configuration': None, + 'model_version': None, + 'worker_version': { 'id': str(self.version_1.id), 'configuration': {'test': 42}, 'revision': { @@ -1816,6 +1872,6 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'type': 'recognizer', 'slug': 'reco', } - }, - ], + } + }] }) diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index f025c9a0dd64ad61a3bc1eda9af6ac4f25e8101c..6010c8342c3240fc93e1e58fac10fcc313a0b764 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -374,7 +374,7 @@ class TestModelAPI(FixtureAPITestCase): This also deletes every worker run that used this model version """ self.client.force_login(self.user1) - with self.assertNumQueries(10): + with self.assertNumQueries(11): response = self.client.delete(reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version1.id)})) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)