From 231f955ca18f6d42d87aa0c5d3573ef2aa146664 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Tue, 31 Aug 2021 15:03:18 +0000 Subject: [PATCH] Corpus worker versions cache --- arkindex/dataimport/api.py | 33 ++------- .../commands/cache_worker_versions.py | 21 ++++++ arkindex/dataimport/managers.py | 68 ++++++++++++++++++ .../migrations/0035_corpus_version_cache.py | 63 +++++++++++++++++ arkindex/dataimport/models.py | 69 ++++++++++--------- .../dataimport/tests/commands/__init__.py | 0 .../commands/test_cache_worker_versions.py | 30 ++++++++ .../test_fake_worker_version.py | 0 .../tests/{ => commands}/test_import_s3.py | 0 arkindex/dataimport/tests/test_imports.py | 6 +- arkindex/dataimport/tests/test_managers.py | 48 +++++++++++++ .../dataimport/tests/test_workeractivity.py | 2 +- arkindex/dataimport/tests/test_workers.py | 26 +------ .../dataimport/tests/test_workflows_api.py | 11 ++- arkindex/project/settings.py | 4 ++ 15 files changed, 290 insertions(+), 91 deletions(-) create mode 100644 arkindex/dataimport/management/commands/cache_worker_versions.py create mode 100644 arkindex/dataimport/managers.py create mode 100644 arkindex/dataimport/migrations/0035_corpus_version_cache.py create mode 100644 arkindex/dataimport/tests/commands/__init__.py create mode 100644 arkindex/dataimport/tests/commands/test_cache_worker_versions.py rename arkindex/dataimport/tests/{ => commands}/test_fake_worker_version.py (100%) rename arkindex/dataimport/tests/{ => commands}/test_import_s3.py (100%) create mode 100644 arkindex/dataimport/tests/test_managers.py diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index edd37be4bb..10e398edcc 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -940,14 +940,6 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView): 'List worker versions used by elements of a given corpus.\n\n' 'No check is performed on workers access level in order to allow any user to see versions.' ), - parameters=[ - OpenApiParameter( - 'with_element_count', - type=bool, - default=False, - description='Include element counts in the response.', - ) - ], ) ) class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView): @@ -964,25 +956,12 @@ class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView): return get_object_or_404(self.readable_corpora, pk=self.kwargs['pk']) def get_queryset(self): - corpus = self.get_corpus() - queryset = WorkerVersion.objects \ - .filter(elements__corpus_id=corpus.id) \ - .prefetch_related( - 'revision__repo', - 'revision__refs', - 'revision__versions', - 'worker__repository', - ) \ - .order_by('-id') - - if self.request.query_params.get('with_element_count', '').lower() in ('true', '1'): - queryset = queryset.annotate(element_count=Count('id')) - else: - # The Count() causes Django to add a GROUP BY, and without a count we need a DISTINCT - # because filtering on `elements` causes worker versions to be duplicated. - queryset = queryset.distinct() - - return queryset + return self.get_corpus().worker_versions.prefetch_related( + 'revision__repo', + 'revision__refs', + 'revision__versions', + 'worker__repository', + ).order_by('-id') @extend_schema(tags=['repos']) diff --git a/arkindex/dataimport/management/commands/cache_worker_versions.py b/arkindex/dataimport/management/commands/cache_worker_versions.py new file mode 100644 index 0000000000..e46a64cc18 --- /dev/null +++ b/arkindex/dataimport/management/commands/cache_worker_versions.py @@ -0,0 +1,21 @@ +from django.core.management.base import BaseCommand + +from arkindex.dataimport.models import CorpusWorkerVersion + + +class Command(BaseCommand): + help = 'Rebuild the corpus worker versions cache' + + def add_arguments(self, parser): + parser.add_argument( + '--drop', + help='Drop the existing cache before rebuilding.', + action='store_true', + ) + + def handle(self, *args, drop=False, **options): + if drop: + CorpusWorkerVersion.objects.all().delete() + self.stdout.write('Deleted all existing CorpusWorkerVersion.') + + CorpusWorkerVersion.objects.rebuild() diff --git a/arkindex/dataimport/managers.py b/arkindex/dataimport/managers.py new file mode 100644 index 0000000000..9c8d2c0bac --- /dev/null +++ b/arkindex/dataimport/managers.py @@ -0,0 +1,68 @@ +import logging + +from django.db import connections, models + +logger = logging.getLogger(__name__) + + +class ActivityManager(models.Manager): + """Model management for worker activities""" + + def bulk_insert(self, worker_version_id, process_id, elements_qs, state=None): + """ + Create initial worker activities from a queryset of elements in a efficient way. + Due to the possible large amount of elements, we use a bulk insert from the elements query (best performances). + The `ON CONFLICT` clause allows to automatically skip elements that already have an activity with this version. + """ + from arkindex.dataimport.models import WorkerActivityState + + if state is None: + state = WorkerActivityState.Queued + + assert isinstance(state, WorkerActivityState), 'State should be an instance of WorkerActivityState' + + sql, params = elements_qs.values('id').query.sql_with_params() + + with connections['default'].cursor() as cursor: + cursor.execute( + f""" + INSERT INTO dataimport_workeractivity + (element_id, worker_version_id, state, process_id, id, created, updated) + SELECT + elt.id, + '{worker_version_id}'::uuid, + '{state.value}', + '{process_id}', + uuid_generate_v4(), + current_timestamp, + current_timestamp + FROM ({sql}) AS elt + ON CONFLICT (element_id, worker_version_id) DO NOTHING + """, + params + ) + + +class CorpusWorkerVersionManager(models.Manager): + + def rebuild(self): + """ + Rebuild the corpus worker versions cache from all ML results. + """ + from arkindex.documents.models import Element, Transcription, Entity, TranscriptionEntity, Classification, MetaData + + querysets = [ + Element.objects.exclude(worker_version_id=None).values_list('corpus_id', 'worker_version_id'), + Transcription.objects.exclude(worker_version_id=None).values_list('element__corpus_id', 'worker_version_id'), + Entity.objects.exclude(worker_version_id=None).values_list('corpus_id', 'worker_version_id'), + TranscriptionEntity.objects.exclude(worker_version_id=None).values_list('entity__corpus_id', 'worker_version_id'), + Classification.objects.exclude(worker_version_id=None).values_list('element__corpus_id', 'worker_version_id'), + MetaData.objects.exclude(worker_version_id=None).values_list('element__corpus_id', 'worker_version_id'), + ] + + for i, queryset in enumerate(querysets, start=1): + logger.info(f'Rebuilding cache from {queryset.model.__name__} ({i}/{len(querysets)})') + self.bulk_create([ + self.model(corpus_id=corpus_id, worker_version_id=worker_version_id) + for corpus_id, worker_version_id in queryset.distinct() + ], ignore_conflicts=True) diff --git a/arkindex/dataimport/migrations/0035_corpus_version_cache.py b/arkindex/dataimport/migrations/0035_corpus_version_cache.py new file mode 100644 index 0000000000..e4f16bd72b --- /dev/null +++ b/arkindex/dataimport/migrations/0035_corpus_version_cache.py @@ -0,0 +1,63 @@ +# Generated by Django 3.2.3 on 2021-08-31 07:53 + +import uuid + +import django.db.models.deletion +from django.db import migrations, models + + +def rebuild_reminder(apps, schema_editor): + """ + Print a reminder to rebuild the cache manually if there is anything in the database. + """ + Corpus = apps.get_model('documents', 'Corpus') + if Corpus.objects.exists(): + print("Please run `arkindex cache_worker_versions` to fill the corpus worker versions cache.") + + +class Migration(migrations.Migration): + + dependencies = [ + ('documents', '0042_transcription_entity_confidence'), + ('dataimport', '0034_worker_run_config'), + ] + + operations = [ + migrations.CreateModel( + name='CorpusWorkerVersion', + fields=[ + ('id', models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False + )), + ('corpus', models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name='worker_version_cache', + to='documents.corpus', + )), + ('worker_version', models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name='corpus_cache', + to='dataimport.workerversion', + )), + ], + options={ + 'unique_together': {('corpus', 'worker_version')}, + }, + ), + migrations.AddField( + model_name='workerversion', + name='corpora', + field=models.ManyToManyField( + related_name='worker_versions', + through='dataimport.CorpusWorkerVersion', + to='documents.Corpus', + ), + ), + migrations.RunPython( + code=rebuild_reminder, + reverse_code=migrations.RunPython.noop, + ) + ] diff --git a/arkindex/dataimport/models.py b/arkindex/dataimport/models.py index 2e4cda6ef3..5f5ff56600 100644 --- a/arkindex/dataimport/models.py +++ b/arkindex/dataimport/models.py @@ -7,12 +7,13 @@ from uuid import UUID import yaml from django.conf import settings from django.contrib.contenttypes.fields import GenericRelation -from django.db import connections, models +from django.db import models from django.db.models import Q from django.utils.functional import cached_property from enumfields import Enum, EnumField from rest_framework.exceptions import ValidationError +from arkindex.dataimport.managers import ActivityManager, CorpusWorkerVersionManager from arkindex.dataimport.providers import get_provider, git_providers from arkindex.dataimport.utils import get_default_farm_id from arkindex.documents.models import ClassificationState, Element @@ -363,6 +364,12 @@ class DataImport(IndexableModel): from arkindex.project.triggers import initialize_activity initialize_activity(self) + if self.mode == DataImportMode.Workers: + CorpusWorkerVersion.objects.bulk_create([ + CorpusWorkerVersion(corpus_id=self.corpus_id, worker_version_id=worker_version_id) + for worker_version_id in self.worker_runs.values_list('version_id', flat=True) + ], ignore_conflicts=True) + def retry(self): if self.mode == DataImportMode.Repository and self.revision is not None and not self.revision.repo.enabled: raise ValidationError('Git repository does not have any valid credentials') @@ -541,6 +548,12 @@ class WorkerVersion(models.Model): # The Docker internal image id (sha256:xxx) that can be shared across multiple images docker_image_iid = models.CharField(null=True, blank=True, max_length=80) + corpora = models.ManyToManyField( + 'documents.Corpus', + through='dataimport.CorpusWorkerVersion', + related_name='worker_versions', + ) + class Meta: unique_together = (('worker', 'revision'),) constraints = [ @@ -633,39 +646,6 @@ class WorkerActivityState(Enum): Error = 'error' -class ActivityManager(models.Manager): - """Model management for worker activities""" - - def bulk_insert(self, worker_version_id, process_id, elements_qs, state=WorkerActivityState.Queued): - """ - Create initial worker activities from a queryset of elements in a efficient way. - Due to the possible large amount of elements, we use a bulk insert from the elements query (best performances). - The `ON CONFLICT` clause allows to automatically skip elements that already have an activity with this version. - """ - assert isinstance(state, WorkerActivityState), 'State should be an instance of WorkerActivityState' - - sql, params = elements_qs.values('id').query.sql_with_params() - - with connections['default'].cursor() as cursor: - cursor.execute( - f""" - INSERT INTO dataimport_workeractivity - (element_id, worker_version_id, state, process_id, id, created, updated) - SELECT - elt.id, - '{worker_version_id}'::uuid, - '{state.value}', - '{process_id}', - uuid_generate_v4(), - current_timestamp, - current_timestamp - FROM ({sql}) AS elt - ON CONFLICT (element_id, worker_version_id) DO NOTHING - """, - params - ) - - class WorkerActivity(IndexableModel): """ Many-to-many relationship between Element and WorkerVersion @@ -702,3 +682,24 @@ class WorkerActivity(IndexableModel): unique_together = ( ('worker_version', 'element'), ) + + +class CorpusWorkerVersion(models.Model): + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + corpus = models.ForeignKey( + 'documents.Corpus', + on_delete=models.CASCADE, + related_name='worker_version_cache', + ) + worker_version = models.ForeignKey( + WorkerVersion, + on_delete=models.CASCADE, + related_name='corpus_cache', + ) + + objects = CorpusWorkerVersionManager() + + class Meta: + unique_together = ( + ('corpus', 'worker_version') + ) diff --git a/arkindex/dataimport/tests/commands/__init__.py b/arkindex/dataimport/tests/commands/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/arkindex/dataimport/tests/commands/test_cache_worker_versions.py b/arkindex/dataimport/tests/commands/test_cache_worker_versions.py new file mode 100644 index 0000000000..f90428feac --- /dev/null +++ b/arkindex/dataimport/tests/commands/test_cache_worker_versions.py @@ -0,0 +1,30 @@ +from django.core.management import call_command + +from arkindex.dataimport.models import WorkerVersion +from arkindex.project.tests import FixtureTestCase + + +class TestCacheWorkerVersions(FixtureTestCase): + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.dla, cls.recognizer = WorkerVersion.objects.order_by('worker__slug') + + def test_run(self): + self.corpus.worker_versions.add(self.dla) + self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.dla]) + call_command('cache_worker_versions') + self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.dla, self.recognizer]) + + def test_drop(self): + self.corpus.worker_versions.add(self.dla) + self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.dla]) + call_command('cache_worker_versions', drop=True) + self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.recognizer]) + + def test_ignore_conflicts(self): + self.corpus.worker_versions.add(self.recognizer) + self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.recognizer]) + call_command('cache_worker_versions') + self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.recognizer]) diff --git a/arkindex/dataimport/tests/test_fake_worker_version.py b/arkindex/dataimport/tests/commands/test_fake_worker_version.py similarity index 100% rename from arkindex/dataimport/tests/test_fake_worker_version.py rename to arkindex/dataimport/tests/commands/test_fake_worker_version.py diff --git a/arkindex/dataimport/tests/test_import_s3.py b/arkindex/dataimport/tests/commands/test_import_s3.py similarity index 100% rename from arkindex/dataimport/tests/test_import_s3.py rename to arkindex/dataimport/tests/commands/test_import_s3.py diff --git a/arkindex/dataimport/tests/test_imports.py b/arkindex/dataimport/tests/test_imports.py index a3be84c9d2..839e6c60e5 100644 --- a/arkindex/dataimport/tests/test_imports.py +++ b/arkindex/dataimport/tests/test_imports.py @@ -818,7 +818,7 @@ class TestImports(FixtureAPITestCase): def test_retry_no_workflow(self): self.client.force_login(self.user) self.assertIsNone(self.elts_process.workflow) - with self.assertNumQueries(17): + with self.assertNumQueries(18): response = self.client.post(reverse('api:import-retry', kwargs={'pk': self.elts_process.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.elts_process.refresh_from_db() @@ -1066,7 +1066,7 @@ class TestImports(FixtureAPITestCase): self.assertIsNone(dataimport2.workflow) self.client.force_login(self.user) - with self.assertNumQueries(21): + with self.assertNumQueries(22): response = self.client.post( reverse('api:process-start', kwargs={'pk': str(dataimport2.id)}) ) @@ -1089,7 +1089,7 @@ class TestImports(FixtureAPITestCase): self.assertNotEqual(get_default_farm_id(), barley_farm.id) workers_process = self.corpus.imports.create(creator=self.user, mode=DataImportMode.Workers) self.client.force_login(self.user) - with self.assertNumQueries(21): + with self.assertNumQueries(22): response = self.client.post( reverse('api:process-start', kwargs={'pk': str(workers_process.id)}), {'farm': str(barley_farm.id)} diff --git a/arkindex/dataimport/tests/test_managers.py b/arkindex/dataimport/tests/test_managers.py new file mode 100644 index 0000000000..8f671dd915 --- /dev/null +++ b/arkindex/dataimport/tests/test_managers.py @@ -0,0 +1,48 @@ +from uuid import uuid4 + +from arkindex.dataimport.models import CorpusWorkerVersion, Repository, RepositoryType, WorkerVersionState +from arkindex.documents.models import Classification, Element, Entity, MetaData, Transcription, TranscriptionEntity +from arkindex.project.tests import FixtureTestCase +from ponos.models import Artifact + + +class TestManagers(FixtureTestCase): + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.repo = Repository.objects.get(type=RepositoryType.Worker) + cls.revision = cls.repo.revisions.first() + cls.artifact = Artifact.objects.get() + # The fixtures have two worker versions, only one of them is used in existing elements + cls.recognizer = cls.repo.workers.get(slug='reco').versions.get() + + def _make_worker_version(self): + return self.revision.versions.create( + worker=self.repo.workers.create(slug=str(uuid4())), + configuration={}, + state=WorkerVersionState.Available, + docker_image=Artifact.objects.first(), + ) + + def test_corpus_worker_version_rebuild(self): + # Assign a different worker version for each ML result to get a lot of versions + querysets = [ + Element.objects.filter(worker_version_id=None), + Transcription.objects.filter(worker_version_id=None), + TranscriptionEntity.objects.filter(worker_version_id=None), + Entity.objects.filter(worker_version_id=None), + Classification.objects.filter(worker_version_id=None), + MetaData.objects.filter(worker_version_id=None), + ] + versions = [self.recognizer] + for queryset in querysets: + for obj in queryset: + version = self._make_worker_version() + versions.append(version) + obj.worker_version = version + obj.save() + + self.assertFalse(self.corpus.worker_versions.exists()) + CorpusWorkerVersion.objects.rebuild() + self.assertCountEqual(self.corpus.worker_versions.all(), versions) diff --git a/arkindex/dataimport/tests/test_workeractivity.py b/arkindex/dataimport/tests/test_workeractivity.py index 0a0a919ab7..2f715f5a35 100644 --- a/arkindex/dataimport/tests/test_workeractivity.py +++ b/arkindex/dataimport/tests/test_workeractivity.py @@ -100,7 +100,7 @@ class TestWorkerActivity(FixtureTestCase): best_class=agent_class.name ) dataimport.worker_runs.create(version=self.worker_version, parents=[]) - with self.assertNumQueries(20): + with self.assertNumQueries(22): dataimport.start() self.assertCountEqual( diff --git a/arkindex/dataimport/tests/test_workers.py b/arkindex/dataimport/tests/test_workers.py index 5faba0a8aa..40962cc8f4 100644 --- a/arkindex/dataimport/tests/test_workers.py +++ b/arkindex/dataimport/tests/test_workers.py @@ -874,7 +874,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): return data def test_corpus_worker_version_no_login(self): - self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1) + self.corpus.worker_versions.set([self.version_1]) with self.assertNumQueries(8): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) @@ -893,7 +893,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.user.verified_email = False self.user.save() self.client.force_login(self.user) - self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1) + self.corpus.worker_versions.set([self.version_1]) with self.assertNumQueries(12): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) @@ -910,7 +910,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): def test_corpus_worker_version_list(self): self.client.force_login(self.user) - self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1) + self.corpus.worker_versions.set([self.version_1]) with self.assertNumQueries(12): response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id})) @@ -924,23 +924,3 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self._serialize_worker_version(self.version_1) ] }) - - def test_corpus_worker_version_list_with_element_count(self): - self.client.force_login(self.user) - self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1) - - with self.assertNumQueries(12): - response = self.client.get( - reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}), - {'with_element_count': 'true'} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - self.assertDictEqual(response.json(), { - 'count': None, - 'previous': None, - 'next': None, - 'results': [ - self._serialize_worker_version(self.version_1, element_count=True) - ] - }) diff --git a/arkindex/dataimport/tests/test_workflows_api.py b/arkindex/dataimport/tests/test_workflows_api.py index 738514f85b..ddd39306fc 100644 --- a/arkindex/dataimport/tests/test_workflows_api.py +++ b/arkindex/dataimport/tests/test_workflows_api.py @@ -494,7 +494,7 @@ class TestWorkflows(FixtureAPITestCase): self.assertIsNone(dataimport_2.workflow) self.client.force_login(self.user) - with self.assertNumQueries(21): + with self.assertNumQueries(22): response = self.client.post( reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)}) ) @@ -511,6 +511,7 @@ class TestWorkflows(FixtureAPITestCase): 'image': 'registry.gitlab.com/arkindex/tasks' } }) + self.assertFalse(self.corpus.worker_versions.exists()) @patch('arkindex.project.triggers.dataimport_tasks.initialize_activity.delay') def test_workers_multiple_worker_runs(self, activities_delay_mock): @@ -532,9 +533,10 @@ class TestWorkflows(FixtureAPITestCase): workflow_tmp.start() self.assertIsNone(dataimport_2.workflow) + self.assertFalse(self.corpus.worker_versions.exists()) self.client.force_login(self.user) - with self.assertNumQueries(28): + with self.assertNumQueries(30): response = self.client.post( reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)}) ) @@ -601,6 +603,9 @@ class TestWorkflows(FixtureAPITestCase): WorkerActivity.objects.filter(worker_version=self.version_2).values_list('element_id', flat=True) ) + # Check that the corpus worker version cache has been updated + self.assertCountEqual(self.corpus.worker_versions.all(), [self.version_1, self.version_2]) + def test_create_process_use_cache_option(self): """ A process with the `use_cache` parameter creates an initialization task with the --use-cache flag @@ -613,7 +618,7 @@ class TestWorkflows(FixtureAPITestCase): dataimport_2.use_cache = True dataimport_2.save() self.client.force_login(self.user) - with self.assertNumQueries(25): + with self.assertNumQueries(27): response = self.client.post( reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)}) ) diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index 6373963ad6..092ad4b11f 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -413,6 +413,10 @@ LOGGING = { 'handlers': ['console'], 'level': 'INFO', }, + 'arkindex.dataimport.managers': { + 'handlers': ['console'], + 'level': 'INFO', + } }, 'formatters': { 'verbose': { -- GitLab