From b5d18e70a03ae7a63b7d4c2b8d60140ba79c2865 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Thu, 31 Aug 2023 12:49:47 +0200
Subject: [PATCH] Use WorkerRun ID in task slugs

---
 arkindex/process/models.py                    | 31 +++++----
 arkindex/process/tests/test_create_process.py | 12 ++--
 arkindex/process/tests/test_processes.py      | 67 ++++++++++++++-----
 arkindex/process/tests/test_workeractivity.py |  2 +-
 arkindex/process/tests/test_workerruns.py     | 28 ++++----
 5 files changed, 89 insertions(+), 51 deletions(-)

diff --git a/arkindex/process/models.py b/arkindex/process/models.py
index 78742b096b..53dc26c9ec 100644
--- a/arkindex/process/models.py
+++ b/arkindex/process/models.py
@@ -733,11 +733,8 @@ class Process(IndexableModel):
 
                 # Generate a task for each WorkerRun on the Process
                 for worker_run in worker_runs:
-                    task_name = worker_run.version.slug
-                    # The suffix is handled by WorkerRun.build_task
                     task, parent_slugs = worker_run.build_task(
                         self,
-                        task_name,
                         env,
                         import_task_name,
                         elements_path,
@@ -1192,7 +1189,18 @@ class WorkerRun(models.Model):
             ),
         ]
 
-    def build_task(self, process, task_name, env, import_task_name, elements_path, run=0, chunk=None, workflow_runs=None):
+    @property
+    def task_slug(self):
+        """
+        A slug that can be used to create a Ponos task from this WorkerRun.
+        This does not include the chunk suffix when chunks are enabled.
+        """
+        # Since WorkerVersion can be used multiple times in a process,
+        # we cannot depend only on the WorkerVersion to make a unique task slug:
+        # we add the WorkerRun ID at the end of the slug
+        return f'{self.version.worker.slug}_{str(self.id)[:6]}'
+
+    def build_task(self, process, env, import_task_name, elements_path, run=0, chunk=None, workflow_runs=None):
         '''
         Build the Task that will represent this WorkerRun in ponos using :
         - the docker image name given by the WorkerVersion
@@ -1200,12 +1208,7 @@ class WorkerRun(models.Model):
         - the artifact given by the WorkerVersion that will help us download the distant Docker image (S3)
         Return that Task and the list of its parent tasks slugs, to be added after they have all been created.
         '''
-        task_env = env.copy()
-        suffix = f'_{chunk}' if chunk else ''
-        task_name = f'{task_name}{suffix}'
-
-        token = task_token_default()
-        task_env['ARKINDEX_TASK_TOKEN'] = token
+        slug_suffix = f'_{chunk}' if chunk else ''
 
         parents = None
         if self.parents:
@@ -1224,7 +1227,7 @@ class WorkerRun(models.Model):
                     .only('id', 'version_id', 'version__worker_id', 'version__worker__slug')
                 )
             parents = [
-                f'{worker_run.version.slug}{suffix}'
+                worker_run.task_slug + slug_suffix
                 for worker_run in parent_runs
             ]
         elif import_task_name:
@@ -1239,6 +1242,9 @@ class WorkerRun(models.Model):
             f"Worker Version {self.version.id} is not available and cannot be used to build a task."
         )
 
+        task_env = env.copy()
+        token = task_token_default()
+        task_env["ARKINDEX_TASK_TOKEN"] = token
         task_env["TASK_ELEMENTS"] = elements_path
         task_env["ARKINDEX_WORKER_RUN_ID"] = str(self.id)
         if chunk:
@@ -1256,7 +1262,8 @@ class WorkerRun(models.Model):
             image_artifact_id=self.version.docker_image_id,
             env=task_env,
             shm_size=self.version.docker_shm_size,
-            slug=task_name,
+            slug=self.task_slug + slug_suffix,
+            # The depth will be recomputed before creating all tasks for the process
             depth=0,
             run=run,
             token=token,
diff --git a/arkindex/process/tests/test_create_process.py b/arkindex/process/tests/test_create_process.py
index 171cde6b53..1476ba6df9 100644
--- a/arkindex/process/tests/test_create_process.py
+++ b/arkindex/process/tests/test_create_process.py
@@ -601,7 +601,7 @@ class TestCreateProcess(FixtureAPITestCase):
         self.assertEqual(init_task.command, f'python -m arkindex_tasks.init_elements {process_2.id} --chunks-number 1')
         self.assertEqual(init_task.image, 'registry.teklia.com/tasks')
 
-        reco_task = process_2.tasks.get(slug=f'reco_{str(self.version_1.id)[0:6]}')
+        reco_task = process_2.tasks.get(slug=run_1.task_slug)
         self.assertEqual(reco_task.command, None)
         self.assertEqual(reco_task.image, f'my_repo.fake/workers/worker/reco:{self.version_1.id}')
         self.assertEqual(reco_task.shm_size, None)
@@ -614,7 +614,7 @@ class TestCreateProcess(FixtureAPITestCase):
             'ARKINDEX_TASK_TOKEN': reco_task.token,
         })
 
-        dla_task = process_2.tasks.get(slug=f'dla_{str(self.version_2.id)[0:6]}')
+        dla_task = process_2.tasks.get(slug=run_2.task_slug)
         self.assertEqual(dla_task.command, None)
         self.assertEqual(dla_task.image, f'my_repo.fake/workers/worker/dla:{self.version_2.id}')
         self.assertEqual(dla_task.shm_size, None)
@@ -669,7 +669,7 @@ class TestCreateProcess(FixtureAPITestCase):
         """
         token_mock.side_effect = ['12345', '67891']
         process_2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
-        process_2.worker_runs.create(
+        run = process_2.worker_runs.create(
             version=self.version_1,
             parents=[],
         )
@@ -691,7 +691,7 @@ class TestCreateProcess(FixtureAPITestCase):
         self.assertEqual(init_task.command, f'python -m arkindex_tasks.init_elements {process_2.id} --chunks-number 1 --use-cache')
         self.assertEqual(init_task.image, 'registry.teklia.com/tasks')
 
-        worker_task = process_2.tasks.get(slug=f'reco_{str(self.version_1.id)[0:6]}')
+        worker_task = process_2.tasks.get(slug=run.task_slug)
         self.assertEqual(worker_task.command, None)
         self.assertEqual(worker_task.image, f'my_repo.fake/workers/worker/reco:{self.version_1.id}')
         self.assertEqual(worker_task.image_artifact.id, self.version_1.docker_image.id)
@@ -717,7 +717,7 @@ class TestCreateProcess(FixtureAPITestCase):
         """
         token_mock.side_effect = ['12345', '67891']
         process_2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
-        process_2.worker_runs.create(
+        run = process_2.worker_runs.create(
             version=self.version_3,
             parents=[],
         )
@@ -738,7 +738,7 @@ class TestCreateProcess(FixtureAPITestCase):
         self.assertEqual(init_task.command, f'python -m arkindex_tasks.init_elements {process_2.id} --chunks-number 1')
         self.assertEqual(init_task.image, 'registry.teklia.com/tasks')
 
-        worker_task = process_2.tasks.get(slug=f'worker-gpu_{str(self.version_3.id)[0:6]}')
+        worker_task = process_2.tasks.get(slug=run.task_slug)
         self.assertEqual(worker_task.command, None)
         self.assertEqual(worker_task.image, f'my_repo.fake/workers/worker/worker-gpu:{self.version_3.id}')
         self.assertEqual(worker_task.image_artifact.id, self.version_3.docker_image.id)
diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py
index dba6c0dd11..f161f6be95 100644
--- a/arkindex/process/tests/test_processes.py
+++ b/arkindex/process/tests/test_processes.py
@@ -2169,7 +2169,7 @@ class TestProcesses(FixtureAPITestCase):
         Default chunks, thumbnails and farm are used. Nor cache or workers activity is set.
         """
         process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
-        process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
+        run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
         self.assertFalse(process2.tasks.exists())
 
         self.client.force_login(self.user)
@@ -2188,7 +2188,7 @@ class TestProcesses(FixtureAPITestCase):
         self.assertEqual(process2.tasks.count(), 2)
         task1, task2 = process2.tasks.order_by('slug')
         self.assertEqual(task1.slug, 'initialisation')
-        self.assertEqual(task2.slug, f'reco_{str(self.recognizer.id)[:6]}')
+        self.assertEqual(task2.slug, f'reco_{str(run.id)[:6]}')
         self.assertIn('--chunks-number 1', task1.command)
 
     def test_start_process_dataset_requires_datasets(self):
@@ -2251,7 +2251,7 @@ class TestProcesses(FixtureAPITestCase):
     def test_start_process_dataset(self):
         process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
         process2.datasets.set([self.dataset1, self.private_dataset])
-        process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
+        run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
         self.assertFalse(process2.tasks.exists())
 
         self.client.force_login(self.user)
@@ -2269,7 +2269,7 @@ class TestProcesses(FixtureAPITestCase):
         self.assertEqual(process2.farm_id, get_default_farm_id())
         self.assertEqual(process2.tasks.count(), 1)
         task = process2.tasks.get()
-        self.assertEqual(task.slug, f'reco_{str(self.recognizer.id)[:6]}')
+        self.assertEqual(task.slug, run.task_slug)
         self.assertQuerysetEqual(process2.datasets.order_by('name'), [
             self.private_dataset, self.dataset1
         ])
@@ -2284,7 +2284,7 @@ class TestProcesses(FixtureAPITestCase):
         self.assertEqual(self.recognizer.state, WorkerVersionState.Available)
 
         process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
-        process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
+        run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
         self.assertFalse(process.tasks.exists())
 
         self.client.force_login(self.user)
@@ -2299,7 +2299,7 @@ class TestProcesses(FixtureAPITestCase):
         self.assertEqual(process.state, State.Unscheduled)
         task1, task2 = process.tasks.order_by('slug')
         self.assertEqual(task1.slug, 'initialisation')
-        self.assertEqual(task2.slug, f'reco_{str(self.recognizer.id)[:6]}')
+        self.assertEqual(task2.slug, run.task_slug)
 
     def test_start_process_select_farm_id(self):
         """
@@ -2382,7 +2382,7 @@ class TestProcesses(FixtureAPITestCase):
         """
         process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
         # Add a worker run to this process
-        process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
+        run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
 
         self.client.force_login(self.user)
         response = self.client.post(
@@ -2394,9 +2394,9 @@ class TestProcesses(FixtureAPITestCase):
 
         self.assertEqual(list(process.tasks.order_by('slug').values_list('slug', flat=True)), [
             'initialisation',
-            f'reco_{str(self.recognizer.id)[:6]}_1',
-            f'reco_{str(self.recognizer.id)[:6]}_2',
-            f'reco_{str(self.recognizer.id)[:6]}_3',
+            f'{run.task_slug}_1',
+            f'{run.task_slug}_2',
+            f'{run.task_slug}_3',
             'thumbnails_1',
             'thumbnails_2',
             'thumbnails_3'
@@ -2409,7 +2409,7 @@ class TestProcesses(FixtureAPITestCase):
         process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
         process.datasets.set([self.dataset1, self.dataset2])
         # Add a worker run to this process
-        process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
+        run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
 
         self.client.force_login(self.user)
         response = self.client.post(
@@ -2420,9 +2420,9 @@ class TestProcesses(FixtureAPITestCase):
         process.refresh_from_db()
 
         self.assertEqual(list(process.tasks.order_by('slug').values_list('slug', flat=True)), [
-            f'reco_{str(self.recognizer.id)[:6]}_1',
-            f'reco_{str(self.recognizer.id)[:6]}_2',
-            f'reco_{str(self.recognizer.id)[:6]}_3'
+            f'{run.task_slug}_1',
+            f'{run.task_slug}_2',
+            f'{run.task_slug}_3',
         ])
 
     @patch('arkindex.process.models.Process.worker_runs')
@@ -2491,6 +2491,40 @@ class TestProcesses(FixtureAPITestCase):
         self.assertEqual(process.activity_state, ActivityState.Pending)
         self.assertEqual(activities_delay_mock.call_count, 1)
 
+    def test_start_duplicated_worker_version(self):
+        """
+        A WorkerRun depending on another WorkerRun of the same WorkerVersion should be supported
+        """
+        process = self.corpus.processes.create(
+            creator=self.user,
+            mode=ProcessMode.Workers,
+        )
+        run_1 = process.worker_runs.create(
+            version=self.recognizer,
+            configuration=self.recognizer.worker.configurations.create(
+                name='some_config',
+                configuration={"a": "b"},
+            ),
+        )
+        run_2 = process.worker_runs.create(
+            version=self.recognizer,
+            parents=[run_1.id],
+        )
+        self.assertNotEqual(run_1.task_slug, run_2.task_slug)
+
+        self.client.force_login(self.user)
+        with self.assertNumQueries(19):
+            response = self.client.post(reverse('api:process-start', kwargs={'pk': str(process.id)}))
+            self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+
+        self.assertEqual(process.tasks.count(), 3)
+        init_task = process.tasks.get(slug='initialisation')
+        task_1 = process.tasks.get(slug=run_1.task_slug)
+        task_2 = process.tasks.get(slug=run_2.task_slug)
+
+        self.assertQuerysetEqual(task_1.parents.all(), [init_task])
+        self.assertQuerysetEqual(task_2.parents.all(), [task_1])
+
     @override_settings(PONOS_DEFAULT_ENV={'ARKINDEX_API_TOKEN': 'testToken'})
     @override_settings(ARKINDEX_TASKS_IMAGE='registry.teklia.com/tasks')
     @patch('arkindex.process.models.task_token_default')
@@ -2530,9 +2564,6 @@ class TestProcesses(FixtureAPITestCase):
     @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost')
     @patch('arkindex.process.models.task_token_default')
     def test_worker_run_model_version_build_workflow(self, token_mock):
-        """
-        Build tasks for a PDF import with a worker version defined in settings
-        """
         process = self.corpus.processes.create(
             creator=self.user,
             mode=ProcessMode.Workers,
@@ -2553,7 +2584,7 @@ class TestProcesses(FixtureAPITestCase):
         self.assertEqual(initialization_task.command, f'python -m arkindex_tasks.init_elements {process.id} --chunks-number 1')
         self.assertEqual(initialization_task.image, 'registry.teklia.com/tasks')
 
-        worker_task = process.tasks.get(slug=f'generic_{str(self.version_with_model.id)[:6]}')
+        worker_task = process.tasks.get(slug=run.task_slug)
         self.assertEqual(worker_task.env, {
             'ARKINDEX_API_TOKEN': 'testToken',
             'ARKINDEX_PROCESS_ID': str(process.id),
diff --git a/arkindex/process/tests/test_workeractivity.py b/arkindex/process/tests/test_workeractivity.py
index c7ef9aae48..1254fbcd12 100644
--- a/arkindex/process/tests/test_workeractivity.py
+++ b/arkindex/process/tests/test_workeractivity.py
@@ -51,7 +51,7 @@ class TestWorkerActivity(FixtureTestCase):
         )
         cls.worker_type = WorkerType.objects.get(slug='recognizer')
         cls.process.run()
-        cls.task = cls.process.tasks.get(slug=cls.worker_version.slug)
+        cls.task = cls.process.tasks.get(slug=cls.worker_run.task_slug)
 
     def setUp(self):
         super().setUp()
diff --git a/arkindex/process/tests/test_workerruns.py b/arkindex/process/tests/test_workerruns.py
index 16a6a867f9..3047b2f752 100644
--- a/arkindex/process/tests/test_workerruns.py
+++ b/arkindex/process/tests/test_workerruns.py
@@ -2634,9 +2634,9 @@ class TestWorkerRuns(FixtureAPITestCase):
 
     def test_build_task_no_parent(self):
         self.version_1.docker_image_id = self.artifact.id
-        task, parent_slugs = self.run_1.build_task(self.process_1, 'test', ENV.copy(), 'import', '/data/import/elements.json')
+        task, parent_slugs = self.run_1.build_task(self.process_1, ENV.copy(), 'import', '/data/import/elements.json')
 
-        self.assertEqual(task.slug, 'test')
+        self.assertEqual(task.slug, f'reco_{str(self.run_1.id)[0:6]}')
         self.assertEqual(task.image, f'my_repo.fake/workers/worker/reco:{str(self.version_1.id)}')
         self.assertEqual(task.command, None)
         self.assertEqual(task.image_artifact, self.artifact)
@@ -2651,9 +2651,9 @@ class TestWorkerRuns(FixtureAPITestCase):
 
     def test_build_task_with_chunk(self):
         self.version_1.docker_image_id = self.artifact.id
-        task, parent_slugs = self.run_1.build_task(self.process_1, 'test', ENV.copy(), 'import', '/data/import/elements.json', chunk=4)
+        task, parent_slugs = self.run_1.build_task(self.process_1, ENV.copy(), 'import', '/data/import/elements.json', chunk=4)
 
-        self.assertEqual(task.slug, 'test_4')
+        self.assertEqual(task.slug, f'reco_{str(self.run_1.id)[0:6]}_4')
         self.assertEqual(task.image, f'my_repo.fake/workers/worker/reco:{str(self.version_1.id)}')
         self.assertEqual(task.command, None)
         self.assertEqual(task.image_artifact, self.artifact)
@@ -2685,14 +2685,14 @@ class TestWorkerRuns(FixtureAPITestCase):
             parents=[self.run_1.id],
         )
 
-        task, parent_slugs = run_2.build_task(self.process_1, f'reco_{str(version_2.id)[0:6]}', ENV.copy(), 'import', '/data/import/elements.json')
+        task, parent_slugs = run_2.build_task(self.process_1, ENV.copy(), 'import', '/data/import/elements.json')
 
-        self.assertEqual(task.slug, f'reco_{str(version_2.id)[0:6]}')
+        self.assertEqual(task.slug, f'reco_{str(run_2.id)[0:6]}')
         self.assertEqual(task.image, f'my_repo.fake/workers/worker/reco:{str(version_2.id)}')
         self.assertEqual(task.command, None)
         self.assertEqual(task.image_artifact, self.artifact)
         self.assertEqual(task.shm_size, None)
-        self.assertEqual(parent_slugs, [f'reco_{str(self.version_1.id)[0:6]}'])
+        self.assertEqual(parent_slugs, [f'reco_{str(self.run_1.id)[0:6]}'])
         self.assertEqual(task.env, {
             'ARKINDEX_PROCESS_ID': '12345',
             'ARKINDEX_TASK_TOKEN': str(task.token),
@@ -2718,14 +2718,14 @@ class TestWorkerRuns(FixtureAPITestCase):
             parents=[self.run_1.id],
         )
 
-        task, parent_slugs = run_2.build_task(self.process_1, f'reco_{str(version_2.id)[0:6]}', ENV.copy(), 'import', '/data/import/elements.json', chunk=4)
+        task, parent_slugs = run_2.build_task(self.process_1, ENV.copy(), 'import', '/data/import/elements.json', chunk=4)
 
-        self.assertEqual(task.slug, f'reco_{str(version_2.id)[0:6]}_4')
+        self.assertEqual(task.slug, f'reco_{str(run_2.id)[0:6]}_4')
         self.assertEqual(task.image, f'my_repo.fake/workers/worker/reco:{str(version_2.id)}')
         self.assertEqual(task.command, None)
         self.assertEqual(task.image_artifact, self.artifact)
         self.assertEqual(task.shm_size, None)
-        self.assertEqual(parent_slugs, [f'reco_{str(self.version_1.id)[0:6]}_4'])
+        self.assertEqual(parent_slugs, [f'reco_{str(self.run_1.id)[0:6]}_4'])
         self.assertEqual(task.env, {
             'ARKINDEX_PROCESS_ID': '12345',
             'ARKINDEX_TASK_TOKEN': str(task.token),
@@ -2741,9 +2741,9 @@ class TestWorkerRuns(FixtureAPITestCase):
                 'shm_size': 505,
             }
         }
-        task, parent_slugs = self.run_1.build_task(self.process_1, 'test', ENV.copy(), 'import', '/data/import/elements.json')
+        task, parent_slugs = self.run_1.build_task(self.process_1, ENV.copy(), 'import', '/data/import/elements.json')
 
-        self.assertEqual(task.slug, 'test')
+        self.assertEqual(task.slug, f'reco_{str(self.run_1.id)[0:6]}')
         self.assertEqual(task.image, f'my_repo.fake/workers/worker/reco:{str(self.version_1.id)}')
         self.assertEqual(task.command, None)
         self.assertEqual(task.image_artifact, self.artifact)
@@ -2778,7 +2778,7 @@ class TestWorkerRuns(FixtureAPITestCase):
             AssertionError,
             f"Worker Version {version_2.id} is not available and cannot be used to build a task."
         ):
-            run_2.build_task(self.process_1, 'test', ENV.copy(), 'import', '/data/import/elements.json')
+            run_2.build_task(self.process_1, ENV.copy(), 'import', '/data/import/elements.json')
 
     def test_build_task_unavailable_model_version(self):
         self.model_version_1.state = ModelVersionState.Created
@@ -2789,4 +2789,4 @@ class TestWorkerRuns(FixtureAPITestCase):
             AssertionError,
             f"ModelVersion {self.model_version_1.id} is not available and cannot be used to build a task."
         ):
-            self.run_1.build_task(self.process_1, 'test', ENV.copy(), 'import', '/data/import/elements.json')
+            self.run_1.build_task(self.process_1, ENV.copy(), 'import', '/data/import/elements.json')
-- 
GitLab