Formatting buffer and fixing imports.
This commit is contained in:
@@ -1,18 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,
|
||||
cohen_kappa_score, confusion_matrix, max_error,
|
||||
mean_absolute_error, multilabel_confusion_matrix,
|
||||
precision_score)
|
||||
from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix
|
||||
|
||||
from facebias.estimators import BaseEstimator, Capability
|
||||
from facebias.estimators import Capability
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("facebias:metrics.py")
|
||||
@@ -32,7 +28,7 @@ _agegroup_int_map = {
|
||||
}
|
||||
|
||||
|
||||
def n_off_accuracy(gt: list[int], pred: list[int], n: int=1) -> float:
|
||||
def n_off_accuracy(gt: list[int], pred: list[int], n: int = 1) -> float:
|
||||
"""Returns the n-off accuracy for ordinal class labels encoded as consecutive integers.
|
||||
|
||||
A prediction is counted as correct if it is exact or off by at most `n`.
|
||||
@@ -88,7 +84,9 @@ def binary_fpr_fnr(cm: np.ndarray) -> dict[str, np.number]:
|
||||
}
|
||||
|
||||
|
||||
def multiclass_fpr_fnr(gt: pd.Series, preds: pd.Series, labels: list[Any] | None = None):
|
||||
def multiclass_fpr_fnr(
|
||||
gt: pd.Series, preds: pd.Series, labels: list[Any] | None = None
|
||||
):
|
||||
"""Calculates one-vs-rest false-positive and negative rates for each class.
|
||||
|
||||
Also returns the counts of false-positives, false-negatives, true-positives
|
||||
@@ -169,7 +167,9 @@ def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series):
|
||||
return (x_pred == y_pred).sum() / len(x_pred)
|
||||
|
||||
|
||||
def agreement_elements(x_pred: pd.Series, y_pred: pd.Series, return_disagreement: bool=True):
|
||||
def agreement_elements(
|
||||
x_pred: pd.Series, y_pred: pd.Series, return_disagreement: bool = True
|
||||
):
|
||||
"""Returns the elements of agreement, and optionally, disagreement between models.
|
||||
|
||||
Note that, as in `agreement_fraction`, the predictions must have the same
|
||||
@@ -223,8 +223,9 @@ def _to_age_bracket(row):
|
||||
return "{}-{}".format(d, d + 9)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
logger.info(os.getcwd())
|
||||
|
||||
from facebias import load_dataset
|
||||
@@ -241,7 +242,7 @@ if __name__ == '__main__':
|
||||
imdict, _ = load_dataset(
|
||||
DATASET_PATH, meta_path=None, imname_proc_fn=lambda x: x.split("_")[0]
|
||||
)
|
||||
meta = pd.read_csv(METADATA_PATH, sep=',', index_col="image")
|
||||
meta = pd.read_csv(METADATA_PATH, sep=",", index_col="image")
|
||||
meta[Capability.AGEGROUP.value] = meta.apply(_to_age_bracket, axis=1)
|
||||
meta = meta.sort_index()
|
||||
|
||||
@@ -251,8 +252,13 @@ if __name__ == '__main__':
|
||||
print(MiVOLOv1.capabilities())
|
||||
|
||||
models = {
|
||||
"fairface": FairFace(Path("../../models/fairface_alldata_4race_20191111.pt"), device="cpu"),
|
||||
"mivolo": MiVOLOv1(Path("../../models/volo-v1_model_imdb_age_gender_4.22.pth.tar"), device="cpu")
|
||||
"fairface": FairFace(
|
||||
Path("../../models/fairface_alldata_4race_20191111.pt"), device="cpu"
|
||||
),
|
||||
"mivolo": MiVOLOv1(
|
||||
Path("../../models/volo-v1_model_imdb_age_gender_4.22.pth.tar"),
|
||||
device="cpu",
|
||||
),
|
||||
}
|
||||
|
||||
preds_per_model = dict((k, None) for k in models.keys())
|
||||
@@ -272,8 +278,7 @@ if __name__ == '__main__':
|
||||
acc_two_off = two_off_accuracy(gt_age_group_ord, preds_age_group_ord)
|
||||
|
||||
agegroup_subclass, labels = multiclass_fpr_fnr(
|
||||
meta["age_group"],
|
||||
preds["age_group"]
|
||||
meta["age_group"], preds["age_group"]
|
||||
)
|
||||
|
||||
print("==== Age group metrics by class ====")
|
||||
@@ -289,11 +294,8 @@ if __name__ == '__main__':
|
||||
model_cls = type(model)
|
||||
ordered_labels = model_cls.possible_capability_values(Capability.SEX)
|
||||
metrics_sex = binary_fpr_fnr(
|
||||
confusion_matrix(
|
||||
meta["sex"],
|
||||
preds["sex"],
|
||||
labels=ordered_labels
|
||||
))
|
||||
confusion_matrix(meta["sex"], preds["sex"], labels=ordered_labels)
|
||||
)
|
||||
|
||||
print(
|
||||
"==== Sex metrics ===="
|
||||
@@ -306,18 +308,18 @@ if __name__ == '__main__':
|
||||
# Agreement tests
|
||||
model_list = list(models.keys())
|
||||
for i in range(len(model_list)):
|
||||
for j in range(i+1, len(model_list)):
|
||||
for j in range(i + 1, len(model_list)):
|
||||
first, second = model_list[i], model_list[j]
|
||||
print(f"{first} -- {second}")
|
||||
|
||||
for cap in model_cls.capabilities():
|
||||
if cap == Capability.AGE:
|
||||
continue
|
||||
frac = agreement_fraction(preds_per_model[first][cap], preds_per_model[second][cap])
|
||||
frac = agreement_fraction(
|
||||
preds_per_model[first][cap], preds_per_model[second][cap]
|
||||
)
|
||||
print(f'Agreement fraction for capability: "{cap}" - {frac}')
|
||||
agreement, disagreement = agreement_elements(
|
||||
meta[cap],
|
||||
preds[cap],
|
||||
return_disagreement=True
|
||||
meta[cap], preds[cap], return_disagreement=True
|
||||
)
|
||||
print(disagreement)
|
||||
|
||||
Reference in New Issue
Block a user