From f9b83350ab94c067b1921768b80743fde694bc73 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Mon, 14 Nov 2022 17:40:51 +0100
Subject: [PATCH] remove not needed files, reorganize inside dan module

---
 .../dataset_formatters/bessin_formatter.py    | 133 -----------------
 .../dataset_formatters/synist_formatter.py    | 137 ------------------
 dan/cli.py                                    |   4 +-
 dan/datasets/extract/arkindex_utils.py        |   1 -
 dan/datasets/extract/extract_from_arkindex.py |   8 -
 dan/datasets/extract/utils.py                 |  26 +++-
 dan/datasets/format/bessin.py                 | 113 +++++++++++++++
 dan/datasets/format/synist.py                 | 120 +++++++++++++++
 dan/datasets/utils.py                         |  33 +++++
 dan/decoder.py                                |   1 -
 dan/ocr/train.py                              |   3 +-
 prediction-requirements.txt                   |   8 -
 requirements.txt                              |   8 +
 setup.py                                      |   1 -
 14 files changed, 303 insertions(+), 293 deletions(-)
 delete mode 100644 Datasets/dataset_formatters/bessin_formatter.py
 delete mode 100644 Datasets/dataset_formatters/synist_formatter.py
 create mode 100644 dan/datasets/format/bessin.py
 create mode 100644 dan/datasets/format/synist.py
 delete mode 100644 prediction-requirements.txt

diff --git a/Datasets/dataset_formatters/bessin_formatter.py b/Datasets/dataset_formatters/bessin_formatter.py
deleted file mode 100644
index c7c3db66..00000000
--- a/Datasets/dataset_formatters/bessin_formatter.py
+++ /dev/null
@@ -1,133 +0,0 @@
-#  contributors :
-#  - Denis Coquenet
-#
-#
-#  This software is a computer program written in Python  whose purpose is to
-#  provide public implementation of deep learning works, in pytorch.
-#
-#  This software is governed by the CeCILL-C license under French law and
-#  abiding by the rules of distribution of free software.  You can  use,
-#  modify and/ or redistribute the software under the terms of the CeCILL-C
-#  license as circulated by CEA, CNRS and INRIA at the following URL
-#  "http://www.cecill.info".
-#
-#  As a counterpart to the access to the source code and  rights to copy,
-#  modify and redistribute granted by the license, users are provided only
-#  with a limited warranty  and the software's author,  the holder of the
-#  economic rights,  and the successive licensors  have only  limited
-#  liability.
-#
-#  In this respect, the user's attention is drawn to the risks associated
-#  with loading,  using,  modifying and/or developing or reproducing the
-#  software by the user in light of its specific status of free software,
-#  that may mean  that it is complicated to manipulate,  and  that  also
-#  therefore means  that it is reserved for developers  and  experienced
-#  professionals having in-depth computer knowledge. Users are therefore
-#  encouraged to load and test the software's suitability as regards their
-#  requirements in conditions enabling the security of their systems and/or
-#  data to be ensured and,  more generally, to use and operate it in the
-#  same conditions as regards security.
-#
-#  The fact that you are presently reading this means that you have had
-#  knowledge of the CeCILL-C license and that you accept its terms.
-
-from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
-import os
-import numpy as np
-from Datasets.dataset_formatters.utils_dataset import natural_sort
-from PIL import Image
-import xml.etree.ElementTree as ET
-import re
-from tqdm import tqdm
-from collections import Counter
-
-
-def remove_spaces(text):
-    # remove begin/ending spaces
-    text = text.strip()
-    # replace \t with regular space
-    text = re.sub("\t", " ", text)
-    # remove consecutive spaces
-    text = re.sub(" +", " ", text)
-#    text = text.encode('ascii', 'ignore').decode("utf-8") 
-    return text
-
-
-class SynistDatasetFormatter(OCRDatasetFormatter):
-    def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False):
-        super(SynistDatasetFormatter, self).__init__("bessin", level, "_sem" if sem_token else "", set_names)
-
-        self.dpi = dpi
-        self.counter = Counter()
-        self.map_datasets_files.update({
-            "bessin": {
-                # (1,050 for train, 100 for validation and 100 for test)
-                "line": {
-                    "needed_files": [],
-                    "arx_files": [],
-                    "format_function": self.format_bessin_zone
-                        }
-            }
-        })
-
-    def preformat_bessin_zone(self):
-        """
-        Extract all information from dataset and correct some annotations
-        """
-        
-        dataset = {
-            "train": list(),
-            "valid": list(),
-            "test": list()
-        }
-        img_folder_path = os.path.join("Datasets", "raw", "bessin", "images")
-        labels_folder_path = os.path.join("Datasets", "raw", "bessin", "labels")
-        
-        train_files = [
-            os.path.join(labels_folder_path, 'train', name)
-            for name in os.listdir(os.path.join(labels_folder_path, 'train'))]
-        valid_files = [
-            os.path.join(labels_folder_path, 'valid', name)
-            for name in os.listdir(os.path.join(labels_folder_path, 'valid'))]
-        test_files = [
-            os.path.join(labels_folder_path, 'test', name)
-            for name in os.listdir(os.path.join(labels_folder_path, 'test'))]
-
-        for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]):
-            for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)):
-                with open(label_file, 'r') as f:
-                    text = remove_spaces(f.read())
-
-                dataset[set_name].append({
-                    "img_path": os.path.join(
-                        img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')),
-                    "label": text.strip()
-                })
-
-        return dataset
-
-    def format_bessin_zone(self):
-        """
-        Format synist page dataset
-        """
-        dataset = self.preformat_bessin_zone()
-        for set_name in self.set_names:
-            fold = os.path.join(self.target_fold_path, set_name)
-            for sample in tqdm(dataset[set_name], desc='Formatting '+set_name):
-                new_name = sample['img_path'].split('/')[-1]
-                new_img_path = os.path.join(fold, new_name)
-                self.load_resize_save(sample["img_path"], new_img_path)
-                zone = {
-                    "text": sample["label"],
-                }
-                self.charset = self.charset.union(set(zone["text"]))
-                self.gt[set_name][new_name] = zone
-                self.counter.update(zone['text'])
-
-
-if __name__ == "__main__":
-    formatter = SynistDatasetFormatter("line", sem_token=False)
-    formatter.format()
-    print("Character freq: ")
-    for k,v in formatter.counter.items():
-        print(k, v)
\ No newline at end of file
diff --git a/Datasets/dataset_formatters/synist_formatter.py b/Datasets/dataset_formatters/synist_formatter.py
deleted file mode 100644
index c028ee73..00000000
--- a/Datasets/dataset_formatters/synist_formatter.py
+++ /dev/null
@@ -1,137 +0,0 @@
-#  contributors :
-#  - Denis Coquenet
-#
-#
-#  This software is a computer program written in Python  whose purpose is to
-#  provide public implementation of deep learning works, in pytorch.
-#
-#  This software is governed by the CeCILL-C license under French law and
-#  abiding by the rules of distribution of free software.  You can  use,
-#  modify and/ or redistribute the software under the terms of the CeCILL-C
-#  license as circulated by CEA, CNRS and INRIA at the following URL
-#  "http://www.cecill.info".
-#
-#  As a counterpart to the access to the source code and  rights to copy,
-#  modify and redistribute granted by the license, users are provided only
-#  with a limited warranty  and the software's author,  the holder of the
-#  economic rights,  and the successive licensors  have only  limited
-#  liability.
-#
-#  In this respect, the user's attention is drawn to the risks associated
-#  with loading,  using,  modifying and/or developing or reproducing the
-#  software by the user in light of its specific status of free software,
-#  that may mean  that it is complicated to manipulate,  and  that  also
-#  therefore means  that it is reserved for developers  and  experienced
-#  professionals having in-depth computer knowledge. Users are therefore
-#  encouraged to load and test the software's suitability as regards their
-#  requirements in conditions enabling the security of their systems and/or
-#  data to be ensured and,  more generally, to use and operate it in the
-#  same conditions as regards security.
-#
-#  The fact that you are presently reading this means that you have had
-#  knowledge of the CeCILL-C license and that you accept its terms.
-
-from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
-import os
-import numpy as np
-from Datasets.dataset_formatters.utils_dataset import natural_sort
-from PIL import Image
-import xml.etree.ElementTree as ET
-import re
-from tqdm import tqdm
-from collections import Counter
-
-
-def remove_spaces(text):
-    # remove begin/ending spaces
-    text = text.strip()
-    # replace \t with regular space
-    text = re.sub("\t", " ", text)
-    # remove consecutive spaces
-    text = re.sub(" +", " ", text)
-#    text = text.encode('ascii', 'ignore').decode("utf-8") 
-    return text
-
-
-class SynistDatasetFormatter(OCRDatasetFormatter):
-    def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False):
-        super(SynistDatasetFormatter, self).__init__("synist_synth", level, "_sem" if sem_token else "", set_names)
-
-        self.dpi = dpi
-        self.counter = Counter()
-        self.map_datasets_files.update({
-            "synist_synth": {
-                # (1,050 for train, 100 for validation and 100 for test)
-                "line": {
-                    "needed_files": [],
-                    "arx_files": [],
-                    "format_function": self.format_synist_page
-                        }
-            }
-        })
-
-    def preformat_synist_page(self):
-        """
-        Extract all information from dataset and correct some annotations
-        """
-        
-        dataset = {
-            "train": list(),
-            "valid": list(),
-            "test": list()
-        }
-        img_folder_path = os.path.join("Datasets", "raw", "synist_synth_lines", "images")
-        labels_folder_path = os.path.join("Datasets", "raw", "synist_synth_lines", "labels")
-        
-        train_files = [
-            os.path.join(labels_folder_path, 'train', name)
-            for name in os.listdir(os.path.join(labels_folder_path, 'train'))]
-        valid_files = [
-            os.path.join(labels_folder_path, 'valid', name)
-            for name in os.listdir(os.path.join(labels_folder_path, 'valid'))]
-        test_files = [
-            os.path.join(labels_folder_path, 'test', name)
-            for name in os.listdir(os.path.join(labels_folder_path, 'test'))]
-
-        for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]):
-            for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)):
-                with open(label_file, 'r') as f:
-                    text = remove_spaces(f.read())
-
-                dataset[set_name].append({
-                    "img_path": os.path.join(
-                        img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'png')),
-                    "label": text.strip()
-                })
-
-        return dataset
-
-    def format_synist_page(self):
-        """
-        Format synist page dataset
-        """
-        dataset = self.preformat_synist_page()
-        for set_name in self.set_names:
-            fold = os.path.join(self.target_fold_path, set_name)
-            for sample in tqdm(dataset[set_name], desc='Formatting '+set_name):
-                new_name = sample['img_path'].split('/')[-1]
-                new_img_path = os.path.join(fold, new_name)
-                #self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi)
-                self.load_resize_save(sample["img_path"], new_img_path)
-                #self.load_flip_save(new_img_path, new_img_path)
-                page = {
-                    "text": sample["label"],
-                }
-                self.charset = self.charset.union(set(page["text"]))
-                self.gt[set_name][new_name] = page
-                self.counter.update(page['text'])
-
-
-if __name__ == "__main__":
-    formatter = SynistDatasetFormatter("line", sem_token=False)
-    formatter.format()
-    print(formatter.counter)
-    print(formatter.counter.most_common(80))
-    for k,v in formatter.counter.items():
-        print(k)
-        print(k.encode('utf-8'), v)
diff --git a/dan/cli.py b/dan/cli.py
index b0ed2199..fff057b9 100644
--- a/dan/cli.py
+++ b/dan/cli.py
@@ -1,13 +1,12 @@
 # -*- coding: utf-8 -*-
 import argparse
 import errno
+
 from dan.datasets.extract.extract_from_arkindex import add_extract_parser
 from dan.ocr.line.generate_synthetic import add_generate_parser
-
 from dan.ocr.train import add_train_parser
 
 
-
 def get_parser():
     parser = argparse.ArgumentParser(prog="TEKLIA DAN training")
     subcommands = parser.add_subparsers(metavar="subcommand")
@@ -17,6 +16,7 @@ def get_parser():
     add_generate_parser(subcommands)
     return parser
 
+
 def main():
     parser = get_parser()
     args = vars(parser.parse_args())
diff --git a/dan/datasets/extract/arkindex_utils.py b/dan/datasets/extract/arkindex_utils.py
index 5216a7e0..d9d2e065 100644
--- a/dan/datasets/extract/arkindex_utils.py
+++ b/dan/datasets/extract/arkindex_utils.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 
 """
diff --git a/dan/datasets/extract/extract_from_arkindex.py b/dan/datasets/extract/extract_from_arkindex.py
index 1a6e207d..c9c1c2c1 100644
--- a/dan/datasets/extract/extract_from_arkindex.py
+++ b/dan/datasets/extract/extract_from_arkindex.py
@@ -1,12 +1,4 @@
-#!/usr/bin/env python
 # -*- coding: utf-8 -*-
-#
-# Example of minimal usage:
-# python extract_from_arkindex.py
-#   --corpus "AN-Simara annotations E (2022-06-20)"
-#   --parents-types folder
-#   --parents-names FRAN_IR_032031_4538.pdf
-#   --output-dir ../
 
 """
     The extraction module
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index 72ab461b..77d15804 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 
 """
@@ -25,6 +24,13 @@ def get_cli_args():
         help="Name of the corpus from which the data will be retrieved.",
         required=True,
     )
+    parser.add_argument(
+        "--element-type",
+        nargs="+",
+        type=str,
+        help="Type of elements to retrieve",
+        required=True,
+    )
     parser.add_argument(
         "--parents-types",
         nargs="+",
@@ -47,4 +53,22 @@ def get_cli_args():
         help="Names of parents of the elements.",
         default=None,
     )
+    parser.add_argument(
+        "--no-entities", action="store_true", help="Extract text without entities"
+    )
+
+    parser.add_argument(
+        "--use-existing-split",
+        action="store_true",
+        help="Do not partition pages into train/val/test",
+    )
+
+    parser.add_argument(
+        "--train-prob", type=float, default=0.7, help="Training set probability"
+    )
+
+    parser.add_argument(
+        "--val-prob", type=float, default=0.15, help="Validation set probability"
+    )
+
     return parser.parse_args()
diff --git a/dan/datasets/format/bessin.py b/dan/datasets/format/bessin.py
new file mode 100644
index 00000000..1b2248b2
--- /dev/null
+++ b/dan/datasets/format/bessin.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+import os
+import re
+from collections import Counter
+
+from tqdm import tqdm
+
+from dan.datasets.format.generic import OCRDatasetFormatter
+
+
+def remove_spaces(text):
+    # remove begin/ending spaces
+    text = text.strip()
+    # replace \t with regular space
+    text = re.sub("\t", " ", text)
+    # remove consecutive spaces
+    text = re.sub(" +", " ", text)
+    #    text = text.encode('ascii', 'ignore').decode("utf-8")
+    return text
+
+
+class BessinDatasetFormatter(OCRDatasetFormatter):
+    def __init__(
+        self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False
+    ):
+        super(BessinDatasetFormatter, self).__init__(
+            "bessin", level, "_sem" if sem_token else "", set_names
+        )
+
+        self.dpi = dpi
+        self.counter = Counter()
+        self.map_datasets_files.update(
+            {
+                "bessin": {
+                    # (1,050 for train, 100 for validation and 100 for test)
+                    "line": {
+                        "needed_files": [],
+                        "arx_files": [],
+                        "format_function": self.format_bessin_zone,
+                    }
+                }
+            }
+        )
+
+    def preformat_bessin_zone(self):
+        """
+        Extract all information from dataset and correct some annotations
+        """
+
+        dataset = {"train": list(), "valid": list(), "test": list()}
+        img_folder_path = os.path.join("Datasets", "raw", "bessin", "images")
+        labels_folder_path = os.path.join("Datasets", "raw", "bessin", "labels")
+
+        train_files = [
+            os.path.join(labels_folder_path, "train", name)
+            for name in os.listdir(os.path.join(labels_folder_path, "train"))
+        ]
+        valid_files = [
+            os.path.join(labels_folder_path, "valid", name)
+            for name in os.listdir(os.path.join(labels_folder_path, "valid"))
+        ]
+        test_files = [
+            os.path.join(labels_folder_path, "test", name)
+            for name in os.listdir(os.path.join(labels_folder_path, "test"))
+        ]
+
+        for set_name, files in zip(
+            self.set_names, [train_files, valid_files, test_files]
+        ):
+            for i, label_file in enumerate(
+                tqdm(files, desc="Pre-formatting " + set_name)
+            ):
+                with open(label_file, "r") as f:
+                    text = remove_spaces(f.read())
+
+                dataset[set_name].append(
+                    {
+                        "img_path": os.path.join(
+                            img_folder_path,
+                            set_name,
+                            label_file.split("/")[-1].replace("txt", "jpg"),
+                        ),
+                        "label": text.strip(),
+                    }
+                )
+
+        return dataset
+
+    def format_bessin_zone(self):
+        """
+        Format synist page dataset
+        """
+        dataset = self.preformat_bessin_zone()
+        for set_name in self.set_names:
+            fold = os.path.join(self.target_fold_path, set_name)
+            for sample in tqdm(dataset[set_name], desc="Formatting " + set_name):
+                new_name = sample["img_path"].split("/")[-1]
+                new_img_path = os.path.join(fold, new_name)
+                self.load_resize_save(sample["img_path"], new_img_path)
+                zone = {
+                    "text": sample["label"],
+                }
+                self.charset = self.charset.union(set(zone["text"]))
+                self.gt[set_name][new_name] = zone
+                self.counter.update(zone["text"])
+
+
+if __name__ == "__main__":
+    formatter = BessinDatasetFormatter("line", sem_token=False)
+    formatter.format()
+    print("Character freq: ")
+    for k, v in formatter.counter.items():
+        print(k, v)
diff --git a/dan/datasets/format/synist.py b/dan/datasets/format/synist.py
new file mode 100644
index 00000000..87c18fea
--- /dev/null
+++ b/dan/datasets/format/synist.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+import os
+import re
+from collections import Counter
+
+from tqdm import tqdm
+
+from dan.datasets.format.generic import OCRDatasetFormatter
+
+
+def remove_spaces(text):
+    # remove begin/ending spaces
+    text = text.strip()
+    # replace \t with regular space
+    text = re.sub("\t", " ", text)
+    # remove consecutive spaces
+    text = re.sub(" +", " ", text)
+    return text
+
+
+class SynistDatasetFormatter(OCRDatasetFormatter):
+    def __init__(
+        self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False
+    ):
+        super(SynistDatasetFormatter, self).__init__(
+            "synist_synth", level, "_sem" if sem_token else "", set_names
+        )
+
+        self.dpi = dpi
+        self.counter = Counter()
+        self.map_datasets_files.update(
+            {
+                "synist_synth": {
+                    # (1,050 for train, 100 for validation and 100 for test)
+                    "line": {
+                        "needed_files": [],
+                        "arx_files": [],
+                        "format_function": self.format_synist_page,
+                    }
+                }
+            }
+        )
+
+    def preformat_synist_page(self):
+        """
+        Extract all information from dataset and correct some annotations
+        """
+
+        dataset = {"train": list(), "valid": list(), "test": list()}
+        img_folder_path = os.path.join(
+            "Datasets", "raw", "synist_synth_lines", "images"
+        )
+        labels_folder_path = os.path.join(
+            "Datasets", "raw", "synist_synth_lines", "labels"
+        )
+
+        train_files = [
+            os.path.join(labels_folder_path, "train", name)
+            for name in os.listdir(os.path.join(labels_folder_path, "train"))
+        ]
+        valid_files = [
+            os.path.join(labels_folder_path, "valid", name)
+            for name in os.listdir(os.path.join(labels_folder_path, "valid"))
+        ]
+        test_files = [
+            os.path.join(labels_folder_path, "test", name)
+            for name in os.listdir(os.path.join(labels_folder_path, "test"))
+        ]
+
+        for set_name, files in zip(
+            self.set_names, [train_files, valid_files, test_files]
+        ):
+            for i, label_file in enumerate(
+                tqdm(files, desc="Pre-formatting " + set_name)
+            ):
+                with open(label_file, "r") as f:
+                    text = remove_spaces(f.read())
+
+                dataset[set_name].append(
+                    {
+                        "img_path": os.path.join(
+                            img_folder_path,
+                            set_name,
+                            label_file.split("/")[-1].replace("txt", "png"),
+                        ),
+                        "label": text.strip(),
+                    }
+                )
+
+        return dataset
+
+    def format_synist_page(self):
+        """
+        Format synist page dataset
+        """
+        dataset = self.preformat_synist_page()
+        for set_name in self.set_names:
+            fold = os.path.join(self.target_fold_path, set_name)
+            for sample in tqdm(dataset[set_name], desc="Formatting " + set_name):
+                new_name = sample["img_path"].split("/")[-1]
+                new_img_path = os.path.join(fold, new_name)
+                # self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi)
+                self.load_resize_save(sample["img_path"], new_img_path)
+                # self.load_flip_save(new_img_path, new_img_path)
+                page = {
+                    "text": sample["label"],
+                }
+                self.charset = self.charset.union(set(page["text"]))
+                self.gt[set_name][new_name] = page
+                self.counter.update(page["text"])
+
+
+if __name__ == "__main__":
+    formatter = SynistDatasetFormatter("line", sem_token=False)
+    formatter.format()
+    print(formatter.counter)
+    print(formatter.counter.most_common(80))
+    for k, v in formatter.counter.items():
+        print(k)
+        print(k.encode("utf-8"), v)
diff --git a/dan/datasets/utils.py b/dan/datasets/utils.py
index c911cc9b..6422357c 100644
--- a/dan/datasets/utils.py
+++ b/dan/datasets/utils.py
@@ -1,6 +1,12 @@
 # -*- coding: utf-8 -*-
+import json
+import random
 import re
 
+import cv2
+
+random.seed(42)
+
 
 def convert(text):
     return int(text) if text.isdigit() else text.lower()
@@ -8,3 +14,30 @@ def convert(text):
 
 def natural_sort(data):
     return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)])
+
+
+def assign_random_split(train_prob, val_prob):
+    """
+    assuming train_prob + val_prob + test_prob = 1
+    """
+    prob = random.random()
+    if prob <= train_prob:
+        return "train"
+    elif prob <= train_prob + val_prob:
+        return "val"
+    else:
+        return "test"
+
+
+def save_text(path, text):
+    with open(path, "w") as f:
+        f.write(text)
+
+
+def save_image(path, image):
+    cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
+
+
+def save_json(path, dict):
+    with open(path, "w") as outfile:
+        json.dump(dict, outfile, indent=4)
diff --git a/dan/decoder.py b/dan/decoder.py
index d2070fc6..f69d8e9a 100644
--- a/dan/decoder.py
+++ b/dan/decoder.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
 import torch
diff --git a/dan/ocr/train.py b/dan/ocr/train.py
index 1bcbe1c7..dda43ba8 100644
--- a/dan/ocr/train.py
+++ b/dan/ocr/train.py
@@ -1,7 +1,8 @@
 # -*- coding: utf-8 -*-
 
-from dan.ocr.line.train import add_line_parser
 from dan.ocr.document.train import add_document_parser
+from dan.ocr.line.train import add_line_parser
+
 
 def add_train_parser(subcommands) -> None:
     parser = subcommands.add_parser(
diff --git a/prediction-requirements.txt b/prediction-requirements.txt
deleted file mode 100644
index 86bfaeb8..00000000
--- a/prediction-requirements.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-arkindex-client==1.0.11
-editdistance==0.6.0
-fontTools==4.29.1
-imageio==2.16.0
-networkx==2.6.3
-tensorboard==0.2.1
-torchvision==0.12.0
-tqdm==4.62.3
diff --git a/requirements.txt b/requirements.txt
index 1fe32d37..27758aa6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,12 @@
+arkindex-client==1.0.11
+editdistance==0.6.0
+fontTools==4.29.1
+imageio==2.16.0
+networkx==2.6.3
 numpy==1.22.3
 opencv-python==4.5.5.64
 PyYAML==6.0
+tensorboard==0.2.1
 torch==1.11.0
+torchvision==0.12.0
+tqdm==4.62.3
diff --git a/setup.py b/setup.py
index a0002964..51d55e0f 100755
--- a/setup.py
+++ b/setup.py
@@ -28,5 +28,4 @@ setup(
             "teklia-dan=dan.cli:main",
         ]
     },
-    extras_require={"predict": parse_requirements("prediction-requirements.txt")},
 )
-- 
GitLab