100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from collections import OrderedDict
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
from facebias import load_dataset
|
|
from facebias.estimators import BaseEstimator, Capability
|
|
from facebias.estimators.mivolo.mi_volo import MiVOLO
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from torchvision import transforms
|
|
|
|
|
|
class MiVOLOv1(BaseEstimator):
|
|
"""Generic wrapper for a local MiVOLO v1-style face only model.
|
|
|
|
Parameters
|
|
----------
|
|
checkpoint_path: Path
|
|
Path to the saved checkpoint. Note that we expect the checkpoint to be
|
|
in PyTorch format.
|
|
|
|
device: Optional[torch.device]
|
|
Device to load the checkpoints into. By default is "cpu".
|
|
|
|
Reference
|
|
---------
|
|
MiVOLO repository: https://github.com/wildchlamydia/mivolo
|
|
"""
|
|
def __init__(
|
|
self,
|
|
checkpoint_path: Path,
|
|
device: Optional[torch.device]=torch.device("cpu")
|
|
):
|
|
self.T = transforms.Compose([
|
|
transforms.ToPILImage(),
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
])
|
|
|
|
self.device = device
|
|
self.model = MiVOLO(checkpoint_path, half=False, device=device)
|
|
|
|
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
|
|
preds = OrderedDict()
|
|
min_age = self.model.meta.min_age
|
|
max_age = self.model.meta.max_age
|
|
avg_age = self.model.meta.avg_age
|
|
|
|
for p, im in ims.items():
|
|
im = self.T(im).view(1, 3, 224, 224).to(self.device)
|
|
with torch.no_grad():
|
|
y = self.model.inference(im)
|
|
|
|
y = y.cpu().detach()
|
|
|
|
# Age calc.
|
|
y_age = y[:, 2]
|
|
age = y_age.item() * (max_age - min_age) + avg_age
|
|
age = round(age, 2)
|
|
|
|
# Gender est.
|
|
y_gender = y[:, :2].softmax(-1)
|
|
gender_prob, gender_idx = y_gender.topk(1)
|
|
gender_label = 'm' if gender_idx.item() == 0 else 'f'
|
|
|
|
preds[p] = {
|
|
Capability.AGE: age,
|
|
Capability.AGEGROUP: _to_age_label(age),
|
|
Capability.SEX: gender_label,
|
|
}
|
|
|
|
return preds
|
|
|
|
def capabilities() -> list[Capability]:
|
|
return [Capability.AGE, Capability.AGEGROUP, Capability.SEX]
|
|
|
|
|
|
def _to_age_label(age):
|
|
base_age = int(age // 10 * 10)
|
|
return "{}-{}".format(base_age, base_age + 9)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from facebias import load_dataset
|
|
|
|
DATASET_PATH = Path("data/frll_neutral_front")
|
|
dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0])
|
|
|
|
print(MiVOLOv1.capabilities())
|
|
MODEL_PATH = Path("models/volo-v1_model_imdb_age_gender_4.22.pth.tar")
|
|
model = MiVOLOv1(MODEL_PATH)
|
|
|
|
out = model.predict(dataset)
|