Files
facebias/src/facebias/estimators/mivolov1.py

100 lines
2.9 KiB
Python

# -*- coding: utf-8 -*-
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
import torchvision
from facebias import load_dataset
from facebias.estimators import BaseEstimator, Capability
from facebias.estimators.mivolo.mi_volo import MiVOLO
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
class MiVOLOv1(BaseEstimator):
"""Generic wrapper for a local MiVOLO v1-style face only model.
Parameters
----------
checkpoint_path: Path
Path to the saved checkpoint. Note that we expect the checkpoint to be
in PyTorch format.
device: Optional[torch.device]
Device to load the checkpoints into. By default is "cpu".
Reference
---------
MiVOLO repository: https://github.com/wildchlamydia/mivolo
"""
def __init__(
self,
checkpoint_path: Path,
device: Optional[torch.device]=torch.device("cpu")
):
self.T = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
self.device = device
self.model = MiVOLO(checkpoint_path, half=False, device=device)
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
preds = OrderedDict()
min_age = self.model.meta.min_age
max_age = self.model.meta.max_age
avg_age = self.model.meta.avg_age
for p, im in ims.items():
im = self.T(im).view(1, 3, 224, 224).to(self.device)
with torch.no_grad():
y = self.model.inference(im)
y = y.cpu().detach()
# Age calc.
y_age = y[:, 2]
age = y_age.item() * (max_age - min_age) + avg_age
age = round(age, 2)
# Gender est.
y_gender = y[:, :2].softmax(-1)
gender_prob, gender_idx = y_gender.topk(1)
gender_label = 'm' if gender_idx.item() == 0 else 'f'
preds[p] = {
Capability.AGE: age,
Capability.AGEGROUP: _to_age_label(age),
Capability.SEX: gender_label,
}
return preds
def capabilities() -> list[Capability]:
return [Capability.AGE, Capability.AGEGROUP, Capability.SEX]
def _to_age_label(age):
base_age = int(age // 10 * 10)
return "{}-{}".format(base_age, base_age + 9)
if __name__ == '__main__':
from facebias import load_dataset
DATASET_PATH = Path("data/frll_neutral_front")
dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0])
print(MiVOLOv1.capabilities())
MODEL_PATH = Path("models/volo-v1_model_imdb_age_gender_4.22.pth.tar")
model = MiVOLOv1(MODEL_PATH)
out = model.predict(dataset)