Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
Backend
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Container Registry
Analyze
Contributor analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Arkindex
Backend
Commits
2dae2711
Commit
2dae2711
authored
3 years ago
by
Bastien Abadie
Browse files
Options
Downloads
Plain Diff
Merge branch 'class-filters' into 'master'
Replace best_class with classification filters Closes
#908
See merge request
!1566
parents
96560e0b
a162c210
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!1566
Replace best_class with classification filters
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
arkindex/documents/api/elements.py
+120
-51
120 additions, 51 deletions
arkindex/documents/api/elements.py
arkindex/documents/tests/test_classes.py
+123
-150
123 additions, 150 deletions
arkindex/documents/tests/test_classes.py
with
243 additions
and
201 deletions
arkindex/documents/api/elements.py
+
120
−
51
View file @
2dae2711
...
...
@@ -43,6 +43,7 @@ from arkindex.documents.models import (
ElementType
,
MetaData
,
MetaType
,
MLClass
,
Selection
,
Transcription
,
)
...
...
@@ -131,13 +132,20 @@ def _fetch_has_children(elements):
return
elements
# Operators available for numeric filters in element lists
# Maps valid operator names in the API to Django QuerySet lookups
METADATA
_OPERATORS
=
{
NUMERIC
_OPERATORS
=
{
'
eq
'
:
'
exact
'
,
'
lt
'
:
'
lt
'
,
'
gt
'
:
'
gt
'
,
'
lte
'
:
'
lte
'
,
'
gte
'
:
'
gte
'
,
}
# Operators available for metadata values.
METADATA_OPERATORS
=
{
# Only for numeric metadata
**
NUMERIC_OPERATORS
,
# The contains operator should be case-insensitive
'
contains
'
:
'
icontains
'
,
}
...
...
@@ -202,17 +210,6 @@ class ElementsListAutoSchema(AutoSchema):
# Add method-specific parameters
if
self
.
method
.
lower
()
==
'
get
'
:
parameters
.
extend
([
OpenApiParameter
(
'
best_class
'
,
description
=
'
Restrict to or exclude elements with a best class,
'
'
or restrict to elements with specific best class
'
,
type
=
{
'
oneOf
'
:
[
{
'
type
'
:
'
string
'
,
'
format
'
:
'
uuid
'
},
{
'
type
'
:
'
boolean
'
}
]
}
),
OpenApiParameter
(
'
metadata_name
'
,
description
=
'
Restrict to elements having a metadata with the given name.
'
,
...
...
@@ -221,7 +218,7 @@ class ElementsListAutoSchema(AutoSchema):
OpenApiParameter
(
'
metadata_value
'
,
description
=
'
Restrict to elements having a metadata with the given value.
'
'
Can be set it exclude elements with a value or filter by numerical values
using `metadata_operator`.
'
'
The comparison operator can be set
using `metadata_operator`.
'
'
Requires `metadata_name` to be set.
'
,
required
=
False
,
),
...
...
@@ -243,6 +240,47 @@ class ElementsListAutoSchema(AutoSchema):
default
=
'
eq
'
,
required
=
False
,
),
OpenApiParameter
(
'
class_id
'
,
description
=
'
Restrict to elements having a classification with the specified ML class ID.
'
'
If `classification_confidence` or `classification_high_confidence` are set,
'
'
the elements must have a classification that satisfies all of the parameters at once.
'
,
type
=
UUID
,
required
=
False
,
),
OpenApiParameter
(
'
classification_confidence
'
,
description
=
'
Restrict to elements having a classification with the given confidence.
'
'
The comparison operator can be set using `classification_confidence_operator`.
'
'
If `class_id` or `classification_high_confidence` are set, the elements must have a
'
'
classification that satisfies all of the parameters at once.
'
,
required
=
False
,
),
OpenApiParameter
(
'
classification_confidence_operator
'
,
description
=
dedent
(
"""
Set the comparison operator to filter on classification confidence scores:
* `eq` (default): Elements having a classification with this exact confidence.
* `lt`: Elements having a classification with a confidence strictly lower than the filter.
* `lte`: Elements having a classification with a confidence lower than or equal to the filter.
* `lt`: Elements having a classification with a confidence strictly greater than the filter.
* `gte`: Elements having a classification with a confidence greather than or equal to the filter.
This requires `classification_confidence` to be set.
"""
),
enum
=
NUMERIC_OPERATORS
.
keys
(),
default
=
'
eq
'
,
required
=
False
,
),
OpenApiParameter
(
'
classification_high_confidence
'
,
description
=
'
Restrict to elements having a classification marked as `high_confidence`.
'
'
If `class_id` or `classification_confidence` are set, the elements must have a
'
'
classification that satisfies all of the parameters at once.
'
,
type
=
bool
,
required
=
False
,
),
OpenApiParameter
(
'
order
'
,
description
=
'
Sort elements by a specific field
'
,
...
...
@@ -260,7 +298,7 @@ class ElementsListAutoSchema(AutoSchema):
OpenApiParameter
(
'
with_best_classes
'
,
description
=
'
Returns best classifications for each element.
'
'
If not
se
t
,
elements
best_classes
field
will always be null
'
,
'
Otherwi
se,
`
best_classes
`
will always be null
.
'
,
type
=
bool
,
required
=
False
,
),
...
...
@@ -460,6 +498,66 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
return
queryset
.
values
(
'
element_id
'
)
def
get_classification_queryset
(
self
):
"""
Returns a queryset that includes matched element IDs from classification filters,
or None if no classification filters apply.
"""
class_id
=
self
.
clean_params
.
get
(
'
class_id
'
)
confidence
=
self
.
clean_params
.
get
(
'
classification_confidence
'
)
confidence_operator
=
self
.
clean_params
.
get
(
'
classification_confidence_operator
'
,
''
).
lower
().
strip
()
high_confidence
=
self
.
clean_params
.
get
(
'
classification_high_confidence
'
,
''
).
lower
().
strip
()
if
len
(
high_confidence
):
high_confidence
=
high_confidence
not
in
(
'
false
'
,
'
0
'
)
else
:
# An empty string should be treated as no filter at all
high_confidence
=
None
if
not
class_id
and
confidence
is
None
and
not
confidence_operator
and
high_confidence
is
None
:
# No filters apply
return
None
queryset
=
Classification
.
objects
.
all
()
errors
=
defaultdict
(
list
)
if
class_id
:
try
:
ml_class
=
self
.
selected_corpus
.
ml_classes
.
get
(
id
=
class_id
)
except
DjangoValidationError
as
e
:
# An invalid UUID would cause a Django ValidationError
errors
[
'
class_id
'
].
extend
(
e
.
messages
)
except
MLClass
.
DoesNotExist
:
errors
[
'
class_id
'
].
append
(
f
'
ML class
"
{
class_id
}
"
not found
'
)
else
:
queryset
=
queryset
.
filter
(
ml_class
=
ml_class
)
if
confidence_operator
:
if
confidence_operator
not
in
NUMERIC_OPERATORS
:
errors
[
'
classification_confidence_operator
'
].
append
(
'
This operator is not supported.
'
)
if
confidence
is
None
:
errors
[
'
classification_confidence_operator
'
].
append
(
'
This option is not supported without classification_confidence.
'
)
else
:
confidence_operator
=
'
eq
'
if
confidence
:
try
:
confidence
=
float
(
confidence
)
assert
0
<=
confidence
<=
1
,
'
Confidence must be between 0 and 1
'
except
(
TypeError
,
ValueError
,
AssertionError
)
as
e
:
errors
[
'
classification_confidence
'
].
append
(
str
(
e
))
else
:
lookup
=
NUMERIC_OPERATORS
.
get
(
confidence_operator
,
'
exact
'
)
queryset
=
queryset
.
filter
(
**
{
f
'
confidence__
{
lookup
}
'
:
confidence
})
if
high_confidence
is
not
None
:
queryset
=
queryset
.
filter
(
high_confidence
=
high_confidence
)
if
errors
:
raise
ValidationError
(
errors
)
return
queryset
.
values
(
'
element_id
'
)
def
get_filters
(
self
):
filters
=
{
'
corpus
'
:
self
.
selected_corpus
...
...
@@ -500,43 +598,19 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
if
metadata_queryset
is
not
None
:
filters
[
'
id__in
'
]
=
metadata_queryset
try
:
classification_queryset
=
self
.
get_classification_queryset
()
except
ValidationError
as
e
:
errors
.
update
(
e
.
detail
)
else
:
if
classification_queryset
is
not
None
:
filters
[
'
id__in
'
]
=
classification_queryset
if
errors
:
raise
ValidationError
(
errors
)
return
filters
def
get_classifications_filters
(
self
):
"""
Build Django ORM filters using Q expressions related to best classes.
Supports 3 modes:
- elements without any best classes
- elements with any best classes
- elements with a specific best class
"""
class_filter
=
self
.
clean_params
.
get
(
'
best_class
'
)
if
class_filter
is
None
:
return
# Generic ORM query to find best classes:
# - elements with a validated classification
# - OR where high confidence is True and not rejected
best_classifications
=
Q
(
classifications__state
=
ClassificationState
.
Validated
)
|
(
Q
(
classifications__high_confidence
=
True
)
&
~
Q
(
classifications__state
=
ClassificationState
.
Rejected
)
)
# List elements without any best classes, by inverting the query above
if
class_filter
.
lower
()
in
(
'
false
'
,
'
0
'
):
return
~
best_classifications
try
:
# Filter on a specific class
class_filter
=
uuid
.
UUID
(
class_filter
)
return
best_classifications
&
Q
(
classifications__ml_class_id
=
class_filter
)
except
(
TypeError
,
ValueError
):
# By default, use all best classifications
return
best_classifications
def
get_serializer_context
(
self
):
context
=
super
().
get_serializer_context
()
context
[
'
corpus
'
]
=
self
.
selected_corpus
...
...
@@ -577,11 +651,6 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
.
prefetch_related
(
*
self
.
get_prefetch
())
\
.
order_by
(
*
self
.
get_order_by
())
class_filters
=
self
.
get_classifications_filters
()
if
class_filters
is
not
None
:
# Use queryset.distinct() whenever best_class is defined
queryset
=
queryset
.
filter
(
class_filters
).
distinct
()
with_has_children
=
self
.
clean_params
.
get
(
'
with_has_children
'
)
if
with_has_children
and
with_has_children
.
lower
()
not
in
(
'
false
'
,
'
0
'
):
queryset
=
BulkMap
(
_fetch_has_children
,
queryset
)
...
...
This diff is collapsed.
Click to expand it.
arkindex/documents/tests/test_classes.py
+
123
−
150
View file @
2dae2711
...
...
@@ -424,165 +424,22 @@ class TestClasses(FixtureAPITestCase):
[(
str
(
self
.
version1
.
id
),
.
99
),
(
str
(
self
.
version2
.
id
),
.
99
)]
)
def
test_class_filter_list_elements
(
self
):
element
=
Element
.
objects
.
filter
(
type
=
self
.
classified
.
id
).
first
()
element
.
classifications
.
create
(
ml_class
=
self
.
text
,
confidence
=
.
1337
,
high_confidence
=
True
,
)
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
get
(
reverse
(
'
api:corpus-elements
'
,
kwargs
=
{
'
corpus
'
:
self
.
corpus
.
id
}),
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
best_class
'
:
str
(
self
.
text
.
id
)}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
data
=
response
.
json
()
self
.
assertEqual
(
data
[
'
count
'
],
1
)
self
.
assertEqual
(
data
[
'
results
'
][
0
][
'
id
'
],
str
(
element
.
id
))
def
test_class_filter_list_parents
(
self
):
parent
=
Element
.
objects
.
get_ascending
(
self
.
common_children
.
id
).
get
(
name
=
'
elt_1
'
)
self
.
assertEqual
(
parent
.
classifications
.
filter
(
ml_class
=
self
.
text
).
count
(),
2
)
parent
.
classifications
.
filter
(
ml_class
=
self
.
text
).
update
(
state
=
ClassificationState
.
Validated
)
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
get
(
reverse
(
'
api:elements-parents
'
,
kwargs
=
{
'
pk
'
:
str
(
self
.
common_children
.
id
)}),
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
best_class
'
:
str
(
self
.
text
.
id
)}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
data
=
response
.
json
()
self
.
assertEqual
(
data
[
'
count
'
],
1
)
self
.
assertEqual
(
data
[
'
results
'
][
0
][
'
id
'
],
str
(
parent
.
id
))
def
test_class_filter_list_children
(
self
):
child
=
Element
.
objects
.
filter
(
type
=
self
.
classified
.
id
).
first
()
child
.
classifications
.
all
().
filter
(
confidence
=
.
7
).
update
(
state
=
ClassificationState
.
Validated
)
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
get
(
reverse
(
'
api:elements-children
'
,
kwargs
=
{
'
pk
'
:
str
(
self
.
parent
.
id
)}),
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
best_class
'
:
str
(
self
.
text
.
id
)}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
data
=
response
.
json
()
self
.
assertEqual
(
data
[
'
count
'
],
1
)
self
.
assertEqual
(
data
[
'
results
'
][
0
][
'
id
'
],
str
(
child
.
id
))
def
test_class_filter_list_elements_distinct
(
self
):
self
.
assertEqual
(
Classification
.
objects
.
filter
(
high_confidence
=
True
).
count
(),
24
)
self
.
assertEqual
(
Classification
.
objects
.
filter
(
high_confidence
=
True
).
distinct
(
'
element_id
'
).
count
(),
12
)
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
get
(
reverse
(
'
api:corpus-elements
'
,
kwargs
=
{
'
corpus
'
:
self
.
corpus
.
id
}),
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
best_class
'
:
str
(
self
.
cover
.
id
)}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
data
=
response
.
json
()
self
.
assertEqual
(
data
[
'
count
'
],
12
)
# Ensure element IDs are unique
ids
=
[
e
[
'
id
'
]
for
e
in
data
[
'
results
'
]]
self
.
assertCountEqual
(
ids
,
set
(
ids
))
def
test_class_filter_list_parents_distinct
(
self
):
self
.
assertEqual
(
Classification
.
objects
.
filter
(
high_confidence
=
True
).
count
(),
24
)
self
.
assertEqual
(
Classification
.
objects
.
filter
(
high_confidence
=
True
).
distinct
(
'
element_id
'
).
count
(),
12
)
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
get
(
reverse
(
'
api:elements-parents
'
,
kwargs
=
{
'
pk
'
:
str
(
self
.
common_children
.
id
)}),
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
best_class
'
:
str
(
self
.
cover
.
id
)}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
data
=
response
.
json
()
self
.
assertEqual
(
data
[
'
count
'
],
12
)
# Ensure element IDs are unique
ids
=
[
e
[
'
id
'
]
for
e
in
data
[
'
results
'
]]
self
.
assertCountEqual
(
ids
,
set
(
ids
))
def
test_class_filter_list_children_distinct
(
self
):
self
.
assertEqual
(
Classification
.
objects
.
filter
(
high_confidence
=
True
).
count
(),
24
)
self
.
assertEqual
(
Classification
.
objects
.
filter
(
high_confidence
=
True
).
distinct
(
'
element_id
'
).
count
(),
12
)
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
get
(
reverse
(
'
api:elements-children
'
,
kwargs
=
{
'
pk
'
:
str
(
self
.
parent
.
id
)}),
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
best_class
'
:
str
(
self
.
cover
.
id
)}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
data
=
response
.
json
()
self
.
assertEqual
(
data
[
'
count
'
],
12
)
# Ensure element IDs are unique
ids
=
[
e
[
'
id
'
]
for
e
in
data
[
'
results
'
]]
self
.
assertCountEqual
(
ids
,
set
(
ids
))
def
test_class_filter_true
(
self
):
element
=
Element
.
objects
.
filter
(
type
=
self
.
classified
.
id
).
first
()
element
.
classifications
.
all
().
delete
()
element
.
classifications
.
create
(
worker_version
=
self
.
version2
,
ml_class_id
=
self
.
text
.
id
,
confidence
=
.
1337
,
high_confidence
=
True
,
)
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
get
(
reverse
(
'
api:corpus-elements
'
,
kwargs
=
{
'
corpus
'
:
self
.
corpus
.
id
}),
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
best_class
'
:
'
true
'
,
'
with_best_classes
'
:
'
true
'
},
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
data
=
response
.
json
()
self
.
assertEqual
(
data
[
'
count
'
],
12
)
best_class_ids
=
set
(
best_class
[
'
ml_class
'
][
'
id
'
]
for
element
in
data
[
'
results
'
]
for
best_class
in
element
[
'
best_classes
'
]
)
self
.
assertSetEqual
(
best_class_ids
,
{
str
(
self
.
text
.
id
),
str
(
self
.
cover
.
id
)})
def
test_class_filter_false
(
self
):
element
=
Element
.
objects
.
filter
(
type
=
self
.
classified
.
id
).
first
()
element
.
classifications
.
all
().
delete
()
element
.
classifications
.
create
(
worker_version
=
self
.
version2
,
ml_class_id
=
self
.
text
.
id
,
confidence
=
.
1337
,
high_confidence
=
False
,
)
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
get
(
reverse
(
'
api:corpus-elements
'
,
kwargs
=
{
'
corpus
'
:
self
.
corpus
.
id
}),
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
best_class
'
:
'
false
'
,
'
with_best_classes
'
:
'
true
'
}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
data
=
response
.
json
()
self
.
assertEqual
(
data
[
'
count
'
],
1
)
self
.
assertEqual
(
data
[
'
results
'
][
0
][
'
id
'
],
str
(
element
.
id
))
self
.
assertListEqual
(
data
[
'
results
'
][
0
][
'
best_classes
'
],
[])
def
test_exclude_rejected
(
self
):
def
test_with_best_classes_exclude_rejected
(
self
):
"""
Ensure that
best_classes and
with_best_classes ignore rejected high confidence classifications
Ensure that with_best_classes ignore
s
rejected high confidence classifications
"""
Classification
.
objects
.
all
().
delete
()
# One element with only a rejected classification; should be ignored
element1
=
Element
.
objects
.
get
(
type
=
self
.
classified
.
id
,
name
=
'
elt_1
'
)
element1
.
classifications
.
create
(
worker_version
=
self
.
version2
,
ml_class_id
=
self
.
text
.
id
,
confidence
=
.
1337
,
high_confidence
=
True
,
state
=
'
rejected
'
)
# One element with a rejected classification and a best class, should be included with 1 best class
element
2
=
Element
.
objects
.
get
(
type
=
self
.
classified
.
id
,
name
=
'
elt_2
'
)
element
2
.
classifications
.
create
(
element
=
Element
.
objects
.
get
(
type
=
self
.
classified
.
id
,
name
=
'
elt_2
'
)
element
.
classifications
.
create
(
worker_version
=
self
.
version1
,
ml_class_id
=
self
.
text
.
id
,
confidence
=
.
1337
,
high_confidence
=
True
,
state
=
'
rejected
'
)
expected_classification
=
element
2
.
classifications
.
create
(
expected_classification
=
element
.
classifications
.
create
(
worker_version
=
self
.
version2
,
ml_class_id
=
self
.
text
.
id
,
confidence
=
.
1337
,
...
...
@@ -592,13 +449,13 @@ class TestClasses(FixtureAPITestCase):
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
get
(
reverse
(
'
api:corpus-elements
'
,
kwargs
=
{
'
corpus
'
:
self
.
corpus
.
id
}),
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
best_class
'
:
'
true
'
,
'
with_best_classes
'
:
'
true
'
},
data
=
{
'
type
'
:
self
.
classified
.
slug
,
'
name
'
:
'
elt_2
'
,
'
with_best_classes
'
:
'
true
'
},
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
data
=
response
.
json
()
self
.
assertEqual
(
data
[
'
count
'
],
1
)
self
.
assertEqual
(
data
[
'
results
'
][
0
][
'
id
'
],
str
(
element
2
.
id
))
self
.
assertEqual
(
data
[
'
results
'
][
0
][
'
id
'
],
str
(
element
.
id
))
self
.
assertListEqual
(
data
[
'
results
'
][
0
][
'
best_classes
'
],
[
{
"
id
"
:
str
(
expected_classification
.
id
),
...
...
@@ -612,3 +469,119 @@ class TestClasses(FixtureAPITestCase):
}
}
])
def
test_element_lists_invalid_class_filters
(
self
):
corpus2
=
Corpus
.
objects
.
create
(
name
=
'
Corpus 2
'
)
other_class
=
corpus2
.
ml_classes
.
create
(
name
=
'
something
'
)
endpoints
=
[
reverse
(
'
api:corpus-elements
'
,
kwargs
=
{
'
corpus
'
:
self
.
corpus
.
id
}),
reverse
(
'
api:elements-children
'
,
kwargs
=
{
'
pk
'
:
str
(
self
.
parent
.
id
)}),
reverse
(
'
api:elements-parents
'
,
kwargs
=
{
'
pk
'
:
str
(
self
.
common_children
.
id
)}),
]
cases
=
[
({
'
class_id
'
:
'
lol
'
},
{
'
class_id
'
:
[
'
“lol” is not a valid UUID.
'
]}),
({
'
class_id
'
:
str
(
other_class
.
id
)},
{
'
class_id
'
:
[
f
'
ML class
"
{
other_class
.
id
}
"
not found
'
]}),
({
'
classification_confidence
'
:
'
very
'
},
{
'
classification_confidence
'
:
[
"
could not convert string to float:
'
very
'"
]}),
({
'
classification_confidence
'
:
'
nan
'
},
{
'
classification_confidence
'
:
[
'
Confidence must be between 0 and 1
'
]}),
({
'
classification_confidence
'
:
'
inf
'
},
{
'
classification_confidence
'
:
[
'
Confidence must be between 0 and 1
'
]}),
({
'
classification_confidence
'
:
'
42
'
},
{
'
classification_confidence
'
:
[
'
Confidence must be between 0 and 1
'
]}),
({
'
classification_confidence
'
:
'
-.5
'
},
{
'
classification_confidence
'
:
[
'
Confidence must be between 0 and 1
'
]}),
(
{
'
classification_confidence_operator
'
:
'
hah
'
},
{
'
classification_confidence_operator
'
:
[
'
This operator is not supported.
'
,
'
This option is not supported without classification_confidence.
'
,
]}
),
(
{
'
classification_confidence
'
:
'
0.3
'
,
'
classification_confidence_operator
'
:
'
lol
'
},
{
'
classification_confidence_operator
'
:
[
'
This operator is not supported.
'
]}
)
]
for
endpoint
in
endpoints
:
for
data
,
expected_errors
in
cases
:
with
self
.
subTest
(
endpoint
=
endpoint
,
**
data
):
response
=
self
.
client
.
get
(
endpoint
,
data
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_400_BAD_REQUEST
)
self
.
assertDictEqual
(
response
.
json
(),
expected_errors
)
def
test_element_lists_class_filters
(
self
):
# Create more diverse test data for these filters
Classification
.
objects
.
all
().
delete
()
for
i
in
range
(
1
,
7
):
self
.
corpus
.
elements
.
get
(
name
=
f
'
elt_
{
i
}
'
).
classifications
.
create
(
worker_version
=
self
.
version1
,
ml_class
=
self
.
text
,
confidence
=
i
/
6
,
# Give every even element a high confidence classification
high_confidence
=
i
%
2
==
0
)
# Second half of elements gets self.cover
for
i
in
range
(
7
,
13
):
self
.
corpus
.
elements
.
get
(
name
=
f
'
elt_
{
i
}
'
).
classifications
.
create
(
worker_version
=
self
.
version2
,
ml_class
=
self
.
cover
,
confidence
=
(
i
-
6
)
/
6
,
high_confidence
=
i
%
2
==
0
)
endpoints
=
[
reverse
(
'
api:corpus-elements
'
,
kwargs
=
{
'
corpus
'
:
self
.
corpus
.
id
}),
reverse
(
'
api:elements-children
'
,
kwargs
=
{
'
pk
'
:
str
(
self
.
parent
.
id
)}),
reverse
(
'
api:elements-parents
'
,
kwargs
=
{
'
pk
'
:
str
(
self
.
common_children
.
id
)}),
]
cases
=
[
(
{
'
class_id
'
:
str
(
self
.
text
.
id
)},
[
'
elt_1
'
,
'
elt_2
'
,
'
elt_3
'
,
'
elt_4
'
,
'
elt_5
'
,
'
elt_6
'
],
),
(
{
'
classification_confidence
'
:
'
0.5
'
},
[
'
elt_3
'
,
'
elt_9
'
],
),
(
{
'
classification_confidence
'
:
'
0.51
'
},
[],
),
(
{
'
classification_confidence
'
:
'
0.5
'
,
'
classification_confidence_operator
'
:
'
gt
'
},
[
'
elt_4
'
,
'
elt_5
'
,
'
elt_6
'
,
'
elt_10
'
,
'
elt_11
'
,
'
elt_12
'
]
),
(
{
'
classification_confidence
'
:
'
0.7
'
,
'
classification_confidence_operator
'
:
'
lte
'
},
[
'
elt_1
'
,
'
elt_2
'
,
'
elt_3
'
,
'
elt_4
'
,
'
elt_7
'
,
'
elt_8
'
,
'
elt_9
'
,
'
elt_10
'
],
),
(
{
'
class_id
'
:
str
(
self
.
cover
.
id
),
'
classification_confidence
'
:
'
0.7
'
},
[],
),
(
{
'
classification_high_confidence
'
:
True
},
[
'
elt_2
'
,
'
elt_4
'
,
'
elt_6
'
,
'
elt_8
'
,
'
elt_10
'
,
'
elt_12
'
],
),
(
{
'
classification_high_confidence
'
:
False
},
[
'
elt_1
'
,
'
elt_3
'
,
'
elt_5
'
,
'
elt_7
'
,
'
elt_9
'
,
'
elt_11
'
],
),
(
{
'
class_id
'
:
str
(
self
.
text
.
id
),
'
classification_high_confidence
'
:
True
},
[
'
elt_2
'
,
'
elt_4
'
,
'
elt_6
'
],
),
(
{
'
class_id
'
:
str
(
self
.
text
.
id
),
'
classification_high_confidence
'
:
False
},
[
'
elt_1
'
,
'
elt_3
'
,
'
elt_5
'
],
),
]
for
endpoint
in
endpoints
:
for
data
,
element_names
in
cases
:
with
self
.
subTest
(
endpoint
=
endpoint
,
**
data
):
response
=
self
.
client
.
get
(
endpoint
,
data
=
{
**
data
,
'
type
'
:
self
.
classified
.
slug
},
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
,
response
.
json
())
self
.
assertCountEqual
(
[
element
[
'
name
'
]
for
element
in
response
.
json
()[
'
results
'
]],
element_names
,
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment