Added metrics per subgroup and related functions.

This commit is contained in:
Guilherme G. Schardong
2026-04-19 00:37:19 +01:00
parent 0ebca5c450
commit 923c0aeff7

View File

@@ -1,13 +1,26 @@
# -*- coding: utf-8 -*-
import logging
from typing import Any
import numpy as np
from sklearn.metrics import (
accuracy_score,
cohen_kappa_score,
hamming_loss,
max_error,
mean_absolute_error,
mean_squared_error,
precision_score,
)
from facebias.estimators import Capability
logger = logging.getLogger("facebias:metrics")
def find_common_capabilities(
gt: dict[str, dict[Capability, Any]],
preds: dict[str, dict[Capability, Any]]
gt: dict[str, dict[Capability, Any]], preds: dict[str, dict[Capability, Any]]
) -> list[str]:
"""Iterates on `preds` and `gt`, finding common model capabilities.
@@ -49,3 +62,191 @@ def find_common_capabilities(
gt_keys = set(gt[common_elem].keys())
preds_keys = set(preds[common_elem].keys())
return list(gt_keys & preds_keys)
def calc_model_performance(
gt: dict[str, dict[str, Any]],
preds: dict[str, dict[str, Any]],
keys: list[str] = [],
) -> dict[str, dict[str, float]]:
"""
We assume that both `gt` and `preds` have the same structure. They should
be indexed by individual ID, such as the image name, and each value is a
dictionary with model prediction capabilities as keys (e.g., "age_group",
"sex", "skin-color", etc.), and the values are the predictions, or ground-truth
values for each ID/capability.
if `keys` is empty, then we infer from common keys present in `preds` and `gt`.
Parameters
----------
gt: dict[str, dict[str, Any]]
preds: dict[str, dict[str, Any]]
keys: list[str] | None
Returns
-------
metrics: dict[str, dict[str, float]]
"""
common_caps = keys
if not keys:
common_caps = find_common_capabilities(gt, preds)
if not common_caps:
kgt = next(iter(gt))
kpd = next(iter(preds))
logger.error(
f'No common capabilities found. Predictions has "{preds[kpd].keys()}",'
f' ground-truth has "{gt[kgt].keys()}".'
)
return None
# Finding common images between predictions and ground-truth.
common_inds = set(preds.keys()) & set(gt.keys())
if not common_inds:
logger.error("No common images found between predictions and ground-truth.")
return None
metric_vals = dict()
for cat in common_caps:
pred_data = [None for _ in common_inds]
gt_data = [None for _ in common_inds]
for i, ix in enumerate(common_inds):
pred_data[i] = preds[ix][cat]
gt_data[i] = gt[ix][cat]
if isinstance(pred_data[0], float):
pred_data = np.array(pred_data)
gt_data = np.array(gt_data)
metric_vals[cat] = {
"mean_absolute_error": mean_absolute_error(gt_data, pred_data),
}
else:
metric_vals[cat] = {
"accuracy": accuracy_score(gt_data, pred_data),
"cohen-kappa": cohen_kappa_score(gt_data, pred_data),
}
return metric_vals
def find_unique_values_per_capability(
class_output: dict[str, dict[Capability, Any]], caps: list[Capability] | None = None
) -> dict[Capability, str]:
"""Returns the set of values per capability in `class_output`.
Parameters
----------
class_output: dict[str, dict[Capability, Any]]
The classification results, or ground-truth data indexed by element.
caps: list[Capability] | None
The list of capabilities to find unique values for. If left as `None`,
we will find unique values for all of them.
Results
-------
unique_vals: dict[Capability, str]
The unique values indexed by capability.
"""
if caps is None:
caps = list(next(iter(class_output.values())).keys())
elif not isinstance(caps, (list, tuple)):
caps = [caps]
unique_vals = dict()
for cap in caps:
unique_vals[cap] = set()
for res in class_output.values():
unique_vals[cap].add(res[cap])
return unique_vals
def get_capability_data(
class_outputs: dict[str, dict[Capability, Any]], cap: Capability
) -> dict[str, Any]:
"""Returns data for all individuals regarding a capability.
Parameters
----------
class_outputs: dict[str, dict[Capability, Any]]
The estimator outputs indexed by individual.
cap: Capability
The desired capability.
Returns
-------
data: dict[str, Any]
The capability data indexed by individual.
"""
data_per_id = dict()
for ind, data in class_outputs.items():
if cap not in data:
logger.warning(
f'Entry for capability "{cap.value}" not found for individual "{ind}". Skipping.'
)
continue
data_per_id[ind] = data[cap]
return data_per_id
def filter_by_index(data: dict[str, Any], indx: Any):
return dict((k, v) for k, v in data.items() if k in indx)
def calc_metrics_per_subgroup(
gt: dict[str, dict[str, Any]], preds: dict[str, dict[str, Any]]
) -> dict[Capability, dict[Any, dict]]:
"""Calculate performance metrics per sub-group for each capability.
Parameters
----------
gt: dict[str, dict[str, Any]]
preds: dict[str, dict[str, Any]]
Returns
-------
metrics: dict[Capability, dict[Any, dict]]
"""
common_caps = set(find_common_capabilities(gt, preds))
metrics = {}
for cap in common_caps:
if cap == Capability.AGE:
continue
other_caps = common_caps - set([cap])
unique_values_cap = find_unique_values_per_capability(gt, cap)[cap]
metrics[cap] = {}
for val in unique_values_cap:
ids = [k for k, v in gt.items() if v[cap] == val]
metrics[cap][val] = {"number_of_elements": len(ids)}
for ocap in other_caps:
metrics[cap][val][ocap] = {}
filtered_pred = filter_by_index(get_capability_data(preds, ocap), ids)
filtered_gt = filter_by_index(get_capability_data(gt, ocap), ids)
filtered_pred_data = np.array([filtered_pred[i] for i in ids])
filtered_gt_data = np.array([filtered_gt[i] for i in ids])
if isinstance(filtered_pred_data[0], float):
metrics[cap][val][ocap] = {
"mean_absolute_error": mean_absolute_error(
filtered_gt_data, filtered_pred_data
),
"max_error": max_error(filtered_gt_data, filtered_pred_data),
}
else:
metrics[cap][val][ocap] = {
"accuracy": accuracy_score(
filtered_gt_data, filtered_pred_data
),
}
return metrics