From e2790d2d5dea8dc980958728efa1ec6f2b5b1e90 Mon Sep 17 00:00:00 2001 From: Guilherme Schardong Date: Wed, 6 May 2026 16:26:52 +0100 Subject: [PATCH] Moving the `gini` function to `metrics.py`. --- src/facebias/metrics.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/facebias/metrics.py b/src/facebias/metrics.py index fa4c8cd..427a4cb 100644 --- a/src/facebias/metrics.py +++ b/src/facebias/metrics.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- +"""Model evaluation metrics.""" + import logging +from itertools import combinations from pathlib import Path from typing import Any @@ -86,7 +89,7 @@ def binary_fpr_fnr(cm: np.ndarray) -> dict[str, np.number]: def multiclass_fpr_fnr( gt: pd.Series, preds: pd.Series, labels: list[Any] | None = None -): +) -> tuple[dict, list[str]]: """Calculates one-vs-rest false-positive and negative rates for each class. Also returns the counts of false-positives, false-negatives, true-positives @@ -128,7 +131,7 @@ def multiclass_fpr_fnr( return results, labels -def _agreement_sanity_checks(x_pred: pd.Series, y_pred: pd.Series): +def _agreement_sanity_checks(x_pred: pd.Series, y_pred: pd.Series) -> bool: if len(x_pred) != len(y_pred): raise ValueError( f"Predictions have different lengths. len(x_pred) = {len(x_pred)}" @@ -140,7 +143,7 @@ def _agreement_sanity_checks(x_pred: pd.Series, y_pred: pd.Series): return True -def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series): +def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series) -> float: """Calculates the fraction of agreement between predictions by two models. Note that the predictions must both have the same indices and lengths. @@ -169,7 +172,7 @@ def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series): def agreement_elements( x_pred: pd.Series, y_pred: pd.Series, return_disagreement: bool = True -): +) -> tuple[pd.Series, pd.Series | None]: """Returns the elements of agreement, and optionally, disagreement between models. Note that, as in `agreement_fraction`, the predictions must have the same @@ -209,6 +212,13 @@ def agreement_elements( return idx[x_pred == y_pred], None +def gini(x: list[float]) -> float: + x = np.array(x, dtype=np.float32) + n = len(x) + diffs = sum(abs(i - j) for i, j in combinations(x, r=2)) + return (diffs / (n**2 * x.mean())).item() + + # TODO(gschardong): Move to the same file as `load_dataset` def _to_age_bracket(row): iage = int(row["age"])