Added metrics per subgroup and related functions.
This commit is contained in:
@@ -1,13 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
cohen_kappa_score,
|
||||
hamming_loss,
|
||||
max_error,
|
||||
mean_absolute_error,
|
||||
mean_squared_error,
|
||||
precision_score,
|
||||
)
|
||||
|
||||
from facebias.estimators import Capability
|
||||
|
||||
logger = logging.getLogger("facebias:metrics")
|
||||
|
||||
|
||||
def find_common_capabilities(
|
||||
gt: dict[str, dict[Capability, Any]],
|
||||
preds: dict[str, dict[Capability, Any]]
|
||||
gt: dict[str, dict[Capability, Any]], preds: dict[str, dict[Capability, Any]]
|
||||
) -> list[str]:
|
||||
"""Iterates on `preds` and `gt`, finding common model capabilities.
|
||||
|
||||
@@ -49,3 +62,191 @@ def find_common_capabilities(
|
||||
gt_keys = set(gt[common_elem].keys())
|
||||
preds_keys = set(preds[common_elem].keys())
|
||||
return list(gt_keys & preds_keys)
|
||||
|
||||
|
||||
def calc_model_performance(
|
||||
gt: dict[str, dict[str, Any]],
|
||||
preds: dict[str, dict[str, Any]],
|
||||
keys: list[str] = [],
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
We assume that both `gt` and `preds` have the same structure. They should
|
||||
be indexed by individual ID, such as the image name, and each value is a
|
||||
dictionary with model prediction capabilities as keys (e.g., "age_group",
|
||||
"sex", "skin-color", etc.), and the values are the predictions, or ground-truth
|
||||
values for each ID/capability.
|
||||
|
||||
if `keys` is empty, then we infer from common keys present in `preds` and `gt`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
gt: dict[str, dict[str, Any]]
|
||||
preds: dict[str, dict[str, Any]]
|
||||
keys: list[str] | None
|
||||
|
||||
Returns
|
||||
-------
|
||||
metrics: dict[str, dict[str, float]]
|
||||
"""
|
||||
common_caps = keys
|
||||
if not keys:
|
||||
common_caps = find_common_capabilities(gt, preds)
|
||||
if not common_caps:
|
||||
kgt = next(iter(gt))
|
||||
kpd = next(iter(preds))
|
||||
logger.error(
|
||||
f'No common capabilities found. Predictions has "{preds[kpd].keys()}",'
|
||||
f' ground-truth has "{gt[kgt].keys()}".'
|
||||
)
|
||||
return None
|
||||
|
||||
# Finding common images between predictions and ground-truth.
|
||||
common_inds = set(preds.keys()) & set(gt.keys())
|
||||
if not common_inds:
|
||||
logger.error("No common images found between predictions and ground-truth.")
|
||||
return None
|
||||
|
||||
metric_vals = dict()
|
||||
for cat in common_caps:
|
||||
pred_data = [None for _ in common_inds]
|
||||
gt_data = [None for _ in common_inds]
|
||||
for i, ix in enumerate(common_inds):
|
||||
pred_data[i] = preds[ix][cat]
|
||||
gt_data[i] = gt[ix][cat]
|
||||
|
||||
if isinstance(pred_data[0], float):
|
||||
pred_data = np.array(pred_data)
|
||||
gt_data = np.array(gt_data)
|
||||
metric_vals[cat] = {
|
||||
"mean_absolute_error": mean_absolute_error(gt_data, pred_data),
|
||||
}
|
||||
else:
|
||||
metric_vals[cat] = {
|
||||
"accuracy": accuracy_score(gt_data, pred_data),
|
||||
"cohen-kappa": cohen_kappa_score(gt_data, pred_data),
|
||||
}
|
||||
|
||||
return metric_vals
|
||||
|
||||
|
||||
def find_unique_values_per_capability(
|
||||
class_output: dict[str, dict[Capability, Any]], caps: list[Capability] | None = None
|
||||
) -> dict[Capability, str]:
|
||||
"""Returns the set of values per capability in `class_output`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_output: dict[str, dict[Capability, Any]]
|
||||
The classification results, or ground-truth data indexed by element.
|
||||
|
||||
caps: list[Capability] | None
|
||||
The list of capabilities to find unique values for. If left as `None`,
|
||||
we will find unique values for all of them.
|
||||
|
||||
Results
|
||||
-------
|
||||
unique_vals: dict[Capability, str]
|
||||
The unique values indexed by capability.
|
||||
"""
|
||||
if caps is None:
|
||||
caps = list(next(iter(class_output.values())).keys())
|
||||
elif not isinstance(caps, (list, tuple)):
|
||||
caps = [caps]
|
||||
|
||||
unique_vals = dict()
|
||||
for cap in caps:
|
||||
unique_vals[cap] = set()
|
||||
for res in class_output.values():
|
||||
unique_vals[cap].add(res[cap])
|
||||
|
||||
return unique_vals
|
||||
|
||||
|
||||
def get_capability_data(
|
||||
class_outputs: dict[str, dict[Capability, Any]], cap: Capability
|
||||
) -> dict[str, Any]:
|
||||
"""Returns data for all individuals regarding a capability.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_outputs: dict[str, dict[Capability, Any]]
|
||||
The estimator outputs indexed by individual.
|
||||
|
||||
cap: Capability
|
||||
The desired capability.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data: dict[str, Any]
|
||||
The capability data indexed by individual.
|
||||
"""
|
||||
data_per_id = dict()
|
||||
|
||||
for ind, data in class_outputs.items():
|
||||
if cap not in data:
|
||||
logger.warning(
|
||||
f'Entry for capability "{cap.value}" not found for individual "{ind}". Skipping.'
|
||||
)
|
||||
continue
|
||||
data_per_id[ind] = data[cap]
|
||||
|
||||
return data_per_id
|
||||
|
||||
|
||||
def filter_by_index(data: dict[str, Any], indx: Any):
|
||||
return dict((k, v) for k, v in data.items() if k in indx)
|
||||
|
||||
|
||||
def calc_metrics_per_subgroup(
|
||||
gt: dict[str, dict[str, Any]], preds: dict[str, dict[str, Any]]
|
||||
) -> dict[Capability, dict[Any, dict]]:
|
||||
"""Calculate performance metrics per sub-group for each capability.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
gt: dict[str, dict[str, Any]]
|
||||
|
||||
preds: dict[str, dict[str, Any]]
|
||||
|
||||
Returns
|
||||
-------
|
||||
metrics: dict[Capability, dict[Any, dict]]
|
||||
"""
|
||||
common_caps = set(find_common_capabilities(gt, preds))
|
||||
|
||||
metrics = {}
|
||||
for cap in common_caps:
|
||||
if cap == Capability.AGE:
|
||||
continue
|
||||
|
||||
other_caps = common_caps - set([cap])
|
||||
unique_values_cap = find_unique_values_per_capability(gt, cap)[cap]
|
||||
|
||||
metrics[cap] = {}
|
||||
for val in unique_values_cap:
|
||||
ids = [k for k, v in gt.items() if v[cap] == val]
|
||||
|
||||
metrics[cap][val] = {"number_of_elements": len(ids)}
|
||||
for ocap in other_caps:
|
||||
metrics[cap][val][ocap] = {}
|
||||
filtered_pred = filter_by_index(get_capability_data(preds, ocap), ids)
|
||||
filtered_gt = filter_by_index(get_capability_data(gt, ocap), ids)
|
||||
|
||||
filtered_pred_data = np.array([filtered_pred[i] for i in ids])
|
||||
filtered_gt_data = np.array([filtered_gt[i] for i in ids])
|
||||
|
||||
if isinstance(filtered_pred_data[0], float):
|
||||
metrics[cap][val][ocap] = {
|
||||
"mean_absolute_error": mean_absolute_error(
|
||||
filtered_gt_data, filtered_pred_data
|
||||
),
|
||||
"max_error": max_error(filtered_gt_data, filtered_pred_data),
|
||||
}
|
||||
else:
|
||||
metrics[cap][val][ocap] = {
|
||||
"accuracy": accuracy_score(
|
||||
filtered_gt_data, filtered_pred_data
|
||||
),
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
Reference in New Issue
Block a user