#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import random
from collections import Counter, defaultdict
from enum import Enum
from itertools import groupby
from pathlib import Path
from typing import List

import numpy as np
import tqdm
from apistar.exceptions import ErrorResponse
from arkindex import options_from_env
from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import WHITE, Extraction, rotate_and_trim

import jsonargparse
from kaldi_data_generator.arguments import (
    CommonArgs,
    FilterArgs,
    ImageArgs,
    SelectArgs,
    SplitArgs,
    Style,
)
from kaldi_data_generator.image_utils import download_image, resize_transcription_data
from kaldi_data_generator.utils import (
    CachedApiClient,
    TranscriptionData,
    export_parameters,
    logger,
    write_file,
)

SEED = 42
random.seed(SEED)
MANUAL = "manual"
TEXT_LINE = "text_line"
DEFAULT_RESCALE = 1.0
STYLE_CLASSES = [el.value for el in Style]

ROTATION_CLASSES_TO_ANGLES = {
    "rotate_0": 0,
    "rotate_left_90": 90,
    "rotate_180": 180,
    "rotate_right_90": -90,
}


def create_api_client(cache_dir=None):
    logger.info("Creating API client")
    # return ArkindexClient(**options_from_env())
    return CachedApiClient(cache_root=cache_dir, **options_from_env())


class HTRDataGenerator:
    def __init__(
        self,
        format,
        dataset_name="my_dataset",
        out_dir_base="data",
        image=ImageArgs(),
        common=CommonArgs(),
        filter=FilterArgs(),
        api_client=None,
    ):

        self.format = format
        self.out_dir_base = out_dir_base
        self.dataset_name = dataset_name
        self.grayscale = image.grayscale
        self.extraction_mode = Extraction[image.extraction_mode.value]
        self.accepted_classes = filter.accepted_classes
        self.ignored_classes = filter.ignored_classes
        self.should_filter_by_class = bool(self.accepted_classes) or bool(
            self.ignored_classes
        )
        self.accepted_worker_version_ids = filter.accepted_worker_version_ids
        self.should_filter_by_worker = bool(self.accepted_worker_version_ids)
        self.style = filter.style
        self.should_filter_by_style = bool(self.style)
        self.accepted_metadatas = filter.accepted_metadatas
        self.should_filter_by_metadatas = bool(self.accepted_metadatas)
        self.transcription_type = filter.transcription_type.value
        self.skip_vertical_lines = filter.skip_vertical_lines
        self.skipped_pages_count = 0
        self.skipped_vertical_lines_count = 0
        self.accepted_lines_count = 0
        self.max_deskew_angle = image.max_deskew_angle
        self.skew_angle = image.skew_angle
        self.should_rotate = image.should_rotate
        if image.scale_x or image.scale_y_top or image.scale_y_bottom:
            self.should_resize_polygons = True
            # use 1.0 as default - no resize, if not specified
            self.scale_x = image.scale_x or DEFAULT_RESCALE
            self.scale_y_top = image.scale_y_top or DEFAULT_RESCALE
            self.scale_y_bottom = image.scale_y_bottom or DEFAULT_RESCALE
        else:
            self.should_resize_polygons = False
        self.api_client = api_client

        if MANUAL in self.accepted_worker_version_ids:
            self.accepted_worker_version_ids[
                self.accepted_worker_version_ids.index(MANUAL)
            ] = None

        if self.format == "kraken":
            self.out_line_dir = out_dir_base
            os.makedirs(self.out_line_dir, exist_ok=True)
        else:
            self.out_line_text_dir = os.path.join(
                self.out_dir_base, "Transcriptions", self.dataset_name
            )
            os.makedirs(self.out_line_text_dir, exist_ok=True)
            self.out_line_img_dir = os.path.join(
                self.out_dir_base, "Lines", self.dataset_name
            )
            os.makedirs(self.out_line_img_dir, exist_ok=True)

        self.cache_dir = Path(common.cache_dir)
        logger.info(f"Setting up cache to {self.cache_dir}")
        self.img_cache_dir = self.cache_dir / "images"
        self.img_cache_dir.mkdir(exist_ok=True, parents=True)
        if not any(self.img_cache_dir.iterdir()):
            logger.info("Cache is empty, no need to check")
            self._cache_is_empty = True
        else:
            self._cache_is_empty = False

        if self.grayscale:
            self._color = "grayscale"
        else:
            self._color = "rgb"

    def get_image(self, image_url: str, page_id: str) -> "np.ndarray":
        # id is last part before full/full/0/default.jpg
        img_id = image_url.split("/")[-5].replace("%2F", "/")

        cached_img_path = self.img_cache_dir / self._color / img_id
        if not self._cache_is_empty and cached_img_path.exists():
            logger.info(f"Cached image exists: {cached_img_path} - {page_id}")
        else:
            logger.info(f"Image not in cache: {cached_img_path} - {page_id}")
            cached_img_path.parent.mkdir(exist_ok=True, parents=True)
            pil_img = download_image(image_url)
            if self.grayscale:
                pil_img = pil_img.convert("L")
            pil_img.save(cached_img_path, format="jpeg")

        img = read_img(cached_img_path, self.grayscale)
        return img

    def metadata_filtering(self, elt):
        metadatas = {
            metadata["name"]: metadata["value"] for metadata in elt["metadata"]
        }
        for meta in self.accepted_metadatas:
            if not (
                meta in metadatas and metadatas[meta] == self.accepted_metadatas[meta]
            ):
                return False
        return True

    def get_accepted_zones(self, page_id: str):
        try:
            accepted_zones = []
            for elt in self.api_client.cached_paginate(
                "ListElementChildren",
                id=page_id,
                with_classes=self.should_filter_by_class,
                with_metadata=self.should_filter_by_metadatas,
            ):
                elem_classes = [c for c in elt["classes"] if c["state"] != "rejected"]

                should_accept = True
                if self.should_filter_by_class:
                    # at first filter to only have elements with accepted classes
                    # if accepted classes list is empty then should accept all
                    # except for ignored classes
                    should_accept = len(self.accepted_classes) == 0
                    for classification in elem_classes:
                        class_name = classification["ml_class"]["name"]
                        if class_name in self.accepted_classes:
                            should_accept = True
                            break
                        elif class_name in self.ignored_classes:
                            should_accept = False
                            break

                if not should_accept:
                    continue

                if self.should_filter_by_style:
                    style_counts = Counter()
                    for classification in elem_classes:
                        class_name = classification["ml_class"]["name"]
                        if class_name in STYLE_CLASSES:
                            style_counts[class_name] += 1

                    if len(style_counts) == 0:
                        # no handwritten or typewritten found, so other
                        found_class = Style.other
                    elif len(style_counts) == 1:
                        found_class = list(style_counts.keys())[0]
                        found_class = Style(found_class)
                    else:
                        raise ValueError(
                            f"Multiple style classes on the same element! {elt['id']} - {elem_classes}"
                        )
                        should_accept = found_class == self.style

                if not should_accept:
                    continue

                if self.should_filter_by_metadatas:
                    if self.metadata_filtering(elt):

                        accepted_zones.append(elt["zone"]["id"])
                else:
                    accepted_zones.append(elt["zone"]["id"])
            logger.info(
                "Number of accepted zone for page {} : {}".format(
                    page_id, len(accepted_zones)
                )
            )
            return accepted_zones
        except ErrorResponse as e:
            logger.info(
                f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}"
            )
            raise e

    def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]):
        if not lines:
            return

        line_elem_counter = Counter([trans.element_id for trans in lines])
        most_common = line_elem_counter.most_common(10)
        if most_common[0][-1] > 1:
            logger.error("Line elements have multiple transcriptions! Showing top 10:")
            logger.error(f"{most_common}")
            raise ValueError(f"Multiple transcriptions: {most_common[0]}")

        worker_version_counter = Counter([trans.worker_version_id for trans in lines])
        if len(worker_version_counter) > 1:
            logger.warning(
                f"There are transcriptions from multiple worker versions on this page: {page_id}:"
            )
            logger.warning(
                f"Top 10 worker versions: {worker_version_counter.most_common(10)}"
            )

    def _choose_best_transcriptions(
        self, lines: List[TranscriptionData]
    ) -> List[TranscriptionData]:
        """
        Get the best transcription based on the order of accepted worker version ids.
        :param lines:
        :return:
        """
        if not lines:
            return []

        trans_by_element = defaultdict(list)
        for line in lines:
            trans_by_element[line.element_id].append(line)

        best_transcriptions = []
        for elem, trans_list in trans_by_element.items():
            tmp_dict = {t.worker_version_id: t for t in trans_list}

            for wv in self.accepted_worker_version_ids:
                if wv in tmp_dict:
                    best_transcriptions.append(tmp_dict[wv])
                    break
            else:
                logger.info(f"No suitable trans found for {elem}")
        return best_transcriptions

    def get_transcriptions(self, page_id: str, accepted_zones):
        lines = []
        try:
            for res in self.api_client.cached_paginate(
                "ListTranscriptions", id=page_id, recursive=True
            ):
                if (
                    self.should_filter_by_worker
                    and res["worker_version_id"] not in self.accepted_worker_version_ids
                ):
                    continue
                if (
                    self.should_filter_by_class
                    or self.should_filter_by_style
                    or self.should_filter_by_metadatas
                ) and (res["element"]["zone"]["id"] not in accepted_zones):
                    continue
                if res["element"]["type"] != self.transcription_type:
                    continue

                text = res["text"]
                if not text or not text.strip():
                    continue

                if "\n" in text.strip() and not self.transcription_type == "text":
                    elem_id = res["element"]["id"]
                    raise ValueError(
                        f"Newlines are not allowed in line transcriptions - {page_id} - {elem_id} - {text}"
                    )

                if "zone" in res:
                    polygon = res["zone"]["polygon"]
                elif "element" in res:
                    polygon = res["element"]["zone"]["polygon"]
                else:
                    raise ValueError(f"Data problem with polygon :: {res}")

                trans_data = TranscriptionData(
                    element_id=res["element"]["id"],
                    element_name=res["element"]["name"],
                    polygon=polygon,
                    text=text,
                    trans_id=res["id"],
                    worker_version_id=res["worker_version_id"],
                )

                lines.append(trans_data)

            if self.accepted_worker_version_ids:
                # if accepted worker versions have been defined then use them
                lines = self._choose_best_transcriptions(lines)
            else:
                # if no accepted worker versions have been defined
                # then check that there aren't multiple transcriptions
                # on the same text line
                self._validate_transcriptions(page_id, lines)

            if self.should_rotate:
                classes_by_elem = self.get_children_classes(page_id)

                for trans in lines:
                    rotation_classes = [
                        c
                        for c in classes_by_elem[trans.element_id]
                        if c in ROTATION_CLASSES_TO_ANGLES
                    ]
                    if len(rotation_classes) > 0:
                        if len(rotation_classes) > 1:
                            logger.warning(
                                f"Several rotation classes = {len(rotation_classes)} - {trans.element_id}"
                            )
                        trans.rotation_class = rotation_classes[0]
                    else:
                        logger.warning(f"No rotation classes on {trans.element_id}")

            count_skipped = 0
            if self.skip_vertical_lines:
                filtered_lines = []
                for line in lines:
                    if line.is_vertical:
                        count_skipped += 1
                        continue
                    filtered_lines.append(line)

                lines = filtered_lines

            count = len(lines)

            return lines, count, count_skipped

        except ErrorResponse as e:
            logger.info(
                f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}"
            )
            raise e

    def get_children_classes(self, page_id):
        return {
            elem["id"]: [
                best_class["ml_class"]["name"]
                for best_class in elem["classes"]
                if best_class["state"] != "rejected"
            ]
            for elem in self.api_client.cached_paginate(
                "ListElementChildren",
                id=page_id,
                recursive=True,
                type=TEXT_LINE,
                with_classes=True,
            )
        }

    def _save_line_image(
        self, page_id, line_img, manifest_fp=None, trans: TranscriptionData = None
    ):
        # Get line id
        line_id = trans.element_id

        # Get line number from its name
        line_number = trans.element_name.split("_")[-1]

        if self.should_rotate:
            if trans.rotation_class:
                rotate_angle = ROTATION_CLASSES_TO_ANGLES[trans.rotation_class]
                line_img = rotate_and_trim(line_img, rotate_angle, WHITE)
        if self.format == "kraken":
            # Save image using the template {page_id}_{line_number}_{line_id}
            # TODO: check if (0>3) is enough (pad line_number to 3 digits)
            save_img(
                f"{self.out_line_dir}/{page_id}_{line_number:0>3}_{line_id}.png",
                line_img,
            )
            manifest_fp.write(f"{page_id}_{line_number:0>3}_{line_id}.png\n")
        else:
            save_img(
                f"{self.out_line_img_dir}/{page_id}_{line_number:0>3}_{line_id}.jpg",
                line_img,
            )

    def extract_lines(self, page_id: str, image_data: dict):
        if (
            self.should_filter_by_class
            or self.should_filter_by_style
            or self.should_filter_by_metadatas
        ):
            accepted_zones = self.get_accepted_zones(page_id)
        else:
            accepted_zones = []
        lines, count, count_skipped = self.get_transcriptions(page_id, accepted_zones)

        if count == 0:
            self.skipped_pages_count += 1
            logger.info(f"Page {page_id} skipped, because it has no lines")
            return

        logger.debug(f"Total num of lines {count + count_skipped}")
        logger.debug(f"Num of accepted lines {count}")
        logger.debug(f"Num of skipped lines {count_skipped}")

        self.skipped_vertical_lines_count += count_skipped
        self.accepted_lines_count += count

        full_image_url = image_data["s3_url"]
        if full_image_url is None:
            full_image_url = image_data["url"] + "/full/full/0/default.jpg"

        img = self.get_image(full_image_url, page_id=page_id)

        # sort vertically then horizontally
        sorted_lines = sorted(lines, key=lambda key: (key.rect.y, key.rect.x))

        if self.should_resize_polygons:
            sorted_lines = [
                resize_transcription_data(
                    line,
                    image_data["width"],
                    image_data["height"],
                    self.scale_x,
                    self.scale_y_top,
                    self.scale_y_bottom,
                )
                for line in sorted_lines
            ]

        if self.format == "kraken":
            manifest_fp = open(f"{self.out_line_dir}/manifest.txt", "a")
            # append to file, not re-write it
        else:
            # not needed for kaldi
            manifest_fp = None

        for trans in sorted_lines:
            extracted_img = extract(
                img=img,
                polygon=trans.polygon,
                bbox=trans.rect,
                extraction_mode=self.extraction_mode,
                max_deskew_angle=self.max_deskew_angle,
                skew_angle=self.skew_angle,
                grayscale=self.grayscale,
            )

            # don't enumerate, read the line number from the elements's name (e.g. line_xx) so that it matches with Arkindex
            self._save_line_image(page_id, extracted_img, manifest_fp, trans)

        if self.format == "kraken":
            manifest_fp.close()

        for trans in sorted_lines:
            line_number = trans.element_name.split("_")[-1]
            line_id = trans.element_id
            if self.format == "kraken":
                write_file(
                    f"{self.out_line_dir}/{page_id}_{line_number:0>3}_{line_id}.gt.txt",
                    trans.text,
                )
            else:
                write_file(
                    f"{self.out_line_text_dir}/{page_id}_{line_number:0>3}_{line_id}.txt",
                    trans.text,
                )

    def run_selection(self, select):
        """
        Update select to keep track of selected ids.
        """
        selected_elems = [e for e in self.api_client.paginate("ListSelection")]
        for elem_type, elems_of_type in groupby(
            selected_elems, key=lambda x: x["type"]
        ):
            elements_ids = [el["id"] for el in elems_of_type]
            if elem_type == "page":
                select.pages += elements_ids
            elif elem_type == "volume":
                select.volumes += elements_ids
            elif elem_type == "folder":
                select.folders += elements_ids
            else:
                raise ValueError(f"Unsupported element type {elem_type} in selection!")
        return select

    def run_pages(self, pages: list):
        if all(isinstance(n, str) for n in pages):
            for page in pages:
                elt = self.api_client.request("RetrieveElement", id=page)
                page_id = elt["id"]
                image_data = elt["zone"]["image"]
                logger.debug(f"Page {page_id}")
                self.extract_lines(page_id, image_data)
        else:
            for page in tqdm.tqdm(pages):
                page_id = page["id"]
                image_data = page["zone"]["image"]
                logger.debug(f"Page {page_id}")
                self.extract_lines(page_id, image_data)

    def run_volumes(self, volume_ids: list):
        for volume_id in tqdm.tqdm(volume_ids):
            logger.info(f"Volume {volume_id}")
            pages = [
                page
                for page in self.api_client.cached_paginate(
                    "ListElementChildren", id=volume_id, recursive=True, type="page"
                )
            ]
            self.run_pages(pages)

    def run_folders(self, element_ids: list, volume_type: str):
        for elem_id in tqdm.tqdm(element_ids):
            logger.info(f"Folder {elem_id}")
            vol_ids = [
                page["id"]
                for page in self.api_client.cached_paginate(
                    "ListElementChildren", id=elem_id, recursive=True, type=volume_type
                )
            ]
            self.run_volumes(vol_ids)

    def run_corpora(self, corpus_ids: list, volume_type: str):
        for corpus_id in tqdm.tqdm(corpus_ids):
            logger.info(f"Corpus {corpus_id}")
            vol_ids = [
                vol["id"]
                for vol in self.api_client.cached_paginate(
                    "ListElements", corpus=corpus_id, type=volume_type
                )
            ]
            self.run_volumes(vol_ids)


class Split(Enum):
    Train: str = "train"
    Test: str = "test"
    Validation: str = "val"


class KaldiPartitionSplitter:
    def __init__(
        self,
        out_dir_base="/tmp/kaldi_data",
        split_train_ratio=0.8,
        split_val_ratio=0.1,
        split_test_ratio=0.1,
        use_existing_split=False,
    ):
        self.out_dir_base = out_dir_base
        self.split_train_ratio = split_train_ratio
        self.split_test_ratio = split_test_ratio
        self.split_val_ratio = split_val_ratio
        self.use_existing_split = use_existing_split

    def page_level_split(self, line_ids: list) -> dict:
        """
        Split pages into train, validation and test subsets.
        Don't split lines to avoid data leakage.
            line_ids (list): a list of line ids named {page_id}_{line_number}_{line_id}
        """
        # Get page ids from line ids to create splits at page level
        page_ids = ["_".join(line_id.split("_")[:-2]) for line_id in line_ids]
        # Remove duplicates and sort for reproducibility
        page_ids = sorted(set(page_ids))
        random.Random(SEED).shuffle(page_ids)
        page_count = len(page_ids)

        # Use np.split to split in three sets
        stop_train_idx = round(page_count * self.split_train_ratio)
        stop_val_idx = stop_train_idx + round(page_count * self.split_val_ratio)
        train_page_ids, val_page_ids, test_page_ids = np.split(
            page_ids, [stop_train_idx, stop_val_idx]
        )

        # Build dictionary that will be used to split lines {id: split}
        page_dict = {page_id: Split.Train.value for page_id in train_page_ids}
        page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids})
        page_dict.update({page_id: Split.Test.value for page_id in test_page_ids})
        return page_dict

    def existing_split(self, line_ids: list) -> list:
        """
        Expect line_ids to be named {split}/{path_to_image} where split in ["train", "val", "test"]
        """
        split_dict = {split: [] for split in Split}
        for line_id in line_ids:
            split_prefix = line_id.split("/")[0].lower()
            split_dict[split_prefix].append(line_id)
        return split_dict

    def create_partitions(self):
        """ """
        logger.info(f"Creating {[split.value for split in Split]} partitions")
        # Get all images ids (and remove extension)
        lines_path = Path(f"{self.out_dir_base}/Lines")
        line_ids = [
            str(file.relative_to(lines_path).with_suffix(""))
            for file in sorted(lines_path.glob("**/*.jpg"))
        ]

        if self.use_existing_split:
            logger.info("Using existing split")
            datasets = self.existing_split(line_ids)
        else:
            page_dict = self.page_level_split(line_ids)
            # extend this split for lines
            datasets = {s.value: [] for s in Split}
            for line_id in line_ids:
                page_id = "_".join(line_id.split("_")[:-2])
                split_id = page_dict[page_id]
                datasets[split_id].append(line_id)

        partitions_dir = os.path.join(self.out_dir_base, "Partitions")
        os.makedirs(partitions_dir, exist_ok=True)
        for split, split_line_ids in datasets.items():
            if not split_line_ids:
                logger.info(f"Partition {split} is empty! Skipping...")
                continue
            file_name = f"{partitions_dir}/{Split(split).name}Lines.lst"
            write_file(file_name, "\n".join(split_line_ids) + "\n")
        return datasets


def run(format, dataset_name, out_dir, common, image, split, select, filter):

    api_client = create_api_client(Path(common.cache_dir))

    if not split.split_only:
        data_generator = HTRDataGenerator(
            format=format,
            dataset_name=dataset_name,
            out_dir_base=out_dir,
            common=common,
            image=image,
            filter=filter,
            api_client=api_client,
        )

        # extract all the lines and transcriptions
        if select.selection:
            select = data_generator.run_selection(select)
        if select.pages:
            data_generator.run_pages(select.pages)
        if select.volumes:
            data_generator.run_volumes(select.volumes)
        if select.folders:
            data_generator.run_folders(select.folders, select.volume_type)
        if select.corpora:
            data_generator.run_corpora(select.corpora, select.volume_type)
        if data_generator.skipped_vertical_lines_count > 0:
            logger.info(
                f"Number of skipped pages: {data_generator.skipped_pages_count}"
            )
            _skipped_vertical_count = data_generator.skipped_vertical_lines_count
            _total_count = _skipped_vertical_count + data_generator.accepted_lines_count
            skipped_ratio = _skipped_vertical_count / _total_count * 100

            logger.info(
                f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)"
            )
    else:
        logger.info("Creating a split from already downloaded files")
        data_generator = None

    if not split.no_split:
        kaldi_partitioner = KaldiPartitionSplitter(
            out_dir_base=out_dir,
            split_train_ratio=split.train_ratio,
            split_val_ratio=split.val_ratio,
            split_test_ratio=split.test_ratio,
            use_existing_split=split.use_existing_split,
        )

        # create partitions from all the extracted data
        datasets = kaldi_partitioner.create_partitions()
    else:
        logger.info("No split to be done")
        datasets = {}

    logger.info("DONE")

    export_parameters(
        format,
        dataset_name,
        out_dir,
        common,
        image,
        split,
        select,
        filter,
        datasets,
        arkindex_api_url=options_from_env()["base_url"],
    )

    logger.warning(
        f"Consider cleaning your cache directory {common.cache_dir} if you are done."
    )


def get_args():
    parser = jsonargparse.ArgumentParser(
        description="Script to generate Kaldi or kraken training data from annotations from Arkindex",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parse_as_dict=True,
    )
    parser.add_argument(
        "--config", action=jsonargparse.ActionConfigFile, help="Configuration file"
    )
    parser.add_argument(
        "-f",
        "--format",
        type=str,
        required=True,
        help="is the data generated going to be used for kaldi or kraken",
    )
    parser.add_argument(
        "-d",
        "--dataset_name",
        type=str,
        required=True,
        help="Name of the dataset being created for kaldi or kraken "
        "(useful for distinguishing different datasets when in Lines or Transcriptions directory)",
    )
    parser.add_argument(
        "-o", "--out_dir", type=str, required=True, help="output directory"
    )

    parser.add_class_arguments(CommonArgs, "common")
    parser.add_class_arguments(ImageArgs, "image")
    parser.add_class_arguments(SplitArgs, "split")
    parser.add_class_arguments(FilterArgs, "filter")
    parser.add_class_arguments(SelectArgs, "select")

    args = parser.parse_args(with_meta=False)
    args["common"] = CommonArgs(**args["common"])
    args["image"] = ImageArgs(**args["image"])
    args["split"] = SplitArgs(**args["split"])
    args["select"] = SelectArgs(**args["select"])
    args["filter"] = FilterArgs(**args["filter"])

    # Check overlap of accepted and ignored classes
    accepted_classes = args["filter"].accepted_classes
    ignored_classes = args["filter"].accepted_classes
    if accepted_classes and ignored_classes:
        if set(accepted_classes) & set(ignored_classes):
            parser.error(
                f"--filter.accepted_classes and --filter.ignored_classes values must not overlap ({accepted_classes} - {ignored_classes})"
            )

    if args["filter"].style and (accepted_classes or ignored_classes):
        if set(STYLE_CLASSES) & (set(accepted_classes) | set(ignored_classes)):
            parser.error(
                f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list "
                f"(or ignored_classes list) "
                "if --filter.style is used with either --filter.accepted_classes or --filter.ignored_classes."
            )

    del args["config"]
    return args


def main():
    args = get_args()
    logger.info(f"Arguments: {args} \n")
    run(**args)


if __name__ == "__main__":
    main()