Added prototype code for dataset distribution evaluation.

This commit is contained in:
2026-05-06 16:29:13 +01:00
parent 536c29978d
commit 92db500bd0

141
src/facebias/evaluation.py Normal file
View File

@@ -0,0 +1,141 @@
# -*- coding: utf-8 -*-
"""Dataset/model evaluation functions."""
import logging
from itertools import permutations, combinations
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import entropy
from facebias.estimators import Capability
from facebias.metrics import gini
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("facebias:evaluation.py")
# TODO(gschardong): Move to the same file as `load_dataset`
def _to_age_bracket(row):
iage = int(row["age"])
if iage < 3:
return "00-02"
elif iage < 10:
return "03-09"
elif iage > 69:
return "70+"
d = iage // 10 * 10
return "{}-{}".format(d, d + 9)
if __name__ == "__main__":
import os
logger.info(os.getcwd())
DATASET_PATH = Path("../../data/facing2-train/")
METADATA_PATH = DATASET_PATH / "meta-w-age.csv"
meta = pd.read_csv(METADATA_PATH, sep=",", index_col="image")
meta[Capability.AGEGROUP.value] = meta.apply(_to_age_bracket, axis=1)
meta = meta.sort_index()
meta[Capability.AGEGROUP + "_cat"], _ = pd.factorize(
meta[Capability.AGEGROUP], sort=True
)
meta[Capability.SEX + "_cat"], _ = pd.factorize(meta[Capability.SEX], sort=True)
# GINI IMPURITY
# Lower values means a concentration of values around a single class, i.e. bias.
age_gini = gini(meta["age"])
# gt_age_group_ord = meta["age_group"].apply(lambda x: _agegroup_int_map[x])
agegroup_gini = gini(meta[Capability.AGEGROUP + "_cat"])
# Should be close to 0.5, indicating a 50/50 split of males and females,
# representing maximum uncertainty.
sex_gini = gini(meta[Capability.SEX + "_cat"])
# SHANNON'S ENTROPY
count_per_agegroup = meta["age_group"].value_counts()
prob_per_agegroup = count_per_agegroup / count_per_agegroup.sum()
H_agegroup = entropy(prob_per_agegroup)
count_per_sex = meta["sex"].value_counts()
prob_per_sex = count_per_sex / count_per_sex.sum()
H_sex = entropy(prob_per_sex)
# Now, onto the subgroup metrics.
# The goal is to be able to answer the following types of questions:
# 1) How many women are in each age-bracket?
# 2) Given the population in age-bracket 20-49 years, how is their gender distribution?
# 3) Do we need to collect more images of new individuals? If so, what population should we focus on?
sex_gb = meta.groupby(Capability.SEX)[["age_group_cat"]]
agegroup_gb = meta.groupby(Capability.AGEGROUP)[["sex_cat"]]
gini_per_sex = sex_gb.apply(gini)
gini_per_agegroup = agegroup_gb.apply(gini)
# Prototype textual description of the dataset. To be incorporated into a
# "generate_report" function.
print(
f'The dataset "{DATASET_PATH.name}" has a total of {len(meta)} {meta.index.name}s,'
" with the following features/capabilities:"
)
caps = []
for c in Capability:
if c.value in meta:
caps.append(c)
print(f"- {c.value}")
print("\nEach feature/capability has the following types and values:")
for c in caps:
if c == Capability.AGE:
print(f"{c.value}: numeric")
else:
print(f"{c.value}: categorical")
print(f" - {sorted(meta[c].unique())}")
print("\nData distribution statistics.")
for c in caps:
print(f'The feature/capability "{c}" has the following distribution of values:')
if c == Capability.AGE:
m1 = meta[c].min()
m2 = meta[c].max()
mean = meta[c].mean()
std = meta[c].std()
p25 = meta[c].quantile(0.25)
median = meta[c].median()
p75 = meta[c].quantile(0.75)
print(
f" - min = {m1}, max = {m2}, mean = {mean:.2f}, std = {std:.2f}"
f" p25 = {p25}, p50 = {median}, p75 = {p75}"
)
print(" Interqualtile ranges:")
print(f" - p25-min = {p25 - m1}")
print(f" - p50-p25 = {median - p25}")
print(f" - p75-p50 = {p75 - median}")
print(f" - max-p75 = {m2 - p75}")
else:
series = meta[c].value_counts().sort_index()
for s in series.index:
print(f" - {s}: {series[s]}")
print("\nPer capability/class data distribution statistics.")
for c1, c2 in combinations(caps, 2):
if c1 == Capability.AGE:
continue
if c2 != Capability.AGE:
gb = meta.groupby(c1)[[c2]]
print(
f'Grouping by "{c1}", the dataset has the following data distribution for "{c2}"'
)
print(gb.value_counts().sort_index().unstack(level=c1).fillna(0))
# Diagnostics of biases in the dataset. To be incorporated into a
# "generate_diagnostics" function later on.