diff --git a/src/facebias/estimators/__init__.py b/src/facebias/estimators/__init__.py index e632b5d..f369c98 100644 --- a/src/facebias/estimators/__init__.py +++ b/src/facebias/estimators/__init__.py @@ -16,6 +16,11 @@ class Capability(StrEnum): ETHNICITY = "ethinicity" +class InvalidCapabilityError(Exception): + def __init__(self, msg="Invalid capability for model."): + super(InvalidCapability, self).__init__(msg=msg) + + class BaseEstimator(ABC): @abstractmethod def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]: @@ -35,3 +40,8 @@ class BaseEstimator(ABC): face image, such as `age_group`, `sex`, `skin_color`, etc. """ pass + + @abstractclassmethod + def possible_capability_values(cap: Capability) -> list: + """Returns all possible values for a given model capability.""" + pass diff --git a/src/facebias/estimators/fairface.py b/src/facebias/estimators/fairface.py index 703f9d2..e884741 100644 --- a/src/facebias/estimators/fairface.py +++ b/src/facebias/estimators/fairface.py @@ -8,7 +8,7 @@ import numpy as np import torch import torch.nn as nn import torchvision -from facebias.estimators import BaseEstimator, Capability +from facebias.estimators import BaseEstimator, Capability, InvalidCapabilityError from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import transforms @@ -96,9 +96,20 @@ class FairFace(BaseEstimator): def capabilities() -> list[Capability]: return [Capability.AGEGROUP, Capability.SEX] #, Capability.ETHNICITY] + def possible_capability_values(cap: Capability) -> list: + if cap == Capability.AGEGROUP: + return ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"] + elif cap == Capability.SEX: + return ['m', 'f'] + # elif cap == Capability.ETHNICITY: + # return ["White", "Black", "Asian", "Indian"] + else: + raise InvalidCapabilityError() + def _to_age_label(age): - return "{}-{}".format(age * 10, age * 10 + 9) + return FairFace.possible_classes(Capability.AGEGROUP)[age] + # return "{}-{}".format(age * 10, age * 10 + 9) def _to_ethno_label(val): diff --git a/src/facebias/estimators/mivolov1.py b/src/facebias/estimators/mivolov1.py index ac40abb..26bbbff 100644 --- a/src/facebias/estimators/mivolov1.py +++ b/src/facebias/estimators/mivolov1.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn import torchvision from facebias import load_dataset -from facebias.estimators import BaseEstimator, Capability +from facebias.estimators import BaseEstimator, Capability, InvalidCapabilityError from facebias.estimators.mivolo.mi_volo import MiVOLO from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import transforms @@ -80,10 +80,28 @@ class MiVOLOv1(BaseEstimator): def capabilities() -> list[Capability]: return [Capability.AGE, Capability.AGEGROUP, Capability.SEX] + def possible_capability_values(cap: Capability) -> list[str]: + if cap == Capability.AGE: + return [] + elif cap == Capability.AGEGROUP: + return ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"] + elif cap == Capability.SEX: + return ['m', 'f'] + else: + raise InvalidCapabilityError() + def _to_age_label(age): - base_age = int(age // 10 * 10) - return "{}-{}".format(base_age, base_age + 9) + iage = int(age) + if iage < 3: + return "0-2" + elif iage < 10: + return "3-9" + elif iage > 69: + return "70+" + + d = int(iage // 10 * 10) + return "{}-{}".format(d, d + 9) if __name__ == '__main__':