Moving the gini function to metrics.py.

This commit is contained in:
2026-05-06 16:26:52 +01:00
parent 4b6a2b2335
commit e2790d2d5d

View File

@@ -1,6 +1,9 @@
# -*- coding: utf-8 -*-
"""Model evaluation metrics."""
import logging
from itertools import combinations
from pathlib import Path
from typing import Any
@@ -86,7 +89,7 @@ def binary_fpr_fnr(cm: np.ndarray) -> dict[str, np.number]:
def multiclass_fpr_fnr(
gt: pd.Series, preds: pd.Series, labels: list[Any] | None = None
):
) -> tuple[dict, list[str]]:
"""Calculates one-vs-rest false-positive and negative rates for each class.
Also returns the counts of false-positives, false-negatives, true-positives
@@ -128,7 +131,7 @@ def multiclass_fpr_fnr(
return results, labels
def _agreement_sanity_checks(x_pred: pd.Series, y_pred: pd.Series):
def _agreement_sanity_checks(x_pred: pd.Series, y_pred: pd.Series) -> bool:
if len(x_pred) != len(y_pred):
raise ValueError(
f"Predictions have different lengths. len(x_pred) = {len(x_pred)}"
@@ -140,7 +143,7 @@ def _agreement_sanity_checks(x_pred: pd.Series, y_pred: pd.Series):
return True
def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series):
def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series) -> float:
"""Calculates the fraction of agreement between predictions by two models.
Note that the predictions must both have the same indices and lengths.
@@ -169,7 +172,7 @@ def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series):
def agreement_elements(
x_pred: pd.Series, y_pred: pd.Series, return_disagreement: bool = True
):
) -> tuple[pd.Series, pd.Series | None]:
"""Returns the elements of agreement, and optionally, disagreement between models.
Note that, as in `agreement_fraction`, the predictions must have the same
@@ -209,6 +212,13 @@ def agreement_elements(
return idx[x_pred == y_pred], None
def gini(x: list[float]) -> float:
x = np.array(x, dtype=np.float32)
n = len(x)
diffs = sum(abs(i - j) for i, j in combinations(x, r=2))
return (diffs / (n**2 * x.mean())).item()
# TODO(gschardong): Move to the same file as `load_dataset`
def _to_age_bracket(row):
iage = int(row["age"])