From 3b97300590e435534502ed4c0155f53d4bbadf9c Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 25 Apr 2023 15:54:06 +0200
Subject: [PATCH] use configuration to get folder IDs

---
 .arkindex.yml                             | 13 +++++++++++++
 worker_generic_training_dataset/worker.py | 15 +++++++--------
 2 files changed, 20 insertions(+), 8 deletions(-)

diff --git a/.arkindex.yml b/.arkindex.yml
index 1f435a4..276ff64 100644
--- a/.arkindex.yml
+++ b/.arkindex.yml
@@ -9,3 +9,16 @@ workers:
     type: data-extract
     docker:
       build: Dockerfile
+    user_configuration:
+      train_folder_id:
+        type: str
+        title: ID of the training folder on Arkindex
+        required: true
+      validation_folder_id:
+        type: str
+        title: ID of the validation folder on Arkindex
+        required: true
+      test_folder_id:
+        type: str
+        title: ID of the testing folder on Arkindex
+        required: true
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index 4c475a4..78ba703 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -44,14 +44,13 @@ class DatasetExtractor(BaseWorker):
         self.args = self.parser.parse_args()
         if self.is_read_only:
             super().configure_for_developers()
-            self.process_information = {
-                "train_folder_id": "47a0e07b-d07a-4969-aced-44450d132f0d",
-                "validation_folder_id": "8cbc4b53-9e07-4a72-b4e6-93f7f5b0cbed",
-                "test_folder_id": "659a37ea-3b26-42f0-8b65-78964f9e433e",
-            }
         else:
             super().configure()
 
+        if self.user_configuration:
+            logger.info("Overriding with user_configuration")
+            self.config.update(self.user_configuration)
+
         # database arg is mandatory in dev mode
         assert (
             not self.is_read_only or self.args.database is not None
@@ -79,15 +78,15 @@ class DatasetExtractor(BaseWorker):
         """
         logger.info("Retrieving information from process_information")
 
-        train_folder_id = self.process_information.get("train_folder_id")
+        train_folder_id = self.config.get("train_folder_id")
         assert train_folder_id, "A training folder id is necessary to use this worker"
         self.training_folder_id = UUID(train_folder_id)
 
-        val_folder_id = self.process_information.get("validation_folder_id")
+        val_folder_id = self.config.get("validation_folder_id")
         assert val_folder_id, "A validation folder id is necessary to use this worker"
         self.validation_folder_id = UUID(val_folder_id)
 
-        test_folder_id = self.process_information.get("test_folder_id")
+        test_folder_id = self.config.get("test_folder_id")
         self.testing_folder_id = UUID(test_folder_id) if test_folder_id else None
 
     def initialize_database(self):
-- 
GitLab