From ceb2adbcca88fecfa8f75d31a320ce7594212e62 Mon Sep 17 00:00:00 2001
From: Nolan <nboukachab@teklia.com>
Date: Tue, 29 Nov 2022 12:42:54 +0000
Subject: [PATCH] Only use files when computing model archive hash

---
 arkindex_worker/worker/training.py             |  8 +++-----
 tests/conftest.py                              |  7 ++++++-
 .../model-best/model_file.pth                  |  1 +
 .../model-last/model_file.pth                  |  1 +
 tests/test_elements_worker/test_training.py    | 18 ++++++++++++++++++
 5 files changed, 29 insertions(+), 6 deletions(-)
 create mode 100644 tests/samples/model_files_with_subfolder/model-best/model_file.pth
 create mode 100644 tests/samples/model_files_with_subfolder/model-last/model_file.pth

diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py
index c53cffee..ffe52d6f 100644
--- a/arkindex_worker/worker/training.py
+++ b/arkindex_worker/worker/training.py
@@ -50,12 +50,10 @@ def create_archive(path: DirPath) -> Tuple[Path, Hash, FileSize, Hash]:
 
     # Create an uncompressed tar archive with all the needed files
     # Files hierarchy ifs kept in the archive.
-    file_list = []
+
     with tarfile.open(path_to_tar_archive, "w") as tar:
-        for p in path.glob("**/*"):
-            x = p.relative_to(path)
-            tar.add(p, arcname=x, recursive=False)
-            file_list.append(p)
+        tar.add(path)
+        file_list = [member for member in tar.getnames() if os.path.isfile(member)]
 
     # Sort by path
     file_list.sort()
diff --git a/tests/conftest.py b/tests/conftest.py
index d058dc87..8b6f16ef 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -26,7 +26,7 @@ from arkindex_worker.worker import BaseWorker, ElementsWorker
 from arkindex_worker.worker.transcription import TextOrientation
 
 FIXTURES_DIR = Path(__file__).resolve().parent / "data"
-SAMPLES_DIR = Path(__file__).resolve().parent / "samples"
+SAMPLES_DIR = Path("tests") / "samples"
 
 __yaml_cache = {}
 
@@ -280,6 +280,11 @@ def model_file_dir():
     return SAMPLES_DIR / "model_files"
 
 
+@pytest.fixture
+def model_file_dir_with_subfolder():
+    return SAMPLES_DIR / "model_files_with_subfolder"
+
+
 @pytest.fixture
 def fake_dummy_worker():
     api_client = MockApiClient()
diff --git a/tests/samples/model_files_with_subfolder/model-best/model_file.pth b/tests/samples/model_files_with_subfolder/model-best/model_file.pth
new file mode 100644
index 00000000..cc78ba30
--- /dev/null
+++ b/tests/samples/model_files_with_subfolder/model-best/model_file.pth
@@ -0,0 +1 @@
+Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
diff --git a/tests/samples/model_files_with_subfolder/model-last/model_file.pth b/tests/samples/model_files_with_subfolder/model-last/model_file.pth
new file mode 100644
index 00000000..cc78ba30
--- /dev/null
+++ b/tests/samples/model_files_with_subfolder/model-last/model_file.pth
@@ -0,0 +1 @@
+Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py
index e8f41962..4448554d 100644
--- a/tests/test_elements_worker/test_training.py
+++ b/tests/test_elements_worker/test_training.py
@@ -45,6 +45,24 @@ def test_create_archive(model_file_dir):
     assert not os.path.exists(zst_archive_path), "Auto removal failed"
 
 
+def test_create_archive_with_subfolder(model_file_dir_with_subfolder):
+    """Create an archive when the model's file is in a folder containing a subfolder"""
+
+    with create_archive(path=model_file_dir_with_subfolder) as (
+        zst_archive_path,
+        hash,
+        size,
+        archive_hash,
+    ):
+        assert os.path.exists(zst_archive_path), "The archive was not created"
+        assert (
+            hash == "e2fa86cefc33b24502ad4151a638dd29"
+        ), "Hash was not properly computed"
+        assert 300 < size < 1500
+
+    assert not os.path.exists(zst_archive_path), "Auto removal failed"
+
+
 @pytest.mark.parametrize(
     "tag, description",
     [
-- 
GitLab