Formatting buffer and fixing imports.

This commit is contained in:
2026-05-05 15:43:04 +01:00
parent 8d3f039bba
commit 4b6a2b2335

View File

@@ -1,18 +1,14 @@
# -*- coding: utf-8 -*-
import logging
from itertools import product
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,
cohen_kappa_score, confusion_matrix, max_error,
mean_absolute_error, multilabel_confusion_matrix,
precision_score)
from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix
from facebias.estimators import BaseEstimator, Capability
from facebias.estimators import Capability
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("facebias:metrics.py")
@@ -32,7 +28,7 @@ _agegroup_int_map = {
}
def n_off_accuracy(gt: list[int], pred: list[int], n: int=1) -> float:
def n_off_accuracy(gt: list[int], pred: list[int], n: int = 1) -> float:
"""Returns the n-off accuracy for ordinal class labels encoded as consecutive integers.
A prediction is counted as correct if it is exact or off by at most `n`.
@@ -88,7 +84,9 @@ def binary_fpr_fnr(cm: np.ndarray) -> dict[str, np.number]:
}
def multiclass_fpr_fnr(gt: pd.Series, preds: pd.Series, labels: list[Any] | None = None):
def multiclass_fpr_fnr(
gt: pd.Series, preds: pd.Series, labels: list[Any] | None = None
):
"""Calculates one-vs-rest false-positive and negative rates for each class.
Also returns the counts of false-positives, false-negatives, true-positives
@@ -169,7 +167,9 @@ def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series):
return (x_pred == y_pred).sum() / len(x_pred)
def agreement_elements(x_pred: pd.Series, y_pred: pd.Series, return_disagreement: bool=True):
def agreement_elements(
x_pred: pd.Series, y_pred: pd.Series, return_disagreement: bool = True
):
"""Returns the elements of agreement, and optionally, disagreement between models.
Note that, as in `agreement_fraction`, the predictions must have the same
@@ -223,8 +223,9 @@ def _to_age_bracket(row):
return "{}-{}".format(d, d + 9)
if __name__ == '__main__':
if __name__ == "__main__":
import os
logger.info(os.getcwd())
from facebias import load_dataset
@@ -241,7 +242,7 @@ if __name__ == '__main__':
imdict, _ = load_dataset(
DATASET_PATH, meta_path=None, imname_proc_fn=lambda x: x.split("_")[0]
)
meta = pd.read_csv(METADATA_PATH, sep=',', index_col="image")
meta = pd.read_csv(METADATA_PATH, sep=",", index_col="image")
meta[Capability.AGEGROUP.value] = meta.apply(_to_age_bracket, axis=1)
meta = meta.sort_index()
@@ -251,8 +252,13 @@ if __name__ == '__main__':
print(MiVOLOv1.capabilities())
models = {
"fairface": FairFace(Path("../../models/fairface_alldata_4race_20191111.pt"), device="cpu"),
"mivolo": MiVOLOv1(Path("../../models/volo-v1_model_imdb_age_gender_4.22.pth.tar"), device="cpu")
"fairface": FairFace(
Path("../../models/fairface_alldata_4race_20191111.pt"), device="cpu"
),
"mivolo": MiVOLOv1(
Path("../../models/volo-v1_model_imdb_age_gender_4.22.pth.tar"),
device="cpu",
),
}
preds_per_model = dict((k, None) for k in models.keys())
@@ -272,8 +278,7 @@ if __name__ == '__main__':
acc_two_off = two_off_accuracy(gt_age_group_ord, preds_age_group_ord)
agegroup_subclass, labels = multiclass_fpr_fnr(
meta["age_group"],
preds["age_group"]
meta["age_group"], preds["age_group"]
)
print("==== Age group metrics by class ====")
@@ -289,11 +294,8 @@ if __name__ == '__main__':
model_cls = type(model)
ordered_labels = model_cls.possible_capability_values(Capability.SEX)
metrics_sex = binary_fpr_fnr(
confusion_matrix(
meta["sex"],
preds["sex"],
labels=ordered_labels
))
confusion_matrix(meta["sex"], preds["sex"], labels=ordered_labels)
)
print(
"==== Sex metrics ===="
@@ -306,18 +308,18 @@ if __name__ == '__main__':
# Agreement tests
model_list = list(models.keys())
for i in range(len(model_list)):
for j in range(i+1, len(model_list)):
for j in range(i + 1, len(model_list)):
first, second = model_list[i], model_list[j]
print(f"{first} -- {second}")
for cap in model_cls.capabilities():
if cap == Capability.AGE:
continue
frac = agreement_fraction(preds_per_model[first][cap], preds_per_model[second][cap])
frac = agreement_fraction(
preds_per_model[first][cap], preds_per_model[second][cap]
)
print(f'Agreement fraction for capability: "{cap}" - {frac}')
agreement, disagreement = agreement_elements(
meta[cap],
preds[cap],
return_disagreement=True
meta[cap], preds[cap], return_disagreement=True
)
print(disagreement)