Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
Backend
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Container Registry
Analyze
Contributor analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Arkindex
Backend
Commits
7e4fc11d
Commit
7e4fc11d
authored
1 year ago
by
ml bonhomme
Browse files
Options
Downloads
Patches
Plain Diff
fixed datasets API
parent
3b4b7072
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
arkindex/training/api.py
+10
-7
10 additions, 7 deletions
arkindex/training/api.py
arkindex/training/serializers.py
+6
-5
6 additions, 5 deletions
arkindex/training/serializers.py
arkindex/training/tests/test_datasets_api.py
+19
-16
19 additions, 16 deletions
arkindex/training/tests/test_datasets_api.py
with
35 additions
and
28 deletions
arkindex/training/api.py
+
10
−
7
View file @
7e4fc11d
...
@@ -72,18 +72,18 @@ def _fetch_datasetelement_neighbors(datasetelements):
...
@@ -72,18 +72,18 @@ def _fetch_datasetelement_neighbors(datasetelements):
SELECT
SELECT
n.id,
n.id,
lag(element_id) OVER (
lag(element_id) OVER (
partition BY (n.
data
set_id
, n.set
)
partition BY (n.set_id)
order by
order by
n.element_id
n.element_id
) as previous,
) as previous,
lead(element_id) OVER (
lead(element_id) OVER (
partition BY (n.
data
set_id
, n.set
)
partition BY (n.set_id)
order by
order by
n.element_id
n.element_id
) as next
) as next
FROM training_datasetelement as n
FROM training_datasetelement as n
WHERE
(data
set_id
, set)
IN (
WHERE set_id IN (
SELECT
data
set_id
, set
SELECT set_id
FROM training_datasetelement
FROM training_datasetelement
WHERE id IN %(ids)s
WHERE id IN %(ids)s
)
)
...
@@ -688,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
...
@@ -688,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
serializer_class
=
DatasetSerializer
serializer_class
=
DatasetSerializer
def
get_queryset
(
self
):
def
get_queryset
(
self
):
queryset
=
Dataset
.
objects
.
filter
(
corpus__in
=
Corpus
.
objects
.
readable
(
self
.
request
.
user
))
.
prefetch_related
(
"
sets
"
)
queryset
=
Dataset
.
objects
.
filter
(
corpus__in
=
Corpus
.
objects
.
readable
(
self
.
request
.
user
))
return
queryset
.
select_related
(
"
corpus
"
,
"
creator
"
)
return
queryset
.
select_related
(
"
corpus
"
,
"
creator
"
)
def
check_object_permissions
(
self
,
request
,
obj
):
def
check_object_permissions
(
self
,
request
,
obj
):
...
@@ -910,7 +910,8 @@ class ElementDatasetSets(CorpusACLMixin, ListAPIView):
...
@@ -910,7 +910,8 @@ class ElementDatasetSets(CorpusACLMixin, ListAPIView):
qs
=
(
qs
=
(
self
.
element
.
dataset_elements
self
.
element
.
dataset_elements
.
select_related
(
"
set__dataset__creator
"
)
.
select_related
(
"
set__dataset__creator
"
)
.
order_by
(
"
set__name
"
,
"
id
"
)
.
prefetch_related
(
"
set__dataset__sets
"
)
.
order_by
(
"
set__dataset__name
"
,
"
set__name
"
)
)
)
with_neighbors
=
self
.
request
.
query_params
.
get
(
"
with_neighbors
"
,
"
false
"
)
with_neighbors
=
self
.
request
.
query_params
.
get
(
"
with_neighbors
"
,
"
false
"
)
...
@@ -953,7 +954,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
...
@@ -953,7 +954,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
serializer_class
=
DatasetSerializer
serializer_class
=
DatasetSerializer
def
get_queryset
(
self
):
def
get_queryset
(
self
):
return
Dataset
.
objects
.
filter
(
corpus__in
=
Corpus
.
objects
.
readable
(
self
.
request
.
user
))
return
(
Dataset
.
objects
.
filter
(
corpus__in
=
Corpus
.
objects
.
readable
(
self
.
request
.
user
))
)
def
check_object_permissions
(
self
,
request
,
dataset
):
def
check_object_permissions
(
self
,
request
,
dataset
):
if
not
self
.
has_write_access
(
dataset
.
corpus
):
if
not
self
.
has_write_access
(
dataset
.
corpus
):
...
...
This diff is collapsed.
Click to expand it.
arkindex/training/serializers.py
+
6
−
5
View file @
7e4fc11d
...
@@ -512,7 +512,11 @@ class DatasetSerializer(serializers.ModelSerializer):
...
@@ -512,7 +512,11 @@ class DatasetSerializer(serializers.ModelSerializer):
help_text
=
"
Display name of the user who created the dataset.
"
,
help_text
=
"
Display name of the user who created the dataset.
"
,
)
)
set_names
=
serializers
.
ListField
(
child
=
serializers
.
CharField
(
max_length
=
50
),
write_only
=
True
,
required
=
False
)
set_names
=
serializers
.
ListField
(
child
=
serializers
.
CharField
(
max_length
=
50
),
write_only
=
True
,
default
=
serializers
.
CreateOnlyDefault
([
"
training
"
,
"
validation
"
,
"
test
"
])
)
sets
=
DatasetSetSerializer
(
many
=
True
,
read_only
=
True
)
sets
=
DatasetSetSerializer
(
many
=
True
,
read_only
=
True
)
# When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through
# When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through
...
@@ -587,10 +591,7 @@ class DatasetSerializer(serializers.ModelSerializer):
...
@@ -587,10 +591,7 @@ class DatasetSerializer(serializers.ModelSerializer):
@transaction.atomic
@transaction.atomic
def
create
(
self
,
validated_data
):
def
create
(
self
,
validated_data
):
if
"
set_names
"
not
in
validated_data
:
sets
=
validated_data
.
pop
(
"
set_names
"
)
sets
=
[
"
training
"
,
"
validation
"
,
"
test
"
]
else
:
sets
=
validated_data
.
pop
(
"
set_names
"
)
dataset
=
Dataset
.
objects
.
create
(
**
validated_data
)
dataset
=
Dataset
.
objects
.
create
(
**
validated_data
)
DatasetSet
.
objects
.
bulk_create
(
DatasetSet
.
objects
.
bulk_create
(
DatasetSet
(
DatasetSet
(
...
...
This diff is collapsed.
Click to expand it.
arkindex/training/tests/test_datasets_api.py
+
19
−
16
View file @
7e4fc11d
...
@@ -981,7 +981,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -981,7 +981,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
def
test_retrieve
(
self
):
def
test_retrieve
(
self
):
self
.
client
.
force_login
(
self
.
user
)
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
5
):
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
get
(
response
=
self
.
client
.
get
(
reverse
(
"
api:dataset-update
"
,
kwargs
=
{
"
pk
"
:
self
.
dataset
.
pk
})
reverse
(
"
api:dataset-update
"
,
kwargs
=
{
"
pk
"
:
self
.
dataset
.
pk
})
)
)
...
@@ -1010,7 +1010,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1010,7 +1010,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
self
.
client
.
force_login
(
self
.
user
)
self
.
client
.
force_login
(
self
.
user
)
self
.
dataset
.
task
=
self
.
task
self
.
dataset
.
task
=
self
.
task
self
.
dataset
.
save
()
self
.
dataset
.
save
()
with
self
.
assertNumQueries
(
5
):
with
self
.
assertNumQueries
(
6
):
response
=
self
.
client
.
get
(
response
=
self
.
client
.
get
(
reverse
(
"
api:dataset-update
"
,
kwargs
=
{
"
pk
"
:
self
.
dataset
.
pk
})
reverse
(
"
api:dataset-update
"
,
kwargs
=
{
"
pk
"
:
self
.
dataset
.
pk
})
)
)
...
@@ -1431,7 +1431,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1431,7 +1431,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
def
test_add_element_wrong_element
(
self
):
def
test_add_element_wrong_element
(
self
):
element
=
self
.
private_corpus
.
elements
.
create
(
type
=
self
.
private_corpus
.
types
.
create
(
slug
=
"
folder
"
))
element
=
self
.
private_corpus
.
elements
.
create
(
type
=
self
.
private_corpus
.
types
.
create
(
slug
=
"
folder
"
))
self
.
client
.
force_login
(
self
.
user
)
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
4
):
with
self
.
assertNumQueries
(
5
):
response
=
self
.
client
.
post
(
response
=
self
.
client
.
post
(
reverse
(
"
api:dataset-elements
"
,
kwargs
=
{
"
pk
"
:
self
.
dataset
.
id
}),
reverse
(
"
api:dataset-elements
"
,
kwargs
=
{
"
pk
"
:
self
.
dataset
.
id
}),
data
=
{
"
set
"
:
"
test
"
,
"
element_id
"
:
str
(
element
.
id
)},
data
=
{
"
set
"
:
"
test
"
,
"
element_id
"
:
str
(
element
.
id
)},
...
@@ -1724,7 +1724,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1724,7 +1724,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"
created
"
:
self
.
dataset
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
created
"
:
self
.
dataset
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
},
},
"
set
"
:
"
train
"
,
"
set
"
:
"
train
ing
"
,
"
previous
"
:
None
,
"
previous
"
:
None
,
"
next
"
:
None
"
next
"
:
None
},
{
},
{
...
@@ -1770,7 +1770,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1770,7 +1770,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"
created
"
:
self
.
dataset2
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
created
"
:
self
.
dataset2
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset2
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset2
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
},
},
"
set
"
:
"
train
"
,
"
set
"
:
"
train
ing
"
,
"
previous
"
:
None
,
"
previous
"
:
None
,
"
next
"
:
None
"
next
"
:
None
}]
}]
...
@@ -1811,7 +1811,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1811,7 +1811,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"
created
"
:
self
.
dataset
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
created
"
:
self
.
dataset
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
},
},
"
set
"
:
"
train
"
,
"
set
"
:
"
train
ing
"
,
"
previous
"
:
None
,
"
previous
"
:
None
,
"
next
"
:
None
"
next
"
:
None
},
{
},
{
...
@@ -1857,7 +1857,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1857,7 +1857,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"
created
"
:
self
.
dataset2
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
created
"
:
self
.
dataset2
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset2
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset2
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
},
},
"
set
"
:
"
train
"
,
"
set
"
:
"
train
ing
"
,
"
previous
"
:
None
,
"
previous
"
:
None
,
"
next
"
:
None
"
next
"
:
None
}]
}]
...
@@ -1880,9 +1880,10 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1880,9 +1880,10 @@ class TestDatasetsAPI(FixtureAPITestCase):
sorted_dataset2_elements
=
sorted
([
str
(
self
.
page1
.
id
),
str
(
self
.
page3
.
id
)])
sorted_dataset2_elements
=
sorted
([
str
(
self
.
page1
.
id
),
str
(
self
.
page3
.
id
)])
page1_index_2
=
sorted_dataset2_elements
.
index
(
str
(
self
.
page1
.
id
))
page1_index_2
=
sorted_dataset2_elements
.
index
(
str
(
self
.
page1
.
id
))
with
self
.
assertNumQueries
(
7
):
with
self
.
assertNumQueries
(
8
):
response
=
self
.
client
.
get
(
reverse
(
"
api:element-datasets
"
,
kwargs
=
{
"
pk
"
:
str
(
self
.
page1
.
id
)}),
{
"
with_neighbors
"
:
True
})
response
=
self
.
client
.
get
(
reverse
(
"
api:element-datasets
"
,
kwargs
=
{
"
pk
"
:
str
(
self
.
page1
.
id
)}),
{
"
with_neighbors
"
:
True
})
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
self
.
maxDiff
=
None
self
.
assertDictEqual
(
response
.
json
(),
{
self
.
assertDictEqual
(
response
.
json
(),
{
"
count
"
:
3
,
"
count
"
:
3
,
"
next
"
:
None
,
"
next
"
:
None
,
...
@@ -1908,7 +1909,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1908,7 +1909,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"
created
"
:
self
.
dataset
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
created
"
:
self
.
dataset
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
},
},
"
set
"
:
"
train
"
,
"
set
"
:
"
train
ing
"
,
"
previous
"
:
(
"
previous
"
:
(
sorted_dataset_elements
[
page1_index_1
-
1
]
sorted_dataset_elements
[
page1_index_1
-
1
]
if
page1_index_1
-
1
>=
0
if
page1_index_1
-
1
>=
0
...
@@ -1952,7 +1953,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1952,7 +1953,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"
id
"
:
str
(
ds
.
id
),
"
id
"
:
str
(
ds
.
id
),
"
name
"
:
ds
.
name
"
name
"
:
ds
.
name
}
}
for
ds
in
self
.
dataset
.
sets
.
all
()
for
ds
in
self
.
dataset
2
.
sets
.
all
()
],
],
"
set_elements
"
:
None
,
"
set_elements
"
:
None
,
"
state
"
:
"
open
"
,
"
state
"
:
"
open
"
,
...
@@ -1962,7 +1963,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -1962,7 +1963,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"
created
"
:
self
.
dataset2
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
created
"
:
self
.
dataset2
.
created
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset2
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
"
updated
"
:
self
.
dataset2
.
updated
.
isoformat
().
replace
(
"
+00:00
"
,
"
Z
"
),
},
},
"
set
"
:
"
train
"
,
"
set
"
:
"
train
ing
"
,
"
previous
"
:
(
"
previous
"
:
(
sorted_dataset2_elements
[
page1_index_2
-
1
]
sorted_dataset2_elements
[
page1_index_2
-
1
]
if
page1_index_2
==
1
if
page1_index_2
==
1
...
@@ -2101,7 +2102,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -2101,7 +2102,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
def
test_clone_existing_name
(
self
):
def
test_clone_existing_name
(
self
):
self
.
corpus
.
datasets
.
create
(
name
=
"
Clone of First Dataset
"
,
creator
=
self
.
user
)
self
.
corpus
.
datasets
.
create
(
name
=
"
Clone of First Dataset
"
,
creator
=
self
.
user
)
self
.
client
.
force_login
(
self
.
user
)
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
1
1
):
with
self
.
assertNumQueries
(
1
5
):
response
=
self
.
client
.
post
(
response
=
self
.
client
.
post
(
reverse
(
"
api:dataset-clone
"
,
kwargs
=
{
"
pk
"
:
self
.
dataset
.
id
}),
reverse
(
"
api:dataset-clone
"
,
kwargs
=
{
"
pk
"
:
self
.
dataset
.
id
}),
format
=
"
json
"
,
format
=
"
json
"
,
...
@@ -2115,12 +2116,14 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -2115,12 +2116,14 @@ class TestDatasetsAPI(FixtureAPITestCase):
])
])
data
=
response
.
json
()
data
=
response
.
json
()
data
.
pop
(
"
id
"
)
data
.
pop
(
"
created
"
)
data
.
pop
(
"
created
"
)
data
.
pop
(
"
updated
"
)
data
.
pop
(
"
updated
"
)
cloned_dataset
=
Dataset
.
objects
.
get
(
id
=
data
[
"
id
"
])
self
.
maxDiff
=
None
self
.
assertDictEqual
(
self
.
assertDictEqual
(
response
.
json
(),
response
.
json
(),
{
{
"
id
"
:
str
(
cloned_dataset
.
id
),
"
name
"
:
"
Clone of First Dataset 1
"
,
"
name
"
:
"
Clone of First Dataset 1
"
,
"
description
"
:
self
.
dataset
.
description
,
"
description
"
:
self
.
dataset
.
description
,
"
creator
"
:
self
.
user
.
display_name
,
"
creator
"
:
self
.
user
.
display_name
,
...
@@ -2130,9 +2133,9 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -2130,9 +2133,9 @@ class TestDatasetsAPI(FixtureAPITestCase):
"
id
"
:
str
(
ds
.
id
),
"
id
"
:
str
(
ds
.
id
),
"
name
"
:
ds
.
name
"
name
"
:
ds
.
name
}
}
for
ds
in
self
.
dataset
.
sets
.
all
()
for
ds
in
cloned_
dataset
.
sets
.
all
()
],
],
"
set_elements
"
:
{
k
:
0
for
k
in
self
.
dataset
.
sets
.
all
()},
"
set_elements
"
:
{
str
(
k
.
name
)
:
0
for
k
in
self
.
dataset
.
sets
.
all
()},
"
state
"
:
DatasetState
.
Open
.
value
,
"
state
"
:
DatasetState
.
Open
.
value
,
"
task_id
"
:
None
,
"
task_id
"
:
None
,
},
},
...
@@ -2141,7 +2144,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
...
@@ -2141,7 +2144,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
def
test_clone_name_too_long
(
self
):
def
test_clone_name_too_long
(
self
):
dataset
=
self
.
corpus
.
datasets
.
create
(
name
=
"
A
"
*
99
,
creator
=
self
.
user
)
dataset
=
self
.
corpus
.
datasets
.
create
(
name
=
"
A
"
*
99
,
creator
=
self
.
user
)
self
.
client
.
force_login
(
self
.
user
)
self
.
client
.
force_login
(
self
.
user
)
with
self
.
assertNumQueries
(
1
1
):
with
self
.
assertNumQueries
(
1
4
):
response
=
self
.
client
.
post
(
response
=
self
.
client
.
post
(
reverse
(
"
api:dataset-clone
"
,
kwargs
=
{
"
pk
"
:
dataset
.
id
}),
reverse
(
"
api:dataset-clone
"
,
kwargs
=
{
"
pk
"
:
dataset
.
id
}),
format
=
"
json
"
,
format
=
"
json
"
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment