Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
Backend
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Container Registry
Analyze
Contributor analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Arkindex
Backend
Commits
5075f131
Commit
5075f131
authored
1 year ago
by
Valentin Rigal
Committed by
Erwan Rouchet
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
Restrict usage of CreateClassifications
parent
30283405
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!1986
Restrict usage of CreateClassifications
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
arkindex/documents/serializers/ml.py
+8
-4
8 additions, 4 deletions
arkindex/documents/serializers/ml.py
arkindex/documents/tests/test_bulk_classification.py
+270
-126
270 additions, 126 deletions
arkindex/documents/tests/test_bulk_classification.py
with
278 additions
and
130 deletions
arkindex/documents/serializers/ml.py
+
8
−
4
View file @
5075f131
...
...
@@ -542,10 +542,14 @@ class ClassificationsSerializer(serializers.Serializer):
style
=
{
'
base_template
'
:
'
input.html
'
},
)
worker_version
=
ForbiddenField
()
worker_run_id
=
serializers
.
PrimaryKeyRelatedField
(
queryset
=
WorkerRun
.
objects
.
all
(),
style
=
{
'
base_template
'
:
'
input.html
'
},
source
=
'
worker_run
'
,
worker_run_id
=
WorkerRunIDField
(
help_text
=
dedent
(
"""
A WorkerRun ID that the classifications will refer to.
Regular users may only use the WorkerRuns of their own `Local` process.
Tasks authenticated via the Ponos task authentication may only use the WorkerRuns of their process.
"""
).
strip
(),
)
classifications
=
ClassificationBulkSerializer
(
many
=
True
,
allow_empty
=
False
)
...
...
This diff is collapsed.
Click to expand it.
arkindex/documents/tests/test_bulk_classification.py
+
270
−
126
View file @
5075f131
...
...
@@ -15,6 +15,7 @@ class TestBulkClassification(FixtureAPITestCase):
cls
.
private_corpus
=
Corpus
.
objects
.
create
(
name
=
'
private
'
,
public
=
False
)
cls
.
worker_version
=
WorkerVersion
.
objects
.
get
(
worker__slug
=
'
reco
'
)
cls
.
worker_run
=
cls
.
worker_version
.
worker_runs
.
filter
(
process__mode
=
ProcessMode
.
Workers
).
get
()
cls
.
local_worker_run
=
cls
.
worker_version
.
worker_runs
.
filter
(
process__mode
=
ProcessMode
.
Local
).
get
()
cls
.
dog_class
=
cls
.
corpus
.
ml_classes
.
create
(
name
=
'
dog
'
)
cls
.
cat_class
=
cls
.
corpus
.
ml_classes
.
create
(
name
=
'
cat
'
)
...
...
@@ -23,17 +24,21 @@ class TestBulkClassification(FixtureAPITestCase):
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_403_FORBIDDEN
)
def
test_wrong_acl
(
self
):
"""
The user must have access to the parent element
"""
self
.
client
.
force_login
(
self
.
user
)
private_page
=
self
.
private_corpus
.
elements
.
create
(
type
=
self
.
private_corpus
.
types
.
create
(
slug
=
'
page
'
),
)
local_worker_run
=
self
.
user
.
processes
.
get
(
mode
=
ProcessMode
.
Local
).
worker_runs
.
get
()
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
'
parent
'
:
str
(
private_page
.
id
),
'
worker_run_id
'
:
str
(
self
.
worker_run
.
id
),
'
worker_run_id
'
:
str
(
local_
worker_run
.
id
),
'
classifications
'
:
[
{
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
...
...
@@ -51,9 +56,9 @@ class TestBulkClassification(FixtureAPITestCase):
}
)
def
test_worker_
version
(
self
):
def
test_worker_
run_required
(
self
):
"""
Classifications
canno
t be linked to a worker
versio
n
Classifications
mus
t be linked to a worker
ru
n
"""
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
5
):
...
...
@@ -82,119 +87,6 @@ class TestBulkClassification(FixtureAPITestCase):
'
worker_version
'
:
[
'
This field is forbidden.
'
],
})
def
test_worker_version_or_worker_run
(
self
):
"""
Either a worker run or a worker version is required
"""
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
"
parent
"
:
str
(
self
.
page
.
id
),
"
classifications
"
:
[
{
'
ml_class
'
:
str
(
self
.
cat_class
.
id
),
"
confidence
"
:
0.42
,
}
]
}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_400_BAD_REQUEST
)
self
.
assertDictEqual
(
response
.
json
(),
{
'
worker_run_id
'
:
[
'
This field is required.
'
],
})
def
test_worker_version_and_worker_run
(
self
):
"""
Worker run and worker version cannot be set at the same time
"""
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
"
parent
"
:
str
(
self
.
page
.
id
),
"
classifications
"
:
[
{
'
ml_class
'
:
str
(
self
.
cat_class
.
id
),
"
confidence
"
:
0.42
,
}
],
"
worker_run_id
"
:
str
(
self
.
worker_run
.
id
),
"
worker_version
"
:
str
(
self
.
worker_version
.
id
),
}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_400_BAD_REQUEST
)
self
.
assertDictEqual
(
response
.
json
(),
{
'
worker_version
'
:
[
'
This field is forbidden.
'
],
})
def
test_worker_run
(
self
):
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
9
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
"
parent
"
:
str
(
self
.
page
.
id
),
"
classifications
"
:
[
{
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
"
confidence
"
:
0.99
,
"
high_confidence
"
:
True
},
{
'
ml_class
'
:
str
(
self
.
cat_class
.
id
),
"
confidence
"
:
0.42
,
}
],
"
worker_run_id
"
:
str
(
self
.
worker_run
.
id
),
}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_201_CREATED
)
first_cl
,
second_cl
=
self
.
page
.
classifications
.
order_by
(
'
-confidence
'
).
all
()
self
.
assertEqual
(
response
.
json
(),
{
'
parent
'
:
str
(
self
.
page
.
id
),
'
worker_run_id
'
:
str
(
self
.
worker_run
.
id
),
'
classifications
'
:
[
{
'
id
'
:
str
(
first_cl
.
id
),
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
'
confidence
'
:
0.99
,
'
high_confidence
'
:
True
,
'
state
'
:
'
pending
'
,
},
{
'
id
'
:
str
(
second_cl
.
id
),
'
ml_class
'
:
str
(
self
.
cat_class
.
id
),
'
confidence
'
:
0.42
,
'
high_confidence
'
:
False
,
'
state
'
:
'
pending
'
,
},
]
})
self
.
assertCountEqual
(
list
(
self
.
page
.
classifications
.
values_list
(
'
ml_class__name
'
,
'
confidence
'
,
'
high_confidence
'
,
'
worker_version_id
'
,
'
worker_run_id
'
,
)),
[
(
'
dog
'
,
0.99
,
True
,
self
.
worker_version
.
id
,
self
.
worker_run
.
id
),
(
'
cat
'
,
0.42
,
False
,
self
.
worker_version
.
id
,
self
.
worker_run
.
id
),
],
)
# Worker run is set, and worker version is deduced from it
self
.
assertEqual
(
first_cl
.
worker_version
,
self
.
worker_version
)
self
.
assertEqual
(
second_cl
.
worker_version
,
self
.
worker_version
)
self
.
assertEqual
(
first_cl
.
worker_run
,
self
.
worker_run
)
self
.
assertEqual
(
second_cl
.
worker_run
,
self
.
worker_run
)
def
test_worker_run_not_found
(
self
):
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
6
):
...
...
@@ -225,15 +117,15 @@ class TestBulkClassification(FixtureAPITestCase):
def
test_ml_class_not_found
(
self
):
self
.
dog_class
.
delete
()
self
.
client
.
force_login
(
self
.
user
)
self
.
client
.
force_login
(
self
.
super
user
)
with
self
.
assertNumQueries
(
7
):
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
'
parent
'
:
str
(
self
.
page
.
id
),
'
worker_run_id
'
:
str
(
self
.
worker_run
.
id
),
'
worker_run_id
'
:
str
(
self
.
local_
worker_run
.
id
),
'
classifications
'
:
[
{
'
ml_class
'
:
"
aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa
"
,
...
...
@@ -252,8 +144,8 @@ class TestBulkClassification(FixtureAPITestCase):
"""
Test the bulk classification API deletes previous classifications with the same worker run
"""
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
9
):
self
.
client
.
force_login
(
self
.
super
user
)
with
self
.
assertNumQueries
(
7
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
...
...
@@ -270,7 +162,7 @@ class TestBulkClassification(FixtureAPITestCase):
"
confidence
"
:
0.42
,
}
],
"
worker_run_id
"
:
str
(
self
.
worker_run
.
id
),
'
worker_run_id
'
:
str
(
self
.
local_
worker_run
.
id
),
}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_201_CREATED
)
...
...
@@ -295,7 +187,7 @@ class TestBulkClassification(FixtureAPITestCase):
"
high_confidence
"
:
True
,
},
],
'
worker_run_id
'
:
str
(
self
.
worker_run
.
id
),
'
worker_run_id
'
:
str
(
self
.
local_
worker_run
.
id
),
}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_201_CREATED
)
...
...
@@ -312,14 +204,14 @@ class TestBulkClassification(FixtureAPITestCase):
"""
Test the bulk classification API prevents creating classifications with duplicate ML classes
"""
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
7
):
self
.
client
.
force_login
(
self
.
super
user
)
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
'
parent
'
:
str
(
self
.
page
.
id
),
'
worker_run_id
'
:
str
(
self
.
worker_run
.
id
),
'
worker_run_id
'
:
str
(
self
.
local_
worker_run
.
id
),
'
classifications
'
:
[
{
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
...
...
@@ -336,3 +228,255 @@ class TestBulkClassification(FixtureAPITestCase):
self
.
assertDictEqual
(
response
.
json
(),
{
'
classifications
'
:
[
'
Duplicated ML classes are not allowed from the same worker run.
'
]
})
def
test_worker_run_non_local
(
self
):
"""
A regular user cannot create classifications with a WorkerRun of a non-local process
"""
self
.
client
.
force_login
(
self
.
superuser
)
with
self
.
assertNumQueries
(
4
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
'
parent
'
:
str
(
self
.
page
.
id
),
'
worker_run_id
'
:
str
(
self
.
worker_run
.
id
),
'
classifications
'
:
[
{
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
'
confidence
'
:
0.99
,
},
]
}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_400_BAD_REQUEST
)
self
.
assertDictEqual
(
response
.
json
(),
{
'
worker_run_id
'
:
[
"
Ponos task authentication is required to use a WorkerRun
"
"
of a process other than the user
'
s local process.
"
]
})
def
test_worker_run_other_user
(
self
):
"""
A regular user cannot create classifications with a WorkerRun of someone else
'
s local process
"""
worker_run
=
self
.
user
.
processes
.
get
(
mode
=
ProcessMode
.
Local
).
worker_runs
.
first
()
self
.
client
.
force_login
(
self
.
superuser
)
with
self
.
assertNumQueries
(
4
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
'
parent
'
:
str
(
self
.
page
.
id
),
'
worker_run_id
'
:
str
(
worker_run
.
id
),
'
classifications
'
:
[
{
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
'
confidence
'
:
0.99
,
},
]
}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_400_BAD_REQUEST
)
self
.
assertDictEqual
(
response
.
json
(),
{
'
worker_run_id
'
:
[
"
Ponos task authentication is required to use a WorkerRun
"
"
of a process other than the user
'
s local process.
"
]
})
def
test_worker_run_other_process
(
self
):
"""
A Ponos task cannot create classifications with a WorkerRun of another process
"""
process2
=
self
.
worker_run
.
process
.
creator
.
processes
.
create
(
mode
=
ProcessMode
.
Workers
,
corpus
=
self
.
corpus
,
)
other_worker_run
=
process2
.
worker_runs
.
create
(
version
=
self
.
worker_run
.
version
,
parents
=
[])
self
.
worker_run
.
process
.
start
()
task
=
self
.
worker_run
.
process
.
workflow
.
tasks
.
first
()
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
'
parent
'
:
str
(
self
.
page
.
id
),
'
worker_run_id
'
:
str
(
other_worker_run
.
id
),
'
classifications
'
:
[
{
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
'
confidence
'
:
0.99
,
},
]
},
HTTP_AUTHORIZATION
=
f
'
Ponos
{
task
.
token
}
'
,
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_400_BAD_REQUEST
)
self
.
assertDictEqual
(
response
.
json
(),
{
'
worker_run_id
'
:
[
"
Only the WorkerRuns of the authenticated task
'
s process may be used.
"
]
})
def
test_create_local
(
self
):
"""
A regular user can create classifications with a WorkerRun of their own local process
"""
self
.
client
.
force_login
(
self
.
superuser
)
with
self
.
assertNumQueries
(
7
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
"
parent
"
:
str
(
self
.
page
.
id
),
"
classifications
"
:
[
{
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
"
confidence
"
:
0.99
,
"
high_confidence
"
:
True
},
{
'
ml_class
'
:
str
(
self
.
cat_class
.
id
),
"
confidence
"
:
0.42
,
},
],
"
worker_run_id
"
:
str
(
self
.
local_worker_run
.
id
),
}
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_201_CREATED
)
first_cl
,
second_cl
=
self
.
page
.
classifications
.
order_by
(
'
-confidence
'
).
all
()
self
.
assertEqual
(
response
.
json
(),
{
'
parent
'
:
str
(
self
.
page
.
id
),
'
worker_run_id
'
:
str
(
self
.
local_worker_run
.
id
),
'
classifications
'
:
[
{
'
id
'
:
str
(
first_cl
.
id
),
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
'
confidence
'
:
0.99
,
'
high_confidence
'
:
True
,
'
state
'
:
'
pending
'
,
},
{
'
id
'
:
str
(
second_cl
.
id
),
'
ml_class
'
:
str
(
self
.
cat_class
.
id
),
'
confidence
'
:
0.42
,
'
high_confidence
'
:
False
,
'
state
'
:
'
pending
'
,
},
]
})
self
.
assertCountEqual
(
list
(
self
.
page
.
classifications
.
values_list
(
'
ml_class__name
'
,
'
confidence
'
,
'
high_confidence
'
,
'
worker_version_id
'
,
'
worker_run_id
'
,
)),
[
(
'
dog
'
,
0.99
,
True
,
self
.
worker_version
.
id
,
self
.
local_worker_run
.
id
),
(
'
cat
'
,
0.42
,
False
,
self
.
worker_version
.
id
,
self
.
local_worker_run
.
id
),
],
)
# Worker run is set, and worker version is deduced from it
self
.
assertEqual
(
first_cl
.
worker_version
,
self
.
worker_version
)
self
.
assertEqual
(
second_cl
.
worker_version
,
self
.
worker_version
)
self
.
assertEqual
(
first_cl
.
worker_run
,
self
.
local_worker_run
)
self
.
assertEqual
(
second_cl
.
worker_run
,
self
.
local_worker_run
)
def
test_create_task_auth
(
self
):
"""
Classifications can be created with a WorkerRun of a non-local process
when authenticated as a Ponos task of this process
"""
self
.
worker_run
.
process
.
start
()
task
=
self
.
worker_run
.
process
.
workflow
.
tasks
.
first
()
with
self
.
assertNumQueries
(
8
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
"
parent
"
:
str
(
self
.
page
.
id
),
"
classifications
"
:
[
{
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
"
confidence
"
:
0.99
,
"
high_confidence
"
:
True
},
],
"
worker_run_id
"
:
str
(
self
.
worker_run
.
id
),
},
HTTP_AUTHORIZATION
=
f
'
Ponos
{
task
.
token
}
'
,
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_201_CREATED
)
ml_class
=
self
.
page
.
classifications
.
get
()
self
.
assertEqual
(
response
.
json
(),
{
'
parent
'
:
str
(
self
.
page
.
id
),
'
worker_run_id
'
:
str
(
self
.
worker_run
.
id
),
'
classifications
'
:
[
{
'
id
'
:
str
(
ml_class
.
id
),
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
'
confidence
'
:
0.99
,
'
high_confidence
'
:
True
,
'
state
'
:
'
pending
'
,
},
]
})
self
.
assertEqual
(
ml_class
.
worker_version
,
self
.
worker_version
)
self
.
assertEqual
(
ml_class
.
worker_run
,
self
.
worker_run
)
def
test_worker_run_local_task_auth
(
self
):
"""
Classifications can be created with a WorkerRun of a Local process
even when authenticated as a Ponos task from a different process
"""
local_process
=
self
.
user
.
processes
.
get
(
mode
=
ProcessMode
.
Local
)
local_worker_run
=
local_process
.
worker_runs
.
get
()
self
.
worker_run
.
process
.
start
()
task
=
self
.
worker_run
.
process
.
workflow
.
tasks
.
first
()
self
.
assertNotEqual
(
self
.
worker_run
.
process_id
,
local_worker_run
.
process_id
)
with
self
.
assertNumQueries
(
8
):
response
=
self
.
client
.
post
(
reverse
(
'
api:classification-bulk
'
),
format
=
'
json
'
,
data
=
{
"
parent
"
:
str
(
self
.
page
.
id
),
"
classifications
"
:
[
{
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
"
confidence
"
:
0.99
,
"
high_confidence
"
:
True
},
],
"
worker_run_id
"
:
str
(
local_worker_run
.
id
),
},
HTTP_AUTHORIZATION
=
f
'
Ponos
{
task
.
token
}
'
,
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_201_CREATED
)
ml_class
=
self
.
page
.
classifications
.
get
()
self
.
assertEqual
(
response
.
json
(),
{
'
parent
'
:
str
(
self
.
page
.
id
),
'
worker_run_id
'
:
str
(
local_worker_run
.
id
),
'
classifications
'
:
[
{
'
id
'
:
str
(
ml_class
.
id
),
'
ml_class
'
:
str
(
self
.
dog_class
.
id
),
'
confidence
'
:
0.99
,
'
high_confidence
'
:
True
,
'
state
'
:
'
pending
'
,
},
]
})
self
.
assertEqual
(
ml_class
.
worker_version
,
local_worker_run
.
version
)
self
.
assertEqual
(
ml_class
.
worker_run
,
local_worker_run
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment