From 5a84a75ec792dbb02e994f652245dd46c1f37e63 Mon Sep 17 00:00:00 2001
From: Bastien Abadie <bastien@nextcairn.com>
Date: Tue, 30 Mar 2021 11:38:43 +0200
Subject: [PATCH] Add test case from worker

---
 arkindex_worker/cache.py  | 18 +++++-----
 arkindex_worker/worker.py |  2 +-
 tests/test_merge.py       | 70 +++++++++++++++++++++++++++++++++++++--
 3 files changed, 78 insertions(+), 12 deletions(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index 2391c45b..15f1514f 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -102,13 +102,15 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No
         filenames.append(f"db_{chunk}.sqlite")
 
     # Find all the paths for these databases
-    paths = filter(
-        lambda p: os.path.isfile(p),
-        [
-            os.path.join(data_dir, parent, name)
-            for parent in parent_ids
-            for name in filenames
-        ],
+    paths = list(
+        filter(
+            lambda p: os.path.isfile(p),
+            [
+                os.path.join(data_dir, parent, name)
+                for parent in parent_ids
+                for name in filenames
+            ],
+        )
     )
 
     if not paths:
@@ -133,5 +135,3 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No
         for statement in statements:
             cursor.execute(statement)
         connection.commit()
-
-    # TODO: maybe reopen peewee connection ?
diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index c76c0118..f1aaa5ae 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -160,7 +160,7 @@ class BaseWorker(object):
             merge_parents_cache(
                 task["parents"],
                 self.cache_path,
-                data_dir=os.environ.get("PONOS_DATA_DIR"),
+                data_dir=os.environ.get("PONOS_DATA_DIR", "/data"),
                 chunk=os.environ.get("ARKINDEX_TASK_CHUNK"),
             )
 
diff --git a/tests/test_merge.py b/tests/test_merge.py
index 5b405989..f2f4203a 100644
--- a/tests/test_merge.py
+++ b/tests/test_merge.py
@@ -1,7 +1,12 @@
 # -*- coding: utf-8 -*-
 from uuid import UUID
 
-from arkindex_worker.cache import MODELS, CachedElement, merge_parents_cache
+from arkindex_worker.cache import (
+    MODELS,
+    CachedElement,
+    CachedTranscription,
+    merge_parents_cache,
+)
 
 
 def test_merge_no_parent(mock_databases, tmpdir):
@@ -271,4 +276,65 @@ def test_merge_chunk(mock_databases, tmpdir, monkeypatch):
         ]
 
 
-# TODO: add a unit test using base worker
+def test_merge_from_worker(
+    responses, mock_base_worker_with_cache, mock_databases, tmpdir, monkeypatch
+):
+    """
+    High level merge from the base worker
+    """
+    responses.add(
+        responses.GET,
+        "http://testserver/ponos/v1/task/my_task/from-agent/",
+        status=200,
+        json={"parents": ["first", "second"]},
+    )
+
+    # Add an element in first parent database
+    with mock_databases["first"]["db"].bind_ctx(MODELS):
+        CachedElement.create(
+            id=UUID("12341234-1234-1234-1234-123412341234"),
+            type="page",
+            polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
+            worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
+        )
+        CachedElement.create(
+            id=UUID("56785678-5678-5678-5678-567856785678"),
+            type="page",
+            polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
+            worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
+        )
+
+    # Add another element in second parent database
+    with mock_databases["second"]["db"].bind_ctx(MODELS):
+        CachedElement.create(
+            id=UUID("12341234-1234-1234-1234-123412341234"),
+            type="page",
+            polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
+            worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
+        )
+        CachedTranscription.create(
+            id=UUID("11111111-1111-1111-1111-111111111111"),
+            element_id=UUID("12341234-1234-1234-1234-123412341234"),
+            text="Hello!",
+            confidence=0.42,
+            worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
+        )
+
+    # At first we have no data in our main database
+    assert CachedElement.select().count() == 0
+    assert CachedTranscription.select().count() == 0
+
+    # Configure worker with a specific data directory
+    monkeypatch.setenv("PONOS_DATA_DIR", str(tmpdir))
+    mock_base_worker_with_cache.configure()
+
+    # Then we have 2 elements and a transcription
+    assert CachedElement.select().count() == 2
+    assert CachedTranscription.select().count() == 1
+    assert [e.id for e in CachedElement.select().order_by("id")] == [
+        UUID("56785678-5678-5678-5678-567856785678"),
+        UUID("12341234-1234-1234-1234-123412341234"),
+    ]
+    assert [t.id for t in CachedTranscription.select().order_by("id")] == [
+        UUID("11111111-1111-1111-1111-111111111111"),
+    ]
-- 
GitLab