Skip to content
Snippets Groups Projects

Remove source field from creation endpoints

Merged Valentin Rigal requested to merge remove-sources into master
6 files
+ 104
588
Compare changes
  • Side-by-side
  • Inline
Files
6
@@ -62,71 +62,6 @@ class TestBulkClassification(FixtureAPITestCase):
}
)
def test_bulk_classification_requires_source_xor_worker_version(self):
"""
A classifier data source XOR a worker_version is required to push classifications on an element
"""
wrong_payloads = (
{'classifier': self.src.slug, 'worker_version': self.worker_version.id},
{'classifier': None, 'worker_version': None},
{'classifier': ''},
{}
)
self.client.force_login(self.user)
for payload in wrong_payloads:
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
'parent': str(self.page.id),
'classifications': [{'class_name': 'cat', 'confidence': 0.42}],
**payload
}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
if payload.get('classifier') and payload.get('worker_version'):
self.assertDictEqual(response.json(), {
'classifier': ['This field XOR worker_version field must be set to create classifications'],
'worker_version': ['This field XOR classifier field must be set to create classifications']
})
def test_bulk_classification_source(self):
"""
Bulk classifications are created using an existing classifier source
"""
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data={
"parent": str(self.page.id),
"classifier": self.src.slug,
"classifications": [{
"class_name": 'dog',
"confidence": 0.99,
"high_confidence": True
}, {
"class_name": 'cat',
"confidence": 0.42,
}]
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertCountEqual(
list(self.page.classifications.values_list(
'ml_class__name',
'confidence',
'high_confidence',
'source',
'worker_version'
)),
[
('dog', 0.99, True, self.src.id, None),
('cat', 0.42, False, self.src.id, None)
],
)
def test_bulk_classification_worker_version(self):
"""
Classifications are created and linked to a worker version
@@ -255,5 +190,5 @@ class TestBulkClassification(FixtureAPITestCase):
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'classifications': ['Duplicated ML classes are not allowed from the same source or worker version.']
'classifications': ['Duplicated ML classes are not allowed from the same worker version.']
})
Loading