Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (6)
......@@ -8,7 +8,19 @@ from uuid import UUID
from django.conf import settings
from django.core.exceptions import ValidationError as DjangoValidationError
from django.db import connection, transaction
from django.db.models import CharField, Count, F, FloatField, Prefetch, Q, QuerySet, Value, prefetch_related_objects
from django.db.models import (
CharField,
Count,
Exists,
F,
FloatField,
OuterRef,
Prefetch,
Q,
QuerySet,
Value,
prefetch_related_objects,
)
from django.db.models.functions import Cast
from django.shortcuts import get_object_or_404
from django.utils.functional import cached_property
......@@ -82,7 +94,7 @@ from arkindex.project.triggers import (
selection_worker_results_delete,
worker_results_delete,
)
from arkindex.training.models import ModelVersion
from arkindex.training.models import DatasetElement, ModelVersion
from arkindex.users.models import Role
from arkindex.users.utils import filter_rights
......@@ -1189,8 +1201,12 @@ class ElementChildren(ElementsListBase):
patch=extend_schema(description='Rename an element'),
put=extend_schema(description="Edit an element's attributes. Requires a write access on the corpus."),
delete=extend_schema(
description='Delete an element. Requires either an admin access on the corpus, '
'or a write access and to be the creator of this element.',
description=dedent("""
Delete an element.
This element cannot be part of a dataset.
Requires either an admin access on the corpus, or a write access and to be the creator of this element.
""").strip(),
parameters=[
OpenApiParameter(
'delete_children',
......@@ -1218,18 +1234,25 @@ class ElementRetrieve(ACLMixin, RetrieveUpdateDestroyAPIView):
queryset = Element.objects.filter(corpus__in=corpora)
if self.request and self.request.method == 'DELETE':
# Only include corpus and creator for ACL check and ID for deletion
return queryset.select_related('corpus').only('id', 'creator_id', 'corpus')
return (
queryset
.select_related('corpus')
.annotate(has_dataset=Exists(DatasetElement.objects.filter(element_id=OuterRef('pk'))))
.only('id', 'creator_id', 'corpus')
)
return queryset \
return (
queryset
.select_related(
'corpus',
'type',
'image__server',
'creator',
'worker_run'
) \
.prefetch_related(Prefetch('classifications', queryset=classifications_queryset)) \
)
.prefetch_related(Prefetch('classifications', queryset=classifications_queryset))
.annotate(metadata_count=Count('metadatas'))
)
def check_object_permissions(self, request, obj):
super().check_object_permissions(request, obj)
......@@ -1242,6 +1265,9 @@ class ElementRetrieve(ACLMixin, RetrieveUpdateDestroyAPIView):
if not self.has_access(obj.corpus, role.value):
access_repr = 'admin' if role == Role.Admin else 'write'
raise PermissionDenied(detail=f'You do not have {access_repr} access to this element.')
# Prevent the direct deletion of an element that is part of a dataset
if request.method == 'DELETE' and getattr(obj, 'has_dataset', False):
raise PermissionDenied(detail='You cannot delete an element that is part of a dataset.')
def get_serializer_context(self):
context = super().get_serializer_context()
......
from unittest.mock import call, patch
from django.db import connections
from django.db.utils import IntegrityError
from arkindex.documents.tasks import selection_worker_results_delete
from arkindex.process.models import Worker, WorkerVersion
from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import Model, ModelVersionState
from arkindex.training.models import Dataset, Model, ModelVersionState
class TestDeleteSelectionWorkerResults(FixtureTestCase):
......@@ -112,3 +115,21 @@ class TestDeleteSelectionWorkerResults(FixtureTestCase):
element_id=self.page2.id,
),
])
@patch('arkindex.documents.tasks.get_current_job')
def test_run_dataset_failure(self, job_mock):
"""
Elements that are part of a dataset cannot be deleted
"""
job_mock.return_value.user_id = self.user.id
self.page1.worker_version = self.version
self.page1.save()
Dataset.objects.get(name='First Dataset').dataset_elements.create(element=self.page1, set='test')
self.user.selected_elements.set([self.page1])
selection_worker_results_delete(corpus_id=self.corpus.id, version_id=self.version.id)
# Prevent delaying constraints check at end of the test transaction
# by directly setting "SET CONSTRAINTS ALL IMMEDIATE"
# https://code.djangoproject.com/ticket/11665
with self.assertRaises(IntegrityError):
connections['default'].check_constraints()
from django.core.exceptions import ObjectDoesNotExist
from django.db import connections
from django.db.utils import IntegrityError
from arkindex.documents.models import Entity, EntityType, MLClass, TranscriptionEntity
from arkindex.documents.tasks import worker_results_delete
from arkindex.process.models import ProcessMode, WorkerVersion
from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import Model, ModelVersionState
from arkindex.training.models import Dataset, Model, ModelVersionState
class TestDeleteWorkerResults(FixtureTestCase):
......@@ -260,3 +262,19 @@ class TestDeleteWorkerResults(FixtureTestCase):
self.transcription_entity2,
self.page2,
)
def test_run_dataset_failure(self):
"""
Elements that are part of a dataset cannot be deleted
"""
self.page1.worker_run = self.worker_run_1
self.page1.worker_version = self.version_1
self.page1.save()
Dataset.objects.get(name='First Dataset').dataset_elements.create(element=self.page1, set='test')
worker_results_delete(corpus_id=self.corpus.id)
# Prevent delaying constraints check at end of the test transaction
# by directly setting "SET CONSTRAINTS ALL IMMEDIATE"
# https://code.djangoproject.com/ticket/11665
with self.assertRaises(IntegrityError):
connections['default'].check_constraints()
......@@ -2,6 +2,8 @@ from datetime import datetime, timezone
from itertools import cycle
from unittest.mock import patch
from django.db import connections
from django.db.utils import IntegrityError
from django.urls import reverse
from rest_framework import status
......@@ -9,6 +11,7 @@ from arkindex.documents.models import Corpus, Element
from arkindex.process.models import Process, WorkerActivity, WorkerActivityState, WorkerVersion
from arkindex.project.tests import FixtureAPITestCase
from arkindex.project.tools import build_tree
from arkindex.training.models import Dataset
from arkindex.users.models import Role
......@@ -139,6 +142,17 @@ class TestDestroyElements(FixtureAPITestCase):
{'detail': 'You do not have write access to this element.'}
)
def test_element_destroy_in_dataset_forbidden(self):
"""
An element cannot be deleted via the API if linked to a dataset
"""
Dataset.objects.get(name='First Dataset').dataset_elements.create(element=self.vol, set='test')
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.delete(reverse('api:element-retrieve', kwargs={'pk': str(self.vol.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'You cannot delete an element that is part of a dataset.'})
@patch('arkindex.project.triggers.documents_tasks.element_trash.delay')
def test_non_empty_element(self, delay_mock):
"""
......@@ -159,6 +173,22 @@ class TestDestroyElements(FixtureAPITestCase):
'description': 'Element deletion',
})
def test_element_trash_dataset_failure(self):
"""
Elements that are part of a dataset cannot be deleted
"""
Dataset.objects.get(name='First Dataset').dataset_elements.create(
element=Element.objects.get_descending(self.vol.id).first(),
set='test',
)
Element.objects.filter(id=self.vol.id).trash()
# Prevent delaying constraints check at end of the test transaction
# by directly setting "SET CONSTRAINTS ALL IMMEDIATE"
# https://code.djangoproject.com/ticket/11665
with self.assertRaises(IntegrityError):
connections['default'].check_constraints()
def test_element_trash_children(self):
# Check this volume has children
children = Element.objects.get_descending(self.vol.id)
......
......@@ -159,6 +159,9 @@ class Agent(models.Model):
:param tasks: Number of tasks to estimate the cost for.
:returns: A cost expressed as a percentage. If > 1, the agent would be overloaded.
"""
if self.cpu_load is None or self.ram_load is None:
# The agent has not shared its state yet
return 1
current_tasks_count = getattr(self, 'current_tasks', 0)
if current_tasks_count + AGENT_SLOT["cpu"] >= self.cpu_cores:
return 1
......
......@@ -56,6 +56,8 @@ class TestAPI(FixtureAPITestCase):
public_key=pubkey,
ram_total=2e9,
last_ping=timezone.now(),
cpu_load=.1,
ram_load=.1e9,
)
cls.rev = Revision.objects.first()
cls.process = Process.objects.get(mode=ProcessMode.Workers)
......@@ -834,7 +836,7 @@ class TestAPI(FixtureAPITestCase):
"agent": {
"cpu_cores": 2,
"cpu_frequency": 1000000000,
"cpu_load": None,
"cpu_load": .1,
"farm": {"id": str(self.agent.farm_id), "name": "Wheat farm"},
"gpus": [
{
......@@ -853,7 +855,7 @@ class TestAPI(FixtureAPITestCase):
"hostname": "ghostname",
"id": str(self.agent.id),
"last_ping": str_date(self.agent.last_ping),
"ram_load": None,
"ram_load": 100000000,
"ram_total": 2000000000,
},
"gpu": {
......@@ -890,7 +892,7 @@ class TestAPI(FixtureAPITestCase):
"agent": {
"cpu_cores": 2,
"cpu_frequency": 1000000000,
"cpu_load": None,
"cpu_load": .1,
"farm": {"id": str(self.agent.farm_id), "name": "Wheat farm"},
"gpus": [
{
......@@ -909,7 +911,7 @@ class TestAPI(FixtureAPITestCase):
"hostname": "ghostname",
"id": str(self.agent.id),
"last_ping": str_date(self.agent.last_ping),
"ram_load": None,
"ram_load": 100000000,
"ram_total": 2000000000,
},
"gpu": {
......@@ -1650,7 +1652,7 @@ class TestAPI(FixtureAPITestCase):
{
"cpu_cores": 12,
"cpu_frequency": 1000000000,
"cpu_load": None,
"cpu_load": .1,
"farm": str(self.wheat_farm.id),
"gpus": [
{
......@@ -1663,7 +1665,7 @@ class TestAPI(FixtureAPITestCase):
"hostname": "ghostname",
"id": str(self.agent.id),
"last_ping": "2000-01-01T12:00:00Z",
"ram_load": None,
"ram_load": 100000000,
"ram_total": 32000000000,
},
)
......@@ -1859,25 +1861,52 @@ class TestAPI(FixtureAPITestCase):
},
)
def test_agent_null_state(self):
"""
Agents with unknown CPU or RAM load are excluded by the assignation algorithm
"""
self.agent.cpu_load = None
self.agent.ram_load = None
self.agent.save()
pubkey = build_public_key()
second_agent = AgentUser.objects.create(
id=hashlib.md5(pubkey.encode("utf-8")).hexdigest(),
farm=self.wheat_farm,
hostname="new agent",
cpu_cores=2,
cpu_frequency=1e9,
public_key=pubkey,
ram_total=1e9,
last_ping=timezone.now(),
)
with self.assertNumQueries(6):
resp = self.client.get(
reverse("api:agent-actions"),
HTTP_AUTHORIZATION=f'Bearer {second_agent.token.access_token}',
data={"cpu_load": 1.9, "ram_load": 0.49},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertDictEqual(resp.json(), {"actions": []})
def test_agent_non_pending_actions(self):
"""
Only pending tasks may be retrieved as new actions
"""
self.process.tasks.update(state=State.Error)
Task.objects.filter(process__farm=self.agent.farm_id).update(state=State.Error)
with self.assertNumQueries(7):
resp = self.client.get(
reverse("api:agent-actions"),
HTTP_AUTHORIZATION=f'Bearer {self.agent.token.access_token}',
data={"cpu_load": 0.9, "ram_load": 0.49},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertDictEqual(resp.json(), {"actions": []})
def test_agent_no_stealing(self):
"""
An agent may not take another agent's tasks
"""
self.process.tasks.update(agent=self.agent, state=State.Pending)
Task.objects.filter(process__farm=self.agent.farm_id).update(agent=self.agent, state=State.Pending)
pubkey = build_public_key()
agent2 = AgentUser.objects.create(
id=uuid.UUID(hashlib.md5(pubkey.encode("utf-8")).hexdigest()),
......@@ -2377,7 +2406,7 @@ class TestAPI(FixtureAPITestCase):
"active": True,
"cpu_cores": 2,
"cpu_frequency": 1000000000,
"cpu_load": None,
"cpu_load": .1,
"farm": {
"id": str(self.wheat_farm.id),
"name": "Wheat farm",
......@@ -2397,7 +2426,7 @@ class TestAPI(FixtureAPITestCase):
},
],
"hostname": "ghostname",
"ram_load": None,
"ram_load": 100000000,
"ram_total": 2000000000,
"running_tasks_count": 1,
},
......@@ -2458,7 +2487,7 @@ class TestAPI(FixtureAPITestCase):
"active": True,
"cpu_cores": 2,
"cpu_frequency": 1000000000,
"cpu_load": None,
"cpu_load": .1,
"farm": {
"id": str(self.wheat_farm.id),
"name": "Wheat farm",
......@@ -2478,7 +2507,7 @@ class TestAPI(FixtureAPITestCase):
},
],
"hostname": "ghostname",
"ram_load": None,
"ram_load": 100000000,
"ram_total": 2000000000,
"running_tasks": [
{
......
......@@ -1867,9 +1867,11 @@ class WorkerActivityBase(ListAPIView):
.annotate(
# Completion time estimate based on median time per task from the previously annotated data
completion_estimate_time=ExpressionWrapper(
F("median_processed_time") * (
F(WorkerActivityState.Queued.value)
+ F(WorkerActivityState.Started.value)
(
F("median_processed_time") * (
F(WorkerActivityState.Queued.value)
+ F(WorkerActivityState.Started.value)
) / F("process__chunks")
),
output_field=DurationField(),
)
......
......@@ -28,6 +28,7 @@ class TestWorkerActivityStats(FixtureAPITestCase):
super().setUpTestData()
cls.version_1 = WorkerVersion.objects.get(worker__slug='reco')
cls.version_2 = WorkerVersion.objects.get(worker__slug='dla')
cls.version_3 = WorkerVersion.objects.get(worker__slug='worker-gpu')
cls.configuration_1 = WorkerConfiguration.objects.create(name='some_config', configuration={'aa': 'bb'}, worker=cls.version_1.worker)
cls.model = Model.objects.create(name='Loukoum', public=False)
cls.model_version = cls.model.versions.create()
......@@ -38,6 +39,12 @@ class TestWorkerActivityStats(FixtureAPITestCase):
creator=cls.user,
corpus=cls.corpus,
)
cls.process_3 = Process.objects.create(
mode=ProcessMode.Workers,
creator=cls.user,
corpus=cls.corpus,
chunks=2
)
# Generate worker activities
WorkerActivity.objects.bulk_create([
......@@ -58,6 +65,14 @@ class TestWorkerActivityStats(FixtureAPITestCase):
worker_version_id=cls.version_2.id,
process_id=cls.process_2.id,
) for elt in cls.corpus.elements.all()
), *(
WorkerActivity(
element_id=elt.id,
state=state,
worker_version_id=cls.version_3.id,
process_id=cls.process_3.id,
started=datetime.now(timezone.utc),
) for elt, state in zip(cls.corpus.elements.all(), itertools.cycle(WorkerActivityState))
)
])
......@@ -84,6 +99,20 @@ class TestWorkerActivityStats(FixtureAPITestCase):
]
]
cls.error_2, cls.processed_2, cls.queued_2, cls.started_2 = [
WorkerActivity.objects.filter(
element__corpus_id=cls.corpus.id,
worker_version_id=cls.version_3.id,
state=state
).count()
for state in [
WorkerActivityState.Error,
WorkerActivityState.Processed,
WorkerActivityState.Queued,
WorkerActivityState.Started
]
]
def test_corpus_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(
......@@ -160,22 +189,45 @@ class TestWorkerActivityStats(FixtureAPITestCase):
'min_processed_time': '00:00:00',
'max_processed_time': '00:00:00',
'median_processed_time': '00:00:00',
}, {
'worker_version_id': str(self.version_3.id),
'configuration_id': None,
'model_version_id': None,
'queued': self.queued_2,
'started': self.started_2,
'processed': self.processed_2,
'error': self.error_2,
'average_processed_time': '00:00:00',
'completion_estimate_time': '00:00:00',
'min_processed_time': '00:00:00',
'max_processed_time': '00:00:00',
'median_processed_time': '00:00:00',
}
])
def test_corpus_timing(self):
# Grab 4 processed WorkerActivities, and delete all the other processed activities.
processed_activities = WorkerActivity.objects.filter(
# Grab 4 processed WorkerActivities from process1 and 4 processed WorkerActivities from
# process3 (with chunks) and delete all the other processed activities.
processed_activities_1 = WorkerActivity.objects.filter(
element__corpus_id=self.corpus.id,
worker_version_id=self.version_1.id,
state=WorkerActivityState.Processed,
).order_by('id')
self.assertGreaterEqual(processed_activities.count(), 4)
self.assertGreaterEqual(processed_activities_1.count(), 4)
processed_activities_2 = WorkerActivity.objects.filter(
element__corpus_id=self.corpus.id,
worker_version_id=self.version_3.id,
state=WorkerActivityState.Processed,
).order_by('id')
self.assertGreaterEqual(processed_activities_2.count(), 4)
# id__in is necessary because Django does not support LIMIT/OFFSET on a DELETE
WorkerActivity.objects.filter(id__in=processed_activities[4:].values('id')).delete()
WorkerActivity.objects.filter(id__in=processed_activities_1[4:].values('id')).delete()
WorkerActivity.objects.filter(id__in=processed_activities_2[4:].values('id')).delete()
# Set a fixed start time so we can edit `updated` to change the processed time
processed_activities.update(started=datetime(2012, 12, 12, 12, 12, 12, 12, timezone.utc))
act1, act2, act3, act4 = processed_activities
processed_activities_1.update(started=datetime(2012, 12, 12, 12, 12, 12, 12, timezone.utc))
processed_activities_2.update(started=datetime(2012, 12, 12, 12, 12, 12, 12, timezone.utc))
act1, act2, act3, act4 = processed_activities_1
act5, act6, act7, act8 = processed_activities_2
# With 4 activities, we can set some simple values to test the min, max, average and median.
# With an even number, we can ensure that the median processed time can be computed between two values
......@@ -185,12 +237,20 @@ class TestWorkerActivityStats(FixtureAPITestCase):
act3.updated = act3.started + timedelta(minutes=3)
act4.updated = act4.started + timedelta(minutes=10)
act5.updated = act5.started + timedelta(minutes=4)
act6.updated = act6.started + timedelta(minutes=8)
act7.updated = act7.started + timedelta(minutes=6)
act8.updated = act8.started + timedelta(minutes=10)
# 15 remaining tasks, so the time estimate is 15 * median time
self.assertEqual(self.started + self.queued, 15)
# 15 remaining tasks; for process_3 with 2 chunks, the estimate time is 15 * median time / 2
self.assertEqual(self.started_2 + self.queued_2, 15)
# Disable the Postgres triggers so we can force our updated times
with pgtrigger.ignore('process.WorkerActivity:read_only_workeractivity_updated', 'process.WorkerActivity:update_workeractivity_updated'):
WorkerActivity.objects.bulk_update([act1, act2, act3, act4], ['updated'])
WorkerActivity.objects.bulk_update([act5, act6, act7, act8], ['updated'])
self.client.force_login(self.user)
with self.assertNumQueries(4):
......@@ -226,6 +286,19 @@ class TestWorkerActivityStats(FixtureAPITestCase):
'max_processed_time': '00:00:00',
'median_processed_time': '00:00:00',
'completion_estimate_time': '00:00:00',
}, {
'worker_version_id': str(self.version_3.id),
'configuration_id': None,
'model_version_id': None,
'queued': self.queued_2,
'started': self.started_2,
'processed': 4,
'error': self.error_2,
'average_processed_time': '00:07:00',
'completion_estimate_time': '00:52:30',
'min_processed_time': '00:04:00',
'max_processed_time': '00:10:00',
'median_processed_time': '00:07:00',
}
])
......@@ -237,7 +310,7 @@ class TestWorkerActivityStats(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(4):
response = self.client.get(
reverse('api:corpus-activity-stats', kwargs={'corpus': str(self.corpus.id)})
reverse('api:corpus-activity-stats', kwargs={'corpus': str(self.corpus.id)}),
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertCountEqual(response.json(), [
......@@ -254,6 +327,19 @@ class TestWorkerActivityStats(FixtureAPITestCase):
'min_processed_time': None,
'max_processed_time': None,
'median_processed_time': None,
}, {
'worker_version_id': str(self.version_3.id),
'configuration_id': None,
'model_version_id': None,
'queued': self.queued_2,
'started': self.started_2,
'processed': 0,
'error': self.error_2,
'completion_estimate_time': None,
'average_processed_time': None,
'min_processed_time': None,
'max_processed_time': None,
'median_processed_time': None,
}
])
......
......@@ -122,6 +122,7 @@ from arkindex.training.api import (
CorpusDataset,
CreateDatasetElementsSelection,
DatasetClone,
DatasetElementDestroy,
DatasetElements,
DatasetUpdate,
ElementDatasets,
......@@ -215,6 +216,7 @@ api = [
path('datasets/<uuid:pk>/', DatasetUpdate.as_view(), name='dataset-update'),
path('datasets/<uuid:pk>/clone/', DatasetClone.as_view(), name='dataset-clone'),
path('datasets/<uuid:pk>/elements/', DatasetElements.as_view(), name='dataset-elements'),
path('datasets/<uuid:dataset>/elements/<uuid:element>/', DatasetElementDestroy.as_view(), name='dataset-element'),
# Moderation
path('classifications/', ClassificationCreate.as_view(), name='classification-create'),
......
......@@ -35,6 +35,7 @@ from arkindex.training.models import (
)
from arkindex.training.serializers import (
CreateModelErrorResponseSerializer,
DatasetElementInfoSerializer,
DatasetElementSerializer,
DatasetLightSerializer,
DatasetSerializer,
......@@ -607,6 +608,49 @@ class DatasetElements(CorpusACLMixin, ListAPIView):
)
@extend_schema_view(
delete=extend_schema(
operation_id='DestroyDatasetElement',
parameters=[
OpenApiParameter(
'set',
type=str,
description='Name of the set from which to remove the element.',
required=True,
)
],
tags=['datasets']
)
)
class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
"""
Remove an element from a dataset.
Elements can only be removed from **open** datasets.
Requires a **contributor** access to the dataset corpus.
"""
permission_classes = (IsVerified, )
serializer_class = DatasetElementInfoSerializer
lookup_url_kwarg = 'element'
def destroy(self, request, *args, **kwargs):
if not self.request.query_params.get('set'):
raise ValidationError({'set': ['This field is required.']})
dataset_element = get_object_or_404(
DatasetElement.objects.select_related('dataset__corpus'),
dataset_id=self.kwargs['dataset'],
element_id=self.kwargs['element'],
set=self.request.query_params.get('set')
)
if dataset_element.dataset.state != DatasetState.Open:
raise ValidationError({'dataset': ['Elements can only be removed from open Datasets.']})
if not self.has_write_access(dataset_element.dataset.corpus):
raise PermissionDenied(detail='You need a Contributor access to the dataset to perform this action.')
dataset_element.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema(tags=['datasets'])
@extend_schema_view(
post=extend_schema(
......@@ -749,7 +793,7 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
clone = copy.copy(dataset)
# Make Django think it is a new dataset that it should insert
clone.id = None
clone.name = clone_name
clone.name = clone_name[:100]
clone.state = DatasetState.Open
clone.creator = request.user
clone.save()
......
......@@ -9,6 +9,7 @@ from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.validators import UniqueTogetherValidator
from arkindex.documents.models import Element
from arkindex.documents.serializers.elements import ElementListSerializer
from arkindex.ponos.models import Task
from arkindex.process.models import Worker
......@@ -531,6 +532,21 @@ class DatasetElementSerializer(serializers.ModelSerializer):
read_only_fields = fields
class DatasetElementInfoSerializer(DatasetElementSerializer):
dataset = serializers.PrimaryKeyRelatedField(
queryset=Dataset.objects.none(),
style={'base_template': 'input.html'},
)
set = serializers.CharField(max_length=50)
element = serializers.PrimaryKeyRelatedField(
queryset=Element.objects.none(),
style={'base_template': 'input.html'},
)
class Meta(DatasetElementSerializer.Meta):
fields = DatasetElementSerializer.Meta.fields + ('dataset',)
class ElementDatasetSerializer(serializers.ModelSerializer):
dataset_name = serializers.CharField(max_length=100, source='dataset__name')
......
......@@ -1049,13 +1049,9 @@ class TestDatasetsAPI(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_list_elements(self):
# id: ddc6a864-82d3-4f4d-9078-6fb547db7279
self.dataset.dataset_elements.create(element_id=self.vol.id, set="test")
# id: ebb0adee-91ff-4d21-bdfd-9dee5e61cd84
self.dataset.dataset_elements.create(element_id=self.page1.id, set="training")
# id: 1a55da30-0c90-4330-9533-02107aaaa93e
self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation")
# id: b7b53865-24f3-4e09-b61c-de8bb1fcbd01
self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation")
self.page3.confidence = 0.42
self.page3.mirrored = True
......@@ -1072,127 +1068,129 @@ class TestDatasetsAPI(FixtureAPITestCase):
data = response.json()
self.assertTrue('?cursor=' in data['next'])
self.assertIsNone(data['count'])
self.maxDiff = None
self.assertListEqual(data['results'], [{
'set': 'validation',
'element': {
'id': str(self.page2.id),
'confidence': None,
'type': 'page',
'mirrored': False,
'corpus': {
'id': str(self.corpus.id),
'name': 'Unit Tests',
'public': True
},
'name': self.page2.name,
'zone': {
self.assertListEqual(data['results'], sorted(
[{
'set': 'validation',
'element': {
'id': str(self.page2.id),
'image': {
'id': str(self.page2.image.id),
'height': 1000,
'width': 1000,
'url': 'http://server/img2',
'path': 'img2',
'status': 'unchecked',
's3_url': None,
'server': {
'display_name': 'Test Server',
'max_height': None,
'max_width': None,
'url': 'http://server'
'confidence': None,
'type': 'page',
'mirrored': False,
'corpus': {
'id': str(self.corpus.id),
'name': 'Unit Tests',
'public': True
},
'name': self.page2.name,
'zone': {
'id': str(self.page2.id),
'image': {
'id': str(self.page2.image.id),
'height': 1000,
'width': 1000,
'url': 'http://server/img2',
'path': 'img2',
'status': 'unchecked',
's3_url': None,
'server': {
'display_name': 'Test Server',
'max_height': None,
'max_width': None,
'url': 'http://server'
},
},
'polygon': [
[0, 0],
[0, 1000],
[1000, 1000],
[1000, 0],
[0, 0]
],
'url': 'http://server/img2/0,0,1000,1000/full/0/default.jpg'
},
'polygon': [
[0, 0],
[0, 1000],
[1000, 1000],
[1000, 0],
[0, 0]
],
'url': 'http://server/img2/0,0,1000,1000/full/0/default.jpg'
},
'rotation_angle': 0,
'classes': None,
'has_children': None,
'metadata': None,
'worker_run': None,
'worker_version_id': None,
'thumbnail_url': None,
'created': FAKE_CREATED
},
}, {
'set': 'validation',
'element': {
'id': str(self.page3.id),
'confidence': 0.42,
'type': 'page',
'mirrored': True,
'corpus': {
'id': str(self.corpus.id),
'name': 'Unit Tests',
'public': True
'rotation_angle': 0,
'classes': None,
'has_children': None,
'metadata': None,
'worker_run': None,
'worker_version_id': None,
'thumbnail_url': None,
'created': FAKE_CREATED
},
'name': self.page3.name,
'zone': {
}, {
'set': 'validation',
'element': {
'id': str(self.page3.id),
'image': {
'id': str(self.page3.image.id),
'height': 1000,
'width': 1000,
'url': 'http://server/img3',
'path': 'img3',
'status': 'unchecked',
's3_url': None,
'server': {
'display_name': 'Test Server',
'max_height': None,
'max_width': None,
'url': 'http://server'
'confidence': 0.42,
'type': 'page',
'mirrored': True,
'corpus': {
'id': str(self.corpus.id),
'name': 'Unit Tests',
'public': True
},
'name': self.page3.name,
'zone': {
'id': str(self.page3.id),
'image': {
'id': str(self.page3.image.id),
'height': 1000,
'width': 1000,
'url': 'http://server/img3',
'path': 'img3',
'status': 'unchecked',
's3_url': None,
'server': {
'display_name': 'Test Server',
'max_height': None,
'max_width': None,
'url': 'http://server'
},
},
'polygon': [
[0, 0],
[0, 1000],
[1000, 1000],
[1000, 0],
[0, 0]
],
'url': 'http://server/img3/0,0,1000,1000/full/0/default.jpg'
},
'polygon': [
[0, 0],
[0, 1000],
[1000, 1000],
[1000, 0],
[0, 0]
],
'url': 'http://server/img3/0,0,1000,1000/full/0/default.jpg'
'rotation_angle': 42,
'classes': None,
'has_children': None,
'metadata': None,
'worker_run': None,
'worker_version_id': None,
'thumbnail_url': None,
'created': FAKE_CREATED
},
'rotation_angle': 42,
'classes': None,
'has_children': None,
'metadata': None,
'worker_run': None,
'worker_version_id': None,
'thumbnail_url': None,
'created': FAKE_CREATED
},
}, {
'set': 'test',
'element': {
'id': str(self.vol.id),
'confidence': None,
'type': 'volume',
'mirrored': False,
'name': 'Volume 1',
'corpus': {
'id': str(self.corpus.id),
'name': 'Unit Tests',
'public': True
}, {
'set': 'test',
'element': {
'id': str(self.vol.id),
'confidence': None,
'type': 'volume',
'mirrored': False,
'name': 'Volume 1',
'corpus': {
'id': str(self.corpus.id),
'name': 'Unit Tests',
'public': True
},
'zone': None,
'rotation_angle': 0,
'classes': None,
'has_children': None,
'metadata': None,
'worker_run': None,
'worker_version_id': None,
'thumbnail_url': self.vol.thumbnail.s3_url,
'created': FAKE_CREATED
},
'zone': None,
'rotation_angle': 0,
'classes': None,
'has_children': None,
'metadata': None,
'worker_run': None,
'worker_version_id': None,
'thumbnail_url': self.vol.thumbnail.s3_url,
'created': FAKE_CREATED
},
}, ])
}],
key=lambda dataset_elt: dataset_elt['element']['id']
))
def test_add_from_selection_requires_login(self):
with self.assertNumQueries(0):
......@@ -1393,6 +1391,8 @@ class TestDatasetsAPI(FixtureAPITestCase):
}]
})
# DatasetClone
def test_clone_requires_login(self):
with self.assertNumQueries(0):
response = self.client.post(reverse('api:dataset-clone', kwargs={'pk': self.dataset.id}))
......@@ -1518,3 +1518,162 @@ class TestDatasetsAPI(FixtureAPITestCase):
'task_id': None,
},
)
def test_clone_name_too_long(self):
dataset = self.corpus.datasets.create(name='A' * 99, creator=self.user)
self.client.force_login(self.user)
with self.assertNumQueries(13):
response = self.client.post(
reverse('api:dataset-clone', kwargs={'pk': dataset.id}),
format='json',
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
clone = Dataset.objects.get(id=data.pop('id'))
self.assertEqual(clone.name, 'Clone of ' + 'A' * 91)
self.assertEqual(data['name'], clone.name)
# DatasetElementDestroy
def test_destroy_dataset_element_requires_login(self):
with self.assertNumQueries(0):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': str(self.page1.id)})
+ '?set=train'
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_destroy_dataset_element_requires_verified(self):
user = User.objects.create(email='not_verified@mail.com', display_name='Not Verified', verified_email=False)
self.client.force_login(user)
with self.assertNumQueries(2):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': str(self.page1.id)})
+ '?set=train'
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_destroy_dataset_element_requires_contributor(self):
self.client.force_login(self.read_user)
self.dataset.dataset_elements.create(element=self.page1, set='train')
self.dataset.dataset_elements.create(element=self.page1, set='validation')
self.assertEqual(self.dataset.dataset_elements.filter(set='train').count(), 1)
self.assertEqual(self.dataset.dataset_elements.filter(set='validation').count(), 1)
with self.assertNumQueries(6):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': str(self.page1.id)})
+ '?set=train'
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'You need a Contributor access to the dataset to perform this action.'})
self.dataset.refresh_from_db()
self.assertEqual(self.dataset.dataset_elements.filter(set='train').count(), 1)
self.assertEqual(self.dataset.dataset_elements.filter(set='validation').count(), 1)
def test_destroy_dataset_element_set_required(self):
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': str(self.page1.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {'set': ['This field is required.']})
def test_destroy_dataset_element_requires_open_dataset(self):
self.client.force_login(self.user)
self.dataset.dataset_elements.create(element=self.page1, set='train')
self.dataset.dataset_elements.create(element=self.page1, set='validation')
self.dataset.state = DatasetState.Error
self.dataset.save()
self.assertEqual(self.dataset.dataset_elements.filter(set='train').count(), 1)
self.assertEqual(self.dataset.dataset_elements.filter(set='validation').count(), 1)
with self.assertNumQueries(3):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': str(self.page1.id)})
+ '?set=train'
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {'dataset': ['Elements can only be removed from open Datasets.']})
self.dataset.refresh_from_db()
self.assertEqual(self.dataset.dataset_elements.filter(set='train').count(), 1)
self.assertEqual(self.dataset.dataset_elements.filter(set='validation').count(), 1)
def test_destroy_dataset_element_dataset_doesnt_exist(self):
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa', 'element': str(self.page1.id)})
+ '?set=train'
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertDictEqual(response.json(), {'detail': 'Not found.'})
def test_destroy_dataset_element_set_doesnt_exist(self):
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': str(self.page1.id)})
+ '?set=match'
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertDictEqual(response.json(), {'detail': 'Not found.'})
def test_destroy_dataset_element_element_doesnt_exist(self):
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'})
+ '?set=train'
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertDictEqual(response.json(), {'detail': 'Not found.'})
def test_destroy_dataset_element_element_not_in_dataset(self):
self.dataset.dataset_elements.create(element=self.page1, set='train')
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': str(self.page2.id)})
+ '?set=train'
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertDictEqual(response.json(), {'detail': 'Not found.'})
def test_destroy_dataset_element_wrong_set(self):
self.dataset.dataset_elements.create(element=self.page1, set='train')
self.dataset.dataset_elements.create(element=self.page2, set='validation')
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': str(self.page2.id)})
+ '?set=train'
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertDictEqual(response.json(), {'detail': 'Not found.'})
def test_destroy_dataset_element(self):
self.client.force_login(self.user)
self.dataset.dataset_elements.create(element=self.page1, set='train')
self.dataset.dataset_elements.create(element=self.page1, set='validation')
self.assertEqual(self.dataset.dataset_elements.filter(set='train').count(), 1)
self.assertEqual(self.dataset.dataset_elements.filter(set='validation').count(), 1)
with self.assertNumQueries(7):
response = self.client.delete(reverse(
'api:dataset-element',
kwargs={'dataset': str(self.dataset.id), 'element': str(self.page1.id)})
+ '?set=train'
)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.dataset.refresh_from_db()
self.assertEqual(self.dataset.dataset_elements.filter(set='train').count(), 0)
self.assertEqual(self.dataset.dataset_elements.filter(set='validation').count(), 1)