Skip to content
Snippets Groups Projects
Commit b27ff5b4 authored by Valentin Rigal's avatar Valentin Rigal Committed by Erwan Rouchet
Browse files

Handle worker version in ListTranscriptions

parent c4ba6842
No related branches found
No related tags found
No related merge requests found
......@@ -659,7 +659,7 @@ class TranscriptionsPagination(PageNumberPagination):
class ElementTranscriptions(ListAPIView):
"""
List all transcriptions for an element, optionally filtered by type.
List all transcriptions for an element, optionally filtered by type or worker version id.
Recursive parameter allow listing transcriptions on sub-elements,
otherwise element fields in the response will be set to null.
"""
......@@ -684,9 +684,13 @@ class ElementTranscriptions(ListAPIView):
'in': 'query',
'required': False,
'description': 'Recursively list transcriptions on sub-elements',
'schema': {
'type': 'boolean',
}
'schema': {'type': 'boolean'}
}, {
'name': 'worker_version',
'in': 'query',
'required': False,
'description': 'Filter transcriptions by worker version',
'schema': {'type': 'string', 'format': 'uuid'}
},
]
}
......@@ -732,6 +736,10 @@ class ElementTranscriptions(ListAPIView):
else:
queryset = queryset.filter(element_id=element.id)
return queryset
def filter_queryset(self, queryset):
# Filter by transcription type
req_type = self.request.query_params.get('type')
if req_type:
try:
......@@ -739,6 +747,12 @@ class ElementTranscriptions(ListAPIView):
queryset = queryset.filter(type=req_type)
except ValueError:
raise ValidationError({'type': 'Not a valid transcription type'})
# Filter by worker version
worker_version = self.request.query_params.get('worker_version')
if worker_version:
queryset = queryset.filter(worker_version_id=worker_version)
return queryset
......
......@@ -246,6 +246,7 @@ class TranscriptionSerializer(serializers.ModelSerializer):
'score',
'zone',
'source',
'worker_version_id',
)
......
......@@ -127,6 +127,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
},
'text': 'A perfect day in a perfect place',
'type': 'line',
'worker_version_id': None,
'zone': None
})
......@@ -340,6 +341,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
'source': None,
'text': 'NEKUDOTAYIM',
'type': 'word',
'worker_version_id': str(version.id),
'zone': None
})
......
......@@ -62,6 +62,7 @@ class TestEditTranscription(FixtureAPITestCase):
},
'text': 'A manual transcription',
'type': 'line',
'worker_version_id': None,
'zone': None
})
......@@ -133,6 +134,7 @@ class TestEditTranscription(FixtureAPITestCase):
},
'text': 'a knight was living lonely',
'type': 'line',
'worker_version_id': None,
'zone': None
})
......
......@@ -3,7 +3,9 @@ from rest_framework import status
from arkindex.project.tests import FixtureAPITestCase
from arkindex.project.polygon import Polygon
from arkindex_common.enums import TranscriptionType
from arkindex_common.ml_tool import MLToolType
from arkindex.documents.models import Corpus, DataSource
from arkindex.dataimport.models import Worker, WorkerVersion
from arkindex.users.models import User
......@@ -25,6 +27,17 @@ class TestTranscriptions(FixtureAPITestCase):
cls.private_read_user = User.objects.create_user('a@bc.de', 'a')
cls.private_read_user.verified_email = True
cls.private_read_user.save()
cls.repo = cls.user.credentials.get().repos.get()
cls.worker_version = WorkerVersion.objects.create(
worker=Worker.objects.create(
repository=cls.repo,
name='Test Worker',
slug='test_worker',
type=MLToolType.Classifier
),
revision=cls.repo.revisions.get(),
configuration={"test": "test1"}
)
def test_list_transcriptions_read_right(self):
# A read right on the element corpus is required to access transcriptions
......@@ -74,7 +87,7 @@ class TestTranscriptions(FixtureAPITestCase):
def test_list_transcriptions_recursive(self):
for i in range(1, 5):
# Add 4 transcriptions on the page line
self.line.transcriptions.create(source_id=self.src.id, type=TranscriptionType.Line, text=f'Text {i}')
self.line.transcriptions.create(worker_version=self.worker_version, type=TranscriptionType.Line, text=f'Text {i}')
for i in range(1, 5):
# Add 4 transcribed line children
zone, _ = self.page.zone.image.zones.get_or_create(polygon=Polygon.from_coords(0, 0, i + 1, i + 1))
......@@ -97,20 +110,22 @@ class TestTranscriptions(FixtureAPITestCase):
page_polygon = [[0, 0], [0, 1000], [1000, 1000], [1000, 0], [0, 0]]
line_polygon = [[400, 400], [400, 500], [500, 500], [500, 400], [400, 400]]
self.assertCountEqual(
[(data['element']['type'], data['element']['zone']['polygon'], data['text']) for data in results],
[
('page', page_polygon, 'PARIS'),
('page', page_polygon, 'ROY'),
('page', page_polygon, 'Lorem ipsum dolor sit amet'),
('page', page_polygon, 'DATUM'),
('text_line', line_polygon, 'Text 1'),
('text_line', line_polygon, 'Text 2'),
('text_line', line_polygon, 'Text 3'),
('text_line', line_polygon, 'Text 4'),
('text_line', [[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]], 'Added text 1'),
('text_line', [[0, 0], [0, 3], [3, 3], [3, 0], [0, 0]], 'Added text 2'),
('text_line', [[0, 0], [0, 4], [4, 4], [4, 0], [0, 0]], 'Added text 3'),
('text_line', [[0, 0], [0, 5], [5, 5], [5, 0], [0, 0]], 'Added text 4')
(data['element']['type'], data['worker_version_id'], data['element']['zone']['polygon'], data['text'])
for data in results
], [
('page', None, page_polygon, 'PARIS'),
('page', None, page_polygon, 'ROY'),
('page', None, page_polygon, 'Lorem ipsum dolor sit amet'),
('page', None, page_polygon, 'DATUM'),
('text_line', str(self.worker_version.id), line_polygon, 'Text 1'),
('text_line', str(self.worker_version.id), line_polygon, 'Text 2'),
('text_line', str(self.worker_version.id), line_polygon, 'Text 3'),
('text_line', str(self.worker_version.id), line_polygon, 'Text 4'),
('text_line', None, [[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]], 'Added text 1'),
('text_line', None, [[0, 0], [0, 3], [3, 3], [3, 0], [0, 0]], 'Added text 2'),
('text_line', None, [[0, 0], [0, 4], [4, 4], [4, 0], [0, 0]], 'Added text 3'),
('text_line', None, [[0, 0], [0, 5], [5, 5], [5, 0], [0, 0]], 'Added text 4')
]
)
......@@ -132,3 +147,27 @@ class TestTranscriptions(FixtureAPITestCase):
self.assertEqual(len(results), 4)
for tr in results:
self.assertEqual(tr.get('type'), 'line')
def test_list_worker_version_transcriptions(self):
for i in range(1, 5):
# Add 4 transcriptions on the page line with a specific worker_version
self.line.transcriptions.create(
type=TranscriptionType.Line,
text=f'Text {i}',
worker_version=self.worker_version
)
self.client.force_login(self.user)
with self.assertNumQueries(12):
response = self.client.get(
reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}),
data={'recursive': 'true', 'worker_version': str(self.worker_version.id)}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.json()['results']
self.assertEqual(len(results), 4)
for tr in results:
self.assertEqual(tr.get('worker_version_id'), str(self.worker_version.id))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment