diff --git a/src/facebias/metrics.py b/src/facebias/metrics.py index f0b5e9c..21bb9a4 100644 --- a/src/facebias/metrics.py +++ b/src/facebias/metrics.py @@ -1,13 +1,26 @@ # -*- coding: utf-8 -*- +import logging from typing import Any +import numpy as np +from sklearn.metrics import ( + accuracy_score, + cohen_kappa_score, + hamming_loss, + max_error, + mean_absolute_error, + mean_squared_error, + precision_score, +) + from facebias.estimators import Capability +logger = logging.getLogger("facebias:metrics") + def find_common_capabilities( - gt: dict[str, dict[Capability, Any]], - preds: dict[str, dict[Capability, Any]] + gt: dict[str, dict[Capability, Any]], preds: dict[str, dict[Capability, Any]] ) -> list[str]: """Iterates on `preds` and `gt`, finding common model capabilities. @@ -49,3 +62,191 @@ def find_common_capabilities( gt_keys = set(gt[common_elem].keys()) preds_keys = set(preds[common_elem].keys()) return list(gt_keys & preds_keys) + + +def calc_model_performance( + gt: dict[str, dict[str, Any]], + preds: dict[str, dict[str, Any]], + keys: list[str] = [], +) -> dict[str, dict[str, float]]: + """ + We assume that both `gt` and `preds` have the same structure. They should + be indexed by individual ID, such as the image name, and each value is a + dictionary with model prediction capabilities as keys (e.g., "age_group", + "sex", "skin-color", etc.), and the values are the predictions, or ground-truth + values for each ID/capability. + + if `keys` is empty, then we infer from common keys present in `preds` and `gt`. + + Parameters + ---------- + gt: dict[str, dict[str, Any]] + preds: dict[str, dict[str, Any]] + keys: list[str] | None + + Returns + ------- + metrics: dict[str, dict[str, float]] + """ + common_caps = keys + if not keys: + common_caps = find_common_capabilities(gt, preds) + if not common_caps: + kgt = next(iter(gt)) + kpd = next(iter(preds)) + logger.error( + f'No common capabilities found. Predictions has "{preds[kpd].keys()}",' + f' ground-truth has "{gt[kgt].keys()}".' + ) + return None + + # Finding common images between predictions and ground-truth. + common_inds = set(preds.keys()) & set(gt.keys()) + if not common_inds: + logger.error("No common images found between predictions and ground-truth.") + return None + + metric_vals = dict() + for cat in common_caps: + pred_data = [None for _ in common_inds] + gt_data = [None for _ in common_inds] + for i, ix in enumerate(common_inds): + pred_data[i] = preds[ix][cat] + gt_data[i] = gt[ix][cat] + + if isinstance(pred_data[0], float): + pred_data = np.array(pred_data) + gt_data = np.array(gt_data) + metric_vals[cat] = { + "mean_absolute_error": mean_absolute_error(gt_data, pred_data), + } + else: + metric_vals[cat] = { + "accuracy": accuracy_score(gt_data, pred_data), + "cohen-kappa": cohen_kappa_score(gt_data, pred_data), + } + + return metric_vals + + +def find_unique_values_per_capability( + class_output: dict[str, dict[Capability, Any]], caps: list[Capability] | None = None +) -> dict[Capability, str]: + """Returns the set of values per capability in `class_output`. + + Parameters + ---------- + class_output: dict[str, dict[Capability, Any]] + The classification results, or ground-truth data indexed by element. + + caps: list[Capability] | None + The list of capabilities to find unique values for. If left as `None`, + we will find unique values for all of them. + + Results + ------- + unique_vals: dict[Capability, str] + The unique values indexed by capability. + """ + if caps is None: + caps = list(next(iter(class_output.values())).keys()) + elif not isinstance(caps, (list, tuple)): + caps = [caps] + + unique_vals = dict() + for cap in caps: + unique_vals[cap] = set() + for res in class_output.values(): + unique_vals[cap].add(res[cap]) + + return unique_vals + + +def get_capability_data( + class_outputs: dict[str, dict[Capability, Any]], cap: Capability +) -> dict[str, Any]: + """Returns data for all individuals regarding a capability. + + Parameters + ---------- + class_outputs: dict[str, dict[Capability, Any]] + The estimator outputs indexed by individual. + + cap: Capability + The desired capability. + + Returns + ------- + data: dict[str, Any] + The capability data indexed by individual. + """ + data_per_id = dict() + + for ind, data in class_outputs.items(): + if cap not in data: + logger.warning( + f'Entry for capability "{cap.value}" not found for individual "{ind}". Skipping.' + ) + continue + data_per_id[ind] = data[cap] + + return data_per_id + + +def filter_by_index(data: dict[str, Any], indx: Any): + return dict((k, v) for k, v in data.items() if k in indx) + + +def calc_metrics_per_subgroup( + gt: dict[str, dict[str, Any]], preds: dict[str, dict[str, Any]] +) -> dict[Capability, dict[Any, dict]]: + """Calculate performance metrics per sub-group for each capability. + + Parameters + ---------- + gt: dict[str, dict[str, Any]] + + preds: dict[str, dict[str, Any]] + + Returns + ------- + metrics: dict[Capability, dict[Any, dict]] + """ + common_caps = set(find_common_capabilities(gt, preds)) + + metrics = {} + for cap in common_caps: + if cap == Capability.AGE: + continue + + other_caps = common_caps - set([cap]) + unique_values_cap = find_unique_values_per_capability(gt, cap)[cap] + + metrics[cap] = {} + for val in unique_values_cap: + ids = [k for k, v in gt.items() if v[cap] == val] + + metrics[cap][val] = {"number_of_elements": len(ids)} + for ocap in other_caps: + metrics[cap][val][ocap] = {} + filtered_pred = filter_by_index(get_capability_data(preds, ocap), ids) + filtered_gt = filter_by_index(get_capability_data(gt, ocap), ids) + + filtered_pred_data = np.array([filtered_pred[i] for i in ids]) + filtered_gt_data = np.array([filtered_gt[i] for i in ids]) + + if isinstance(filtered_pred_data[0], float): + metrics[cap][val][ocap] = { + "mean_absolute_error": mean_absolute_error( + filtered_gt_data, filtered_pred_data + ), + "max_error": max_error(filtered_gt_data, filtered_pred_data), + } + else: + metrics[cap][val][ocap] = { + "accuracy": accuracy_score( + filtered_gt_data, filtered_pred_data + ), + } + + return metrics