#!/usr/bin/env python # -*- coding: utf-8 -*- import logging from pathlib import Path import mediapipe as mp import numpy as np import pandas as pd import torch.nn as nn from facebias import load_dataset from facebias.detectors import get_face_boxes from facebias.detectors.mediapipe import MediapipeDetector from facebias.estimators.fairface import FairFace from facebias.estimators.mivolov1 import MiVOLOv1 from facebias.metrics import calc_metrics_per_subgroup, calc_metrics logging.basicConfig(level=logging.INFO) logger = logging.getLogger(f"facebias:{__file__}") if __name__ == "__main__": import os logger.info(os.getcwd()) DATASET_PATH = Path("data/facing2-train/") METADATA_PATH = DATASET_PATH / "meta-w-age.csv" DETECTOR_PATH = Path("models/blaze_face_full_range.tflite") # TEST_IMS: list[str] = ["10", "12", "14", "9"] detector = MediapipeDetector(str(DETECTOR_PATH)) imdict, meta = load_dataset( DATASET_PATH, meta_path=METADATA_PATH, imname_proc_fn=lambda x: x.split("_")[0] ) for im, feats in meta.items(): age = int(feats["age"]) d = age // 10 * 10 feats["age_group"] = "{}-{}".format(d, d + 9) feats["age"] = float(feats["age"]) face_bboxes = get_face_boxes(imdict, detector) # for t in TEST_IMS: # logger.info("-- {} - {}".format(t, meta[str(t)])) print(FairFace.capabilities()) model_ff = FairFace(Path("models/fairface_alldata_4race_20191111.pt"), device="cpu") preds_ff = model_ff.predict(imdict) metrics_ff = calc_model_performance(meta, preds_ff) # logger.info("FairFace -- Test Images") # for t in TEST_IMS: # logger.info("--{} - {}".format(t, preds_ff[str(t)])) metrics_ff_groups = calc_metrics_per_subgroup(meta, preds_ff) for k, v in metrics_ff_groups.items(): for kv, vv in v.items(): print(k, kv, vv) print(MiVOLOv1.capabilities()) model_mv = MiVOLOv1( Path("models/volo-v1_model_imdb_age_gender_4.22.pth.tar"), device="cpu" ) preds_mv = model_mv.predict(imdict) # logger.info("MiVOLOv1(Face Only) -- Test Images") # for t in TEST_IMS: # logger.info("{} - {}".format(t, preds_mv[str(t)])) metrics_mv = calc_model_performance(meta, preds_mv) metrics_mv_groups = calc_metrics_per_subgroup(meta, preds_mv) for k, v in metrics_mv_groups.items(): for kv, vv in v.items(): print(k, kv, vv)