From 9c55c68389249de47059e6d208e4479a5202fa97 Mon Sep 17 00:00:00 2001
From: NolanB <nboukachab@teklia.com>
Date: Tue, 16 Aug 2022 17:30:39 +0200
Subject: [PATCH] First draft, tests are missing

---
 .isort.cfg                         |   2 +-
 arkindex_worker/worker/training.py | 159 +++++++++++++++++++++++++++++
 requirements.txt                   |   1 +
 3 files changed, 161 insertions(+), 1 deletion(-)
 create mode 100644 arkindex_worker/worker/training.py

diff --git a/.isort.cfg b/.isort.cfg
index de1cf115..f03c5435 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -8,4 +8,4 @@ line_length = 88
 
 default_section=FIRSTPARTY
 known_first_party = arkindex,arkindex_common
-known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml
+known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml,zstandard
diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py
new file mode 100644
index 00000000..cae83c0c
--- /dev/null
+++ b/arkindex_worker/worker/training.py
@@ -0,0 +1,159 @@
+# -*- coding: utf-8 -*-
+import hashlib
+import os
+import tarfile
+import tempfile
+from contextlib import contextmanager
+from typing import NewType, Tuple
+
+import requests
+import zstandard as zstd
+from apistar.exceptions import ErrorResponse
+
+from arkindex import ArkindexClient
+from arkindex_worker import logger
+
+CHUNK_SIZE = 1024
+
+FilePath = NewType("FilePath", str)
+Hash = NewType("Hash", str)
+FileSize = NewType("FileSize", int)
+Archive = Tuple[FilePath, Hash, FileSize]
+
+
+@contextmanager
+def create_archive(path: FilePath) -> Archive:
+    """First create a tar archive, then compress to a zst archive.
+    Finally, get its hash and size
+    """
+    assert path.is_dir(), "create_archive needs a directory"
+
+    compressor = zstd.ZstdCompressor(level=3)
+    content_hasher = hashlib.md5()
+    archive_hasher = hashlib.md5()
+
+    # Remove extension from the model filename
+    _, path_to_tar_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar")
+
+    # Create an uncompressed tar archive with all the needed files
+    # Files hierarchy ifs kept in the archive.
+    file_list = []
+    with tarfile.open(path_to_tar_archive, "w") as tar:
+        for p in path.glob("**/*"):
+            x = p.relative_to(path)
+            tar.add(p, arcname=x, recursive=False)
+            file_list.append(p)
+
+    # Sort by path
+    file_list.sort()
+    # Compute hash of the files
+    for file_path in file_list:
+        with open(file_path, "rb") as file_data:
+            for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""):
+                content_hasher.update(chunk)
+
+    _, path_to_zst_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar.zst")
+
+    # Compress the archive
+    with open(path_to_zst_archive, "wb") as archive_file:
+        with open(path_to_tar_archive, "rb") as model_data:
+            for model_chunk in iter(lambda: model_data.read(CHUNK_SIZE), b""):
+                compressed_chunk = compressor.compress(model_chunk)
+                archive_hasher.update(compressed_chunk)
+                archive_file.write(compressed_chunk)
+
+    # Remove the tar archive
+    os.remove(path_to_tar_archive)
+
+    # Get content hash, archive size and hash
+    hash = content_hasher.hexdigest()
+    size = os.path.getsize(path_to_zst_archive)
+    archive_hash = archive_hasher.hexdigest()
+
+    yield path_to_zst_archive, hash, size, archive_hash
+
+    # Remove the zstd archive
+    os.remove(path_to_zst_archive)
+
+
+class TrainingMixin(object):
+    def publish_model_version(self, client, model_path, model_id):
+        # Create the zst archive, get its hash and size
+        with create_archive(path=model_path) as (
+            path_to_archive,
+            hash,
+            size,
+            archive_hash,
+        ):
+            # Create a new model version with hash and size
+            model_version_details = self.create_model_version(
+                client=client,
+                model_id=model_id,
+                hash=hash,
+                size=size,
+                archive_hash=archive_hash,
+            )
+            if model_version_details is None:
+                return
+            self.upload_to_s3(
+                archive_path=path_to_archive,
+                model_version_details=model_version_details,
+            )
+
+        # Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
+        self.update_model_version(
+            client=client,
+            model_version_details=model_version_details,
+        )
+
+    def create_model_version(
+        client: ArkindexClient, model_id: str, hash: str, size: int, archive_hash: str
+    ) -> dict:
+        # Create a new model version with hash and size
+        try:
+            model_version_details = client.request(
+                "CreateModelVersion",
+                id=model_id,
+                body={"hash": hash, "size": size, "archive_hash": archive_hash},
+            )
+        except ErrorResponse as e:
+            if e.status_code >= 500:
+                logger.error(f"Failed to create model version: {e.content}")
+            model_version_details = e.content.get("hash")
+            # If the existing model is in Created state, this model is returned as a dict.
+            # Else an error in a list is returned.
+            if model_version_details and isinstance(
+                model_version_details, (list, tuple)
+            ):
+                logger.error(model_version_details[0])
+                return
+        return model_version_details
+
+    def upload_to_s3(archive_path: str, model_version_details: dict) -> None:
+        s3_put_url = model_version_details.get("s3_put_url")
+        logger.info("Uploading to s3...")
+        # Upload the archive on s3
+        with open(archive_path, "rb") as archive:
+            r = requests.put(
+                url=s3_put_url,
+                data=archive,
+                headers={"Content-Type": "application/zstd"},
+            )
+        r.raise_for_status()
+
+    def update_model_version(
+        client: ArkindexClient, model_version_details: dict
+    ) -> None:
+        logger.info("Updating the model version...")
+        try:
+            client.request(
+                "UpdateModelVersion",
+                id=model_version_details.get("id"),
+                body={
+                    "state": "available",
+                    "description": "DOC UFCN",
+                    "configuration": {},
+                },
+            )
+        except ErrorResponse as e:
+            logger.error(f"Failed to update model version: {e.content}")
diff --git a/requirements.txt b/requirements.txt
index 2d792cca..da41d9dc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,3 +6,4 @@ python-gnupg==0.4.8
 sh==1.14.2
 shapely==1.8.2
 tenacity==8.0.1
+zstandard==0.18.0
-- 
GitLab