Files
facebias/benchmark.py

76 lines
2.4 KiB
Python
Executable File

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
import mediapipe as mp
import numpy as np
import pandas as pd
import torch.nn as nn
from facebias import load_dataset
from facebias.detectors import get_face_boxes
from facebias.detectors.mediapipe import MediapipeDetector
from facebias.estimators.fairface import FairFace
from facebias.estimators.mivolov1 import MiVOLOv1
from facebias.metrics import calc_metrics_per_subgroup, calc_metrics
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(f"facebias:{__file__}")
if __name__ == "__main__":
import os
logger.info(os.getcwd())
DATASET_PATH = Path("data/facing2-train/")
METADATA_PATH = DATASET_PATH / "meta-w-age.csv"
DETECTOR_PATH = Path("models/blaze_face_full_range.tflite")
# TEST_IMS: list[str] = ["10", "12", "14", "9"]
detector = MediapipeDetector(str(DETECTOR_PATH))
imdict, meta = load_dataset(
DATASET_PATH, meta_path=METADATA_PATH, imname_proc_fn=lambda x: x.split("_")[0]
)
for im, feats in meta.items():
age = int(feats["age"])
d = age // 10 * 10
feats["age_group"] = "{}-{}".format(d, d + 9)
feats["age"] = float(feats["age"])
face_bboxes = get_face_boxes(imdict, detector)
# for t in TEST_IMS:
# logger.info("-- {} - {}".format(t, meta[str(t)]))
print(FairFace.capabilities())
model_ff = FairFace(Path("models/fairface_alldata_4race_20191111.pt"), device="cpu")
preds_ff = model_ff.predict(imdict)
metrics_ff = calc_model_performance(meta, preds_ff)
# logger.info("FairFace -- Test Images")
# for t in TEST_IMS:
# logger.info("--{} - {}".format(t, preds_ff[str(t)]))
metrics_ff_groups = calc_metrics_per_subgroup(meta, preds_ff)
for k, v in metrics_ff_groups.items():
for kv, vv in v.items():
print(k, kv, vv)
print(MiVOLOv1.capabilities())
model_mv = MiVOLOv1(
Path("models/volo-v1_model_imdb_age_gender_4.22.pth.tar"), device="cpu"
)
preds_mv = model_mv.predict(imdict)
# logger.info("MiVOLOv1(Face Only) -- Test Images")
# for t in TEST_IMS:
# logger.info("{} - {}".format(t, preds_mv[str(t)]))
metrics_mv = calc_model_performance(meta, preds_mv)
metrics_mv_groups = calc_metrics_per_subgroup(meta, preds_mv)
for k, v in metrics_mv_groups.items():
for kv, vv in v.items():
print(k, kv, vv)