From 27efada40e9a52556053d2865092e625ffe488eb Mon Sep 17 00:00:00 2001 From: Guilherme Schardong Date: Thu, 16 Apr 2026 13:59:49 +0100 Subject: [PATCH] Adding the MiVOLOv1 Face only estimator. Minor refactoring. --- src/facebias/estimators/mivolo/mi_volo.py | 42 +--------- src/facebias/estimators/mivolov1.py | 95 +++++++++++++++++++++++ 2 files changed, 99 insertions(+), 38 deletions(-) create mode 100644 src/facebias/estimators/mivolov1.py diff --git a/src/facebias/estimators/mivolo/mi_volo.py b/src/facebias/estimators/mivolo/mi_volo.py index 3077bea..1fec6ef 100644 --- a/src/facebias/estimators/mivolo/mi_volo.py +++ b/src/facebias/estimators/mivolo/mi_volo.py @@ -217,8 +217,9 @@ class MiVOLO: def inference(self, model_input: torch.Tensor) -> torch.Tensor: with torch.no_grad(): if self.half: - model_input = model_input.half() - output = self.model(model_input) + output = self.model(model_input.half()) + else: + output = self.model(model_input) return output # def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): @@ -245,42 +246,7 @@ class MiVOLO: # # write gender and age results into detected_bboxes # self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds) - def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds): - if self.meta.only_age: - age_output = output - gender_probs, gender_indx = None, None - else: - age_output = output[:, 2] - gender_output = output[:, :2].softmax(-1) - gender_probs, gender_indx = gender_output.topk(1) - - assert output.shape[0] == len(faces_inds) == len(bodies_inds) - - # per face - for index in range(output.shape[0]): - face_ind = faces_inds[index] - body_ind = bodies_inds[index] - - # get_age - age = age_output[index].item() - age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age - age = round(age, 2) - - detected_bboxes.set_age(face_ind, age) - detected_bboxes.set_age(body_ind, age) - - _logger.info(f"\tage: {age}") - - if gender_probs is not None: - gender = "male" if gender_indx[index].item() == 0 else "female" - gender_score = gender_probs[index].item() - - _logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]") - - detected_bboxes.set_gender(face_ind, gender, gender_score) - detected_bboxes.set_gender(body_ind, gender, gender_score) - - # def prepare_crops_old(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): + # def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): # crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image) # (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies( diff --git a/src/facebias/estimators/mivolov1.py b/src/facebias/estimators/mivolov1.py new file mode 100644 index 0000000..eebc36f --- /dev/null +++ b/src/facebias/estimators/mivolov1.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +from collections import OrderedDict +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from facebias import load_dataset +from facebias.estimators import BaseEstimator, Capability +from facebias.estimators.mivolo.mi_volo import MiVOLO +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision import transforms + + +class MiVOLOv1(BaseEstimator): + """Generic wrapper for a local MiVOLO v1-style face only model. + + Parameters + ---------- + checkpoint_path: Path + Path to the saved checkpoint. Note that we expect the checkpoint to be + in PyTorch format. + + device: Optional[torch.device] + Device to load the checkpoints into. By default is "cpu". + """ + def __init__( + self, + checkpoint_path: Path, + device: Optional[torch.device]=torch.device("cpu") + ): + self.T = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + ]) + + self.device = device + self.model = MiVOLO(checkpoint_path, half=False, device=device) + + def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]: + preds = OrderedDict() + min_age = self.model.meta.min_age + max_age = self.model.meta.max_age + avg_age = self.model.meta.avg_age + + for p, im in ims.items(): + im = self.T(im).view(1, 3, 224, 224).to(self.device) + with torch.no_grad(): + y = self.model.inference(im) + + y = y.cpu().detach() + + # Age calc. + y_age = y[:, 2] + age = y_age.item() * (max_age - min_age) + avg_age + age = round(age, 2) + + # Gender est. + y_gender = y[:, :2].softmax(-1) + gender_prob, gender_idx = y_gender.topk(1) + gender_label = 'm' if gender_idx.item() == 0 else 'f' + + preds[p] = { + Capability.AGE: age, + Capability.AGEGROUP: _to_age_label(age), + Capability.SEX: gender_label, + } + + return preds + + def capabilities() -> list[Capability]: + return [Capability.AGE, Capability.AGEGROUP, Capability.SEX] + + +def _to_age_label(age): + base_age = int(age // 10 * 10) + return "{}-{}".format(base_age, base_age + 9) + + +if __name__ == '__main__': + from facebias import load_dataset + + DATASET_PATH = Path("data/frll_neutral_front") + dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0]) + + print(MiVOLOv1.capabilities()) + MODEL_PATH = Path("models/volo-v1_model_imdb_age_gender_4.22.pth.tar") + model = MiVOLOv1(MODEL_PATH) + + out = model.predict(dataset)