diff --git a/arkindex/dataimport/models.py b/arkindex/dataimport/models.py index eadae800de1dbf29901400c36d7ae2e65b630bed..825c96651d2008967e71ea1e4bab3dfd9d151ac2 100644 --- a/arkindex/dataimport/models.py +++ b/arkindex/dataimport/models.py @@ -7,7 +7,7 @@ from uuid import UUID import yaml from django.conf import settings from django.contrib.contenttypes.fields import GenericRelation -from django.db import models +from django.db import connections, models from django.db.models import Q from django.utils.functional import cached_property from enumfields import Enum, EnumField @@ -314,6 +314,13 @@ class DataImport(IndexableModel): def start(self, chunks=None, thumbnails=False, corpus_id=None): self.workflow = self.build_workflow(chunks, thumbnails, corpus_id) self.save() + + # Initialize activity on elements processed by this worker version + if settings.ARKINDEX_FEATURES['workers_activity'] and self.mode is DataImportMode.Workers: + # Note that an async job may be required to initialize activity on a large number of elements + for version_id in self.versions.values_list('id', flat=True): + WorkerActivity.objects.bulk_insert(worker_version_id=version_id, elements_qs=self.list_elements()) + self.workflow.start() def retry(self): @@ -564,6 +571,38 @@ class WorkerActivityState(Enum): Error = 'error' +class ActivityManager(models.Manager): + """Model management for worker activities""" + + def bulk_insert(self, worker_version_id, elements_qs, state=WorkerActivityState.Queued): + """ + Create initial worker activities from a queryset of elements in a efficient way. + Due to the possible large amount of elements, we use a bulk insert from the elements query (best performances). + The `ON CONFLICT` clause allows to automatically skip elements that already have an activity with this version. + """ + assert isinstance(state, WorkerActivityState), 'State should be an instance of WorkerActivityState' + + sql, params = elements_qs.values('id').query.sql_with_params() + + with connections['default'].cursor() as cursor: + cursor.execute( + f""" + INSERT INTO dataimport_workeractivity + (element_id, worker_version_id, state, id, created, updated) + SELECT + elt.id, + '{worker_version_id}'::uuid, + '{state.value}', + uuid_generate_v4(), + current_timestamp, + current_timestamp + FROM ({sql}) AS elt + ON CONFLICT (element_id, worker_version_id) DO NOTHING + """, + params + ) + + class WorkerActivity(IndexableModel): """ Many-to-many relationship between Element and WorkerVersion @@ -586,6 +625,9 @@ class WorkerActivity(IndexableModel): default=WorkerActivityState.Queued ) + # Specific WorkerActivity manager + objects = ActivityManager() + class Meta: unique_together = ( ('worker_version', 'element'), diff --git a/arkindex/dataimport/tests/test_imports.py b/arkindex/dataimport/tests/test_imports.py index 686e56ff1ac7b06e30ea327dcc12ad84d4a17618..97e67d8a99177a7cc771ea5774ff9522201cb7dc 100644 --- a/arkindex/dataimport/tests/test_imports.py +++ b/arkindex/dataimport/tests/test_imports.py @@ -702,6 +702,7 @@ class TestImports(FixtureAPITestCase): self.elts_process.refresh_from_db() self.assertEqual(self.elts_process.state, State.Unscheduled) + @override_settings(ARKINDEX_FEATURES={**settings.ARKINDEX_FEATURES, 'workers_activity': False}) def test_retry_no_workflow(self): self.client.force_login(self.user) self.assertIsNone(self.elts_process.workflow) @@ -944,6 +945,7 @@ class TestImports(FixtureAPITestCase): {'__all__': ['Only a DataImport with Workers mode and not already launched can be started later on']} ) + @override_settings(ARKINDEX_FEATURES={**settings.ARKINDEX_FEATURES, 'workers_activity': False}) def test_start_process(self): dataimport2 = self.corpus.imports.create(creator=self.user, mode=DataImportMode.Workers) self.assertIsNone(dataimport2.workflow) diff --git a/arkindex/dataimport/tests/test_workeractivity.py b/arkindex/dataimport/tests/test_workeractivity.py index 2517a861d79d2c5262edd2084ae3ce69ae209699..59b7f918b8dddef9f9613e392bf416cc7c47d400 100644 --- a/arkindex/dataimport/tests/test_workeractivity.py +++ b/arkindex/dataimport/tests/test_workeractivity.py @@ -1,10 +1,11 @@ import uuid +from django.test import override_settings from django.urls import reverse from rest_framework import status -from arkindex.dataimport.models import WorkerActivityState, WorkerVersion -from arkindex.documents.models import Element +from arkindex.dataimport.models import DataImportMode, WorkerActivity, WorkerActivityState, WorkerVersion +from arkindex.documents.models import Classification, ClassificationState, Element, MLClass from arkindex.project.tests import FixtureTestCase @@ -18,6 +19,62 @@ class TestWorkerActivity(FixtureTestCase): # Create a queued activity for this element cls.activity = cls.element.activities.create(worker_version=cls.worker_version, state=WorkerActivityState.Queued) + def test_bulk_insert_activity_children(self): + """ + Bulk insert worker activities for acts + """ + elements_qs = Element.objects.filter(type__slug='act', type__corpus_id=self.corpus.id) + params = { + 'worker_version_id': self.worker_version.id, + 'corpus_id': self.corpus.id, + 'state': WorkerActivityState.Started.value + } + with self.assertExactQueries('workeractivity_bulk_insert.sql', params=params): + WorkerActivity.objects.bulk_insert(self.worker_version.id, elements_qs, state=WorkerActivityState.Started) + self.assertEqual(elements_qs.count(), 5) + self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Started).count(), 5) + + def test_bulk_insert_activity_existing(self): + """ + Elements in the queryset should be skipped if they already have an activity + """ + elements_qs = Element.objects.filter(type__slug='act', type__corpus_id=self.corpus.id) + WorkerActivity.objects.bulk_create([ + WorkerActivity(element=element, worker_version=self.worker_version, state=WorkerActivityState.Processed.value) + for element in elements_qs[:2] + ]) + with self.assertNumQueries(1): + WorkerActivity.objects.bulk_insert(self.worker_version.id, elements_qs, state=WorkerActivityState.Started) + self.assertEqual(WorkerActivity.objects.filter(element_id__in=elements_qs.values('id')).count(), 5) + self.assertEqual(elements_qs.count(), 5) + # Only 3 acts have been marked as started for this worker + self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Started).count(), 3) + + @override_settings(ARKINDEX_FEATURES={'workers_activity': True}) + def test_bulk_insert_children_class_filter(self): + """ + Worker activities creation should work with complex elements selection (e.g. with a class filter) + """ + agent_class = MLClass.objects.create(name='James', corpus=self.corpus) + Classification.objects.bulk_create( + Classification(ml_class=agent_class, state=ClassificationState.Validated, element=e) + for e in self.corpus.elements.filter(type__slug='page') + ) + dataimport = self.corpus.imports.create( + creator=self.user, + mode=DataImportMode.Workers, + corpus=self.corpus, + best_class=agent_class.name + ) + dataimport.worker_runs.create(version=self.worker_version, parents=[]) + with self.assertNumQueries(21): + dataimport.start() + + self.assertCountEqual( + WorkerActivity.objects.filter(worker_version=self.worker_version).values_list('element_id', flat=True), + self.corpus.elements.filter(type__slug='page').values_list('id', flat=True) + ) + def test_put_activity_requires_internal(self): """ Only internal users (workers) are able to update the state of a worker activity diff --git a/arkindex/dataimport/tests/test_workflows_api.py b/arkindex/dataimport/tests/test_workflows_api.py index 632d0a4b6b1c193e122ca90bc8d477d3edd91988..02ea0ac7d3c705dab7ee45ecdccea6d4624ec90e 100644 --- a/arkindex/dataimport/tests/test_workflows_api.py +++ b/arkindex/dataimport/tests/test_workflows_api.py @@ -1,9 +1,10 @@ import yaml +from django.conf import settings from django.test import override_settings from rest_framework import status from rest_framework.reverse import reverse -from arkindex.dataimport.models import DataImport, DataImportMode, RepositoryType, WorkerVersion +from arkindex.dataimport.models import DataImport, DataImportMode, RepositoryType, WorkerActivity, WorkerVersion from arkindex.dataimport.utils import get_default_farm_id from arkindex.documents.models import Corpus, Element from arkindex.project.tests import FixtureAPITestCase @@ -474,15 +475,17 @@ class TestWorkflows(FixtureAPITestCase): f'python3 -m arkindex_tasks.generate_thumbnails /data/initialisation/elements_chunk_{i}.json' ) + @override_settings(ARKINDEX_FEATURES={**settings.ARKINDEX_FEATURES, 'workers_activity': False}) def test_workers_no_worker_runs(self): dataimport_2 = self.corpus.imports.create(creator=self.user, mode=DataImportMode.Workers) self.assertIsNone(dataimport_2.workflow) self.client.force_login(self.user) - response = self.client.post( - reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)}) - ) + with self.assertNumQueries(21): + response = self.client.post( + reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)}) + ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.json()['id'], str(dataimport_2.id)) dataimport_2.refresh_from_db() @@ -498,6 +501,7 @@ class TestWorkflows(FixtureAPITestCase): } }) + @override_settings(ARKINDEX_FEATURES={'workers_activity': True}) def test_workers_multiple_worker_runs(self): dataimport_2 = self.corpus.imports.create(creator=self.user, mode=DataImportMode.Workers) run_1 = dataimport_2.worker_runs.create( @@ -515,7 +519,7 @@ class TestWorkflows(FixtureAPITestCase): self.assertIsNone(dataimport_2.workflow) self.client.force_login(self.user) - with self.assertNumQueries(39): + with self.assertNumQueries(44): response = self.client.post( reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)}) ) @@ -549,3 +553,14 @@ class TestWorkflows(FixtureAPITestCase): 'env': {'TASK_ELEMENTS': '/data/initialisation/elements.json', 'WORKER_VERSION_ID': str(self.version_2.id)} } }) + + # Check worker activities has been created for concerned elements on both runs + elements_ids = self.corpus.elements.values_list('id', flat=True) + self.assertCountEqual( + elements_ids, + WorkerActivity.objects.filter(worker_version=self.version_1).values_list('element_id', flat=True) + ) + self.assertCountEqual( + elements_ids, + WorkerActivity.objects.filter(worker_version=self.version_2).values_list('element_id', flat=True) + ) diff --git a/arkindex/sql_validation/workeractivity_bulk_insert.sql b/arkindex/sql_validation/workeractivity_bulk_insert.sql new file mode 100644 index 0000000000000000000000000000000000000000..6455e2564231ed0a67ad06a678ea298ff34f9eec --- /dev/null +++ b/arkindex/sql_validation/workeractivity_bulk_insert.sql @@ -0,0 +1,13 @@ +INSERT INTO dataimport_workeractivity (element_id, worker_version_id, state, id, created, updated) +SELECT elt.id, + '{worker_version_id}'::uuid, + '{state}', + uuid_generate_v4(), + current_timestamp, + current_timestamp +FROM + (SELECT "documents_element"."id" + FROM "documents_element" + INNER JOIN "documents_elementtype" ON ("documents_element"."type_id" = "documents_elementtype"."id") + WHERE ("documents_elementtype"."corpus_id" = '{corpus_id}'::uuid + AND "documents_elementtype"."slug" = 'act')) AS elt ON CONFLICT (element_id, worker_version_id) DO NOTHING