Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (82)
Showing
with 354 additions and 168 deletions
......@@ -56,7 +56,7 @@ RUN python -m nuitka \
--include-package=ponos \
--include-package=transkribus \
--show-progress \
--lto \
--lto=yes \
--output-dir=/build \
arkindex/manage.py
......
1.1.0-rc3
1.1.1-rc3
......@@ -52,7 +52,7 @@ from arkindex.dataimport.serializers.git import ExternalRepositorySerializer, Re
from arkindex.dataimport.serializers.imports import (
CreateImportTranskribusErrorResponseSerializer,
DataImportFromFilesSerializer,
DataImportLightSerializer,
DataImportListSerializer,
DataImportSerializer,
ElementsWorkflowSerializer,
ImportTranskribusSerializer,
......@@ -76,7 +76,6 @@ from arkindex.project.mixins import (
ConflictAPIException,
CorpusACLMixin,
CustomPaginationViewMixin,
DeprecatedMixin,
ProcessACLMixin,
RepositoryACLMixin,
SelectionMixin,
......@@ -141,7 +140,7 @@ class DataImportsList(ProcessACLMixin, ListAPIView):
List all visible data imports.
"""
permission_classes = (IsVerified, )
serializer_class = DataImportLightSerializer
serializer_class = DataImportListSerializer
def get_queryset(self):
filters = Q()
......@@ -502,7 +501,7 @@ class DataFileList(CorpusACLMixin, ListAPIView):
queryset = DataFile.objects.none()
def get_queryset(self):
return DataFile.objects.filter(corpus=self.get_corpus(self.kwargs['pk']))
return DataFile.objects.filter(corpus=self.get_corpus(self.kwargs['pk']), trashed=False)
@extend_schema(tags=['files'])
......@@ -534,11 +533,10 @@ class DataFileRetrieve(CorpusACLMixin, RetrieveUpdateDestroyAPIView):
def perform_destroy(self, instance):
if not self.has_write_access(instance.corpus):
raise PermissionDenied
try:
instance.s3_delete()
except Exception as e:
logger.warning(f'Could not delete DataFile from S3: {e}')
return super().perform_destroy(instance)
# The file is simply marked as trashed here and will be deleted asynchronously through a cron job
instance.trashed = True
instance.save()
@extend_schema_view(post=extend_schema(operation_id='CreateDataFile', tags=['files']))
......@@ -736,27 +734,6 @@ class RevisionRetrieve(RepositoryACLMixin, RetrieveAPIView):
return Revision.objects.filter(repo__in=self.executable_repositories)
@extend_schema(tags=['repos'])
@extend_schema_view(
get=extend_schema(operation_id='ListWorkersDeprecated'),
post=extend_schema(operation_id='CreateWorkerDeprecated'),
)
class DeprecatedWorkerList(DeprecatedMixin, ListCreateAPIView):
"""
List workers for a given repository UUID
Create a Worker instance on a repository for given slug, name and type, return existing
query if a worker with the same slug already exists on the repository
"""
permission_classes = (IsVerified, )
pagination_class = None
deprecation_message = (
'Listing workers associated to a repository is now deprecated. '
'A user is able to list all the workers for which they have an execution access '
'using `WorkerList` endpoint'
)
@extend_schema(tags=['repos'])
@extend_schema_view(
get=extend_schema(
......@@ -940,14 +917,6 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView):
'List worker versions used by elements of a given corpus.\n\n'
'No check is performed on workers access level in order to allow any user to see versions.'
),
parameters=[
OpenApiParameter(
'with_element_count',
type=bool,
default=False,
description='Include element counts in the response.',
)
],
)
)
class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView):
......@@ -964,25 +933,12 @@ class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView):
return get_object_or_404(self.readable_corpora, pk=self.kwargs['pk'])
def get_queryset(self):
corpus = self.get_corpus()
queryset = WorkerVersion.objects \
.filter(elements__corpus_id=corpus.id) \
.prefetch_related(
'revision__repo',
'revision__refs',
'revision__versions',
'worker__repository',
) \
.order_by('-id')
if self.request.query_params.get('with_element_count', '').lower() in ('true', '1'):
queryset = queryset.annotate(element_count=Count('id'))
else:
# The Count() causes Django to add a GROUP BY, and without a count we need a DISTINCT
# because filtering on `elements` causes worker versions to be duplicated.
queryset = queryset.distinct()
return queryset
return self.get_corpus().worker_versions.prefetch_related(
'revision__repo',
'revision__refs',
'revision__versions',
'worker__repository',
).order_by('-id')
@extend_schema(tags=['repos'])
......@@ -1267,6 +1223,8 @@ class ListProcessElements(CustomPaginationViewMixin, CorpusACLMixin, ListAPIView
'image__height',
'polygon',
'image_url',
'rotation_angle',
'mirrored',
)
......
from django.core.management.base import BaseCommand
from arkindex.dataimport.models import CorpusWorkerVersion
class Command(BaseCommand):
help = 'Rebuild the corpus worker versions cache'
def add_arguments(self, parser):
parser.add_argument(
'--drop',
help='Drop the existing cache before rebuilding.',
action='store_true',
)
def handle(self, *args, drop=False, **options):
if drop:
CorpusWorkerVersion.objects.all().delete()
self.stdout.write('Deleted all existing CorpusWorkerVersion.')
CorpusWorkerVersion.objects.rebuild()
import logging
from django.db import connections, models
logger = logging.getLogger(__name__)
class ActivityManager(models.Manager):
"""Model management for worker activities"""
def bulk_insert(self, worker_version_id, process_id, elements_qs, state=None):
"""
Create initial worker activities from a queryset of elements in a efficient way.
Due to the possible large amount of elements, we use a bulk insert from the elements query (best performances).
The `ON CONFLICT` clause allows to automatically skip elements that already have an activity with this version.
"""
from arkindex.dataimport.models import WorkerActivityState
if state is None:
state = WorkerActivityState.Queued
assert isinstance(state, WorkerActivityState), 'State should be an instance of WorkerActivityState'
sql, params = elements_qs.values('id').query.sql_with_params()
with connections['default'].cursor() as cursor:
cursor.execute(
f"""
INSERT INTO dataimport_workeractivity
(element_id, worker_version_id, state, process_id, id, created, updated)
SELECT
elt.id,
'{worker_version_id}'::uuid,
'{state.value}',
'{process_id}',
uuid_generate_v4(),
current_timestamp,
current_timestamp
FROM ({sql}) AS elt
ON CONFLICT (element_id, worker_version_id) DO NOTHING
""",
params
)
class CorpusWorkerVersionManager(models.Manager):
def rebuild(self):
"""
Rebuild the corpus worker versions cache from all ML results.
"""
from arkindex.documents.models import Element, Transcription, Entity, TranscriptionEntity, Classification, MetaData
querysets = [
Element.objects.exclude(worker_version_id=None).values_list('corpus_id', 'worker_version_id'),
Transcription.objects.exclude(worker_version_id=None).values_list('element__corpus_id', 'worker_version_id'),
Entity.objects.exclude(worker_version_id=None).values_list('corpus_id', 'worker_version_id'),
TranscriptionEntity.objects.exclude(worker_version_id=None).values_list('entity__corpus_id', 'worker_version_id'),
Classification.objects.exclude(worker_version_id=None).values_list('element__corpus_id', 'worker_version_id'),
MetaData.objects.exclude(worker_version_id=None).values_list('element__corpus_id', 'worker_version_id'),
]
for i, queryset in enumerate(querysets, start=1):
logger.info(f'Rebuilding cache from {queryset.model.__name__} ({i}/{len(querysets)})')
self.bulk_create([
self.model(corpus_id=corpus_id, worker_version_id=worker_version_id)
for corpus_id, worker_version_id in queryset.distinct()
], ignore_conflicts=True)
# Generated by Django 3.2.3 on 2021-08-31 07:53
import uuid
import django.db.models.deletion
from django.db import migrations, models
def rebuild_reminder(apps, schema_editor):
"""
Print a reminder to rebuild the cache manually if there is anything in the database.
"""
Corpus = apps.get_model('documents', 'Corpus')
if Corpus.objects.exists():
print("Please run `arkindex cache_worker_versions` to fill the corpus worker versions cache.")
class Migration(migrations.Migration):
dependencies = [
('documents', '0042_transcription_entity_confidence'),
('dataimport', '0034_worker_run_config'),
]
operations = [
migrations.CreateModel(
name='CorpusWorkerVersion',
fields=[
('id', models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False
)),
('corpus', models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name='worker_version_cache',
to='documents.corpus',
)),
('worker_version', models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name='corpus_cache',
to='dataimport.workerversion',
)),
],
options={
'unique_together': {('corpus', 'worker_version')},
},
),
migrations.AddField(
model_name='workerversion',
name='corpora',
field=models.ManyToManyField(
related_name='worker_versions',
through='dataimport.CorpusWorkerVersion',
to='documents.Corpus',
),
),
migrations.RunPython(
code=rebuild_reminder,
reverse_code=migrations.RunPython.noop,
)
]
# Generated by Django 3.2.6 on 2021-09-09 07:21
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataimport', '0035_corpus_version_cache'),
]
operations = [
migrations.AddField(
model_name='datafile',
name='trashed',
field=models.BooleanField(default=False),
),
]
......@@ -7,12 +7,13 @@ from uuid import UUID
import yaml
from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation
from django.db import connections, models
from django.db import models
from django.db.models import Q
from django.utils.functional import cached_property
from enumfields import Enum, EnumField
from rest_framework.exceptions import ValidationError
from arkindex.dataimport.managers import ActivityManager, CorpusWorkerVersionManager
from arkindex.dataimport.providers import get_provider, git_providers
from arkindex.dataimport.utils import get_default_farm_id
from arkindex.documents.models import ClassificationState, Element
......@@ -363,6 +364,12 @@ class DataImport(IndexableModel):
from arkindex.project.triggers import initialize_activity
initialize_activity(self)
if self.mode == DataImportMode.Workers:
CorpusWorkerVersion.objects.bulk_create([
CorpusWorkerVersion(corpus_id=self.corpus_id, worker_version_id=worker_version_id)
for worker_version_id in self.worker_runs.values_list('version_id', flat=True)
], ignore_conflicts=True)
def retry(self):
if self.mode == DataImportMode.Repository and self.revision is not None and not self.revision.repo.enabled:
raise ValidationError('Git repository does not have any valid credentials')
......@@ -396,6 +403,7 @@ class DataFile(S3FileMixin, models.Model):
content_type = models.CharField(max_length=120)
corpus = models.ForeignKey('documents.Corpus', on_delete=models.CASCADE, related_name='files')
status = EnumField(S3FileStatus, default=S3FileStatus.Unchecked, max_length=50)
trashed = models.BooleanField(default=False)
s3_bucket = settings.AWS_STAGING_BUCKET
......@@ -541,6 +549,12 @@ class WorkerVersion(models.Model):
# The Docker internal image id (sha256:xxx) that can be shared across multiple images
docker_image_iid = models.CharField(null=True, blank=True, max_length=80)
corpora = models.ManyToManyField(
'documents.Corpus',
through='dataimport.CorpusWorkerVersion',
related_name='worker_versions',
)
class Meta:
unique_together = (('worker', 'revision'),)
constraints = [
......@@ -633,39 +647,6 @@ class WorkerActivityState(Enum):
Error = 'error'
class ActivityManager(models.Manager):
"""Model management for worker activities"""
def bulk_insert(self, worker_version_id, process_id, elements_qs, state=WorkerActivityState.Queued):
"""
Create initial worker activities from a queryset of elements in a efficient way.
Due to the possible large amount of elements, we use a bulk insert from the elements query (best performances).
The `ON CONFLICT` clause allows to automatically skip elements that already have an activity with this version.
"""
assert isinstance(state, WorkerActivityState), 'State should be an instance of WorkerActivityState'
sql, params = elements_qs.values('id').query.sql_with_params()
with connections['default'].cursor() as cursor:
cursor.execute(
f"""
INSERT INTO dataimport_workeractivity
(element_id, worker_version_id, state, process_id, id, created, updated)
SELECT
elt.id,
'{worker_version_id}'::uuid,
'{state.value}',
'{process_id}',
uuid_generate_v4(),
current_timestamp,
current_timestamp
FROM ({sql}) AS elt
ON CONFLICT (element_id, worker_version_id) DO NOTHING
""",
params
)
class WorkerActivity(IndexableModel):
"""
Many-to-many relationship between Element and WorkerVersion
......@@ -702,3 +683,24 @@ class WorkerActivity(IndexableModel):
unique_together = (
('worker_version', 'element'),
)
class CorpusWorkerVersion(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
corpus = models.ForeignKey(
'documents.Corpus',
on_delete=models.CASCADE,
related_name='worker_version_cache',
)
worker_version = models.ForeignKey(
WorkerVersion,
on_delete=models.CASCADE,
related_name='corpus_cache',
)
objects = CorpusWorkerVersionManager()
class Meta:
unique_together = (
('corpus', 'worker_version')
)
......@@ -115,6 +115,15 @@ class DataImportSerializer(DataImportLightSerializer):
return data
class DataImportListSerializer(DataImportLightSerializer):
created = serializers.DateTimeField(read_only=True)
updated = serializers.DateTimeField(source='last_date', read_only=True)
class Meta(DataImportLightSerializer.Meta):
fields = DataImportLightSerializer.Meta.fields + ('created', 'updated')
read_only_fields = DataImportLightSerializer.Meta.read_only_fields + ('created', 'updated')
class DataImportFromFilesSerializer(serializers.Serializer):
mode = EnumField(DataImportMode, default=DataImportMode.Images)
......
from django.core.management import call_command
from arkindex.dataimport.models import WorkerVersion
from arkindex.project.tests import FixtureTestCase
class TestCacheWorkerVersions(FixtureTestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.dla, cls.recognizer = WorkerVersion.objects.order_by('worker__slug')
def test_run(self):
self.corpus.worker_versions.add(self.dla)
self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.dla])
call_command('cache_worker_versions')
self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.dla, self.recognizer])
def test_drop(self):
self.corpus.worker_versions.add(self.dla)
self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.dla])
call_command('cache_worker_versions', drop=True)
self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.recognizer])
def test_ignore_conflicts(self):
self.corpus.worker_versions.add(self.recognizer)
self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.recognizer])
call_command('cache_worker_versions')
self.assertCountEqual(list(self.corpus.worker_versions.all()), [self.recognizer])
from unittest.mock import patch
from botocore.exceptions import ClientError
from django.test import override_settings
from django.urls import reverse
from rest_framework import status
......@@ -44,6 +43,21 @@ class TestFiles(FixtureAPITestCase):
self.assertEqual(file['name'], self.df.name)
self.assertEqual(file['size'], self.df.size)
def test_file_list_filter_out_trashed_files(self):
DataFile.objects.create(
name='test2.txt',
size=42,
content_type='text/plain',
corpus=self.corpus,
trashed=True
)
self.assertEqual(DataFile.objects.count(), 2)
response = self.client.get(reverse('api:file-list', kwargs={'pk': self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertIn('results', data)
self.assertEqual(len(data['results']), 1)
def test_file_list_requires_login(self):
self.client.logout()
response = self.client.get(reverse('api:file-list', kwargs={'pk': self.corpus.id}))
......@@ -54,30 +68,9 @@ class TestFiles(FixtureAPITestCase):
self.assertTrue(DataFile.objects.exists())
response = self.client.delete(reverse('api:file-retrieve', kwargs={'pk': self.df.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertFalse(DataFile.objects.exists())
self.assertEqual(s3_mock.Object().delete.call_count, 1)
@patch('arkindex.project.aws.s3')
def test_file_delete_ignore_s3_errors(self, s3_mock):
"""
Test the DataFile deletion tries multiple times to delete from S3, but ignores errors if it fails
"""
exceptions = [
ClientError({'Error': {'Code': '500'}}, 'delete_object'),
ClientError({'Error': {'Code': '500'}}, 'delete_object'),
ValueError,
]
def _raise(*args, **kwargs):
raise exceptions.pop(0)
s3_mock.Object().delete.side_effect = _raise
self.assertTrue(DataFile.objects.exists())
response = self.client.delete(reverse('api:file-retrieve', kwargs={'pk': self.df.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertFalse(DataFile.objects.exists())
self.assertEqual(s3_mock.Object().delete.call_count, 3)
self.assertTrue(DataFile.objects.get().trashed)
self.assertEqual(s3_mock.Object().delete.call_count, 0)
def test_file_delete_requires_login(self):
"""
......
......@@ -87,6 +87,9 @@ class TestImports(FixtureAPITestCase):
)
def _serialize_process(self, process):
updated = process.updated
if process.workflow:
updated = max(updated, process.workflow.tasks.order_by('-updated').first().updated)
return {
'name': process.name,
'id': str(process.id),
......@@ -94,7 +97,9 @@ class TestImports(FixtureAPITestCase):
'mode': process.mode.value,
'corpus': process.corpus_id and str(process.corpus.id),
'workflow': process.workflow and f'http://testserver/ponos/v1/workflow/{process.workflow.id}/',
'activity_state': ActivityState.Disabled.value,
'activity_state': process.activity_state.value,
'created': process.created.isoformat().replace('+00:00', 'Z'),
'updated': updated.isoformat().replace('+00:00', 'Z')
}
def build_task(self, workflow_id, run, state, depth=1):
......@@ -157,15 +162,7 @@ class TestImports(FixtureAPITestCase):
data = response.json()
self.assertEqual(len(data['results']), 1)
results = data['results']
self.assertListEqual(results, [{
'name': None,
'id': str(self.user_img_process.id),
'state': State.Unscheduled.value,
'mode': DataImportMode.Images.value,
'corpus': str(self.user_img_process.corpus.id),
'workflow': f'http://testserver/ponos/v1/workflow/{self.user_img_process.workflow.id}/',
'activity_state': ActivityState.Ready.value,
}])
self.assertListEqual(results, [self._serialize_process(self.user_img_process)])
def test_list_exclude_workflow(self):
"""
......@@ -180,15 +177,7 @@ class TestImports(FixtureAPITestCase):
data = response.json()
self.assertEqual(len(data['results']), 1)
results = data['results']
self.assertListEqual(results, [{
'name': None,
'id': str(self.user_img_process.id),
'state': State.Unscheduled.value,
'mode': DataImportMode.Images.value,
'corpus': str(self.user_img_process.corpus.id),
'workflow': None,
'activity_state': ActivityState.Disabled.value,
}])
self.assertListEqual(results, [self._serialize_process(self.user_img_process)])
def test_list_filter_corpus(self):
self.client.force_login(self.superuser)
......@@ -829,7 +818,7 @@ class TestImports(FixtureAPITestCase):
def test_retry_no_workflow(self):
self.client.force_login(self.user)
self.assertIsNone(self.elts_process.workflow)
with self.assertNumQueries(17):
with self.assertNumQueries(18):
response = self.client.post(reverse('api:import-retry', kwargs={'pk': self.elts_process.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.elts_process.refresh_from_db()
......@@ -1077,7 +1066,7 @@ class TestImports(FixtureAPITestCase):
self.assertIsNone(dataimport2.workflow)
self.client.force_login(self.user)
with self.assertNumQueries(21):
with self.assertNumQueries(22):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(dataimport2.id)})
)
......@@ -1100,7 +1089,7 @@ class TestImports(FixtureAPITestCase):
self.assertNotEqual(get_default_farm_id(), barley_farm.id)
workers_process = self.corpus.imports.create(creator=self.user, mode=DataImportMode.Workers)
self.client.force_login(self.user)
with self.assertNumQueries(21):
with self.assertNumQueries(22):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(workers_process.id)}),
{'farm': str(barley_farm.id)}
......
from uuid import uuid4
from arkindex.dataimport.models import CorpusWorkerVersion, Repository, RepositoryType, WorkerVersionState
from arkindex.documents.models import Classification, Element, Entity, MetaData, Transcription, TranscriptionEntity
from arkindex.project.tests import FixtureTestCase
from ponos.models import Artifact
class TestManagers(FixtureTestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.repo = Repository.objects.get(type=RepositoryType.Worker)
cls.revision = cls.repo.revisions.first()
cls.artifact = Artifact.objects.get()
# The fixtures have two worker versions, only one of them is used in existing elements
cls.recognizer = cls.repo.workers.get(slug='reco').versions.get()
def _make_worker_version(self):
return self.revision.versions.create(
worker=self.repo.workers.create(slug=str(uuid4())),
configuration={},
state=WorkerVersionState.Available,
docker_image=Artifact.objects.first(),
)
def test_corpus_worker_version_rebuild(self):
# Assign a different worker version for each ML result to get a lot of versions
querysets = [
Element.objects.filter(worker_version_id=None),
Transcription.objects.filter(worker_version_id=None),
TranscriptionEntity.objects.filter(worker_version_id=None),
Entity.objects.filter(worker_version_id=None),
Classification.objects.filter(worker_version_id=None),
MetaData.objects.filter(worker_version_id=None),
]
versions = [self.recognizer]
for queryset in querysets:
for obj in queryset:
version = self._make_worker_version()
versions.append(version)
obj.worker_version = version
obj.save()
self.assertFalse(self.corpus.worker_versions.exists())
CorpusWorkerVersion.objects.rebuild()
self.assertCountEqual(self.corpus.worker_versions.all(), versions)
......@@ -82,12 +82,14 @@ class TestProcessElements(FixtureAPITestCase):
cls.page_4 = Element.objects.create(
corpus=cls.private_corpus,
name="Mongolfiere 2",
type=cls.page_type
type=cls.page_type,
rotation_angle=180,
)
cls.page_5 = Element.objects.create(
corpus=cls.private_corpus,
name="Baba au rhum 2",
type=cls.page_type
type=cls.page_type,
mirrored=True,
)
cls.page_1.add_parent(cls.folder_1)
cls.page_2.add_parent(cls.folder_1)
......
......@@ -100,7 +100,7 @@ class TestWorkerActivity(FixtureTestCase):
best_class=agent_class.name
)
dataimport.worker_runs.create(version=self.worker_version, parents=[])
with self.assertNumQueries(20):
with self.assertNumQueries(22):
dataimport.start()
self.assertCountEqual(
......
......@@ -874,7 +874,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
return data
def test_corpus_worker_version_no_login(self):
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
self.corpus.worker_versions.set([self.version_1])
with self.assertNumQueries(8):
response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}))
......@@ -893,7 +893,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
self.user.verified_email = False
self.user.save()
self.client.force_login(self.user)
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
self.corpus.worker_versions.set([self.version_1])
with self.assertNumQueries(12):
response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}))
......@@ -910,7 +910,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
def test_corpus_worker_version_list(self):
self.client.force_login(self.user)
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
self.corpus.worker_versions.set([self.version_1])
with self.assertNumQueries(12):
response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}))
......@@ -924,23 +924,3 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
self._serialize_worker_version(self.version_1)
]
})
def test_corpus_worker_version_list_with_element_count(self):
self.client.force_login(self.user)
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
with self.assertNumQueries(12):
response = self.client.get(
reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}),
{'with_element_count': 'true'}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
'count': None,
'previous': None,
'next': None,
'results': [
self._serialize_worker_version(self.version_1, element_count=True)
]
})
......@@ -494,7 +494,7 @@ class TestWorkflows(FixtureAPITestCase):
self.assertIsNone(dataimport_2.workflow)
self.client.force_login(self.user)
with self.assertNumQueries(21):
with self.assertNumQueries(22):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)})
)
......@@ -511,6 +511,7 @@ class TestWorkflows(FixtureAPITestCase):
'image': 'registry.gitlab.com/arkindex/tasks'
}
})
self.assertFalse(self.corpus.worker_versions.exists())
@patch('arkindex.project.triggers.dataimport_tasks.initialize_activity.delay')
def test_workers_multiple_worker_runs(self, activities_delay_mock):
......@@ -532,9 +533,10 @@ class TestWorkflows(FixtureAPITestCase):
workflow_tmp.start()
self.assertIsNone(dataimport_2.workflow)
self.assertFalse(self.corpus.worker_versions.exists())
self.client.force_login(self.user)
with self.assertNumQueries(28):
with self.assertNumQueries(30):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)})
)
......@@ -601,6 +603,9 @@ class TestWorkflows(FixtureAPITestCase):
WorkerActivity.objects.filter(worker_version=self.version_2).values_list('element_id', flat=True)
)
# Check that the corpus worker version cache has been updated
self.assertCountEqual(self.corpus.worker_versions.all(), [self.version_1, self.version_2])
def test_create_process_use_cache_option(self):
"""
A process with the `use_cache` parameter creates an initialization task with the --use-cache flag
......@@ -613,7 +618,7 @@ class TestWorkflows(FixtureAPITestCase):
dataimport_2.use_cache = True
dataimport_2.save()
self.client.force_login(self.user)
with self.assertNumQueries(25):
with self.assertNumQueries(27):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(dataimport_2.id)})
)
......