diff --git a/.isort.cfg b/.isort.cfg index de1cf11504411d76ff9f6f19d4bfe30e427ecc20..ad4d2fb8c0e010ee8b098fe2393efd319b3cc8b0 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -8,4 +8,4 @@ line_length = 88 default_section=FIRSTPARTY known_first_party = arkindex,arkindex_common -known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml +known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py new file mode 100644 index 0000000000000000000000000000000000000000..07e175f795163fb31864b7e84be62aef6733d183 --- /dev/null +++ b/arkindex_worker/worker/training.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +import hashlib +import os +import tarfile +import tempfile +from contextlib import contextmanager +from typing import NewType, Tuple + +import requests +import zstandard as zstd +from apistar.exceptions import ErrorResponse + +from arkindex_worker import logger + +CHUNK_SIZE = 1024 + +DirPath = NewType("DirPath", str) +Hash = NewType("Hash", str) +FileSize = NewType("FileSize", int) +Archive = Tuple[DirPath, Hash, FileSize] + + +@contextmanager +def create_archive(path: DirPath) -> Archive: + """First create a tar archive, then compress to a zst archive. + Finally, get its hash and size + """ + assert path.is_dir(), "create_archive needs a directory" + + compressor = zstd.ZstdCompressor(level=3) + content_hasher = hashlib.md5() + archive_hasher = hashlib.md5() + + # Remove extension from the model filename + _, path_to_tar_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar") + + # Create an uncompressed tar archive with all the needed files + # Files hierarchy ifs kept in the archive. + file_list = [] + with tarfile.open(path_to_tar_archive, "w") as tar: + for p in path.glob("**/*"): + x = p.relative_to(path) + tar.add(p, arcname=x, recursive=False) + file_list.append(p) + + # Sort by path + file_list.sort() + # Compute hash of the files + for file_path in file_list: + with open(file_path, "rb") as file_data: + for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""): + content_hasher.update(chunk) + + _, path_to_zst_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar.zst") + + # Compress the archive + with open(path_to_zst_archive, "wb") as archive_file: + with open(path_to_tar_archive, "rb") as model_data: + for model_chunk in iter(lambda: model_data.read(CHUNK_SIZE), b""): + compressed_chunk = compressor.compress(model_chunk) + archive_hasher.update(compressed_chunk) + archive_file.write(compressed_chunk) + + # Remove the tar archive + os.remove(path_to_tar_archive) + + # Get content hash, archive size and hash + hash = content_hasher.hexdigest() + size = os.path.getsize(path_to_zst_archive) + archive_hash = archive_hasher.hexdigest() + + yield path_to_zst_archive, hash, size, archive_hash + + # Remove the zstd archive + os.remove(path_to_zst_archive) + + +class TrainingMixin(object): + """ + Mixin for the Training workers to add Model and ModelVersion helpers + """ + + def publish_model_version(self, model_path: DirPath, model_id: str): + """ + This method creates a model archive and its associated hash, + to create a unique version that will be stored on a bucket and published on arkindex. + """ + + # Create the zst archive, get its hash and size + with create_archive(path=model_path) as ( + path_to_archive, + hash, + size, + archive_hash, + ): + # Create a new model version with hash and size + model_version_details = self.create_model_version( + model_id=model_id, + hash=hash, + size=size, + archive_hash=archive_hash, + ) + if model_version_details is None: + return + self.upload_to_s3( + archive_path=path_to_archive, + model_version_details=model_version_details, + ) + + # Update the model version with state, configuration parsed, tag, description (defaults to name of the worker) + self.update_model_version( + model_version_details=model_version_details, + ) + + def create_model_version( + self, + model_id: str, + hash: str, + size: int, + archive_hash: str, + ) -> dict: + """ + Create a new version of the specified model with the given information (hashes and size). + If a version matching the information already exist, there are two cases: + - The version is in `Created` state: this version's details is used + - The version is in `Available` state: you cannot create twice the same version, an error is raised + """ + + # Create a new model version with hash and size + try: + model_version_details = self.request( + "CreateModelVersion", + id=model_id, + body={"hash": hash, "size": size, "archive_hash": archive_hash}, + ) + except ErrorResponse as e: + if e.status_code >= 500: + logger.error(f"Failed to create model version: {e.content}") + + model_version_details = e.content.get("hash") + # If the existing model is in Created state, this model is returned as a dict. + # Else an error in a list is returned. + if model_version_details and isinstance( + model_version_details, (list, tuple) + ): + logger.error(model_version_details[0]) + return + + return model_version_details + + def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None: + """ + Upload the archive of the model's files to an Amazon s3 compatible storage + """ + + s3_put_url = model_version_details.get("s3_put_url") + logger.info("Uploading to s3...") + # Upload the archive on s3 + with open(archive_path, "rb") as archive: + r = requests.put( + url=s3_put_url, + data=archive, + headers={"Content-Type": "application/zstd"}, + ) + r.raise_for_status() + + def update_model_version( + self, + model_version_details: dict, + description: str = None, + configuration: dict = None, + tag: str = None, + ) -> None: + """ + Update the specified model version to the state `Available` and use the given information" + """ + + logger.info("Updating the model version...") + try: + self.request( + "UpdateModelVersion", + id=model_version_details.get("id"), + body={ + "state": "available", + "description": description, + "configuration": configuration, + "tag": tag, + }, + ) + except ErrorResponse as e: + logger.error(f"Failed to update model version: {e.content}") diff --git a/requirements.txt b/requirements.txt index 2d792ccabe915f6c1bdf08733f87b90b6fd47770..da41d9dce69a286c1f5f0a826d46a8c7f3cbe0bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ python-gnupg==0.4.8 sh==1.14.2 shapely==1.8.2 tenacity==8.0.1 +zstandard==0.18.0 diff --git a/tests-requirements.txt b/tests-requirements.txt index cbfbfa256a156026a0070ca834ffcac16a867d2f..d1f9171264f9bebaf629492863299aabe1c3592e 100644 --- a/tests-requirements.txt +++ b/tests-requirements.txt @@ -1,3 +1,4 @@ pytest==7.1.1 pytest-mock==3.7.0 pytest-responses==0.5.0 +requests==2.27.1 diff --git a/tests/conftest.py b/tests/conftest.py index d4052fd324051d19e4cd4a8a9cebe04b9d16c0d7..f00ae5e75571a79b20588939cedd1415c407bb25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ from arkindex_worker.worker import BaseWorker, ElementsWorker from arkindex_worker.worker.transcription import TextOrientation FIXTURES_DIR = Path(__file__).resolve().parent / "data" +SAMPLES_DIR = Path(__file__).resolve().parent / "samples" __yaml_cache = {} @@ -276,6 +277,11 @@ def fake_transcriptions_small(): return json.load(f) +@pytest.fixture +def model_file_dir(): + return SAMPLES_DIR / "model_files" + + @pytest.fixture def fake_dummy_worker(): api_client = MockApiClient() diff --git a/tests/samples/model_file.pth b/tests/samples/model_file.pth new file mode 100644 index 0000000000000000000000000000000000000000..cc78ba3026c620f5d0e8c5b65071ae8ae2dfe157 --- /dev/null +++ b/tests/samples/model_file.pth @@ -0,0 +1 @@ +Wow this is actually the data of the best model ever created on Arkindex \ No newline at end of file diff --git a/tests/samples/model_files/model_file.pth b/tests/samples/model_files/model_file.pth new file mode 100644 index 0000000000000000000000000000000000000000..cc78ba3026c620f5d0e8c5b65071ae8ae2dfe157 --- /dev/null +++ b/tests/samples/model_files/model_file.pth @@ -0,0 +1 @@ +Wow this is actually the data of the best model ever created on Arkindex \ No newline at end of file diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py new file mode 100644 index 0000000000000000000000000000000000000000..078d0af763a0633c95aa277c96248417ff56c172 --- /dev/null +++ b/tests/test_elements_worker/test_training.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +import os + +import pytest +import responses + +from arkindex.mock import MockApiClient +from arkindex_worker.worker import BaseWorker +from arkindex_worker.worker.training import TrainingMixin, create_archive + + +class TrainingWorker(BaseWorker, TrainingMixin): + """ + This class is only needed for tests + """ + + pass + + +def test_create_archive(model_file_dir): + """Create an archive when the model's file is in a folder""" + + with create_archive(path=model_file_dir) as ( + zst_archive_path, + hash, + size, + archive_hash, + ): + assert os.path.exists(zst_archive_path), "The archive was not created" + assert ( + hash == "c5aedde18a768757351068b840c8c8f9" + ), "Hash was not properly computed" + assert 300 < size < 700 + + assert not os.path.exists(zst_archive_path), "Auto removal failed" + + +def test_create_model_version(): + """A new model version is returned""" + + model_id = "fake_model_id" + model_version_id = "fake_model_version_id" + training = TrainingWorker() + training.api_client = MockApiClient() + model_hash = "hash" + archive_hash = "archive_hash" + size = "30" + model_version_details = { + "id": model_version_id, + "model_id": model_id, + "hash": model_hash, + "archive_hash": archive_hash, + "size": size, + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", + } + + training.api_client.add_response( + "CreateModelVersion", + id=model_id, + response=model_version_details, + body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, + ) + assert ( + training.create_model_version(model_id, model_hash, size, archive_hash) + == model_version_details + ) + + +@pytest.mark.parametrize( + "content, status_code", + [ + ( + { + "hash": { + "id": "fake_model_version_id", + "model_id": "fake_model_id", + "hash": "hash", + "archive_hash": "archive_hash", + "size": "size", + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", + } + }, + 400, + ), + ({"hash": ["A version for this model with this hash already exists."]}, 403), + ], +) +def test_retrieve_created_model_version(content, status_code): + """ + If there is an existing model version in Created mode, + A 400 was raised, but the model is still returned in error content. + Else if an existing model version in Available mode, + 403 was raised, but None will be returned + """ + + model_id = "fake_model_id" + training = TrainingWorker() + training.api_client = MockApiClient() + model_hash = "hash" + archive_hash = "archive_hash" + size = "30" + training.api_client.add_error_response( + "CreateModelVersion", + id=model_id, + status_code=status_code, + body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, + content=content, + ) + if status_code == 400: + assert ( + training.create_model_version(model_id, model_hash, size, archive_hash) + == content["hash"] + ) + elif status_code == 403: + assert ( + training.create_model_version(model_id, model_hash, size, archive_hash) + is None + ) + + +def test_handle_s3_uploading_errors(model_file_dir): + training = TrainingWorker() + training.api_client = MockApiClient() + s3_endpoint_url = "http://s3.localhost.com" + responses.add_passthru(s3_endpoint_url) + responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400)) + file_path = model_file_dir / "model_file.pth" + with pytest.raises(Exception): + training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})