Added capabilities (Capability) to the estimators.
This commit is contained in:
@@ -5,13 +5,33 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class Capability(StrEnum):
|
||||
AGE = "age"
|
||||
AGEGROUP = "age_group"
|
||||
SEX = "sex"
|
||||
SKINCOLOR = "skin_color"
|
||||
ETHNICITY = "ethinicity"
|
||||
|
||||
|
||||
class BaseEstimator(ABC):
|
||||
@abstractmethod
|
||||
def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]:
|
||||
def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
|
||||
"""Runs the estimator on a batch of images.
|
||||
|
||||
The input `images` is a dictionary indexed by the image name, or ID,
|
||||
and values are image data in the appropriate format for the estimator
|
||||
(RGB, BGR' Lab, ...).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def capabilities() -> list[str]:
|
||||
def capabilities() -> list[Capability]:
|
||||
"""Returns the estimator capabilities
|
||||
|
||||
In this context, estimator capabilities mean what is infers from the
|
||||
face image, such as `age_group`, `sex`, `skin_color`, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
@@ -9,12 +8,10 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from facebias import load_dataset
|
||||
from facebias.estimators import BaseEstimator, Capability
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision import transforms
|
||||
|
||||
from facebias.estimators import BaseEstimator
|
||||
|
||||
|
||||
class FairFace(BaseEstimator):
|
||||
"""
|
||||
@@ -41,7 +38,7 @@ class FairFace(BaseEstimator):
|
||||
)
|
||||
return self
|
||||
|
||||
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]:
|
||||
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
|
||||
preds = OrderedDict()
|
||||
for p, im in ims.items():
|
||||
im = self.T(im).view(1, 3, 224, 224).to(self.device)
|
||||
@@ -51,9 +48,9 @@ class FairFace(BaseEstimator):
|
||||
y = y.cpu().detach().squeeze().numpy()
|
||||
|
||||
# Ethnicity prediction
|
||||
# y_ethno = y[:4]
|
||||
# ethno_score = np.exp(y_ethno) / np.sum(np.exp(y_ethno))
|
||||
# ethno_pred = np.argmax(ethno_score)
|
||||
y_ethno = y[:4]
|
||||
ethno_score = np.exp(y_ethno) / np.sum(np.exp(y_ethno))
|
||||
ethno_pred = np.argmax(ethno_score)
|
||||
|
||||
# Age prediction
|
||||
y_age = y[9:18]
|
||||
@@ -67,26 +64,37 @@ class FairFace(BaseEstimator):
|
||||
sex_label = "m" if not sex_pred else "f"
|
||||
|
||||
preds[p] = {
|
||||
"age_group": _to_age_label(age_pred),
|
||||
"sex": sex_label,
|
||||
Capability.AGEGROUP: _to_age_label(age_pred),
|
||||
Capability.SEX: sex_label,
|
||||
# Capability.ETHNICITY: _to_ethno_label(ethno_pred)
|
||||
}
|
||||
|
||||
return preds
|
||||
|
||||
def capabilities() -> list[str]:
|
||||
return ["age_group", "sex"]
|
||||
def capabilities() -> list[Capability]:
|
||||
return [Capability.AGEGROUP, Capability.SEX] #, Capability.ETHNICITY]
|
||||
|
||||
|
||||
def _to_age_label(age):
|
||||
return "{}-{}".format(age * 10, age * 10 + 9)
|
||||
|
||||
|
||||
def _to_ethno_label(val):
|
||||
if val not in range(4):
|
||||
return "unknown"
|
||||
opts = ["white", "black", "asian", "indian"]
|
||||
return opts[val]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
MODEL_PATH = Path("../../../models/fairface_alldata_4race_20191111.pt")
|
||||
model = FairFace()
|
||||
model.load_state_dict(MODEL_PATH)
|
||||
from facebias import load_dataset
|
||||
|
||||
DATASET_PATH = Path("../../../data/frll_neutral_front")
|
||||
dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0])
|
||||
|
||||
print(FairFace.capabilities())
|
||||
MODEL_PATH = Path("../../../models/fairface_alldata_4race_20191111.pt")
|
||||
model = FairFace()
|
||||
model.load_state_dict(MODEL_PATH)
|
||||
|
||||
out = model.predict(dataset)
|
||||
|
||||
Reference in New Issue
Block a user