diff --git a/arkindex/documents/tests/test_classes.py b/arkindex/documents/tests/test_classes.py index e75d36d8b32ed346dbfe18669797560f93f13f92..b9ecef256ae37ef2a6514b54683d9105152a2057 100644 --- a/arkindex/documents/tests/test_classes.py +++ b/arkindex/documents/tests/test_classes.py @@ -12,30 +12,30 @@ class TestClasses(FixtureAPITestCase): @classmethod def setUpTestData(cls): super().setUpTestData() - cls.text = MLClass.objects.create(name='text', corpus=cls.corpus) - cls.cover = MLClass.objects.create(name='cover', corpus=cls.corpus) + cls.text = cls.corpus.ml_classes.create(name='text') + cls.cover = cls.corpus.ml_classes.create(name='cover') + cls.classified = cls.corpus.types.create(slug='classified', folder=True) + cls.folder_type = cls.corpus.types.create(slug='folder', folder=True) + + cls.parent = cls.corpus.elements.create(type=cls.folder_type, name='parent') + cls.common_children = cls.corpus.elements.create(type=cls.folder_type, name='common_children') - def populate_classified_elements(self): - self.folder_type = self.corpus.types.create(slug='folder', folder=True) - self.parent = self.corpus.elements.create(type=self.folder_type) - self.common_children = self.corpus.elements.create(type=self.folder_type) + cls.version1 = WorkerVersion.objects.get(worker__slug='reco') + cls.version2 = WorkerVersion.objects.get(worker__slug='dla') - self.version1 = WorkerVersion.objects.get(worker__slug='reco') - self.version2 = WorkerVersion.objects.get(worker__slug='dla') for elt_num in range(1, 13): - elt = Element.objects.create( + elt = cls.corpus.elements.create( name='elt_{}'.format(elt_num), - type=self.classified, - corpus_id=self.corpus.id + type=cls.classified, ) - elt.add_parent(self.parent) - self.common_children.add_parent(elt) - for ml_class, score in zip((self.text, self.cover), (.7, .99)): - for worker_version in (self.version1, self.version2): + elt.add_parent(cls.parent) + cls.common_children.add_parent(elt) + for ml_class, score in ((cls.text, .7), (cls.cover, .99)): + for worker_version in (cls.version1, cls.version2): elt.classifications.create( worker_version=worker_version, - ml_class_id=ml_class.id, + ml_class=ml_class, confidence=score, high_confidence=bool(score == .99) ) @@ -57,7 +57,7 @@ class TestClasses(FixtureAPITestCase): { "id": str(self.cover.pk), "name": "cover", - "nb_best": 0 + "nb_best": 24 }, { "id": str(self.text.pk), @@ -82,7 +82,7 @@ class TestClasses(FixtureAPITestCase): { "id": str(self.cover.pk), "name": "cover", - "nb_best": 0 + "nb_best": 24 }, { "id": str(self.image.pk), @@ -115,7 +115,6 @@ class TestClasses(FixtureAPITestCase): """ Test nb_best attribute on ListCorpusMLClasses endpoint """ - self.populate_classified_elements() self.client.force_login(self.user) with self.assertNumQueries(5): response = self.client.get(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk}), {}) @@ -140,7 +139,6 @@ class TestClasses(FixtureAPITestCase): def test_list_classes_search(self): buffer_class = MLClass.objects.create(name='buffer overflow', corpus=self.corpus) - self.populate_classified_elements() self.client.force_login(self.user) with self.assertNumQueries(5): response = self.client.get( @@ -173,7 +171,6 @@ class TestClasses(FixtureAPITestCase): @override_settings(SEARCH_FILTER_MAX_TERMS=3) def test_search_filter_limit(self): self.corpus.ml_classes.create(name='buffer overflow') - self.populate_classified_elements() self.client.force_login(self.user) response = self.client.get( @@ -218,7 +215,6 @@ class TestClasses(FixtureAPITestCase): }) def test_list_elements_db_queries(self): - self.populate_classified_elements() with self.assertNumQueries(3): response = self.client.get( reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), @@ -228,7 +224,6 @@ class TestClasses(FixtureAPITestCase): self.assertEqual(response.json()['count'], 12) def test_list_elements_best_classes(self): - self.populate_classified_elements() with self.assertNumQueries(4): response = self.client.get( reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), @@ -244,7 +239,6 @@ class TestClasses(FixtureAPITestCase): ) def test_list_elements_best_classes_false(self): - self.populate_classified_elements() with self.assertNumQueries(3): response = self.client.get( reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}), @@ -258,7 +252,6 @@ class TestClasses(FixtureAPITestCase): self.assertIsNone(elt['best_classes']) def test_element_parents_best_classes(self): - self.populate_classified_elements() with self.assertNumQueries(4): response = self.client.get( reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), @@ -274,7 +267,6 @@ class TestClasses(FixtureAPITestCase): ) def test_element_children_best_classes(self): - self.populate_classified_elements() with self.assertNumQueries(5): response = self.client.get( reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}), @@ -293,7 +285,6 @@ class TestClasses(FixtureAPITestCase): """ A machine classification that have been rejected by a human must not appear """ - self.populate_classified_elements() child = Element.objects.filter(type=self.classified.id).first() child.classifications.all().update(state=ClassificationState.Rejected) with self.assertNumQueries(5): @@ -314,7 +305,6 @@ class TestClasses(FixtureAPITestCase): """ A non best class validated by a human is considered as best as it is for the human """ - self.populate_classified_elements() parent = Element.objects.get_ascending(self.common_children.id).last() parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated) response = self.client.get( @@ -342,7 +332,6 @@ class TestClasses(FixtureAPITestCase): """ A manual classification rejected by a human may not appear in best classes """ - self.populate_classified_elements() element = Element.objects.filter(type=self.classified.id).first() classif = element.classifications.create( ml_class_id=self.text.id, @@ -387,7 +376,6 @@ class TestClasses(FixtureAPITestCase): ) def test_class_filter_list_elements(self): - self.populate_classified_elements() element = Element.objects.filter(type=self.classified.id).first() element.classifications.create( ml_class=self.text, @@ -405,9 +393,9 @@ class TestClasses(FixtureAPITestCase): self.assertEqual(data['results'][0]['id'], str(element.id)) def test_class_filter_list_parents(self): - self.populate_classified_elements() - parent = Element.objects.get_ascending(self.common_children.id).last() - parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated) + parent = Element.objects.get_ascending(self.common_children.id).get(name='elt_1') + self.assertEqual(parent.classifications.filter(ml_class=self.text).count(), 2) + parent.classifications.filter(ml_class=self.text).update(state=ClassificationState.Validated) with self.assertNumQueries(3): response = self.client.get( reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), @@ -419,7 +407,6 @@ class TestClasses(FixtureAPITestCase): self.assertEqual(data['results'][0]['id'], str(parent.id)) def test_class_filter_list_children(self): - self.populate_classified_elements() child = Element.objects.filter(type=self.classified.id).first() child.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated) with self.assertNumQueries(4): @@ -433,7 +420,6 @@ class TestClasses(FixtureAPITestCase): self.assertEqual(data['results'][0]['id'], str(child.id)) def test_class_filter_list_elements_distinct(self): - self.populate_classified_elements() self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) with self.assertNumQueries(3): @@ -449,7 +435,6 @@ class TestClasses(FixtureAPITestCase): self.assertCountEqual(ids, set(ids)) def test_class_filter_list_parents_distinct(self): - self.populate_classified_elements() self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) with self.assertNumQueries(3): @@ -465,7 +450,6 @@ class TestClasses(FixtureAPITestCase): self.assertCountEqual(ids, set(ids)) def test_class_filter_list_children_distinct(self): - self.populate_classified_elements() self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) with self.assertNumQueries(4): @@ -481,7 +465,6 @@ class TestClasses(FixtureAPITestCase): self.assertCountEqual(ids, set(ids)) def test_class_filter_true(self): - self.populate_classified_elements() element = Element.objects.filter(type=self.classified.id).first() element.classifications.all().delete() element.classifications.create( @@ -506,7 +489,6 @@ class TestClasses(FixtureAPITestCase): self.assertSetEqual(best_class_ids, {str(self.text.id), str(self.cover.id)}) def test_class_filter_false(self): - self.populate_classified_elements() element = Element.objects.filter(type=self.classified.id).first() element.classifications.all().delete() element.classifications.create(