Adding the MiVOLOv1 Face only estimator. Minor refactoring.
This commit is contained in:
@@ -217,8 +217,9 @@ class MiVOLO:
|
||||
def inference(self, model_input: torch.Tensor) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
if self.half:
|
||||
model_input = model_input.half()
|
||||
output = self.model(model_input)
|
||||
output = self.model(model_input.half())
|
||||
else:
|
||||
output = self.model(model_input)
|
||||
return output
|
||||
|
||||
# def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
|
||||
@@ -245,42 +246,7 @@ class MiVOLO:
|
||||
# # write gender and age results into detected_bboxes
|
||||
# self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
|
||||
|
||||
def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
|
||||
if self.meta.only_age:
|
||||
age_output = output
|
||||
gender_probs, gender_indx = None, None
|
||||
else:
|
||||
age_output = output[:, 2]
|
||||
gender_output = output[:, :2].softmax(-1)
|
||||
gender_probs, gender_indx = gender_output.topk(1)
|
||||
|
||||
assert output.shape[0] == len(faces_inds) == len(bodies_inds)
|
||||
|
||||
# per face
|
||||
for index in range(output.shape[0]):
|
||||
face_ind = faces_inds[index]
|
||||
body_ind = bodies_inds[index]
|
||||
|
||||
# get_age
|
||||
age = age_output[index].item()
|
||||
age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
|
||||
age = round(age, 2)
|
||||
|
||||
detected_bboxes.set_age(face_ind, age)
|
||||
detected_bboxes.set_age(body_ind, age)
|
||||
|
||||
_logger.info(f"\tage: {age}")
|
||||
|
||||
if gender_probs is not None:
|
||||
gender = "male" if gender_indx[index].item() == 0 else "female"
|
||||
gender_score = gender_probs[index].item()
|
||||
|
||||
_logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
|
||||
|
||||
detected_bboxes.set_gender(face_ind, gender, gender_score)
|
||||
detected_bboxes.set_gender(body_ind, gender, gender_score)
|
||||
|
||||
# def prepare_crops_old(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
|
||||
# def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
|
||||
|
||||
# crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
|
||||
# (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(
|
||||
|
||||
95
src/facebias/estimators/mivolov1.py
Normal file
95
src/facebias/estimators/mivolov1.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from facebias import load_dataset
|
||||
from facebias.estimators import BaseEstimator, Capability
|
||||
from facebias.estimators.mivolo.mi_volo import MiVOLO
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class MiVOLOv1(BaseEstimator):
|
||||
"""Generic wrapper for a local MiVOLO v1-style face only model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
checkpoint_path: Path
|
||||
Path to the saved checkpoint. Note that we expect the checkpoint to be
|
||||
in PyTorch format.
|
||||
|
||||
device: Optional[torch.device]
|
||||
Device to load the checkpoints into. By default is "cpu".
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: Path,
|
||||
device: Optional[torch.device]=torch.device("cpu")
|
||||
):
|
||||
self.T = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
])
|
||||
|
||||
self.device = device
|
||||
self.model = MiVOLO(checkpoint_path, half=False, device=device)
|
||||
|
||||
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
|
||||
preds = OrderedDict()
|
||||
min_age = self.model.meta.min_age
|
||||
max_age = self.model.meta.max_age
|
||||
avg_age = self.model.meta.avg_age
|
||||
|
||||
for p, im in ims.items():
|
||||
im = self.T(im).view(1, 3, 224, 224).to(self.device)
|
||||
with torch.no_grad():
|
||||
y = self.model.inference(im)
|
||||
|
||||
y = y.cpu().detach()
|
||||
|
||||
# Age calc.
|
||||
y_age = y[:, 2]
|
||||
age = y_age.item() * (max_age - min_age) + avg_age
|
||||
age = round(age, 2)
|
||||
|
||||
# Gender est.
|
||||
y_gender = y[:, :2].softmax(-1)
|
||||
gender_prob, gender_idx = y_gender.topk(1)
|
||||
gender_label = 'm' if gender_idx.item() == 0 else 'f'
|
||||
|
||||
preds[p] = {
|
||||
Capability.AGE: age,
|
||||
Capability.AGEGROUP: _to_age_label(age),
|
||||
Capability.SEX: gender_label,
|
||||
}
|
||||
|
||||
return preds
|
||||
|
||||
def capabilities() -> list[Capability]:
|
||||
return [Capability.AGE, Capability.AGEGROUP, Capability.SEX]
|
||||
|
||||
|
||||
def _to_age_label(age):
|
||||
base_age = int(age // 10 * 10)
|
||||
return "{}-{}".format(base_age, base_age + 9)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from facebias import load_dataset
|
||||
|
||||
DATASET_PATH = Path("data/frll_neutral_front")
|
||||
dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0])
|
||||
|
||||
print(MiVOLOv1.capabilities())
|
||||
MODEL_PATH = Path("models/volo-v1_model_imdb_age_gender_4.22.pth.tar")
|
||||
model = MiVOLOv1(MODEL_PATH)
|
||||
|
||||
out = model.predict(dataset)
|
||||
Reference in New Issue
Block a user