Adding the MiVOLOv1 Face only estimator. Minor refactoring.

This commit is contained in:
2026-04-16 13:59:49 +01:00
parent bceea16708
commit 27efada40e
2 changed files with 99 additions and 38 deletions

View File

@@ -217,8 +217,9 @@ class MiVOLO:
def inference(self, model_input: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
if self.half:
model_input = model_input.half()
output = self.model(model_input)
output = self.model(model_input.half())
else:
output = self.model(model_input)
return output
# def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
@@ -245,42 +246,7 @@ class MiVOLO:
# # write gender and age results into detected_bboxes
# self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
if self.meta.only_age:
age_output = output
gender_probs, gender_indx = None, None
else:
age_output = output[:, 2]
gender_output = output[:, :2].softmax(-1)
gender_probs, gender_indx = gender_output.topk(1)
assert output.shape[0] == len(faces_inds) == len(bodies_inds)
# per face
for index in range(output.shape[0]):
face_ind = faces_inds[index]
body_ind = bodies_inds[index]
# get_age
age = age_output[index].item()
age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
age = round(age, 2)
detected_bboxes.set_age(face_ind, age)
detected_bboxes.set_age(body_ind, age)
_logger.info(f"\tage: {age}")
if gender_probs is not None:
gender = "male" if gender_indx[index].item() == 0 else "female"
gender_score = gender_probs[index].item()
_logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
detected_bboxes.set_gender(face_ind, gender, gender_score)
detected_bboxes.set_gender(body_ind, gender, gender_score)
# def prepare_crops_old(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
# def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
# crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
# (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(

View File

@@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
import torchvision
from facebias import load_dataset
from facebias.estimators import BaseEstimator, Capability
from facebias.estimators.mivolo.mi_volo import MiVOLO
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
class MiVOLOv1(BaseEstimator):
"""Generic wrapper for a local MiVOLO v1-style face only model.
Parameters
----------
checkpoint_path: Path
Path to the saved checkpoint. Note that we expect the checkpoint to be
in PyTorch format.
device: Optional[torch.device]
Device to load the checkpoints into. By default is "cpu".
"""
def __init__(
self,
checkpoint_path: Path,
device: Optional[torch.device]=torch.device("cpu")
):
self.T = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
self.device = device
self.model = MiVOLO(checkpoint_path, half=False, device=device)
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
preds = OrderedDict()
min_age = self.model.meta.min_age
max_age = self.model.meta.max_age
avg_age = self.model.meta.avg_age
for p, im in ims.items():
im = self.T(im).view(1, 3, 224, 224).to(self.device)
with torch.no_grad():
y = self.model.inference(im)
y = y.cpu().detach()
# Age calc.
y_age = y[:, 2]
age = y_age.item() * (max_age - min_age) + avg_age
age = round(age, 2)
# Gender est.
y_gender = y[:, :2].softmax(-1)
gender_prob, gender_idx = y_gender.topk(1)
gender_label = 'm' if gender_idx.item() == 0 else 'f'
preds[p] = {
Capability.AGE: age,
Capability.AGEGROUP: _to_age_label(age),
Capability.SEX: gender_label,
}
return preds
def capabilities() -> list[Capability]:
return [Capability.AGE, Capability.AGEGROUP, Capability.SEX]
def _to_age_label(age):
base_age = int(age // 10 * 10)
return "{}-{}".format(base_age, base_age + 9)
if __name__ == '__main__':
from facebias import load_dataset
DATASET_PATH = Path("data/frll_neutral_front")
dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0])
print(MiVOLOv1.capabilities())
MODEL_PATH = Path("models/volo-v1_model_imdb_age_gender_4.22.pth.tar")
model = MiVOLOv1(MODEL_PATH)
out = model.predict(dataset)