From 0116798ca9df11c0a2f6f39d5a001984751db837 Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Fri, 22 Apr 2022 13:41:29 +0000
Subject: [PATCH] add endpoint to list worker types

---
 arkindex/dataimport/api.py                    | 13 ++++++++
 arkindex/dataimport/serializers/workers.py    | 10 ++++++
 .../dataimport/tests/test_worker_types.py     | 31 +++++++++++++++++++
 arkindex/project/api_v1.py                    |  2 ++
 4 files changed, 56 insertions(+)
 create mode 100644 arkindex/dataimport/tests/test_worker_types.py

diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py
index c3d82a3acd..136c1b79ee 100644
--- a/arkindex/dataimport/api.py
+++ b/arkindex/dataimport/api.py
@@ -75,6 +75,7 @@ from arkindex.dataimport.serializers.workers import (
     WorkerConfigurationSerializer,
     WorkerSerializer,
     WorkerStatisticsSerializer,
+    WorkerTypeSerializer,
     WorkerVersionEditSerializer,
     WorkerVersionSerializer,
 )
@@ -851,6 +852,18 @@ class WorkerList(WorkerACLMixin, ListCreateAPIView):
         return Response(WorkerSerializer(worker).data, status=reponse_status)
 
 
+@extend_schema(
+    tags=['repos'],
+    description='List available worker types on an Arkindex instance.',
+)
+class WorkerTypesList(ListAPIView):
+    """
+    List available worker types on instance
+    """
+    serializer_class = WorkerTypeSerializer
+    queryset = WorkerType.objects.all().order_by('display_name')
+
+
 @extend_schema(tags=['repos'])
 @extend_schema_view(
     get=extend_schema(
diff --git a/arkindex/dataimport/serializers/workers.py b/arkindex/dataimport/serializers/workers.py
index c1583d2b6c..df93b565e8 100644
--- a/arkindex/dataimport/serializers/workers.py
+++ b/arkindex/dataimport/serializers/workers.py
@@ -47,6 +47,16 @@ class WorkerSerializer(WorkerLightSerializer):
         fields = WorkerLightSerializer.Meta.fields + ('repository_id', )
 
 
+class WorkerTypeSerializer(serializers.ModelSerializer):
+    class Meta:
+        model = WorkerType
+        fields = (
+            'id',
+            'slug',
+            'display_name'
+        )
+
+
 class UserConfigurationFieldType(Enum):
     Int = 'int'
     Float = 'float'
diff --git a/arkindex/dataimport/tests/test_worker_types.py b/arkindex/dataimport/tests/test_worker_types.py
new file mode 100644
index 0000000000..011df45879
--- /dev/null
+++ b/arkindex/dataimport/tests/test_worker_types.py
@@ -0,0 +1,31 @@
+from django.urls import reverse
+from rest_framework import status
+
+from arkindex.dataimport.models import WorkerType
+from arkindex.project.tests import FixtureAPITestCase
+
+
+class TestWorkersWorkerTypes(FixtureAPITestCase):
+    """
+    Test workers and worker versions endpoints
+    """
+
+    @classmethod
+    def setUpTestData(cls):
+        super().setUpTestData()
+        cls.worker_type_classifier = WorkerType.objects.get(slug="classifier")
+        cls.worker_type_dla = WorkerType.objects.get(slug="dla")
+        cls.worker_type_recognizer = WorkerType.objects.get(slug="recognizer")
+        cls.worker_type_worker = WorkerType.objects.get(slug="worker")
+
+    def test_list_worker_types(self):
+        # An un-authenticated user has access to worker types
+        with self.assertNumQueries(2):
+            response = self.client.get(reverse('api:worker-type-list'))
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(response.json()['results'], [
+            {'id': str(self.worker_type_classifier.id), 'slug': 'classifier', 'display_name': 'Classifier'},
+            {'id': str(self.worker_type_dla.id), 'slug': 'dla', 'display_name': 'Document layout analyser'},
+            {'id': str(self.worker_type_recognizer.id), 'slug': 'recognizer', 'display_name': 'Recognizer'},
+            {'id': str(self.worker_type_worker.id), 'slug': 'worker', 'display_name': 'Worker requiring a GPU'}
+        ])
diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py
index 6f70edfc86..cd994c47bb 100644
--- a/arkindex/project/api_v1.py
+++ b/arkindex/project/api_v1.py
@@ -31,6 +31,7 @@ from arkindex.dataimport.api import (
     WorkerRetrieve,
     WorkerRunDetails,
     WorkerRunList,
+    WorkerTypesList,
     WorkerVersionList,
     WorkerVersionRetrieve,
 )
@@ -213,6 +214,7 @@ api = [
 
     # Workers
     path('workers/', WorkerList.as_view(), name='workers-list'),
+    path('workers/types/', WorkerTypesList.as_view(), name='worker-type-list'),
     path('workers/<uuid:pk>/', WorkerRetrieve.as_view(), name='worker-retrieve'),
     path('workers/<uuid:pk>/configurations/', WorkerConfigurationList.as_view(), name='worker-configurations'),
     path('workers/<uuid:pk>/versions/', WorkerVersionList.as_view(), name='worker-versions'),
-- 
GitLab