From 20cf80d94eafa1f55f0e5a7d080660afab94f0ef Mon Sep 17 00:00:00 2001 From: Guilherme Schardong Date: Wed, 22 Apr 2026 13:49:32 +0100 Subject: [PATCH] Added the `possible_capability_values` base method for estimators... New `possible_capability_values` base method for estimators. The goal is to return a list of possible values for each non-numerical capability, as some performance metrics may benefit from this complete listing. Additionally, added a new exception to indicate when a capability is invalid in a given context, e.g., when the model does not support it. --- src/facebias/estimators/__init__.py | 10 ++++++++++ src/facebias/estimators/fairface.py | 15 +++++++++++++-- src/facebias/estimators/mivolov1.py | 24 +++++++++++++++++++++--- 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/facebias/estimators/__init__.py b/src/facebias/estimators/__init__.py index e632b5d..f369c98 100644 --- a/src/facebias/estimators/__init__.py +++ b/src/facebias/estimators/__init__.py @@ -16,6 +16,11 @@ class Capability(StrEnum): ETHNICITY = "ethinicity" +class InvalidCapabilityError(Exception): + def __init__(self, msg="Invalid capability for model."): + super(InvalidCapability, self).__init__(msg=msg) + + class BaseEstimator(ABC): @abstractmethod def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]: @@ -35,3 +40,8 @@ class BaseEstimator(ABC): face image, such as `age_group`, `sex`, `skin_color`, etc. """ pass + + @abstractclassmethod + def possible_capability_values(cap: Capability) -> list: + """Returns all possible values for a given model capability.""" + pass diff --git a/src/facebias/estimators/fairface.py b/src/facebias/estimators/fairface.py index 703f9d2..e884741 100644 --- a/src/facebias/estimators/fairface.py +++ b/src/facebias/estimators/fairface.py @@ -8,7 +8,7 @@ import numpy as np import torch import torch.nn as nn import torchvision -from facebias.estimators import BaseEstimator, Capability +from facebias.estimators import BaseEstimator, Capability, InvalidCapabilityError from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import transforms @@ -96,9 +96,20 @@ class FairFace(BaseEstimator): def capabilities() -> list[Capability]: return [Capability.AGEGROUP, Capability.SEX] #, Capability.ETHNICITY] + def possible_capability_values(cap: Capability) -> list: + if cap == Capability.AGEGROUP: + return ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"] + elif cap == Capability.SEX: + return ['m', 'f'] + # elif cap == Capability.ETHNICITY: + # return ["White", "Black", "Asian", "Indian"] + else: + raise InvalidCapabilityError() + def _to_age_label(age): - return "{}-{}".format(age * 10, age * 10 + 9) + return FairFace.possible_classes(Capability.AGEGROUP)[age] + # return "{}-{}".format(age * 10, age * 10 + 9) def _to_ethno_label(val): diff --git a/src/facebias/estimators/mivolov1.py b/src/facebias/estimators/mivolov1.py index ac40abb..26bbbff 100644 --- a/src/facebias/estimators/mivolov1.py +++ b/src/facebias/estimators/mivolov1.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn import torchvision from facebias import load_dataset -from facebias.estimators import BaseEstimator, Capability +from facebias.estimators import BaseEstimator, Capability, InvalidCapabilityError from facebias.estimators.mivolo.mi_volo import MiVOLO from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import transforms @@ -80,10 +80,28 @@ class MiVOLOv1(BaseEstimator): def capabilities() -> list[Capability]: return [Capability.AGE, Capability.AGEGROUP, Capability.SEX] + def possible_capability_values(cap: Capability) -> list[str]: + if cap == Capability.AGE: + return [] + elif cap == Capability.AGEGROUP: + return ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"] + elif cap == Capability.SEX: + return ['m', 'f'] + else: + raise InvalidCapabilityError() + def _to_age_label(age): - base_age = int(age // 10 * 10) - return "{}-{}".format(base_age, base_age + 9) + iage = int(age) + if iage < 3: + return "0-2" + elif iage < 10: + return "3-9" + elif iage > 69: + return "70+" + + d = int(iage // 10 * 10) + return "{}-{}".format(d, d + 9) if __name__ == '__main__':