diff --git a/dan/datasets/analyze/statistics.py b/dan/datasets/analyze/statistics.py index efdc0ef6b96accccb091acc01df26593ee08d596..d6a716ed1238717fdf33d4854f30f9c68a4e3402 100644 --- a/dan/datasets/analyze/statistics.py +++ b/dan/datasets/analyze/statistics.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import logging from collections import Counter, defaultdict +from functools import partial from pathlib import Path from typing import Dict, List, Optional @@ -32,23 +33,32 @@ def create_table( operations = [] if count: - operations.append(("Count", len)) + operations.append(("Count", len, None)) operations.extend( [ - ("Min", np.min), - ("Max", np.max), - ("Mean", np.mean), - ("Median", np.median), + ("Min", np.min, None), + ("Max", np.max, None), + ("Mean", np.mean, 2), + ("Median", np.median, 2), ] ) if total: - operations.append(("Total", np.sum)) + operations.append(("Total", np.sum, None)) statistics.add_rows( [ - [col_name, *list(map(operator, data.values()))] - for col_name, operator in operations + [ + col_name, + *list( + map( + # Round values if needed + partial(round, ndigits=digits), + map(operator, data.values()), + ) + ), + ] + for col_name, operator, digits in operations ] )