From 426ca93161c897230746182ddb7ccfa3971ba235 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Thu, 23 May 2019 15:34:23 +0000
Subject: [PATCH] Multiple ML tasks per import

---
 arkindex/dataimport/models.py                    | 16 +++++++---------
 .../tests/ml_tools/classifier/config.yml         |  2 ++
 .../tests/ml_tools/recognizer/config.yml         |  2 ++
 arkindex/dataimport/tests/test_imports.py        |  4 ++++
 4 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/arkindex/dataimport/models.py b/arkindex/dataimport/models.py
index 1e7407bdab..b2321c698a 100644
--- a/arkindex/dataimport/models.py
+++ b/arkindex/dataimport/models.py
@@ -75,13 +75,11 @@ class DataImport(IndexableModel):
                 },
             }
 
-        if self.mode == DataImportMode.Repository or self.ml_tools:
-            # Add the ML task if ML is required - for Git, let the Git import decide
-            tasks['ml'] = {
-                'image': settings.ARKINDEX_ML_IMAGE,
-                'command': 'arkindex-ml-analysis /data/ml_analysis.json',
-                'parent': 'import',
-            }
+        for tool in self.ml_tools:
+            # Add the import task as parent of all ML tasks
+            task = tool.task_recipe
+            task['parent'] = 'import'
+            tasks['ml_{}'.format(tool.slug)] = task
 
         # Build the workflow in db
         recipe = settings.PONOS_RECIPE.copy()
@@ -152,10 +150,10 @@ class DataImport(IndexableModel):
         tools = []
         if 'classifier' in self.payload:
             tool = MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Classifier, self.payload['classifier'])
-            tools.append((MLToolType.Classifier, tool.slug))
+            tools.append(tool)
         if 'recognizer' in self.payload:
             tool = MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Recognizer, self.payload['recognizer'])
-            tools.append((MLToolType.Recognizer, tool.slug))
+            tools.append(tool)
         return tools
 
 
diff --git a/arkindex/dataimport/tests/ml_tools/classifier/config.yml b/arkindex/dataimport/tests/ml_tools/classifier/config.yml
index e1e9ddc0f3..b58ed1f7bd 100644
--- a/arkindex/dataimport/tests/ml_tools/classifier/config.yml
+++ b/arkindex/dataimport/tests/ml_tools/classifier/config.yml
@@ -3,6 +3,8 @@ name: Unit test classifier
 type: classifier
 slug: dummy_classifier
 version: 1.0.0
+docker:
+  image: some_image
 classes:
   - class_1
   - class_2
diff --git a/arkindex/dataimport/tests/ml_tools/recognizer/config.yml b/arkindex/dataimport/tests/ml_tools/recognizer/config.yml
index 83cb84d587..5ec28028d0 100644
--- a/arkindex/dataimport/tests/ml_tools/recognizer/config.yml
+++ b/arkindex/dataimport/tests/ml_tools/recognizer/config.yml
@@ -3,4 +3,6 @@ name: Unit test recognizer
 type: recognizer
 slug: dummy_recognizer
 version: 0.4.2
+docker:
+  image: some_image
 tesseract: []
diff --git a/arkindex/dataimport/tests/test_imports.py b/arkindex/dataimport/tests/test_imports.py
index 5d42ea7c62..6092cf4324 100644
--- a/arkindex/dataimport/tests/test_imports.py
+++ b/arkindex/dataimport/tests/test_imports.py
@@ -137,6 +137,8 @@ class TestImports(FixtureAPITestCase):
         classifier_mock, recognizer_mock = MagicMock(), MagicMock()
         classifier_mock.slug = 'someclassifier'
         recognizer_mock.slug = 'somerecognizer'
+        classifier_mock.task_recipe = {'image': 'some_image1', 'command': 'some_command1'}
+        recognizer_mock.task_recipe = {'image': 'some_image2', 'command': 'some_command2'}
         ml_get_mock.side_effect = [classifier_mock, recognizer_mock] * 2
         self.client.force_login(self.user)
 
@@ -382,6 +384,8 @@ class TestImports(FixtureAPITestCase):
         classifier_mock, recognizer_mock = MagicMock(), MagicMock()
         classifier_mock.slug = 'someclassifier'
         recognizer_mock.slug = 'somerecognizer'
+        classifier_mock.task_recipe = {'image': 'some_image1', 'command': 'some_command1'}
+        recognizer_mock.task_recipe = {'image': 'some_image2', 'command': 'some_command2'}
         ml_get_mock.side_effect = [classifier_mock, recognizer_mock] * 3
         self.client.force_login(self.user)
 
-- 
GitLab