Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (18)
Showing
with 768 additions and 449 deletions
1.3.6-beta1
1.3.6
......@@ -1696,6 +1696,9 @@ class ElementMetadata(ListCreateAPIView):
class ElementMetadataBulk(CreateAPIView):
"""
Create multiple metadata on an existing element.
Exactly one of `worker_version` or `worker_run_id` fields must be set.
If `worker_run` is set, the worker version will be deduced from it.
"""
permission_classes = (IsVerified, )
serializer_class = MetaDataBulkSerializer
......@@ -1904,7 +1907,10 @@ class ElementTypeUpdate(RetrieveUpdateDestroyAPIView):
)
class ElementBulkCreate(CreateAPIView):
"""
Create multiple child elements at once on a single parent
Create multiple child elements at once on a single parent.
Exactly one of `worker_version` or `worker_run_id` fields must be set.
If `worker_run_id` is set, the worker version will be deduced from it.
"""
serializer_class = ElementBulkSerializer
permission_classes = (IsVerified, )
......
......@@ -6,7 +6,7 @@ from uuid import UUID
from botocore.exceptions import ClientError
from django.conf import settings
from django.core.management.base import BaseCommand
from django.db.models import Exists, Max, OuterRef, Q, Value
from django.db.models import Exists, F, Max, OuterRef, Q, Value
from django.utils import timezone
from arkindex.documents.models import CorpusExport, CorpusExportState, Element
......@@ -89,10 +89,27 @@ class Command(BaseCommand):
def cleanup_expired_workflows(self):
# Keep workflows that built artifacts for WorkerVersions on Git tags or main branches
worker_version_docker_image_workflows = GitRef.objects.filter(
Q(type=GitRefType.Tag)
| Q(type=GitRefType.Branch, name__in=('master', 'main'))
).values('revision__versions__docker_image__task__workflow_id')
worker_version_docker_image_workflows = (
GitRef
.objects
.filter(
Q(type=GitRefType.Tag)
| Q(type=GitRefType.Branch, name__in=('master', 'main'))
)
# There might be a revision with no WorkerVersions at all, or a revision with
# no WorkerVersions that have a docker_image, which could cause the workflow ID
# to be NULL. This query will be used in a NOT IN clause, which would return
# FALSE when a workflow is in this subquery, and NULL when it isn't, because
# SQL handles NULL values weirdly. This would cause the parent query to evaluate
# a WHERE NULL, which is assumed to be FALSE, so all workflows would be excluded.
#
# Excluding NULLs directly with a .exclude(revision__...__workflow_id=None)
# causes the JOINs to be duplicated, so we use an annotation to make sure the
# ORM understands we are filtering on the column that we are selecting.
.annotate(workflow_id=F('revision__versions__docker_image__task__workflow_id'))
.exclude(workflow_id=None)
.values('workflow_id')
)
expired_workflows = Workflow \
.objects \
......
# Generated by Django 4.1.3 on 2022-12-05 16:16
from django.db import migrations, models
from arkindex.documents.models import MetaType
class Migration(migrations.Migration):
dependencies = [
('documents', '0059_remove_corpus_thumbnail'),
]
operations = [
migrations.RemoveConstraint(
model_name='metadata',
name='metadata_numeric_values',
),
migrations.AddConstraint(
model_name='metadata',
constraint=models.CheckConstraint(
check=~models.Q(type=MetaType.Numeric) | models.Q(value__iregex=r'^[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:E[+-]?\d+)?$'),
name='metadata_numeric_values'
),
),
]
......@@ -11,7 +11,6 @@ from django.core.exceptions import ValidationError
from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import connections, models, transaction
from django.db.models import Deferrable, Q
from django.db.models.expressions import RawSQL
from django.db.models.functions import Cast, Least
from django.utils.functional import cached_property
from enumfields import Enum, EnumField
......@@ -923,14 +922,15 @@ class MetaData(InterpretedDateMixin, models.Model):
class Meta:
ordering = ('element', 'name', 'id')
constraints = [
# Either the metadata is not numeric, or it has a value that casts to a double properly.
# Casting to a double can cause an exception to occur if the value is invalid, so instead we check that
# the value matches the PostgreSQL syntax for numeric constants using a regex.
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-CONSTANTS-NUMERIC
# 1, +1, -1, .1, 1., 1.0, 1E4, 1E+4, 1e-4, -4E2, -4.E5, .5e2, +123.456E+9 are all acceptable values.
# This regex looks for an optional leading sign, then splits into two cases to handle both `xx(.(yy))`
# and `.yy`, then accepts an optional exponent with an optional leading sign.
models.CheckConstraint(
# Either the metadata is not numeric, or it has a value that casts to a double properly.
# If the cast fails, it will always cause an exception, but SQL still wants a boolean expression,
# so we just tell it to check that it isn't null.
# Django does not let us write a proper boolean expression in here since the left-hand side needs to
# use a Cast() function, but the maintainers chose to not implement any support for this case, so we
# have to use RawSQL. https://code.djangoproject.com/ticket/31646
check=~Q(type=MetaType.Numeric) | RawSQL('value::double precision IS NOT NULL', params=(), output_field=models.BooleanField()),
check=~Q(type=MetaType.Numeric) | Q(value__iregex=r'^[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:E[+-]?\d+)?$'),
name='metadata_numeric_values'
),
# There can be a worker run ID only if there is a worker version ID,
......
......@@ -20,11 +20,10 @@ from arkindex.documents.serializers.light import (
ElementTypeLightSerializer,
MetaDataLightSerializer,
)
from arkindex.documents.serializers.ml import ClassificationSerializer
from arkindex.documents.serializers.ml import ClassificationSerializer, WorkerRunSummarySerializer
from arkindex.images.models import Image
from arkindex.images.serializers import ZoneSerializer
from arkindex.process.models import WorkerRun, WorkerVersion
from arkindex.process.serializers.worker_runs import WorkerRunSummarySerializer
from arkindex.project.fields import Array
from arkindex.project.mixins import SelectionMixin
from arkindex.project.serializer_fields import LinearRingField
......@@ -189,7 +188,7 @@ class MetaDataBulkSerializer(serializers.Serializer):
required=False,
allow_null=True,
style={'base_template': 'input.html'},
source='worker_run'
source='worker_run',
)
class Meta:
......@@ -213,8 +212,8 @@ class MetaDataBulkSerializer(serializers.Serializer):
def validate(self, data):
data = super().validate(data)
worker_run = data.get('worker_run')
if worker_run is not None:
data['worker_version'] = WorkerVersion(id=worker_run.version_id)
if not worker_run and not data.get('worker_version'):
raise ValidationError('Exactly one of `worker_version` or `worker_run` must be set.')
request_metadata = self.make_metadata_tuples(data['metadata_list'])
unique_metadata = set(request_metadata)
if len(unique_metadata) != len(request_metadata):
......@@ -229,13 +228,24 @@ class MetaDataBulkSerializer(serializers.Serializer):
return data
def create(self, validated_data):
base_attrs = {}
worker_run = validated_data.get('worker_run', None)
if worker_run is not None:
# Retrieve the version from the specified worker run
# No need for an extra query here
base_attrs.update({
"worker_run": worker_run,
"worker_version_id": worker_run.version_id,
})
else:
base_attrs.update({"worker_version": validated_data['worker_version']})
validated_data['metadata_list'] = MetaData.objects.bulk_create([
MetaData(
**m,
**base_attrs,
id=uuid.uuid4(),
element=self.context['element'],
worker_version=validated_data.get('worker_version', None),
worker_run=validated_data.get('worker_run', None),
**m
)
for m in validated_data['metadata_list']
])
......@@ -934,7 +944,11 @@ class ElementBulkSerializer(serializers.Serializer):
element_errors[i] = {'polygon': ["An element's polygon must not exceed its image's bounds."]}
worker_run = data.get('worker_run', None)
if worker_run:
if not worker_run and not data.get('worker_version'):
errors['non_field_errors'] = [
'Exactly one of `worker_version` or `worker_run_id` must be set.'
]
elif worker_run:
data['worker_version'] = WorkerVersion(id=worker_run.version_id)
if element_errors:
......
......@@ -2,8 +2,8 @@ from rest_framework import serializers
from arkindex.documents.models import Corpus, Entity, EntityLink, EntityRole, EntityType, TranscriptionEntity
from arkindex.documents.serializers.light import CorpusLightSerializer, InterpretedDateSerializer
from arkindex.documents.serializers.ml import WorkerRunSummarySerializer
from arkindex.process.models import WorkerRun, WorkerVersion
from arkindex.process.serializers.worker_runs import WorkerRunSummarySerializer
from arkindex.project.serializer_fields import EnumField
from arkindex.project.tools import WorkerRunOrVersionValidator
......
......@@ -18,11 +18,16 @@ from arkindex.documents.models import (
)
from arkindex.documents.serializers.light import ElementZoneSerializer
from arkindex.process.models import WorkerRun, WorkerVersion
from arkindex.process.serializers.worker_runs import WorkerRunSummarySerializer
from arkindex.project.serializer_fields import EnumField, LinearRingField
from arkindex.project.tools import ConditionalUniqueValidator, WorkerRunOrVersionValidator, polygon_outside_image
# Defined here to avoid circular imports, because used by documents serializer
class WorkerRunSummarySerializer(serializers.Serializer):
id = serializers.UUIDField()
summary = serializers.CharField(help_text="Human-readable summary of a WorkerRun's information")
class ClassificationMode(Enum):
"""
Classification mode
......@@ -591,7 +596,7 @@ class ClassificationBulkSerializer(serializers.Serializer):
Cannot use ModelSerializer as they become read-only when nested
"""
id = serializers.UUIDField(read_only=True)
class_name = serializers.CharField(source='ml_class')
ml_class = serializers.UUIDField()
confidence = serializers.FloatField(min_value=0, max_value=1)
high_confidence = serializers.BooleanField(default=False)
state = EnumField(ClassificationState, read_only=True)
......@@ -639,43 +644,33 @@ class ClassificationsSerializer(serializers.Serializer):
def validate(self, data):
data = super().validate(data)
errors = defaultdict(list)
ml_class_names = [
ml_class_ids = [
classification['ml_class']
for classification in data['classifications']
]
if len(ml_class_names) != len(set(ml_class_names)):
raise ValidationError({
'classifications': ['Duplicated ML classes are not allowed from the same worker version or worker run.']
})
return data
if len(ml_class_ids) != len(set(ml_class_ids)):
errors['classifications'].append('Duplicated ML classes are not allowed from the same worker version or worker run.')
def create(self, validated_data):
parent = validated_data['parent']
ml_class_names = set(
classification['ml_class']
for classification in validated_data['classifications']
)
# Fetch existing ML classes
ml_classes = dict(
# Check that all ML classes exist in one query
ml_class_count = (
MLClass
.objects
.using('default')
.filter(corpus_id=parent.corpus_id, name__in=ml_class_names)
.values_list('name', 'id')
.filter(corpus_id=data['parent'].corpus_id, id__in=ml_class_ids)
.count()
)
if ml_class_count < len(set(ml_class_ids)):
errors['classifications'].append('Some ML classes do not exist or are not linked to this corpus.')
# Create missing classes
new_classes = [
MLClass(id=uuid.uuid4(), corpus_id=parent.corpus_id, name=name)
for name in ml_class_names - set(ml_classes.keys())
]
MLClass.objects.bulk_create(new_classes)
if errors:
raise ValidationError(errors)
ml_classes.update({ml_class.name: ml_class.id for ml_class in new_classes})
return data
def create(self, validated_data):
parent = validated_data['parent']
worker_version = validated_data.get('worker_version')
worker_run = validated_data.get('worker_run')
......@@ -687,6 +682,8 @@ class ClassificationsSerializer(serializers.Serializer):
parent.classifications.filter(worker_version=worker_version, worker_run_id__isnull=True).delete()
worker_version_id = worker_version.id
# Those attributes will be both used by the bulk_create below
# and sent back in the API response
for cl in validated_data['classifications']:
cl['id'] = uuid.uuid4()
cl['state'] = ClassificationState.Pending
......@@ -695,12 +692,12 @@ class ClassificationsSerializer(serializers.Serializer):
Classification(
id=cl['id'],
element=parent,
ml_class_id=ml_classes[cl['ml_class']],
ml_class_id=cl['ml_class'],
confidence=cl['confidence'],
high_confidence=cl['high_confidence'],
worker_version_id=worker_version_id,
worker_run=worker_run,
state=cl['state']
state=cl['state'],
)
for cl in validated_data['classifications']
])
......
......@@ -577,6 +577,98 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(ponos_s3_mock.Object().delete.call_count, 4)
@patch('ponos.models.s3')
def test_cleanup_expired_workflows_null(self, ponos_s3_mock, s3_mock):
repo = Repository.objects.get(url='http://my_repo.fake/workers/worker')
# This revision on the `main` branch does not have any WorkerVersions.
# Improper handling of NULL values in the queries looking for expired workflows and
# excluding revisions that should not be deleted based on the GitRefs could lead to
# this revision causing no expired workflows to ever be found.
empty_revision = repo.revisions.create(
hash=str(uuid.uuid4()),
message='A revision with no worker versions',
author='Someone',
)
empty_revision.refs.create(repository=repo, type=GitRefType.Branch, name='main')
unavailable_worker_revision = repo.revisions.create(
hash=str(uuid.uuid4()),
message='A revision with a worker version with no artifact',
author='Someone',
)
unavailable_worker_revision.refs.create(repository=repo, type=GitRefType.Tag, name='1.2.3-rc4')
# Same as above: this WorkerVersion with no docker_image could break the query.
repo.workers.get(slug='dla').versions.create(
revision=unavailable_worker_revision,
configuration={},
)
# These artifacts should be cleaned up
lonely_revision, lonely_artifact = self._make_revision_artifact()
branch_revision, branch_artifact = self._make_revision_artifact()
branch_revision.refs.create(repository=branch_revision.repo, type=GitRefType.Branch, name='my-awesome-branch')
ponos_s3_mock.Object().key = 's3_key'
self.assertEqual(
self.cleanup(),
dedent(
"""
Removing orphaned Ponos artifacts…
Successfully cleaned up orphaned Ponos artifacts.
Removing 2 artifacts of expired workflows from S3…
Removing artifact s3_key
Removing artifact s3_key
Removing logs for 2 tasks of expired workflows from S3…
Removing task log s3_key
Removing task log s3_key
Updating 2 available worker versions to the Error state…
Removing 2 artifacts of expired workflows…
Removing 2 tasks of expired workflows…
Removing 2 expired workflows…
Successfully cleaned up expired workflows.
Removing 0 old corpus exports from S3…
Removing 0 old corpus exports…
Successfully cleaned up old corpus exports.
Removing orphaned corpus exports…
Successfully cleaned up orphaned corpus exports.
Deleting 0 DataFiles marked as trashed from S3 and the database…
Successfully cleaned up DataFiles marked as trashed.
Removing orphan images…
Successfully cleaned up orphan images.
Removing orphaned local images…
Successfully cleaned up orphaned local images.
Removing orphaned Ponos logs…
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
"""
).strip()
)
with self.assertRaises(Artifact.DoesNotExist):
lonely_artifact.refresh_from_db()
with self.assertRaises(Task.DoesNotExist):
lonely_artifact.task.refresh_from_db()
with self.assertRaises(Workflow.DoesNotExist):
lonely_artifact.task.workflow.refresh_from_db()
with self.assertRaises(Artifact.DoesNotExist):
branch_artifact.refresh_from_db()
with self.assertRaises(Task.DoesNotExist):
branch_artifact.task.refresh_from_db()
with self.assertRaises(Workflow.DoesNotExist):
branch_artifact.task.workflow.refresh_from_db()
# Those still exist, refreshing works
lonely_revision.refresh_from_db()
branch_revision.refresh_from_db()
empty_revision.refresh_from_db()
unavailable_worker_revision.refresh_from_db()
self.assertEqual(ponos_s3_mock.Object().delete.call_count, 4)
@patch('arkindex.documents.management.commands.cleanup.s3')
def test_cleanup_local_images(self, cleanup_s3_mock, s3_mock):
ImageServer.objects.local.images.create(path='path%2Fto%2Fimage.jpg')
......
from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Corpus, MLClass
from arkindex.documents.models import Corpus
from arkindex.process.models import WorkerVersion
from arkindex.project.tests import FixtureAPITestCase
......@@ -15,13 +15,8 @@ class TestBulkClassification(FixtureAPITestCase):
cls.private_corpus = Corpus.objects.create(name='private', public=False)
cls.worker_version = WorkerVersion.objects.get(worker__slug='reco')
cls.worker_run = cls.worker_version.worker_runs.get()
def create_classifications_data(self, classifications, parent=None):
return {
"parent": parent or str(self.page.id),
"worker_version": str(self.worker_version.id),
"classifications": classifications,
}
cls.dog_class = cls.corpus.ml_classes.create(name='dog')
cls.cat_class = cls.corpus.ml_classes.create(name='cat')
def test_requires_login(self):
response = self.client.post(reverse('api:classification-bulk'), format='json')
......@@ -36,13 +31,19 @@ class TestBulkClassification(FixtureAPITestCase):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data=self.create_classifications_data(
[{"class_name": 'dog', "confidence": 0.99}],
parent=str(private_page.id)
)
data={
'parent': str(private_page.id),
'worker_version': str(self.worker_version.id),
'classifications': [
{
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
},
],
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json(),
{
......@@ -55,25 +56,29 @@ class TestBulkClassification(FixtureAPITestCase):
Classifications are created and linked to a worker version
"""
self.client.force_login(self.user)
with self.assertNumQueries(10):
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data=self.create_classifications_data([
{
"class_name": 'dog',
"confidence": 0.99,
"high_confidence": True
},
{
"class_name": 'cat',
"confidence": 0.42,
}
])
data={
'parent': str(self.page.id),
'worker_version': str(self.worker_version.id),
'classifications': [
{
'ml_class': str(self.dog_class.id),
"confidence": 0.99,
"high_confidence": True
},
{
'ml_class': str(self.cat_class.id),
"confidence": 0.42,
}
],
},
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
first_cl, second_cl = self.page.classifications.values_list('id', flat=True)
first_cl, second_cl = self.page.classifications.order_by('-confidence').values_list('id', flat=True)
self.assertEqual(response.json(), {
'parent': str(self.page.id),
'worker_version': str(self.worker_version.id),
......@@ -81,14 +86,14 @@ class TestBulkClassification(FixtureAPITestCase):
'classifications': [
{
'id': str(first_cl),
'class_name': 'dog',
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
'high_confidence': True,
'state': 'pending',
},
{
'id': str(second_cl),
'class_name': 'cat',
'ml_class': str(self.cat_class.id),
'confidence': 0.42,
'high_confidence': False,
'state': 'pending',
......@@ -119,7 +124,7 @@ class TestBulkClassification(FixtureAPITestCase):
"parent": str(self.page.id),
"classifications": [
{
"class_name": 'cat',
'ml_class': str(self.cat_class.id),
"confidence": 0.42,
}
],
......@@ -135,7 +140,7 @@ class TestBulkClassification(FixtureAPITestCase):
def test_worker_run(self):
self.client.force_login(self.user)
with self.assertNumQueries(10):
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
......@@ -143,12 +148,12 @@ class TestBulkClassification(FixtureAPITestCase):
"parent": str(self.page.id),
"classifications": [
{
"class_name": 'dog',
'ml_class': str(self.dog_class.id),
"confidence": 0.99,
"high_confidence": True
},
{
"class_name": 'cat',
'ml_class': str(self.cat_class.id),
"confidence": 0.42,
}
],
......@@ -157,7 +162,7 @@ class TestBulkClassification(FixtureAPITestCase):
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
first_cl, second_cl = self.page.classifications.values_list('id', flat=True)
first_cl, second_cl = self.page.classifications.order_by('-confidence').values_list('id', flat=True)
self.assertEqual(response.json(), {
'parent': str(self.page.id),
'worker_version': None,
......@@ -165,14 +170,14 @@ class TestBulkClassification(FixtureAPITestCase):
'classifications': [
{
'id': str(first_cl),
'class_name': 'dog',
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
'high_confidence': True,
'state': 'pending',
},
{
'id': str(second_cl),
'class_name': 'cat',
'ml_class': str(self.cat_class.id),
'confidence': 0.42,
'high_confidence': False,
'state': 'pending',
......@@ -204,12 +209,12 @@ class TestBulkClassification(FixtureAPITestCase):
"parent": str(self.page.id),
"classifications": [
{
"class_name": 'dog',
'ml_class': str(self.dog_class.id),
"confidence": 0.99,
"high_confidence": True
},
{
"class_name": 'cat',
'ml_class': str(self.cat_class.id),
"confidence": 0.42,
}
],
......@@ -222,79 +227,89 @@ class TestBulkClassification(FixtureAPITestCase):
'worker_run_id': ['Invalid pk "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" - object does not exist.'],
})
def test_create_ml_class(self):
"""
Adding classifications with non existing classes should automatically create them
"""
self.assertEqual(MLClass.objects.filter(name="dog", corpus=self.corpus).count(), 0)
def test_ml_class_not_found(self):
self.dog_class.delete()
self.client.force_login(self.user)
with self.assertNumQueries(10):
self.client.post(
reverse('api:classification-bulk'),
format='json',
data=self.create_classifications_data(
[{"class_name": 'dog', "confidence": 0.99}],
)
)
self.assertEqual(MLClass.objects.filter(name="dog", corpus=self.corpus).count(), 1)
def test_existing_ml_class(self):
MLClass.objects.create(name="cat", corpus=self.corpus)
self.client.force_login(self.user)
with self.assertNumQueries(9):
self.client.post(
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data=self.create_classifications_data(
[{"class_name": 'cat', "confidence": 0.42}],
)
data={
'parent': str(self.page.id),
'worker_version': str(self.worker_version.id),
'classifications': [
{
'ml_class': "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
'confidence': 0.99,
},
],
},
)
self.assertEqual(MLClass.objects.filter(name="cat", corpus=self.corpus).count(), 1)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'classifications': ['Some ML classes do not exist or are not linked to this corpus.'],
})
def test_delete_worker_version(self):
"""
Test the bulk classification API deletes previous classifications with a similar worker version
"""
self.client.force_login(self.user)
with self.assertNumQueries(10):
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data=self.create_classifications_data([
{
"class_name": 'dog',
"confidence": 0.99,
"high_confidence": True
},
{
"class_name": 'cat',
"confidence": 0.42,
}
])
data={
'parent': str(self.page.id),
'worker_version': str(self.worker_version.id),
'classifications': [
{
'ml_class': str(self.dog_class.id),
"confidence": 0.99,
"high_confidence": True,
},
{
'ml_class': str(self.cat_class.id),
'confidence': 0.42,
},
],
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
with self.assertNumQueries(8):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
another_cat_class = self.corpus.ml_classes.create(name='calico')
best_cat_class = self.corpus.ml_classes.create(name='callico')
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data=self.create_classifications_data([
{
"class_name": 'doggo',
"confidence": 0.5,
},
{
"class_name": 'catte',
"confidence": 0.85,
"high_confidence": True
}
])
data={
'parent': str(self.page.id),
'worker_version': str(self.worker_version.id),
'classifications': [
{
'ml_class': str(another_cat_class.id),
"confidence": 0.5,
},
{
'ml_class': str(best_cat_class.id),
"confidence": 0.85,
"high_confidence": True,
},
],
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertCountEqual(
list(self.page.classifications.values_list('ml_class__name', 'confidence', 'high_confidence')),
[
('doggo', 0.5, False),
('catte', 0.85, True),
('calico', 0.5, False),
('callico', 0.85, True),
],
)
......@@ -303,7 +318,7 @@ class TestBulkClassification(FixtureAPITestCase):
Test the bulk classification API deletes previous classifications with the same worker run
"""
self.client.force_login(self.user)
with self.assertNumQueries(10):
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
......@@ -311,38 +326,41 @@ class TestBulkClassification(FixtureAPITestCase):
"parent": str(self.page.id),
"classifications": [
{
"class_name": 'dog',
'ml_class': str(self.dog_class.id),
"confidence": 0.99,
"high_confidence": True
},
{
"class_name": 'cat',
'ml_class': str(self.cat_class.id),
"confidence": 0.42,
}
],
"worker_run_id": str(self.worker_run.id),
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.json())
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
with self.assertNumQueries(8):
another_cat_class = self.corpus.ml_classes.create(name='calico')
best_cat_class = self.corpus.ml_classes.create(name='callico')
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
"parent": str(self.page.id),
"classifications": [
'parent': str(self.page.id),
'classifications': [
{
"class_name": 'doggo',
'ml_class': str(another_cat_class.id),
"confidence": 0.5,
},
{
"class_name": 'catte',
'ml_class': str(best_cat_class.id),
"confidence": 0.85,
"high_confidence": True
}
"high_confidence": True,
},
],
"worker_run_id": str(self.worker_run.id),
'worker_run_id': str(self.worker_run.id),
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
......@@ -350,8 +368,8 @@ class TestBulkClassification(FixtureAPITestCase):
self.assertCountEqual(
list(self.page.classifications.values_list('ml_class__name', 'confidence', 'high_confidence')),
[
('doggo', 0.5, False),
('catte', 0.85, True),
('calico', 0.5, False),
('callico', 0.85, True),
],
)
......@@ -360,16 +378,26 @@ class TestBulkClassification(FixtureAPITestCase):
Test the bulk classification API prevents creating classifications with duplicate ML classes
"""
self.client.force_login(self.user)
with self.assertNumQueries(6):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data=self.create_classifications_data([
{"class_name": 'dog', "confidence": 0.99},
{"class_name": 'dog', "confidence": 0.99},
])
data={
'parent': str(self.page.id),
'worker_version': str(self.worker_version.id),
'classifications': [
{
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
},
{
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
},
]
}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'classifications': ['Duplicated ML classes are not allowed from the same worker version or worker run.']
})
from uuid import uuid4
from django.contrib.gis.geos import LineString
from django.urls import reverse
from rest_framework import status
......@@ -143,65 +144,6 @@ class TestBulkElements(FixtureAPITestCase):
'non_field_errors': ["Element types with slugs nope do not exist in the parent element's corpus"]
})
def test_bulk_create(self):
self.client.force_login(self.user)
with self.assertNumQueries(13):
response = self.client.post(
reverse('api:elements-bulk-create', kwargs={'pk': str(self.element.id)}),
data=self.payload,
format='json',
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
element_path = self.element.paths.get()
a, b, c = Element \
.objects \
.get_descending(self.element.id) \
.filter(type__slug__in=('act', 'surface')) \
.order_by('name')
a_path, b_path, c_path = a.paths.get(), b.paths.get(), c.paths.get()
self.assertListEqual(
response.json(),
[
{'id': str(a.id)},
{'id': str(b.id)},
{'id': str(c.id)}
]
)
self.assertEqual(a.name, 'A')
self.assertEqual(b.name, 'B')
self.assertEqual(c.name, 'C')
self.assertEqual(a.type.slug, 'act')
self.assertEqual(b.type.slug, 'surface')
self.assertEqual(c.type.slug, 'surface')
self.assertEqual(a.worker_version, self.worker_version)
self.assertEqual(b.worker_version, self.worker_version)
self.assertEqual(c.worker_version, self.worker_version)
self.assertEqual(a.worker_run, None)
self.assertEqual(b.worker_run, None)
self.assertEqual(c.worker_run, None)
self.assertEqual(a.image_id, self.element.image_id)
self.assertEqual(b.image_id, self.element.image_id)
self.assertEqual(c.image_id, self.element.image_id)
self.assertTupleEqual(a.polygon.coords, ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)))
self.assertTupleEqual(b.polygon.coords, ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)))
self.assertTupleEqual(c.polygon.coords, ((0, 0), (0, 9), (9, 9), (9, 0), (0, 0)))
self.assertEqual(a.rotation_angle, self.element.rotation_angle)
self.assertEqual(b.rotation_angle, self.element.rotation_angle)
self.assertEqual(c.rotation_angle, self.element.rotation_angle)
self.assertEqual(a.mirrored, self.element.mirrored)
self.assertEqual(b.mirrored, self.element.mirrored)
self.assertEqual(c.mirrored, self.element.mirrored)
self.assertListEqual(a_path.path, element_path.path + [self.element.id])
self.assertListEqual(b_path.path, element_path.path + [self.element.id])
self.assertListEqual(c_path.path, element_path.path + [self.element.id])
self.assertEqual(a_path.ordering, 0)
self.assertEqual(b_path.ordering, 0)
self.assertEqual(c_path.ordering, 1)
def test_bulk_create_multiple_parent_paths(self):
parent = self.corpus.elements.create(
name='Parent 2',
......@@ -349,7 +291,6 @@ class TestBulkElements(FixtureAPITestCase):
"""
Cannot create elements outside their image
"""
self.maxDiff = None
self.client.force_login(self.user)
payload = {
'worker_version': str(self.worker_version.id),
......@@ -427,7 +368,32 @@ class TestBulkElements(FixtureAPITestCase):
'worker_run_id': [f'Invalid pk "{random_uuid}" - object does not exist.']
})
def test_bulk_create_invalid_parameters_worker_run(self):
def test_bulk_create_worker_run_or_version(self):
"""Either a worker run or a worker version is required"""
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:elements-bulk-create', kwargs={'pk': str(self.element.id)}),
data={
'elements': [
{
'name': 'Blah',
'type': 'surface',
'polygon': [[0, 10], [10, 20], [30, 40], [50, 60], [0, 10]]
}
]
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'non_field_errors': ['Exactly one of `worker_version` or `worker_run_id` must be set.']
})
def test_bulk_create_worker_run_and_version(self):
"""Worker run and worker version cannot be set at the same time"""
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(
......@@ -450,6 +416,48 @@ class TestBulkElements(FixtureAPITestCase):
'non_field_errors': ['Only one of `worker_version` and `worker_run_id` may be set.']
})
def test_bulk_create_with_worker_version(self):
self.client.force_login(self.user)
with self.assertNumQueries(13):
response = self.client.post(
reverse('api:elements-bulk-create', kwargs={'pk': str(self.element.id)}),
data=self.payload,
format='json',
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
elements = (
Element.objects
.get_descending(self.element.id)
.filter(type__slug__in=('act', 'surface'))
.order_by('name')
)
element_ids = list(elements.values_list("id", flat=True))
# Only IDs are returned in the payload
self.assertListEqual(response.json(), [{'id': str(elt_id)} for elt_id in element_ids])
# Test common attributes
parent_path = self.element.paths.get().path
self.assertListEqual(
list(elements.values_list(
'worker_version_id',
'worker_run_id',
'image_id',
'rotation_angle',
'mirrored',
'paths__path',
)),
3 * [(self.worker_version.id, None, self.element.image_id, 0, False, [*parent_path, self.element.id])]
)
# Test specific attributes
self.assertListEqual(
list(elements.values_list('name', 'type__slug', 'paths__ordering', 'polygon')),
[
('A', 'act', 0, LineString(self.payload['elements'][0]['polygon'])),
('B', 'surface', 0, LineString(self.payload['elements'][1]['polygon'])),
('C', 'surface', 1, LineString(self.payload['elements'][2]['polygon'])),
],
)
def test_bulk_create_with_worker_run(self):
self.client.force_login(self.user)
with self.assertNumQueries(13):
......@@ -460,51 +468,34 @@ class TestBulkElements(FixtureAPITestCase):
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
element_path = self.element.paths.get()
a, b, c = Element \
.objects \
.get_descending(self.element.id) \
.filter(type__slug__in=('act', 'surface')) \
elements = (
Element.objects
.get_descending(self.element.id)
.filter(type__slug__in=('act', 'surface'))
.order_by('name')
a_path, b_path, c_path = a.paths.get(), b.paths.get(), c.paths.get()
)
element_ids = list(elements.values_list("id", flat=True))
# Only IDs are returned in the payload
self.assertListEqual(response.json(), [{'id': str(elt_id)} for elt_id in element_ids])
# Test common attributes
parent_path = self.element.paths.get().path
self.assertListEqual(
response.json(),
list(elements.values_list(
'worker_version_id',
'worker_run_id',
'image_id',
'rotation_angle',
'mirrored',
'paths__path',
)),
3 * [(self.worker_version.id, self.worker_run.id, self.element.image_id, 0, False, [*parent_path, self.element.id])]
)
# Test specific attributes
self.assertListEqual(
list(elements.values_list('name', 'type__slug', 'paths__ordering', 'polygon')),
[
{'id': str(a.id)},
{'id': str(b.id)},
{'id': str(c.id)}
]
('A', 'act', 0, LineString(self.payload['elements'][0]['polygon'])),
('B', 'surface', 0, LineString(self.payload['elements'][1]['polygon'])),
('C', 'surface', 1, LineString(self.payload['elements'][2]['polygon'])),
],
)
self.assertEqual(a.name, 'A')
self.assertEqual(b.name, 'B')
self.assertEqual(c.name, 'C')
self.assertEqual(a.type.slug, 'act')
self.assertEqual(b.type.slug, 'surface')
self.assertEqual(c.type.slug, 'surface')
self.assertEqual(a.worker_version, self.worker_run.version)
self.assertEqual(b.worker_version, self.worker_run.version)
self.assertEqual(c.worker_version, self.worker_run.version)
self.assertEqual(a.worker_run, self.worker_run)
self.assertEqual(b.worker_run, self.worker_run)
self.assertEqual(c.worker_run, self.worker_run)
self.assertEqual(a.image_id, self.element.image_id)
self.assertEqual(b.image_id, self.element.image_id)
self.assertEqual(c.image_id, self.element.image_id)
self.assertTupleEqual(a.polygon.coords, ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)))
self.assertTupleEqual(b.polygon.coords, ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0)))
self.assertTupleEqual(c.polygon.coords, ((0, 0), (0, 9), (9, 9), (9, 0), (0, 0)))
self.assertEqual(a.rotation_angle, self.element.rotation_angle)
self.assertEqual(b.rotation_angle, self.element.rotation_angle)
self.assertEqual(c.rotation_angle, self.element.rotation_angle)
self.assertEqual(a.mirrored, self.element.mirrored)
self.assertEqual(b.mirrored, self.element.mirrored)
self.assertEqual(c.mirrored, self.element.mirrored)
self.assertListEqual(a_path.path, element_path.path + [self.element.id])
self.assertListEqual(b_path.path, element_path.path + [self.element.id])
self.assertListEqual(c_path.path, element_path.path + [self.element.id])
self.assertEqual(a_path.ordering, 0)
self.assertEqual(b_path.ordering, 0)
self.assertEqual(c_path.ordering, 1)
......@@ -348,3 +348,21 @@ class TestBulkTranscriptions(FixtureAPITestCase):
self.assertDictEqual(response.json(), {
'non_field_errors': ['Only one of `worker_version` and `worker_run_id` may be set.']
})
def test_bulk_transcriptions_requires_version_xor_run(self):
self.client.force_login(self.user)
test_element = self.corpus.elements.get(name='Volume 2, page 1r')
response = self.client.post(reverse('api:transcription-bulk'), {
"transcriptions": [
{
"element_id": str(test_element.id),
"text": "The Glow Cloud does not need to converse with us.",
"orientation": TextOrientation.VerticalRightToLeft.value,
"confidence": 0.33
},
],
}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'non_field_errors': ['Either `worker_run_id` or `worker_version` must be defined.']
})
from datetime import datetime, timezone
from itertools import cycle
from unittest.mock import patch
......@@ -246,6 +247,7 @@ class TestDestroyElements(FixtureAPITestCase):
element_id=elt.id,
worker_version_id=WorkerVersion.objects.get(worker__slug='reco').id,
state=next(states),
started=datetime.now(timezone.utc),
) for elt in elements
)
elements.trash()
......
import uuid
from django.urls import reverse
from rest_framework import status
......@@ -1197,16 +1199,20 @@ class TestMetaData(FixtureAPITestCase):
entity = self.corpus.entities.create(name='42', type=EntityType.Number)
self.client.force_login(self.user)
response = self.client.post(
reverse('api:element-metadata-bulk', kwargs={'pk': str(self.vol.id)}),
data={'metadata_list': [
{'type': 'location', 'name': 'location', 'value': 'Texas'},
{'type': 'date', 'name': 'date', 'value': '1885'},
{'type': 'numeric', 'name': 'numeric', 'value': '42', 'entity_id': str(entity.id)}
]},
format='json'
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:element-metadata-bulk', kwargs={'pk': str(self.vol.id)}),
data={
"worker_run_id": str(self.worker_run.id),
'metadata_list': [
{'type': 'location', 'name': 'location', 'value': 'Texas'},
{'type': 'date', 'name': 'date', 'value': '1885'},
{'type': 'numeric', 'name': 'numeric', 'value': '42', 'entity_id': str(entity.id)}
]
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(self.vol.metadatas.count(), 4) # 1 added during setUp
md1 = self.vol.metadatas.get(type=MetaType.Location, name='location')
self.assertEqual(md1.value, 'Texas')
......@@ -1247,7 +1253,7 @@ class TestMetaData(FixtureAPITestCase):
}
],
'worker_version': None,
'worker_run_id': None
'worker_run_id': str(self.worker_run.id),
})
def test_bulk_create_metadata_worker_version(self):
......@@ -1264,11 +1270,14 @@ class TestMetaData(FixtureAPITestCase):
md = self.vol.metadatas.get(type=MetaType.Text, name='language')
self.assertEqual(md.value, 'Chinese')
self.assertEqual(md.worker_version, self.worker_version)
self.assertEqual(md.worker_run_id, None)
md = self.vol.metadatas.get(type=MetaType.Text, name='alt_language')
self.assertEqual(md.value, 'Mandarin')
self.assertEqual(md.worker_version, self.worker_version)
self.assertEqual(md.worker_run_id, None)
def test_bulk_create_metadata_worker_run(self):
"""Worker version is implicitly set from the worker run on created metadata"""
self.client.force_login(self.user)
response = self.client.post(
reverse('api:element-metadata-bulk', kwargs={'pk': str(self.vol.id)}),
......@@ -1295,10 +1304,28 @@ class TestMetaData(FixtureAPITestCase):
data={'metadata_list': [
{'type': 'text', 'name': 'language', 'value': 'Chinese'},
{'type': 'text', 'name': 'alt_language', 'value': 'Mandarin'}
], 'worker_run_id': str(self.worker_run.id), 'worker_version': str(self.worker_version.id)},
format='json'
]},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {'non_field_errors': ['Exactly one of `worker_version` or `worker_run` must be set.']})
def test_bulk_create_metadata_worker_run_and_version(self):
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:element-metadata-bulk', kwargs={'pk': str(self.vol.id)}),
data={
'metadata_list': [
{'type': 'text', 'name': 'language', 'value': 'Chinese'},
{'type': 'text', 'name': 'alt_language', 'value': 'Mandarin'}
],
'worker_run_id': str(self.worker_run.id),
'worker_version': str(self.worker_version.id)
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {'non_field_errors': ['Only one of `worker_version` and `worker_run_id` may be set.']})
def test_bulk_create_metadata_bad_worker_run(self):
......@@ -1335,6 +1362,27 @@ class TestMetaData(FixtureAPITestCase):
'metadata_list': {'non_field_errors': ['This list may not be empty.']}
})
def test_bulk_create_metadata_malicious_data(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:element-metadata-bulk', kwargs={'pk': str(self.vol.id)}),
data={'metadata_list': [{
'type': 'text',
'name': 'language',
'value': 'Chinese',
'worker_version_id': str(uuid.uuid4()),
'id': 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa',
},
], 'worker_run_id': str(self.worker_run.id)},
format='json'
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
md = self.vol.metadatas.get(type=MetaType.Text, name='language')
self.assertNotEqual(str(md.id), 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa')
self.assertEqual(md.worker_run_id, self.worker_run.id)
self.assertEqual(md.worker_version, self.worker_version)
def test_admin_bulk_create_any(self):
"""
Admin and internal users are allowed to create any metadata
......@@ -1344,11 +1392,14 @@ class TestMetaData(FixtureAPITestCase):
self.client.force_login(user)
response = self.client.post(
reverse('api:element-metadata-bulk', kwargs={'pk': str(self.vol.id)}),
data={'metadata_list': [
{'type': 'text', 'name': 'certainly not allowed', 'value': 'bla bla bla'},
{'type': 'location', 'name': 'not allowed', 'value': 'boo'},
]},
format='json'
data={
"worker_version": str(self.worker_version.id),
'metadata_list': [
{'type': 'text', 'name': 'certainly not allowed', 'value': 'bla bla bla'},
{'type': 'location', 'name': 'not allowed', 'value': 'boo'},
]
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(self.vol.metadatas.count(), 2 + setup) # 1 added during setUp
......@@ -1375,7 +1426,7 @@ class TestMetaData(FixtureAPITestCase):
'entity_id': None,
},
],
'worker_version': None,
'worker_version': str(self.worker_version.id),
'worker_run_id': None
})
......@@ -1395,14 +1446,17 @@ class TestMetaData(FixtureAPITestCase):
The metadata created with the bulk endpoint must be unique
"""
self.client.force_login(self.user)
with self.assertNumQueries(5):
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:element-metadata-bulk', kwargs={'pk': str(self.vol.id)}),
data={'metadata_list': [
{'type': 'location', 'name': 'location', 'value': 'Texas'},
{'type': 'location', 'name': 'location', 'value': 'Texas'}
]},
format='json'
data={
"worker_run_id": str(self.worker_run.id),
'metadata_list': [
{'type': 'location', 'name': 'location', 'value': 'Texas'},
{'type': 'location', 'name': 'location', 'value': 'Texas'}
]
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
......@@ -1419,13 +1473,16 @@ class TestMetaData(FixtureAPITestCase):
data={'type': 'location', 'name': 'location', 'value': 'Texas'}
)
self.assertEqual(response_1.status_code, status.HTTP_201_CREATED)
with self.assertNumQueries(4):
with self.assertNumQueries(5):
response_2 = self.client.post(
reverse('api:element-metadata-bulk', kwargs={'pk': str(self.vol.id)}),
data={'metadata_list': [
{'type': 'location', 'name': 'location', 'value': 'Texas'}
]},
format='json'
data={
'worker_version': str(self.worker_version.id),
'metadata_list': [
{'type': 'location', 'name': 'location', 'value': 'Texas'}
]
},
format='json',
)
self.assertEqual(response_2.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response_2.json(), {
......
......@@ -7,7 +7,7 @@ from uuid import UUID
from django.conf import settings
from django.core.mail import send_mail
from django.db import transaction
from django.db.models import CharField, Count, F, Max, Q, Value
from django.db.models import Avg, CharField, Count, DurationField, F, Max, Min, Q, Value
from django.db.models.functions import Cast, Coalesce, Concat, Greatest, Now, NullIf
from django.db.models.query import Prefetch
from django.shortcuts import get_object_or_404
......@@ -71,11 +71,10 @@ from arkindex.process.serializers.imports import (
ProcessListSerializer,
ProcessSerializer,
StartProcessSerializer,
WorkerRunEditSerializer,
WorkerRunSerializer,
)
from arkindex.process.serializers.ingest import BucketSerializer, S3ImportSerializer
from arkindex.process.serializers.training import StartTrainingSerializer
from arkindex.process.serializers.worker_runs import WorkerRunEditSerializer, WorkerRunSerializer
from arkindex.process.serializers.workers import (
RepositorySerializer,
WorkerActivitySerializer,
......@@ -93,7 +92,6 @@ from arkindex.project.fields import ArrayRemove
from arkindex.project.mixins import (
ConflictAPIException,
CorpusACLMixin,
CustomPaginationViewMixin,
ProcessACLMixin,
RepositoryACLMixin,
SelectionMixin,
......@@ -101,7 +99,7 @@ from arkindex.project.mixins import (
)
from arkindex.project.pagination import CustomCursorPagination
from arkindex.project.permissions import IsInternal, IsVerified, IsVerifiedOrReadOnly
from arkindex.project.tools import RTrimChr
from arkindex.project.tools import PercentileCont, RTrimChr
from arkindex.project.triggers import process_delete
from arkindex.users.models import OAuthCredentials, Role, Scope
from arkindex.users.utils import get_max_level
......@@ -984,7 +982,6 @@ class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView):
Guest access is required on private corpora. No specific rights are required on the workers.
"""
pagination_class = CustomCursorPagination
permission_classes = (IsVerifiedOrReadOnly, )
serializer_class = WorkerVersionSerializer
# For OpenAPI type discovery
......@@ -995,10 +992,22 @@ class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView):
return get_object_or_404(self.readable_corpora, pk=self.kwargs['pk'])
def get_queryset(self):
# This queryset does not have any ordering because the CustomCursorPagination always overrides it with .order_by('id')
return self.corpus.worker_versions.prefetch_related(
Prefetch('revision', queryset=Revision.objects.select_related('repo').prefetch_related('refs')),
Prefetch('worker', queryset=Worker.objects.select_related('repository', 'type')),
return (
self.corpus.worker_versions
.select_related(
'worker',
'revision',
)
.order_by(
'worker__name',
'revision__hash',
)
.prefetch_related(
'revision__repo',
'revision__refs',
'worker__repository',
'worker__type',
)
)
......@@ -1365,11 +1374,12 @@ class ImportTranskribus(CreateAPIView):
),
tags=['process']
))
class ListProcessElements(CustomPaginationViewMixin, CorpusACLMixin, ListAPIView):
class ListProcessElements(CorpusACLMixin, ListAPIView):
"""
List all elements for a process with workers.\n\n
Requires an **admin** access to the process corpus.
"""
pagination_class = CustomCursorPagination
permission_classes = (IsVerified, )
# For OpenAPI type discovery
queryset = Element.objects.none()
......@@ -1502,17 +1512,20 @@ class UpdateWorkerActivity(GenericAPIView):
state__in=self.allowed_transitions[state]
)
fields_to_update = {'state': state}
# If we are switching to started, we need to restrict to activities that either are not started,
# or have not been updated for an hour.
# or have not been updated for an hour, and also update the `started` timestamp.
if state == WorkerActivityState.Started.value:
activity_filter &= (
~Q(state=WorkerActivityState.Started)
| Q(updated__lte=Now() - timedelta(seconds=settings.WORKER_ACTIVITY_TIMEOUT))
)
fields_to_update['started'] = Now()
activity = WorkerActivity.objects.filter(activity_filter)
update_count = activity.update(state=state)
update_count = activity.update(**fields_to_update)
if not update_count:
# As no row has been updated the provided data was in conflict with the actual state
......@@ -1537,40 +1550,58 @@ class UpdateWorkerActivity(GenericAPIView):
return Response(serializer.data)
class WorkerActivityBase(ListAPIView):
permission_classes = (IsVerified, )
serializer_class = WorkerStatisticsSerializer
pagination_class = None
queryset = WorkerActivity.objects \
.values('worker_version_id', 'configuration_id') \
.annotate(
average_processed_time=Avg(
F('updated') - F('started'),
filter=Q(state=WorkerActivityState.Processed, started__isnull=False),
),
max_processed_time=Max(
F('updated') - F('started'),
filter=Q(state=WorkerActivityState.Processed, started__isnull=False),
),
min_processed_time=Min(
F('updated') - F('started'),
filter=Q(state=WorkerActivityState.Processed, started__isnull=False),
),
median_processed_time=PercentileCont(
0.5,
order_by=F('updated') - F('started'),
filter=Q(state=WorkerActivityState.Processed, started__isnull=False),
output_field=DurationField(),
),
) \
.annotate(
# Add the counts per state in a separate annotate() call, because otherwise the F('started')
# on the processed time annotations would be interpreted as being the started activities count,
# not the started datetime, causing an error: "'<CombinedExpression: F(updated) - F(started)>' is an aggregate"
**{
state.value: Count('id', filter=Q(state=state.value))
for state in WorkerActivityState
}
)
@extend_schema_view(
get=extend_schema(
operation_id='CorpusWorkersActivity',
tags=['process']
)
)
class CorpusWorkersActivity(CorpusACLMixin, ListAPIView):
class CorpusWorkersActivity(CorpusACLMixin, WorkerActivityBase):
"""
Retrieve corpus wise statistics about the activity of all its worker processes.\n
Requires a **guest** access.
"""
permission_classes = (IsVerified, )
serializer_class = WorkerStatisticsSerializer
pagination_class = None
queryset = WorkerActivity.objects.none()
def list(self, request, *args, **kwargs):
def filter_queryset(self, queryset):
corpus = self.get_corpus(self.kwargs['corpus'], role=Role.Guest)
# Retrieve the distribution of activities on this corpus grouped by worker version
stats = WorkerActivity.objects \
.filter(element_id__in=corpus.elements.values('id')) \
.values('worker_version_id', 'configuration_id') \
.annotate(
**{
state.value: Count('id', filter=Q(state=state.value))
for state in WorkerActivityState
}
)
return Response(
status=status.HTTP_200_OK,
data=WorkerStatisticsSerializer(stats, many=True).data
)
return queryset.filter(element_id__in=corpus.elements.values('id'))
@extend_schema_view(
......@@ -1579,17 +1610,13 @@ class CorpusWorkersActivity(CorpusACLMixin, ListAPIView):
tags=['process']
)
)
class ProcessWorkersActivity(ProcessACLMixin, ListAPIView):
class ProcessWorkersActivity(ProcessACLMixin, WorkerActivityBase):
"""
Retrieve process statistics about the activity of its workers.\n
Requires a **guest** access.
"""
permission_classes = (IsVerified, )
serializer_class = WorkerStatisticsSerializer
pagination_class = None
queryset = WorkerActivity.objects.none()
def list(self, request, *args, **kwargs):
def filter_queryset(self, queryset):
process = get_object_or_404(Process.objects.all(), pk=self.kwargs['pk'])
access_level = self.process_access_level(process)
if not access_level:
......@@ -1597,21 +1624,7 @@ class ProcessWorkersActivity(ProcessACLMixin, ListAPIView):
if access_level < Role.Guest.value:
raise PermissionDenied(detail='You do not have a guest access to this process.')
# Retrieve the distribution of activities on this process grouped by worker version
stats = WorkerActivity.objects \
.filter(process_id=process.id) \
.values('worker_version_id', 'configuration_id') \
.annotate(
**{
state.value: Count('id', filter=Q(state=state.value))
for state in WorkerActivityState
}
)
return Response(
status=status.HTTP_200_OK,
data=WorkerStatisticsSerializer(stats, many=True).data
)
return queryset.filter(process_id=process.id)
@extend_schema_view(
......
......@@ -8,22 +8,15 @@ logger = logging.getLogger(__name__)
class ActivityManager(models.Manager):
"""Model management for worker activities"""
def bulk_insert(self, worker_version_id, process_id, configuration_id, elements_qs, state=None):
def bulk_insert(self, worker_version_id, process_id, configuration_id, elements_qs):
"""
Create initial worker activities from a queryset of elements in a efficient way.
Due to the possible large amount of elements, we use a bulk insert from the elements query (best performances).
The `ON CONFLICT` clause allows either skipping existing activities if they are in a `Started` or `Processed`
state, or stealing them from another process if they are `Queued` or `Error`.
"""
from arkindex.process.models import WorkerActivityState
if state is None:
state = WorkerActivityState.Queued
assert isinstance(state, WorkerActivityState), 'State should be an instance of WorkerActivityState'
sql, params = elements_qs.distinct().values('id').query.sql_with_params()
select_params = (worker_version_id, configuration_id, state.value, process_id) + params
select_params = (worker_version_id, configuration_id, process_id) + params
# With ON CONFLICT, the target constraint is only optional when the action is DO NOTHING.
# For DO UPDATE, the target can either be a constraint by name with `ON CONSTRAINT <name>`,
......@@ -42,7 +35,7 @@ class ActivityManager(models.Manager):
INSERT INTO process_workeractivity
(element_id, worker_version_id, configuration_id, state, process_id, id, created, updated)
SELECT
elt.id, %s, %s, %s, %s, uuid_generate_v4(), current_timestamp, current_timestamp
elt.id, %s, %s, 'queued', %s, uuid_generate_v4(), current_timestamp, current_timestamp
FROM ({sql}) AS elt
ON CONFLICT {conflict_target} DO UPDATE SET
process_id = EXCLUDED.process_id,
......
# Generated by Django 4.0.4 on 2022-12-05 09:44
from django.db import migrations, models
from arkindex.process.models import WorkerActivityState
def set_started_on_started(apps, schema_editor):
"""
In case this migration runs while some processes were running, some activities might be in a `started` state
and we therefore need to set their start time before we can apply the new check constraint.
We do know their start time though: their last update time will be the time where they were set to `started`.
"""
WorkerActivity = apps.get_model('process', 'WorkerActivity')
WorkerActivity.objects.filter(state=WorkerActivityState.Started).update(started=models.F('updated'))
class Migration(migrations.Migration):
dependencies = [
('process', '0061_workeractivity_updated_triggers'),
]
operations = [
migrations.AddField(
model_name='workeractivity',
name='started',
field=models.DateTimeField(blank=True, null=True),
),
migrations.RunPython(
set_started_on_started,
reverse_code=migrations.RunPython.noop,
),
migrations.AddConstraint(
model_name='workeractivity',
constraint=models.CheckConstraint(
check=~models.Q(state=WorkerActivityState.Started) | models.Q(started__isnull=False),
name='worker_activity_started_requires_started',
),
),
]
......@@ -181,7 +181,7 @@ class Process(IndexableModel):
"corpus_id": self.corpus_id,
}
if self.name_contains:
filters['name__contains'] = self.name_contains
filters['name__icontains'] = self.name_contains
if self.element_type:
filters['type_id'] = self.element_type_id
......@@ -892,6 +892,9 @@ class WorkerActivity(models.Model):
# and we still let Django handle setting the timestamp initially, so we use `auto_now_add` and not `auto_now`.
updated = models.DateTimeField(auto_now_add=True)
# To properly compute a processing time, we need to separately store the time at which the state was updated to `started`.
started = models.DateTimeField(blank=True, null=True)
element = models.ForeignKey(
'documents.Element',
on_delete=models.CASCADE,
......@@ -937,7 +940,15 @@ class WorkerActivity(models.Model):
fields=['worker_version', 'element', 'configuration'],
name='worker_activity_unique_configuration',
condition=Q(configuration__isnull=False),
)
),
# Require the `started` timestamp to be set if a WorkerActivity is in a `started` state.
# This does not prevent activities in a processed/error state to not have start times,
# as we do not know when activities that existed before the start time was added have started,
# but is enough to ensure that we have to set the start time on any activity that now becomes started.
models.CheckConstraint(
check=~Q(state=WorkerActivityState.Started) | Q(started__isnull=False),
name='worker_activity_started_requires_started',
),
]
triggers = [
pgtrigger.Trigger(
......
......@@ -3,14 +3,10 @@ from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied, ValidationError
from arkindex.documents.models import Corpus, Element, ElementType
from arkindex.documents.serializers.elements import ElementSlimSerializer
from arkindex.process.models import ActivityState, DataFile, Process, ProcessMode, WorkerConfiguration, WorkerRun
from arkindex.process.models import ActivityState, DataFile, Process, ProcessMode, WorkerRun
from arkindex.process.serializers.git import RevisionSerializer
from arkindex.process.serializers.workers import WorkerConfigurationSerializer, WorkerVersionSerializer
from arkindex.project.mixins import ProcessACLMixin
from arkindex.project.serializer_fields import EnumField, LinearRingField
from arkindex.training.models import Model, ModelVersion
from arkindex.training.serializers import ModelVersionLightSerializer
from arkindex.users.models import Role
from arkindex.users.utils import get_max_level
from ponos.models import Farm, State
......@@ -68,6 +64,7 @@ class ProcessSerializer(ProcessTrainingSerializer):
"""
Serialize a process with its settings
"""
from arkindex.documents.serializers.elements import ElementSlimSerializer
revision = RevisionSerializer(read_only=True)
element = ElementSlimSerializer(read_only=True)
element_id = serializers.PrimaryKeyRelatedField(
......@@ -391,78 +388,6 @@ class ElementsWorkflowSerializer(serializers.Serializer):
return data
class WorkerRunSerializer(serializers.ModelSerializer):
"""
Serialize a worker run
"""
parents = serializers.ListField(child=serializers.UUIDField())
worker_version_id = serializers.UUIDField(source='version_id')
worker_version = WorkerVersionSerializer(read_only=True, source='version')
model_version = ModelVersionLightSerializer(read_only=True)
configuration_id = serializers.PrimaryKeyRelatedField(
queryset=WorkerConfiguration.objects.all(),
required=False,
allow_null=True,
style={'base_template': 'input.html'},
)
class Meta:
model = WorkerRun
read_only_fields = ('id', 'process_id', 'model_version_id')
fields = (
'id',
'parents',
'worker_version_id',
'worker_version',
'model_version_id',
'process_id',
'model_version',
'configuration_id',
)
class WorkerRunEditSerializer(WorkerRunSerializer):
"""
Serialize a worker run with only parents as editable field
"""
worker_version_id = serializers.UUIDField(source='version_id', read_only=True)
model_version_id = serializers.PrimaryKeyRelatedField(
queryset=ModelVersion.objects.all(),
required=False,
allow_null=True,
source='model_version',
style={'base_template': 'input.html'},
)
configuration = WorkerConfigurationSerializer(read_only=True)
process = ProcessTrainingSerializer(read_only=True)
class Meta(WorkerRunSerializer.Meta):
fields = WorkerRunSerializer.Meta.fields + (
'configuration',
'process',
)
def validate_model_version_id(self, model_version):
model_usage = self.context.get('model_usage')
if not model_usage:
raise ValidationError("This worker version does not support model usage.")
model = Model.objects.get(id=model_version.model_id)
# Check access rights on model version
access_level = get_max_level(self.context["request"].user, model)
if not access_level or access_level < Role.Contributor.value:
raise ValidationError('You do not have access to this model version.')
return model_version
def validate(self, data):
# Store configuration if provided
if 'configuration_id' in data:
data['configuration'] = data['configuration_id']
return data
class ImportTranskribusSerializer(serializers.Serializer):
"""
Serialize a Transkribus import
......
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from arkindex.process.models import WorkerConfiguration, WorkerRun
from arkindex.process.serializers.imports import ProcessTrainingSerializer
from arkindex.process.serializers.workers import WorkerConfigurationSerializer, WorkerVersionSerializer
from arkindex.training.models import Model, ModelVersion
from arkindex.training.serializers import ModelVersionLightSerializer
from arkindex.users.models import Role
from arkindex.users.utils import get_max_level
# Defined here because it provokes circular imports
# when defined in `imports.py` where it belongs
class WorkerRunSummarySerializer(serializers.Serializer):
id = serializers.UUIDField()
summary = serializers.CharField(help_text="Human-readable summary of a WorkerRun's information")
# To prevent each element worker to retrieve contextual information
# (process, worker version, model version…) with extra GET requests, we
# do serialize all the related information on WorkerRun serializers.
class WorkerRunSerializer(serializers.ModelSerializer):
"""
Serialize a worker run for creation.
Worker version and parents are required, configuration is optional.
"""
parents = serializers.ListField(child=serializers.UUIDField())
worker_version_id = serializers.UUIDField(write_only=True, source='version_id')
worker_version = WorkerVersionSerializer(read_only=True, source='version')
model_version = ModelVersionLightSerializer(read_only=True)
configuration = WorkerConfigurationSerializer(read_only=True)
process = ProcessTrainingSerializer(read_only=True)
configuration_id = serializers.PrimaryKeyRelatedField(
queryset=WorkerConfiguration.objects.all(),
required=False,
allow_null=True,
write_only=True,
style={'base_template': 'input.html'},
)
class Meta:
model = WorkerRun
fields = (
'id',
'parents',
'worker_version_id',
'worker_version',
'process',
'configuration_id',
'configuration',
'model_version',
)
class WorkerRunEditSerializer(WorkerRunSerializer):
"""
Serialize a worker run for edition.
Process and worker version cannot be edited.
Parents, model version and configuration can be edited.
"""
model_version_id = serializers.PrimaryKeyRelatedField(
queryset=ModelVersion.objects.all(),
required=False,
allow_null=True,
write_only=True,
source='model_version',
style={'base_template': 'input.html'},
)
class Meta:
model = WorkerRun
fields = (
'id',
'parents',
'worker_version',
'process',
'configuration_id',
'configuration',
'model_version_id',
'model_version',
)
def validate_model_version_id(self, model_version):
model_usage = self.context.get('model_usage')
if not model_usage:
raise ValidationError("This worker version does not support model usage.")
model = Model.objects.get(id=model_version.model_id)
# Check access rights on model version
access_level = get_max_level(self.context["request"].user, model)
if not access_level or access_level < Role.Contributor.value:
raise ValidationError('You do not have access to this model version.')
return model_version
def validate(self, data):
# Store configuration if provided
if 'configuration_id' in data:
data['configuration'] = data['configuration_id']
return data