Added capabilities (Capability) to the estimators.

This commit is contained in:
2026-04-16 14:00:55 +01:00
parent 27efada40e
commit 250d218f9b
2 changed files with 45 additions and 17 deletions

View File

@@ -5,13 +5,33 @@ from pathlib import Path
from typing import Any
import numpy as np
from enum import StrEnum
class Capability(StrEnum):
AGE = "age"
AGEGROUP = "age_group"
SEX = "sex"
SKINCOLOR = "skin_color"
ETHNICITY = "ethinicity"
class BaseEstimator(ABC):
@abstractmethod
def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]:
def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
"""Runs the estimator on a batch of images.
The input `images` is a dictionary indexed by the image name, or ID,
and values are image data in the appropriate format for the estimator
(RGB, BGR' Lab, ...).
"""
pass
@abstractclassmethod
def capabilities() -> list[str]:
def capabilities() -> list[Capability]:
"""Returns the estimator capabilities
In this context, estimator capabilities mean what is infers from the
face image, such as `age_group`, `sex`, `skin_color`, etc.
"""
pass

View File

@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import logging
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional
@@ -9,12 +8,10 @@ import numpy as np
import torch
import torch.nn as nn
import torchvision
from facebias import load_dataset
from facebias.estimators import BaseEstimator, Capability
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
from facebias.estimators import BaseEstimator
class FairFace(BaseEstimator):
"""
@@ -41,7 +38,7 @@ class FairFace(BaseEstimator):
)
return self
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]:
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
preds = OrderedDict()
for p, im in ims.items():
im = self.T(im).view(1, 3, 224, 224).to(self.device)
@@ -51,9 +48,9 @@ class FairFace(BaseEstimator):
y = y.cpu().detach().squeeze().numpy()
# Ethnicity prediction
# y_ethno = y[:4]
# ethno_score = np.exp(y_ethno) / np.sum(np.exp(y_ethno))
# ethno_pred = np.argmax(ethno_score)
y_ethno = y[:4]
ethno_score = np.exp(y_ethno) / np.sum(np.exp(y_ethno))
ethno_pred = np.argmax(ethno_score)
# Age prediction
y_age = y[9:18]
@@ -67,26 +64,37 @@ class FairFace(BaseEstimator):
sex_label = "m" if not sex_pred else "f"
preds[p] = {
"age_group": _to_age_label(age_pred),
"sex": sex_label,
Capability.AGEGROUP: _to_age_label(age_pred),
Capability.SEX: sex_label,
# Capability.ETHNICITY: _to_ethno_label(ethno_pred)
}
return preds
def capabilities() -> list[str]:
return ["age_group", "sex"]
def capabilities() -> list[Capability]:
return [Capability.AGEGROUP, Capability.SEX] #, Capability.ETHNICITY]
def _to_age_label(age):
return "{}-{}".format(age * 10, age * 10 + 9)
def _to_ethno_label(val):
if val not in range(4):
return "unknown"
opts = ["white", "black", "asian", "indian"]
return opts[val]
if __name__ == '__main__':
MODEL_PATH = Path("../../../models/fairface_alldata_4race_20191111.pt")
model = FairFace()
model.load_state_dict(MODEL_PATH)
from facebias import load_dataset
DATASET_PATH = Path("../../../data/frll_neutral_front")
dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0])
print(FairFace.capabilities())
MODEL_PATH = Path("../../../models/fairface_alldata_4race_20191111.pt")
model = FairFace()
model.load_state_dict(MODEL_PATH)
out = model.predict(dataset)