diff --git a/.isort.cfg b/.isort.cfg index 1b59f25c5884d8878969cf823859d24e4d368dbc..98ed93316ffeee0e12cc6158abfa1610e61e1f67 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -8,4 +8,4 @@ line_length = 88 default_section=FIRSTPARTY known_first_party = -known_third_party = PIL,apistar,arkindex,cv2,numpy,pytest,requests,setuptools,tqdm +known_third_party = PIL,cv2,numpy,pytest,requests,setuptools diff --git a/README.md b/README.md index 48ff866ea084f1842ca6fca8ab9ea4b8db27239f..d70bf7a44cc6e9c64e441e848c9174a54a0218e1 100644 --- a/README.md +++ b/README.md @@ -9,33 +9,33 @@ and converts data to Kaldi format or kraken format. It also generates train, val `ARKINDEX_API_TOKEN` and `ARKINDEX_API_URL` environment variables must be defined. -Install necessary dependencies +Install it as a package ```bash virtualenv -p python3 .env source .env/bin/activate -pip install -r requirements.txt +pip install -e . ``` Use help to list possible parameters: ```bash -python kaldi_data_generator.py --help +kaldi-data-generator --help ``` There is also an option that skips all vertical transcriptions and it is `--skip_vertical_lines` #### Kaldi format Simple example: ```bash -python kaldi_data_generator.py -f kaldi --dataset_name my_balsac --out_dir /tmp/balsac/ --volumes 8f4005e9-1921-47b0-be7b-e27c7fd29486 d2f7c563-1622-4721-bd51-96fab97189f7 +kaldi-data-generator -f kaldi --dataset_name my_balsac --out_dir /tmp/balsac/ --volumes 8f4005e9-1921-47b0-be7b-e27c7fd29486 d2f7c563-1622-4721-bd51-96fab97189f7 ``` With corpus ids ```bash -python kaldi_data_generator.py -f kaldi --dataset_name cz --out_dir /tmp/home_cz/ --corpora 1ed45e94-9108-4029-a529-9abe37f55ba0 +kaldi-data-generator -f kaldi --dataset_name cz --out_dir /tmp/home_cz/ --corpora 1ed45e94-9108-4029-a529-9abe37f55ba0 ``` Polygon example: ```bash -python kaldi_data_generator.py -f kaldi --dataset_name my_balsac2 --extraction_mode polygon --out_dir /tmp/balsac/ --pages 50e1c3c0-2fe9-4216-805e-1a2fd2e7e9f4 +kaldi-data-generator -f kaldi --dataset_name my_balsac2 --extraction_mode polygon --out_dir /tmp/balsac/ --pages 50e1c3c0-2fe9-4216-805e-1a2fd2e7e9f4 ``` The script creates 3 directories `Lines`, `Transcriptions`, `Partitions` in the specified `out_dir`. @@ -45,13 +45,13 @@ The contents of these directories must be copied (or symlinked) to the correspon simple examples: ``` -$ python3 kaldi_data_generator.py -f kraken -o <output_dir> --volumes <volume_id> --no_split +$ kaldi-data-generator -f kraken -o <output_dir> --volumes <volume_id> --no_split ``` For instance to download the 4 sets from IAM (2 validation set on Arkindex) in 3 directories : ``` -$ python3 kaldi_data_generator.py -f kraken -o iam_training --volumes e7a95479-e5fc-4b20-830c-0c6e38bf8f72 --no_split -$ python3 kaldi_data_generator.py -f kraken -o iam_validation --volumes edc78ee1-09e0-4671-806b-5fc0392707d9 --no_split -$ python3 kaldi_data_generator.py -f kraken -o iam_validation --volumes fefbbfca-a6dd-4e00-8797-0d4628cb024d --no_split -$ python3 kaldi_data_generator.py -f kraken -o iam_test --volumes 0ce2b631-01d7-49bf-b213-ceb6eae74a9b --no_split +$ kaldi-data-generator -f kraken -o iam_training --volumes e7a95479-e5fc-4b20-830c-0c6e38bf8f72 --no_split +$ kaldi-data-generator -f kraken -o iam_validation --volumes edc78ee1-09e0-4671-806b-5fc0392707d9 --no_split +$ kaldi-data-generator -f kraken -o iam_validation --volumes fefbbfca-a6dd-4e00-8797-0d4628cb024d --no_split +$ kaldi-data-generator -f kraken -o iam_test --volumes 0ce2b631-01d7-49bf-b213-ceb6eae74a9b --no_split ``` diff --git a/kaldi_data_generator/kaldi_data_generator.py b/kaldi_data_generator/kaldi_data_generator.py deleted file mode 100644 index 2a1a03c7b0d57954ac11795da2cc414f32f7b734..0000000000000000000000000000000000000000 --- a/kaldi_data_generator/kaldi_data_generator.py +++ /dev/null @@ -1,661 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import argparse -import os -import random -from enum import Enum -from pathlib import Path - -import cv2 -import numpy as np -import tqdm -from apistar.exceptions import ErrorResponse -from arkindex import ArkindexClient, options_from_env - -from kaldi_data_generator.image_utils import ( - determine_rotate_angle, - download_image, - extract_min_area_rect_image, - extract_polygon_image, - rotate, - trim, -) -from kaldi_data_generator.utils import logger, write_file - -api_client = ArkindexClient(**options_from_env()) - -SEED = 42 -random.seed(SEED) -MANUAL = "manual" -TEXT_LINE = "text_line" -WHITE = 255 - - -class Extraction(Enum): - boundingRect: int = 0 - polygon: int = 1 - # minimum containing rectangle with an angle (cv2.min_area_rect) - min_area_rect: int = 2 - deskew_polygon: int = 3 - deskew_min_area_rect: int = 4 - - -class HTRDataGenerator: - def __init__( - self, - module, - dataset_name="foo", - out_dir_base="/tmp/kaldi_data", - grayscale=True, - extraction=Extraction.boundingRect, - accepted_slugs=None, - accepted_classes=None, - filter_printed=False, - skip_vertical_lines=False, - accepted_worker_version_ids=None, - transcription_type=TEXT_LINE, - max_deskew_angle=45, - ): - - self.module = module - self.out_dir_base = out_dir_base - self.dataset_name = dataset_name - self.grayscale = grayscale - self.extraction_mode = extraction - self.accepted_slugs = accepted_slugs - self.should_filter_by_slug = bool(self.accepted_slugs) - self.accepted_classes = accepted_classes - self.should_filter_by_class = bool(self.accepted_classes) - self.accepted_worker_version_ids = accepted_worker_version_ids - self.should_filter_by_worker = bool(self.accepted_worker_version_ids) - self.should_filter_printed = filter_printed - self.transcription_type = transcription_type - self.skip_vertical_lines = skip_vertical_lines - self.skipped_pages_count = 0 - self.skipped_vertical_lines_count = 0 - self.accepted_lines_count = 0 - self.max_deskew_angle = max_deskew_angle - - if MANUAL in self.accepted_worker_version_ids: - self.accepted_worker_version_ids[ - self.accepted_worker_version_ids.index(MANUAL) - ] = None - - if self.module == "kraken": - self.out_line_dir = out_dir_base - os.makedirs(self.out_line_dir, exist_ok=True) - else: - self.out_line_text_dir = os.path.join( - self.out_dir_base, "Transcriptions", self.dataset_name - ) - os.makedirs(self.out_line_text_dir, exist_ok=True) - self.out_line_img_dir = os.path.join( - self.out_dir_base, "Lines", self.dataset_name - ) - os.makedirs(self.out_line_img_dir, exist_ok=True) - - def get_image(self, image_url: str, page_id: str) -> "np.ndarray": - out_full_img_dir = os.path.join(self.out_dir_base, "full", page_id) - os.makedirs(out_full_img_dir, exist_ok=True) - out_full_img_path = os.path.join(out_full_img_dir, "full.jpg") - if self.grayscale: - download_image(image_url).convert("L").save( - out_full_img_path, format="jpeg" - ) - img = cv2.imread(out_full_img_path, cv2.IMREAD_GRAYSCALE) - else: - download_image(image_url).save(out_full_img_path, format="jpeg") - img = cv2.imread(out_full_img_path) - return img - - def get_accepted_zones(self, page_id: str): - try: - accepted_zones = [] - for elt in api_client.paginate( - "ListElementChildren", id=page_id, with_best_classes=True - ): - printed = True - for classification in elt["best_classes"]: - if classification["ml_class"]["name"] == "handwritten": - printed = False - for classification in elt["best_classes"]: - if classification["ml_class"]["name"] in self.accepted_classes: - if self.should_filter_printed: - if not printed: - accepted_zones.append(elt["zone"]["id"]) - else: - accepted_zones.append(elt["zone"]["id"]) - logger.info( - "Number of accepted zone for page {} : {}".format( - page_id, len(accepted_zones) - ) - ) - return accepted_zones - except ErrorResponse as e: - logger.info( - f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}" - ) - raise e - - def get_transcriptions(self, page_id: str, accepted_zones): - count = 0 - count_skipped = 0 - lines = [] - try: - for res in api_client.paginate( - "ListTranscriptions", id=page_id, recursive=True - ): - if ( - self.should_filter_by_slug - and res["source"]["slug"] not in self.accepted_slugs - ): - continue - if ( - self.should_filter_by_worker - and res["worker_version_id"] not in self.accepted_worker_version_ids - ): - continue - if ( - self.should_filter_by_class - and res["element"]["zone"]["id"] not in accepted_zones - ): - continue - if res["element"]["type"] != self.transcription_type: - continue - - text = res["text"] - if not text or not text.strip(): - continue - - if "zone" in res: - polygon = res["zone"]["polygon"] - elif "element" in res: - polygon = res["element"]["zone"]["polygon"] - else: - raise ValueError(f"Data problem with polygon :: {res}") - - polygon = np.asarray(polygon).clip(0) - [x, y, w, h] = cv2.boundingRect(polygon) - if self.skip_vertical_lines: - if h > w: - count_skipped += 1 - continue - lines.append(((x, y, w, h), polygon, text)) - count += 1 - return (lines, count, count_skipped) - - except ErrorResponse as e: - logger.info( - f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}" - ) - raise e - - def _save_line_image(self, page_id, i, line_img, manifest_fp=None): - if self.module == "kraken": - cv2.imwrite(f"{self.out_line_dir}/{page_id}_{i}.png", line_img) - manifest_fp.write(f"{page_id}_{i}.png\n") - else: - cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img) - - def extract_lines(self, page_id: str, image_data: dict): - if self.should_filter_by_class: - accepted_zones = self.get_accepted_zones(page_id) - else: - accepted_zones = [] - lines, count, count_skipped = self.get_transcriptions(page_id, accepted_zones) - - if count == 0: - self.skipped_pages_count += 1 - logger.info(f"Page {page_id} skipped, because it has no lines") - return - - logger.debug(f"Total num of lines {count + count_skipped}") - logger.debug(f"Num of accepted lines {count}") - logger.debug(f"Num of skipped lines {count_skipped}") - - self.skipped_vertical_lines_count += count_skipped - self.accepted_lines_count += count - - full_image_url = image_data["s3_url"] - if full_image_url is None: - full_image_url = image_data["url"] + "/full/full/0/default.jpg" - - img = self.get_image(full_image_url, page_id=page_id) - - # sort vertically then horizontally - sorted_lines = sorted(lines, key=lambda key: (key[0][1], key[0][0])) - - if self.module == "kraken": - manifest_fp = open(f"{self.out_line_dir}/manifest.txt", "a") - # append to file, not re-write it - else: - # not needed for kaldi - manifest_fp = None - - if self.extraction_mode == Extraction.boundingRect: - for i, ((x, y, w, h), polygon, text) in enumerate(sorted_lines): - cropped = img[y : y + h, x : x + w].copy() - self._save_line_image(page_id, i, cropped, manifest_fp) - - elif self.extraction_mode == Extraction.polygon: - for i, (rect, polygon, text) in enumerate(sorted_lines): - polygon_img = extract_polygon_image(img, polygon=polygon, rect=rect) - self._save_line_image(page_id, i, polygon_img, manifest_fp) - - elif self.extraction_mode == Extraction.min_area_rect: - for i, (rect, polygon, text) in enumerate(sorted_lines): - min_rect_img = extract_min_area_rect_image( - img, polygon=polygon, rect=rect - ) - - self._save_line_image(page_id, i, min_rect_img, manifest_fp) - - elif self.extraction_mode == Extraction.deskew_polygon: - for i, (rect, polygon, text) in enumerate(sorted_lines): - # get angle from min area rect - rotate_angle = determine_rotate_angle(polygon) - - if abs(rotate_angle) > self.max_deskew_angle: - logger.warning( - f"Deskew angle ({rotate_angle}) over the limit ({self.max_deskew_angle}), won't rotate" - ) - rotate_angle = 0 - - # get polygon image - polygon_img = extract_polygon_image(img, polygon=polygon, rect=rect) - - trimmed_img = self.rotate_and_trim(polygon_img, rotate_angle) - - self._save_line_image(page_id, i, trimmed_img, manifest_fp) - - elif self.extraction_mode == Extraction.deskew_min_area_rect: - for i, (rect, polygon, text) in enumerate(sorted_lines): - # get angle from min area rect - rotate_angle = determine_rotate_angle(polygon) - - if abs(rotate_angle) > self.max_deskew_angle: - logger.warning( - f"Deskew angle ({rotate_angle}) over the limit ({self.max_deskew_angle}), won't rotate" - ) - rotate_angle = 0 - - min_rect_img = extract_min_area_rect_image( - img, polygon=polygon, rect=rect - ) - - trimmed_img = self.rotate_and_trim(min_rect_img, rotate_angle) - - self._save_line_image(page_id, i, trimmed_img, manifest_fp) - else: - raise ValueError(f"Unsupported extraction mode: {self.extraction_mode}") - - if self.module == "kraken": - manifest_fp.close() - - for i, (rect, polygon, text) in enumerate(sorted_lines): - if self.module == "kraken": - write_file(f"{self.out_line_dir}/{page_id}_{i}.gt.txt", text) - else: - write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text) - - def rotate_and_trim(self, img, rotate_angle): - """ - Rotate image by given an angle and trim extra whitespace left after rotating - """ - if self.grayscale: - background = WHITE - else: - background = (WHITE, WHITE, WHITE) - - # rotate polygon image - deskewed_img = rotate(img, rotate_angle, background) - # trim extra whitespace left after rotating - trimmed_img = trim(deskewed_img, background) - trimmed_img = np.array(trimmed_img) - - return trimmed_img - - def run_pages(self, pages: list): - if all(isinstance(n, str) for n in pages): - for page in pages: - elt = api_client.request("RetrieveElement", id=page) - page_id = elt["id"] - image_data = elt["zone"]["image"] - logger.debug(f"Page {page_id}") - self.extract_lines(page_id, image_data) - else: - for page in tqdm.tqdm(pages): - page_id = page["id"] - image_data = page["zone"]["image"] - logger.debug(f"Page {page_id}") - self.extract_lines(page_id, image_data) - - def run_volumes(self, volume_ids: list): - for volume_id in tqdm.tqdm(volume_ids): - logger.info(f"Volume {volume_id}") - pages = [ - page - for page in api_client.paginate( - "ListElementChildren", id=volume_id, recursive=True, type="page" - ) - ] - self.run_pages(pages) - - def run_folders(self, element_ids: list, volume_type: str): - for elem_id in tqdm.tqdm(element_ids): - logger.info(f"Folder {elem_id}") - vol_ids = [ - page["id"] - for page in api_client.paginate( - "ListElementChildren", id=elem_id, recursive=True, type=volume_type - ) - ] - self.run_volumes(vol_ids) - - def run_corpora(self, corpus_ids: list, volume_type: str): - for corpus_id in tqdm.tqdm(corpus_ids): - logger.info(f"Corpus {corpus_id}") - vol_ids = [ - vol["id"] - for vol in api_client.paginate( - "ListElements", corpus=corpus_id, type=volume_type - ) - ] - self.run_volumes(vol_ids) - - -class Split(Enum): - Train: int = 0 - Test: int = 1 - Validation: int = 2 - - @property - def short_name(self) -> str: - if self == self.Validation: - return "val" - return self.name.lower() - - -class KaldiPartitionSplitter: - def __init__( - self, - out_dir_base="/tmp/kaldi_data", - split_train_ratio=0.8, - split_test_ratio=0.1, - use_existing_split=False, - ): - self.out_dir_base = out_dir_base - self.split_train_ratio = split_train_ratio - self.split_test_ratio = split_test_ratio - self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio - self.use_existing_split = use_existing_split - - def page_level_split(self, line_ids: list) -> dict: - # need to sort again, because `set` will lose the order - page_ids = sorted({"_".join(line_id.split("_")[:-1]) for line_id in line_ids}) - random.Random(SEED).shuffle(page_ids) - page_count = len(page_ids) - - train_page_ids = page_ids[: round(page_count * self.split_train_ratio)] - page_ids = page_ids[round(page_count * self.split_train_ratio) :] - - test_page_ids = page_ids[: round(page_count * self.split_test_ratio)] - page_ids = page_ids[round(page_count * self.split_test_ratio) :] - - val_page_ids = page_ids - - page_dict = {page_id: Split.Train.value for page_id in train_page_ids} - page_dict.update({page_id: Split.Test.value for page_id in test_page_ids}) - page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids}) - return page_dict - - def existing_split(self, line_ids: list) -> list: - split_dict = {split.short_name: [] for split in Split} - for line_id in line_ids: - split_prefix = line_id.split("/")[0].lower() - split_dict[split_prefix].append(line_id) - splits = [split_dict[split.short_name] for split in Split] - return splits - - def create_partitions(self): - logger.info("Creating partitions") - lines_path = Path(f"{self.out_dir_base}/Lines") - line_ids = [ - str(file.relative_to(lines_path).with_suffix("")) - for file in sorted(lines_path.glob("**/*.jpg")) - ] - - if self.use_existing_split: - logger.info("Using existing split") - datasets = self.existing_split(line_ids) - else: - page_dict = self.page_level_split(line_ids) - datasets = [[] for _ in range(3)] - for line_id in line_ids: - page_id = "_".join(line_id.split("_")[:-1]) - split_id = page_dict[page_id] - datasets[split_id].append(line_id) - - partitions_dir = os.path.join(self.out_dir_base, "Partitions") - os.makedirs(partitions_dir, exist_ok=True) - for i, dataset in enumerate(datasets): - if not dataset: - logger.info(f"Partition {Split(i).name} is empty! Skipping..") - continue - file_name = f"{partitions_dir}/{Split(i).name}Lines.lst" - write_file(file_name, "\n".join(dataset) + "\n") - - -def create_parser(): - parser = argparse.ArgumentParser( - description="Script to generate Kaldi or kraken training data from annotations from Arkindex", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "-f", - "--format", - type=str, - help="is the data generated going to be used for kaldi or kraken", - ) - parser.add_argument( - "-n", - "--dataset_name", - type=str, - help="Name of the dataset being created for kaldi or kraken " - "(useful for distinguishing different datasets when in Lines or Transcriptions directory)", - ) - parser.add_argument( - "-o", "--out_dir", type=str, required=True, help="output directory" - ) - parser.add_argument( - "--train_ratio", - type=float, - default=0.8, - help="Ratio of pages to be used in train (between 0 and 1)", - ) - parser.add_argument( - "--test_ratio", - type=float, - default=0.1, - help="Ratio of pages to be used in test (between 0 and 1 - train_ratio)", - ) - parser.add_argument( - "--use_existing_split", - action="store_true", - default=False, - help="Use an existing split instead of random. " - "Expecting line_ids to be prefixed with (train, val and test)", - ) - parser.add_argument( - "--split_only", - "--no_download", - action="store_true", - default=False, - help="Create the split from already downloaded lines, don't download the lines", - ) - parser.add_argument( - "--no_split", - action="store_true", - default=False, - help="No splitting of the data to be done just download the line in the right format", - ) - - parser.add_argument( - "-e", - "--extraction_mode", - type=lambda x: Extraction[x], - default=Extraction.boundingRect, - help=f"Mode for extracting the line images: {[e.name for e in Extraction]}", - ) - - parser.add_argument( - "--max_deskew_angle", - type=int, - default=45, - help="Maximum angle by which deskewing is allowed to rotate the line image. " - "If the angle determined by deskew tool is bigger than max " - "then that line won't be deskewed/rotated.", - ) - - parser.add_argument( - "--transcription_type", - type=str, - default="text_line", - help="Which type of elements' transcriptions to use? (page, paragraph, text_line, etc)", - ) - - group = parser.add_mutually_exclusive_group(required=False) - group.add_argument( - "--grayscale", action="store_true", help="Convert images to grayscale" - ) - group.add_argument("--color", action="store_false", help="Use color images") - parser.set_defaults(grayscale=True) - - parser.add_argument( - "--corpora", - nargs="*", - help="List of corpus ids to be used, separated by spaces", - ) - parser.add_argument( - "--folders", - type=str, - nargs="*", - help="List of folder ids to be used, separated by spaces. " - "Elements of `volume_type` will be searched recursively in these folders", - ) - parser.add_argument( - "--volumes", - nargs="*", - help="List of volume ids to be used, separated by spaces", - ) - parser.add_argument( - "--pages", nargs="*", help="List of page ids to be used, separated by spaces" - ) - parser.add_argument( - "-v", - "--volume_type", - type=str, - default="volume", - help="Volumes (1 level above page) may have a different name on corpora", - ) - parser.add_argument( - "--skip_vertical_lines", - action="store_true", - default=False, - help="skips vertical lines when downloading", - ) - - parser.add_argument( - "--accepted_slugs", - nargs="*", - help="List of accepted slugs for downloading transcriptions", - ) - - parser.add_argument( - "--accepted_classes", - nargs="*", - help="List of accepted ml_class names. Filter lines by class of related elements", - ) - - parser.add_argument( - "--accepted_worker_version_ids", - nargs="*", - default=[], - help="List of accepted worker version ids. Filter lines by worker version ids of related elements" - "Use `--accepted_worker_version_ids manual` to get only manual transcriptions", - ) - - parser.add_argument( - "--filter_printed", - action="store_true", - help="Filter lines annotated as printed", - ) - return parser - - -def main(): - parser = create_parser() - args = parser.parse_args() - - if not args.dataset_name and not args.split_only and not args.format == "kraken": - parser.error("--dataset_name must be specified (unless --split-only)") - - logger.info(f"ARGS {args} \n") - - if not args.split_only: - data_generator = HTRDataGenerator( - module=args.format, - dataset_name=args.dataset_name, - out_dir_base=args.out_dir, - grayscale=args.grayscale, - extraction=args.extraction_mode, - accepted_slugs=args.accepted_slugs, - accepted_classes=args.accepted_classes, - filter_printed=args.filter_printed, - skip_vertical_lines=args.skip_vertical_lines, - transcription_type=args.transcription_type, - accepted_worker_version_ids=args.accepted_worker_version_ids, - max_deskew_angle=args.max_deskew_angle, - ) - - # extract all the lines and transcriptions - if args.pages: - data_generator.run_pages(args.pages) - if args.volumes: - data_generator.run_volumes(args.volumes) - if args.folders: - data_generator.run_folders(args.folders, args.volume_type) - if args.corpora: - data_generator.run_corpora(args.corpora, args.volume_type) - if data_generator.skipped_vertical_lines_count > 0: - logger.info( - f"Number of skipped pages: {data_generator.skipped_pages_count}" - ) - skipped_ratio = data_generator.skipped_vertical_lines_count / ( - data_generator.skipped_vertical_lines_count - + data_generator.accepted_lines_count - ) - logger.info( - f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({skipped_ratio}/1.0)" - ) - else: - logger.info("Creating a split from already downloaded files") - if not args.no_split: - kaldi_partitioner = KaldiPartitionSplitter( - out_dir_base=args.out_dir, - split_train_ratio=args.train_ratio, - split_test_ratio=args.test_ratio, - use_existing_split=args.use_existing_split, - ) - - # create partitions from all the extracted data - kaldi_partitioner.create_partitions() - else: - logger.info("No split to be done") - - logger.info("DONE") - - -if __name__ == "__main__": - main() diff --git a/setup.py b/setup.py index be27dd46a0c436a98691a848a01ac1abcd399be7..0958c08116d46fb3e0de1052169fc955801a86e8 100755 --- a/setup.py +++ b/setup.py @@ -22,6 +22,6 @@ setup( author="Martin", author_email="maarand@teklia.com", install_requires=parse_requirements(), - entry_points={"console_scripts": [f"{COMMAND}={MODULE}.kaldi_data_generator:main"]}, + entry_points={"console_scripts": [f"{COMMAND}={MODULE}.main:main"]}, packages=find_packages(), )