From 250d218f9b77b2afd586bf83a343784c0fa16a6c Mon Sep 17 00:00:00 2001 From: Guilherme Schardong Date: Thu, 16 Apr 2026 14:00:55 +0100 Subject: [PATCH] Added capabilities (`Capability`) to the estimators. --- src/facebias/estimators/__init__.py | 24 ++++++++++++++++-- src/facebias/estimators/fairface.py | 38 +++++++++++++++++------------ 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/src/facebias/estimators/__init__.py b/src/facebias/estimators/__init__.py index 1146c34..e632b5d 100644 --- a/src/facebias/estimators/__init__.py +++ b/src/facebias/estimators/__init__.py @@ -5,13 +5,33 @@ from pathlib import Path from typing import Any import numpy as np +from enum import StrEnum + + +class Capability(StrEnum): + AGE = "age" + AGEGROUP = "age_group" + SEX = "sex" + SKINCOLOR = "skin_color" + ETHNICITY = "ethinicity" class BaseEstimator(ABC): @abstractmethod - def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]: + def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]: + """Runs the estimator on a batch of images. + + The input `images` is a dictionary indexed by the image name, or ID, + and values are image data in the appropriate format for the estimator + (RGB, BGR' Lab, ...). + """ pass @abstractclassmethod - def capabilities() -> list[str]: + def capabilities() -> list[Capability]: + """Returns the estimator capabilities + + In this context, estimator capabilities mean what is infers from the + face image, such as `age_group`, `sex`, `skin_color`, etc. + """ pass diff --git a/src/facebias/estimators/fairface.py b/src/facebias/estimators/fairface.py index 7795b76..687094e 100644 --- a/src/facebias/estimators/fairface.py +++ b/src/facebias/estimators/fairface.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -import logging from collections import OrderedDict from pathlib import Path from typing import Any, Optional @@ -9,12 +8,10 @@ import numpy as np import torch import torch.nn as nn import torchvision -from facebias import load_dataset +from facebias.estimators import BaseEstimator, Capability from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import transforms -from facebias.estimators import BaseEstimator - class FairFace(BaseEstimator): """ @@ -41,7 +38,7 @@ class FairFace(BaseEstimator): ) return self - def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]: + def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]: preds = OrderedDict() for p, im in ims.items(): im = self.T(im).view(1, 3, 224, 224).to(self.device) @@ -51,9 +48,9 @@ class FairFace(BaseEstimator): y = y.cpu().detach().squeeze().numpy() # Ethnicity prediction - # y_ethno = y[:4] - # ethno_score = np.exp(y_ethno) / np.sum(np.exp(y_ethno)) - # ethno_pred = np.argmax(ethno_score) + y_ethno = y[:4] + ethno_score = np.exp(y_ethno) / np.sum(np.exp(y_ethno)) + ethno_pred = np.argmax(ethno_score) # Age prediction y_age = y[9:18] @@ -67,26 +64,37 @@ class FairFace(BaseEstimator): sex_label = "m" if not sex_pred else "f" preds[p] = { - "age_group": _to_age_label(age_pred), - "sex": sex_label, + Capability.AGEGROUP: _to_age_label(age_pred), + Capability.SEX: sex_label, + # Capability.ETHNICITY: _to_ethno_label(ethno_pred) } return preds - def capabilities() -> list[str]: - return ["age_group", "sex"] + def capabilities() -> list[Capability]: + return [Capability.AGEGROUP, Capability.SEX] #, Capability.ETHNICITY] def _to_age_label(age): return "{}-{}".format(age * 10, age * 10 + 9) +def _to_ethno_label(val): + if val not in range(4): + return "unknown" + opts = ["white", "black", "asian", "indian"] + return opts[val] + + if __name__ == '__main__': - MODEL_PATH = Path("../../../models/fairface_alldata_4race_20191111.pt") - model = FairFace() - model.load_state_dict(MODEL_PATH) + from facebias import load_dataset DATASET_PATH = Path("../../../data/frll_neutral_front") dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0]) + print(FairFace.capabilities()) + MODEL_PATH = Path("../../../models/fairface_alldata_4race_20191111.pt") + model = FairFace() + model.load_state_dict(MODEL_PATH) + out = model.predict(dataset)