diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 4a739288a9c82595da18ca75e416e6193872daa8..c33c13d1c785d405e01cafcc42fa55d29144cae0 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -30,9 +30,7 @@ FileSize = NewType("FileSize", int) @contextmanager -def create_archive( - path: DirPath, use_parent_folder: bool = False -) -> Tuple[Path, Hash, FileSize, Hash]: +def create_archive(path: DirPath) -> Tuple[Path, Hash, FileSize, Hash]: """ Create a tar archive from the files at the given location then compress it to a zst archive. @@ -54,17 +52,21 @@ def create_archive( # Files hierarchy ifs kept in the archive. file_list = [] with tarfile.open(path_to_tar_archive, "w") as tar: - if use_parent_folder: - # Only publish contents of the given folder without parent folder - tar.add(path, arcname=".") - else: - tar.add(path) - file_list = [member for member in tar.getnames() if os.path.isfile(member)] + for p in path.glob("**/*"): + x = p.relative_to(path) + + tar.add(p, arcname=x, recursive=False) + + if p.is_dir(): + continue + + file_list.append(p) # Sort by path file_list.sort() # Compute hash of the files for file_path in file_list: + print("BOOM PATH", file_path) with open(file_path, "rb") as file_data: for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""): content_hasher.update(chunk) @@ -100,6 +102,7 @@ class TrainingMixin(object): model_id: str, tag: Optional[str] = None, description: Optional[str] = None, + use_parent_folder=False, ): """ This method creates a model archive and its associated hash, diff --git a/tests/conftest.py b/tests/conftest.py index d058dc878d05de7160d06ce8d0380008d5711a14..03a1b841bd85235163c2c688708e843f253acfd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -280,6 +280,11 @@ def model_file_dir(): return SAMPLES_DIR / "model_files" +@pytest.fixture +def model_file_dir_with_subfolder(): + return SAMPLES_DIR / "model_files_with_subfolder" + + @pytest.fixture def fake_dummy_worker(): api_client = MockApiClient() diff --git a/tests/samples/model_files_with_subfolder/model-best/model_file.pth b/tests/samples/model_files_with_subfolder/model-best/model_file.pth new file mode 100644 index 0000000000000000000000000000000000000000..cc78ba3026c620f5d0e8c5b65071ae8ae2dfe157 --- /dev/null +++ b/tests/samples/model_files_with_subfolder/model-best/model_file.pth @@ -0,0 +1 @@ +Wow this is actually the data of the best model ever created on Arkindex \ No newline at end of file diff --git a/tests/samples/model_files_with_subfolder/model-last/model_file.pth b/tests/samples/model_files_with_subfolder/model-last/model_file.pth new file mode 100644 index 0000000000000000000000000000000000000000..cc78ba3026c620f5d0e8c5b65071ae8ae2dfe157 --- /dev/null +++ b/tests/samples/model_files_with_subfolder/model-last/model_file.pth @@ -0,0 +1 @@ +Wow this is actually the data of the best model ever created on Arkindex \ No newline at end of file diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index e8f419621bbf964a4313c93c130ffe14d049af7f..4ecbf0f7302bf6cd351e025fca76ec520b20ecd8 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -45,6 +45,24 @@ def test_create_archive(model_file_dir): assert not os.path.exists(zst_archive_path), "Auto removal failed" +def test_create_archive_with_subfolder(model_file_dir_with_subfolder): + """Create an archive when the model's file is in a folder""" + + with create_archive(path=model_file_dir_with_subfolder) as ( + zst_archive_path, + hash, + size, + archive_hash, + ): + assert os.path.exists(zst_archive_path), "The archive was not created" + assert ( + hash == "e2fa86cefc33b24502ad4151a638dd29" + ), "Hash was not properly computed" + assert 300 < size < 1200 + + assert not os.path.exists(zst_archive_path), "Auto removal failed" + + @pytest.mark.parametrize( "tag, description", [