Added the possible_capability_values base method for estimators...

New `possible_capability_values` base method for estimators. The goal
is to return a list of possible values for each non-numerical
capability, as some performance metrics may benefit from this complete
listing.

Additionally, added a new exception to indicate when a capability is
invalid in a given context, e.g., when the model does not support it.
This commit is contained in:
2026-04-22 13:49:32 +01:00
parent a265fd24a1
commit 20cf80d94e
3 changed files with 44 additions and 5 deletions

View File

@@ -16,6 +16,11 @@ class Capability(StrEnum):
ETHNICITY = "ethinicity"
class InvalidCapabilityError(Exception):
def __init__(self, msg="Invalid capability for model."):
super(InvalidCapability, self).__init__(msg=msg)
class BaseEstimator(ABC):
@abstractmethod
def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
@@ -35,3 +40,8 @@ class BaseEstimator(ABC):
face image, such as `age_group`, `sex`, `skin_color`, etc.
"""
pass
@abstractclassmethod
def possible_capability_values(cap: Capability) -> list:
"""Returns all possible values for a given model capability."""
pass

View File

@@ -8,7 +8,7 @@ import numpy as np
import torch
import torch.nn as nn
import torchvision
from facebias.estimators import BaseEstimator, Capability
from facebias.estimators import BaseEstimator, Capability, InvalidCapabilityError
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
@@ -96,9 +96,20 @@ class FairFace(BaseEstimator):
def capabilities() -> list[Capability]:
return [Capability.AGEGROUP, Capability.SEX] #, Capability.ETHNICITY]
def possible_capability_values(cap: Capability) -> list:
if cap == Capability.AGEGROUP:
return ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"]
elif cap == Capability.SEX:
return ['m', 'f']
# elif cap == Capability.ETHNICITY:
# return ["White", "Black", "Asian", "Indian"]
else:
raise InvalidCapabilityError()
def _to_age_label(age):
return "{}-{}".format(age * 10, age * 10 + 9)
return FairFace.possible_classes(Capability.AGEGROUP)[age]
# return "{}-{}".format(age * 10, age * 10 + 9)
def _to_ethno_label(val):

View File

@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
import torchvision
from facebias import load_dataset
from facebias.estimators import BaseEstimator, Capability
from facebias.estimators import BaseEstimator, Capability, InvalidCapabilityError
from facebias.estimators.mivolo.mi_volo import MiVOLO
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
@@ -80,10 +80,28 @@ class MiVOLOv1(BaseEstimator):
def capabilities() -> list[Capability]:
return [Capability.AGE, Capability.AGEGROUP, Capability.SEX]
def possible_capability_values(cap: Capability) -> list[str]:
if cap == Capability.AGE:
return []
elif cap == Capability.AGEGROUP:
return ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"]
elif cap == Capability.SEX:
return ['m', 'f']
else:
raise InvalidCapabilityError()
def _to_age_label(age):
base_age = int(age // 10 * 10)
return "{}-{}".format(base_age, base_age + 9)
iage = int(age)
if iage < 3:
return "0-2"
elif iage < 10:
return "3-9"
elif iage > 69:
return "70+"
d = int(iage // 10 * 10)
return "{}-{}".format(d, d + 9)
if __name__ == '__main__':