diff --git a/src/facebias/metrics.py b/src/facebias/metrics.py index bf1e68f..fa4c8cd 100644 --- a/src/facebias/metrics.py +++ b/src/facebias/metrics.py @@ -1,18 +1,14 @@ # -*- coding: utf-8 -*- import logging -from itertools import product from pathlib import Path from typing import Any import numpy as np import pandas as pd -from sklearn.metrics import (accuracy_score, balanced_accuracy_score, - cohen_kappa_score, confusion_matrix, max_error, - mean_absolute_error, multilabel_confusion_matrix, - precision_score) +from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix -from facebias.estimators import BaseEstimator, Capability +from facebias.estimators import Capability logging.basicConfig(level=logging.INFO) logger = logging.getLogger("facebias:metrics.py") @@ -32,7 +28,7 @@ _agegroup_int_map = { } -def n_off_accuracy(gt: list[int], pred: list[int], n: int=1) -> float: +def n_off_accuracy(gt: list[int], pred: list[int], n: int = 1) -> float: """Returns the n-off accuracy for ordinal class labels encoded as consecutive integers. A prediction is counted as correct if it is exact or off by at most `n`. @@ -88,7 +84,9 @@ def binary_fpr_fnr(cm: np.ndarray) -> dict[str, np.number]: } -def multiclass_fpr_fnr(gt: pd.Series, preds: pd.Series, labels: list[Any] | None = None): +def multiclass_fpr_fnr( + gt: pd.Series, preds: pd.Series, labels: list[Any] | None = None +): """Calculates one-vs-rest false-positive and negative rates for each class. Also returns the counts of false-positives, false-negatives, true-positives @@ -169,7 +167,9 @@ def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series): return (x_pred == y_pred).sum() / len(x_pred) -def agreement_elements(x_pred: pd.Series, y_pred: pd.Series, return_disagreement: bool=True): +def agreement_elements( + x_pred: pd.Series, y_pred: pd.Series, return_disagreement: bool = True +): """Returns the elements of agreement, and optionally, disagreement between models. Note that, as in `agreement_fraction`, the predictions must have the same @@ -223,8 +223,9 @@ def _to_age_bracket(row): return "{}-{}".format(d, d + 9) -if __name__ == '__main__': +if __name__ == "__main__": import os + logger.info(os.getcwd()) from facebias import load_dataset @@ -241,7 +242,7 @@ if __name__ == '__main__': imdict, _ = load_dataset( DATASET_PATH, meta_path=None, imname_proc_fn=lambda x: x.split("_")[0] ) - meta = pd.read_csv(METADATA_PATH, sep=',', index_col="image") + meta = pd.read_csv(METADATA_PATH, sep=",", index_col="image") meta[Capability.AGEGROUP.value] = meta.apply(_to_age_bracket, axis=1) meta = meta.sort_index() @@ -251,8 +252,13 @@ if __name__ == '__main__': print(MiVOLOv1.capabilities()) models = { - "fairface": FairFace(Path("../../models/fairface_alldata_4race_20191111.pt"), device="cpu"), - "mivolo": MiVOLOv1(Path("../../models/volo-v1_model_imdb_age_gender_4.22.pth.tar"), device="cpu") + "fairface": FairFace( + Path("../../models/fairface_alldata_4race_20191111.pt"), device="cpu" + ), + "mivolo": MiVOLOv1( + Path("../../models/volo-v1_model_imdb_age_gender_4.22.pth.tar"), + device="cpu", + ), } preds_per_model = dict((k, None) for k in models.keys()) @@ -272,8 +278,7 @@ if __name__ == '__main__': acc_two_off = two_off_accuracy(gt_age_group_ord, preds_age_group_ord) agegroup_subclass, labels = multiclass_fpr_fnr( - meta["age_group"], - preds["age_group"] + meta["age_group"], preds["age_group"] ) print("==== Age group metrics by class ====") @@ -289,11 +294,8 @@ if __name__ == '__main__': model_cls = type(model) ordered_labels = model_cls.possible_capability_values(Capability.SEX) metrics_sex = binary_fpr_fnr( - confusion_matrix( - meta["sex"], - preds["sex"], - labels=ordered_labels - )) + confusion_matrix(meta["sex"], preds["sex"], labels=ordered_labels) + ) print( "==== Sex metrics ====" @@ -306,18 +308,18 @@ if __name__ == '__main__': # Agreement tests model_list = list(models.keys()) for i in range(len(model_list)): - for j in range(i+1, len(model_list)): + for j in range(i + 1, len(model_list)): first, second = model_list[i], model_list[j] print(f"{first} -- {second}") for cap in model_cls.capabilities(): if cap == Capability.AGE: continue - frac = agreement_fraction(preds_per_model[first][cap], preds_per_model[second][cap]) + frac = agreement_fraction( + preds_per_model[first][cap], preds_per_model[second][cap] + ) print(f'Agreement fraction for capability: "{cap}" - {frac}') agreement, disagreement = agreement_elements( - meta[cap], - preds[cap], - return_disagreement=True + meta[cap], preds[cap], return_disagreement=True ) print(disagreement)