Skip to content
Snippets Groups Projects
Commit 5075f131 authored by Valentin Rigal's avatar Valentin Rigal Committed by Erwan Rouchet
Browse files

Restrict usage of CreateClassifications

parent 30283405
No related branches found
No related tags found
1 merge request!1986Restrict usage of CreateClassifications
......@@ -542,10 +542,14 @@ class ClassificationsSerializer(serializers.Serializer):
style={'base_template': 'input.html'},
)
worker_version = ForbiddenField()
worker_run_id = serializers.PrimaryKeyRelatedField(
queryset=WorkerRun.objects.all(),
style={'base_template': 'input.html'},
source='worker_run',
worker_run_id = WorkerRunIDField(
help_text=dedent("""
A WorkerRun ID that the classifications will refer to.
Regular users may only use the WorkerRuns of their own `Local` process.
Tasks authenticated via the Ponos task authentication may only use the WorkerRuns of their process.
""").strip(),
)
classifications = ClassificationBulkSerializer(many=True, allow_empty=False)
......
......@@ -15,6 +15,7 @@ class TestBulkClassification(FixtureAPITestCase):
cls.private_corpus = Corpus.objects.create(name='private', public=False)
cls.worker_version = WorkerVersion.objects.get(worker__slug='reco')
cls.worker_run = cls.worker_version.worker_runs.filter(process__mode=ProcessMode.Workers).get()
cls.local_worker_run = cls.worker_version.worker_runs.filter(process__mode=ProcessMode.Local).get()
cls.dog_class = cls.corpus.ml_classes.create(name='dog')
cls.cat_class = cls.corpus.ml_classes.create(name='cat')
......@@ -23,17 +24,21 @@ class TestBulkClassification(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_wrong_acl(self):
"""
The user must have access to the parent element
"""
self.client.force_login(self.user)
private_page = self.private_corpus.elements.create(
type=self.private_corpus.types.create(slug='page'),
)
local_worker_run = self.user.processes.get(mode=ProcessMode.Local).worker_runs.get()
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
'parent': str(private_page.id),
'worker_run_id': str(self.worker_run.id),
'worker_run_id': str(local_worker_run.id),
'classifications': [
{
'ml_class': str(self.dog_class.id),
......@@ -51,9 +56,9 @@ class TestBulkClassification(FixtureAPITestCase):
}
)
def test_worker_version(self):
def test_worker_run_required(self):
"""
Classifications cannot be linked to a worker version
Classifications must be linked to a worker run
"""
self.client.force_login(self.user)
with self.assertNumQueries(5):
......@@ -82,119 +87,6 @@ class TestBulkClassification(FixtureAPITestCase):
'worker_version': ['This field is forbidden.'],
})
def test_worker_version_or_worker_run(self):
"""Either a worker run or a worker version is required"""
self.client.force_login(self.user)
with self.assertNumQueries(5):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
"parent": str(self.page.id),
"classifications": [
{
'ml_class': str(self.cat_class.id),
"confidence": 0.42,
}
]
}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_run_id': ['This field is required.'],
})
def test_worker_version_and_worker_run(self):
"""Worker run and worker version cannot be set at the same time"""
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
"parent": str(self.page.id),
"classifications": [
{
'ml_class': str(self.cat_class.id),
"confidence": 0.42,
}
],
"worker_run_id": str(self.worker_run.id),
"worker_version": str(self.worker_version.id),
}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_version': ['This field is forbidden.'],
})
def test_worker_run(self):
self.client.force_login(self.user)
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
"parent": str(self.page.id),
"classifications": [
{
'ml_class': str(self.dog_class.id),
"confidence": 0.99,
"high_confidence": True
},
{
'ml_class': str(self.cat_class.id),
"confidence": 0.42,
}
],
"worker_run_id": str(self.worker_run.id),
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
first_cl, second_cl = self.page.classifications.order_by('-confidence').all()
self.assertEqual(response.json(), {
'parent': str(self.page.id),
'worker_run_id': str(self.worker_run.id),
'classifications': [
{
'id': str(first_cl.id),
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
'high_confidence': True,
'state': 'pending',
},
{
'id': str(second_cl.id),
'ml_class': str(self.cat_class.id),
'confidence': 0.42,
'high_confidence': False,
'state': 'pending',
},
]
})
self.assertCountEqual(
list(self.page.classifications.values_list(
'ml_class__name',
'confidence',
'high_confidence',
'worker_version_id',
'worker_run_id',
)),
[
('dog', 0.99, True, self.worker_version.id, self.worker_run.id),
('cat', 0.42, False, self.worker_version.id, self.worker_run.id),
],
)
# Worker run is set, and worker version is deduced from it
self.assertEqual(first_cl.worker_version, self.worker_version)
self.assertEqual(second_cl.worker_version, self.worker_version)
self.assertEqual(first_cl.worker_run, self.worker_run)
self.assertEqual(second_cl.worker_run, self.worker_run)
def test_worker_run_not_found(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
......@@ -225,15 +117,15 @@ class TestBulkClassification(FixtureAPITestCase):
def test_ml_class_not_found(self):
self.dog_class.delete()
self.client.force_login(self.user)
self.client.force_login(self.superuser)
with self.assertNumQueries(7):
with self.assertNumQueries(5):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
'parent': str(self.page.id),
'worker_run_id': str(self.worker_run.id),
'worker_run_id': str(self.local_worker_run.id),
'classifications': [
{
'ml_class': "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
......@@ -252,8 +144,8 @@ class TestBulkClassification(FixtureAPITestCase):
"""
Test the bulk classification API deletes previous classifications with the same worker run
"""
self.client.force_login(self.user)
with self.assertNumQueries(9):
self.client.force_login(self.superuser)
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
......@@ -270,7 +162,7 @@ class TestBulkClassification(FixtureAPITestCase):
"confidence": 0.42,
}
],
"worker_run_id": str(self.worker_run.id),
'worker_run_id': str(self.local_worker_run.id),
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
......@@ -295,7 +187,7 @@ class TestBulkClassification(FixtureAPITestCase):
"high_confidence": True,
},
],
'worker_run_id': str(self.worker_run.id),
'worker_run_id': str(self.local_worker_run.id),
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
......@@ -312,14 +204,14 @@ class TestBulkClassification(FixtureAPITestCase):
"""
Test the bulk classification API prevents creating classifications with duplicate ML classes
"""
self.client.force_login(self.user)
with self.assertNumQueries(7):
self.client.force_login(self.superuser)
with self.assertNumQueries(5):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
'parent': str(self.page.id),
'worker_run_id': str(self.worker_run.id),
'worker_run_id': str(self.local_worker_run.id),
'classifications': [
{
'ml_class': str(self.dog_class.id),
......@@ -336,3 +228,255 @@ class TestBulkClassification(FixtureAPITestCase):
self.assertDictEqual(response.json(), {
'classifications': ['Duplicated ML classes are not allowed from the same worker run.']
})
def test_worker_run_non_local(self):
"""
A regular user cannot create classifications with a WorkerRun of a non-local process
"""
self.client.force_login(self.superuser)
with self.assertNumQueries(4):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
'parent': str(self.page.id),
'worker_run_id': str(self.worker_run.id),
'classifications': [
{
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
},
]
}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_run_id': [
"Ponos task authentication is required to use a WorkerRun "
"of a process other than the user's local process."
]
})
def test_worker_run_other_user(self):
"""
A regular user cannot create classifications with a WorkerRun of someone else's local process
"""
worker_run = self.user.processes.get(mode=ProcessMode.Local).worker_runs.first()
self.client.force_login(self.superuser)
with self.assertNumQueries(4):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
'parent': str(self.page.id),
'worker_run_id': str(worker_run.id),
'classifications': [
{
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
},
]
}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_run_id': [
"Ponos task authentication is required to use a WorkerRun "
"of a process other than the user's local process."
]
})
def test_worker_run_other_process(self):
"""
A Ponos task cannot create classifications with a WorkerRun of another process
"""
process2 = self.worker_run.process.creator.processes.create(
mode=ProcessMode.Workers,
corpus=self.corpus,
)
other_worker_run = process2.worker_runs.create(version=self.worker_run.version, parents=[])
self.worker_run.process.start()
task = self.worker_run.process.workflow.tasks.first()
with self.assertNumQueries(5):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
'parent': str(self.page.id),
'worker_run_id': str(other_worker_run.id),
'classifications': [
{
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
},
]
},
HTTP_AUTHORIZATION=f'Ponos {task.token}',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_run_id': [
"Only the WorkerRuns of the authenticated task's process may be used."
]
})
def test_create_local(self):
"""
A regular user can create classifications with a WorkerRun of their own local process
"""
self.client.force_login(self.superuser)
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
"parent": str(self.page.id),
"classifications": [
{
'ml_class': str(self.dog_class.id),
"confidence": 0.99,
"high_confidence": True
},
{
'ml_class': str(self.cat_class.id),
"confidence": 0.42,
},
],
"worker_run_id": str(self.local_worker_run.id),
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
first_cl, second_cl = self.page.classifications.order_by('-confidence').all()
self.assertEqual(response.json(), {
'parent': str(self.page.id),
'worker_run_id': str(self.local_worker_run.id),
'classifications': [
{
'id': str(first_cl.id),
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
'high_confidence': True,
'state': 'pending',
},
{
'id': str(second_cl.id),
'ml_class': str(self.cat_class.id),
'confidence': 0.42,
'high_confidence': False,
'state': 'pending',
},
]
})
self.assertCountEqual(
list(self.page.classifications.values_list(
'ml_class__name',
'confidence',
'high_confidence',
'worker_version_id',
'worker_run_id',
)),
[
('dog', 0.99, True, self.worker_version.id, self.local_worker_run.id),
('cat', 0.42, False, self.worker_version.id, self.local_worker_run.id),
],
)
# Worker run is set, and worker version is deduced from it
self.assertEqual(first_cl.worker_version, self.worker_version)
self.assertEqual(second_cl.worker_version, self.worker_version)
self.assertEqual(first_cl.worker_run, self.local_worker_run)
self.assertEqual(second_cl.worker_run, self.local_worker_run)
def test_create_task_auth(self):
"""
Classifications can be created with a WorkerRun of a non-local process
when authenticated as a Ponos task of this process
"""
self.worker_run.process.start()
task = self.worker_run.process.workflow.tasks.first()
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
"parent": str(self.page.id),
"classifications": [
{
'ml_class': str(self.dog_class.id),
"confidence": 0.99,
"high_confidence": True
},
],
"worker_run_id": str(self.worker_run.id),
},
HTTP_AUTHORIZATION=f'Ponos {task.token}',
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
ml_class = self.page.classifications.get()
self.assertEqual(response.json(), {
'parent': str(self.page.id),
'worker_run_id': str(self.worker_run.id),
'classifications': [
{
'id': str(ml_class.id),
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
'high_confidence': True,
'state': 'pending',
},
]
})
self.assertEqual(ml_class.worker_version, self.worker_version)
self.assertEqual(ml_class.worker_run, self.worker_run)
def test_worker_run_local_task_auth(self):
"""
Classifications can be created with a WorkerRun of a Local process
even when authenticated as a Ponos task from a different process
"""
local_process = self.user.processes.get(mode=ProcessMode.Local)
local_worker_run = local_process.worker_runs.get()
self.worker_run.process.start()
task = self.worker_run.process.workflow.tasks.first()
self.assertNotEqual(self.worker_run.process_id, local_worker_run.process_id)
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
"parent": str(self.page.id),
"classifications": [
{
'ml_class': str(self.dog_class.id),
"confidence": 0.99,
"high_confidence": True
},
],
"worker_run_id": str(local_worker_run.id),
},
HTTP_AUTHORIZATION=f'Ponos {task.token}',
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
ml_class = self.page.classifications.get()
self.assertEqual(response.json(), {
'parent': str(self.page.id),
'worker_run_id': str(local_worker_run.id),
'classifications': [
{
'id': str(ml_class.id),
'ml_class': str(self.dog_class.id),
'confidence': 0.99,
'high_confidence': True,
'state': 'pending',
},
]
})
self.assertEqual(ml_class.worker_version, local_worker_run.version)
self.assertEqual(ml_class.worker_run, local_worker_run)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment