Moving the gini function to metrics.py.
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Model evaluation metrics."""
|
||||
|
||||
import logging
|
||||
from itertools import combinations
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -86,7 +89,7 @@ def binary_fpr_fnr(cm: np.ndarray) -> dict[str, np.number]:
|
||||
|
||||
def multiclass_fpr_fnr(
|
||||
gt: pd.Series, preds: pd.Series, labels: list[Any] | None = None
|
||||
):
|
||||
) -> tuple[dict, list[str]]:
|
||||
"""Calculates one-vs-rest false-positive and negative rates for each class.
|
||||
|
||||
Also returns the counts of false-positives, false-negatives, true-positives
|
||||
@@ -128,7 +131,7 @@ def multiclass_fpr_fnr(
|
||||
return results, labels
|
||||
|
||||
|
||||
def _agreement_sanity_checks(x_pred: pd.Series, y_pred: pd.Series):
|
||||
def _agreement_sanity_checks(x_pred: pd.Series, y_pred: pd.Series) -> bool:
|
||||
if len(x_pred) != len(y_pred):
|
||||
raise ValueError(
|
||||
f"Predictions have different lengths. len(x_pred) = {len(x_pred)}"
|
||||
@@ -140,7 +143,7 @@ def _agreement_sanity_checks(x_pred: pd.Series, y_pred: pd.Series):
|
||||
return True
|
||||
|
||||
|
||||
def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series):
|
||||
def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series) -> float:
|
||||
"""Calculates the fraction of agreement between predictions by two models.
|
||||
|
||||
Note that the predictions must both have the same indices and lengths.
|
||||
@@ -169,7 +172,7 @@ def agreement_fraction(x_pred: pd.Series, y_pred: pd.Series):
|
||||
|
||||
def agreement_elements(
|
||||
x_pred: pd.Series, y_pred: pd.Series, return_disagreement: bool = True
|
||||
):
|
||||
) -> tuple[pd.Series, pd.Series | None]:
|
||||
"""Returns the elements of agreement, and optionally, disagreement between models.
|
||||
|
||||
Note that, as in `agreement_fraction`, the predictions must have the same
|
||||
@@ -209,6 +212,13 @@ def agreement_elements(
|
||||
return idx[x_pred == y_pred], None
|
||||
|
||||
|
||||
def gini(x: list[float]) -> float:
|
||||
x = np.array(x, dtype=np.float32)
|
||||
n = len(x)
|
||||
diffs = sum(abs(i - j) for i, j in combinations(x, r=2))
|
||||
return (diffs / (n**2 * x.mean())).item()
|
||||
|
||||
|
||||
# TODO(gschardong): Move to the same file as `load_dataset`
|
||||
def _to_age_bracket(row):
|
||||
iage = int(row["age"])
|
||||
|
||||
Reference in New Issue
Block a user