Added the possible_capability_values base method for estimators...
New `possible_capability_values` base method for estimators. The goal is to return a list of possible values for each non-numerical capability, as some performance metrics may benefit from this complete listing. Additionally, added a new exception to indicate when a capability is invalid in a given context, e.g., when the model does not support it.
This commit is contained in:
@@ -16,6 +16,11 @@ class Capability(StrEnum):
|
||||
ETHNICITY = "ethinicity"
|
||||
|
||||
|
||||
class InvalidCapabilityError(Exception):
|
||||
def __init__(self, msg="Invalid capability for model."):
|
||||
super(InvalidCapability, self).__init__(msg=msg)
|
||||
|
||||
|
||||
class BaseEstimator(ABC):
|
||||
@abstractmethod
|
||||
def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
|
||||
@@ -35,3 +40,8 @@ class BaseEstimator(ABC):
|
||||
face image, such as `age_group`, `sex`, `skin_color`, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def possible_capability_values(cap: Capability) -> list:
|
||||
"""Returns all possible values for a given model capability."""
|
||||
pass
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from facebias.estimators import BaseEstimator, Capability
|
||||
from facebias.estimators import BaseEstimator, Capability, InvalidCapabilityError
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision import transforms
|
||||
|
||||
@@ -96,9 +96,20 @@ class FairFace(BaseEstimator):
|
||||
def capabilities() -> list[Capability]:
|
||||
return [Capability.AGEGROUP, Capability.SEX] #, Capability.ETHNICITY]
|
||||
|
||||
def possible_capability_values(cap: Capability) -> list:
|
||||
if cap == Capability.AGEGROUP:
|
||||
return ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"]
|
||||
elif cap == Capability.SEX:
|
||||
return ['m', 'f']
|
||||
# elif cap == Capability.ETHNICITY:
|
||||
# return ["White", "Black", "Asian", "Indian"]
|
||||
else:
|
||||
raise InvalidCapabilityError()
|
||||
|
||||
|
||||
def _to_age_label(age):
|
||||
return "{}-{}".format(age * 10, age * 10 + 9)
|
||||
return FairFace.possible_classes(Capability.AGEGROUP)[age]
|
||||
# return "{}-{}".format(age * 10, age * 10 + 9)
|
||||
|
||||
|
||||
def _to_ethno_label(val):
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from facebias import load_dataset
|
||||
from facebias.estimators import BaseEstimator, Capability
|
||||
from facebias.estimators import BaseEstimator, Capability, InvalidCapabilityError
|
||||
from facebias.estimators.mivolo.mi_volo import MiVOLO
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision import transforms
|
||||
@@ -80,10 +80,28 @@ class MiVOLOv1(BaseEstimator):
|
||||
def capabilities() -> list[Capability]:
|
||||
return [Capability.AGE, Capability.AGEGROUP, Capability.SEX]
|
||||
|
||||
def possible_capability_values(cap: Capability) -> list[str]:
|
||||
if cap == Capability.AGE:
|
||||
return []
|
||||
elif cap == Capability.AGEGROUP:
|
||||
return ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"]
|
||||
elif cap == Capability.SEX:
|
||||
return ['m', 'f']
|
||||
else:
|
||||
raise InvalidCapabilityError()
|
||||
|
||||
|
||||
def _to_age_label(age):
|
||||
base_age = int(age // 10 * 10)
|
||||
return "{}-{}".format(base_age, base_age + 9)
|
||||
iage = int(age)
|
||||
if iage < 3:
|
||||
return "0-2"
|
||||
elif iage < 10:
|
||||
return "3-9"
|
||||
elif iage > 69:
|
||||
return "70+"
|
||||
|
||||
d = int(iage // 10 * 10)
|
||||
return "{}-{}".format(d, d + 9)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user