From 1c11b413290b5e107e425b72139857991b11d35e Mon Sep 17 00:00:00 2001 From: "Guilherme G. Schardong" Date: Wed, 15 Apr 2026 22:24:58 +0100 Subject: [PATCH] Added setup and basic module sources. --- pyproject.toml | 30 ++ src/facebias/__init__.py | 65 +++ src/facebias/detectors/__init__.py | 26 ++ src/facebias/detectors/mediapipe.py | 68 +++ src/facebias/estimators/__init__.py | 17 + src/facebias/estimators/fairface.py | 92 ++++ src/facebias/estimators/mivolo/__init__.py | 0 .../estimators/mivolo/create_timm_model.py | 105 +++++ .../mivolo/cross_bottleneck_attn.py | 116 +++++ src/facebias/estimators/mivolo/mi_volo.py | 313 ++++++++++++++ .../estimators/mivolo/mivolo_model.py | 405 ++++++++++++++++++ 11 files changed, 1237 insertions(+) create mode 100644 pyproject.toml create mode 100644 src/facebias/__init__.py create mode 100644 src/facebias/detectors/__init__.py create mode 100755 src/facebias/detectors/mediapipe.py create mode 100644 src/facebias/estimators/__init__.py create mode 100644 src/facebias/estimators/fairface.py create mode 100644 src/facebias/estimators/mivolo/__init__.py create mode 100644 src/facebias/estimators/mivolo/create_timm_model.py create mode 100644 src/facebias/estimators/mivolo/cross_bottleneck_attn.py create mode 100644 src/facebias/estimators/mivolo/mi_volo.py create mode 100644 src/facebias/estimators/mivolo/mivolo_model.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..97fd1f9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = "facebias" +version = "0.0.1" +description = "Face image bias estimation and benchmarking" +license = "MIT" +readme = "README.md" +requires-python = ">= 3.12" + +[build-system] +requires = ["setuptools >= 77.0.3"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + +[project.optional-dependencies] +dev = [ + "jedi-language-server>=0.46.0", + "black>=25.12.0", + "flake8>=7.3.0", + "isort>=8.0.1", + "mypy>=1.19.0", + "pytest>=9.0.0" +] diff --git a/src/facebias/__init__.py b/src/facebias/__init__.py new file mode 100644 index 0000000..243d57b --- /dev/null +++ b/src/facebias/__init__.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- + +import logging +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Optional + +import cv2 +import numpy as np + +logging.basicConfig( + level=logging.DEBUG, # Set the minimum logging level + # format="%(asctime)s - [%(levelname)s] - %(filename)s:%(lineno)s - %(message)s", +) + +logger = logging.getLogger(__name__) + + +@dataclass +class FaceBox: + x1: int + y1: int + x2: int + y2: int + + +def load_dataset( + root: Path, + meta_path: Optional[Path] = None, + imname_proc_fn: Optional[Callable]=None +) -> tuple[dict[Path, np.ndarray], dict[str, dict[str, Any]]]: + metadata = dict() + paths = set([p for p in root.iterdir() if not p.is_dir()]) + if meta_path is not None and meta_path in paths: + # Just read the metadata and avoid needing to test for it in the main loop. + metadata = load_metadata(meta_path) + paths.remove(meta_path) + + ims = OrderedDict() + # TODO(gschardong): Paralellize this? + for p in paths: + try: + im = cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB) + except cv2.error: + logger.info(f'File "{p}" is not an image.') + if meta_path is None: + logger.info(f'Trying to read "{p}" as metadata.') + if p.suffix == ".csv": + metadata = load_metadata(p) + logger.info("Metadata read successfully.") + else: + logger.error(f'Failed to read "{p}" as metadata. Skipping.') + elif not metadata: + logger.critical("Logic error: Metadata should have been read already.") + else: + proc_imname = imname_proc_fn(p.name) if imname_proc_fn is not None else p.name + ims[proc_imname] = im + + if not metadata: + logger.warning("No metadata found.") + if meta_path is not None: + logger.error(f'Metadata file not found at "{meta_path}".') + + return ims, metadata diff --git a/src/facebias/detectors/__init__.py b/src/facebias/detectors/__init__.py new file mode 100644 index 0000000..e659dac --- /dev/null +++ b/src/facebias/detectors/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +from abc import ABC, abstractmethod + +import numpy as np + + +class NoFaceDetectedError(Exception): + def __init__(self, message: str = "No face detected in image."): + self.message = message + super(NoFaceDetectedError, self).__init__(self.message) + + +class BaseFaceDetector(ABC): + """Base class for face detectors. + + This classe provides the base interface for supported face detection + methods. An example is provided in the `MediapipeDetector` class in + `facebias.detectors.mediapipe`. Its goal is simply to load the resources + and perform face detection, whenever a call to `predict` is performed.""" + + @abstractmethod + def detect(self, im) -> np.ndarray: + """Detects faces in image `im`. Throws `NoFaceDetectedError` + if `im` has no valid face(s).""" + pass diff --git a/src/facebias/detectors/mediapipe.py b/src/facebias/detectors/mediapipe.py new file mode 100755 index 0000000..2d52ac1 --- /dev/null +++ b/src/facebias/detectors/mediapipe.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from facebias import FaceBox +from facebias.detectors import BaseFaceDetector, NoFaceDetectedError + +from mediapipe.tasks.python import BaseOptions, vision + + +class MediapipeDetector(BaseFaceDetector): + """Uses mediapipe's blazeface model to detect faces in images or video frames. + + Parameters + ---------- + model_path: str, PathLike + mode: Optional[mediapipe.tasks.python.vision.RunningMode] + """ + + def __init__(self, model_path: str, mode=vision.RunningMode.IMAGE) -> None: + super(MediapipeDetector, self).__init__() + baseopts = BaseOptions(model_asset_path=model_path) + lmopts = vision.FaceDetectorOptions( + base_options=baseopts, + running_mode=mode, + ) + self.mode = mode + self.detector = vision.FaceDetector.create_from_options(lmopts) + + def detect(self, im: np.ndarray, timestamp_ms: float = 0.0) -> list[FaceBox]: + """Returns a list with bounding boxes for each face found in `im`. + + `timestamp_ms` is only used for videos. + """ + detect_results = None + mpim = mp.Image(image_format=mp.ImageFormat.SRGB, data=im) + if self.mode == vision.RunningMode.IMAGE: + detect_results = self.detector.detect(mpim) + else: + detect_results = self.detector.detect_for_video(mpim, round(timestamp_ms)) + + if detect_results is None or not detect_results.detections: + raise NoFaceDetectedError() + + faces = [None for _ in detect_results.detections] + for i, d in enumerate(detect_results.detections): + bbox = d.bounding_box + faces[i] = FaceBox( + x1=bbox.origin_y, + y1=bbox.origin_x, + x2=bbox.origin_y + bbox.height, + y2=bbox.origin_x + bbox.width, + ) + + return faces + + +if __name__ == '__main__': + import cv2 + from pathlib import Path + + TEST_IM_PATH = Path("data/frll_neutral_front/001_03.jpg") + DETECTOR_PATH = Path("models/blaze_face_short_range.tflite") + + detector = MediapipeDetector(str(DETECTOR_PATH)) + + im = cv2.cvtColor(cv2.imread(TEST_IM_PATH), cv2.COLOR_BGR2RGB) + crop = detector.detect(im) + print(crop) diff --git a/src/facebias/estimators/__init__.py b/src/facebias/estimators/__init__.py new file mode 100644 index 0000000..1146c34 --- /dev/null +++ b/src/facebias/estimators/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +from abc import ABC, abstractclassmethod, abstractmethod +from pathlib import Path +from typing import Any + +import numpy as np + + +class BaseEstimator(ABC): + @abstractmethod + def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]: + pass + + @abstractclassmethod + def capabilities() -> list[str]: + pass diff --git a/src/facebias/estimators/fairface.py b/src/facebias/estimators/fairface.py new file mode 100644 index 0000000..7795b76 --- /dev/null +++ b/src/facebias/estimators/fairface.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- + +import logging +from collections import OrderedDict +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from facebias import load_dataset +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision import transforms + +from facebias.estimators import BaseEstimator + + +class FairFace(BaseEstimator): + """ + """ + def __init__(self, checkpoint_path: Optional[Path]=None, device=torch.device("cpu")): + self.T = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + ]) + + self.device = device + self.model = torchvision.models.resnet34(pretrained=True) + self.model.fc = nn.Linear(self.model.fc.in_features, 18) + self.model = self.model.eval().to(device) + if checkpoint_path is not None: + self.load_state_dict(checkpoint_path) + self.model.to(device) + + def load_state_dict(self, checkpoint_path: Path) -> "FairFace": + self.model.load_state_dict( + torch.load(checkpoint_path, map_location=torch.device("cpu")) + ) + return self + + def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]: + preds = OrderedDict() + for p, im in ims.items(): + im = self.T(im).view(1, 3, 224, 224).to(self.device) + with torch.no_grad(): + y = self.model(im) + + y = y.cpu().detach().squeeze().numpy() + + # Ethnicity prediction + # y_ethno = y[:4] + # ethno_score = np.exp(y_ethno) / np.sum(np.exp(y_ethno)) + # ethno_pred = np.argmax(ethno_score) + + # Age prediction + y_age = y[9:18] + age_score = np.exp(y_age) / np.sum(np.exp(y_age)) + age_pred = np.argmax(age_score) + + # Gender prediction + y_sex = y[7:9] + sex_score = np.exp(y_sex) / np.sum(np.exp(y_sex)) + sex_pred = np.argmax(sex_score) + sex_label = "m" if not sex_pred else "f" + + preds[p] = { + "age_group": _to_age_label(age_pred), + "sex": sex_label, + } + + return preds + + def capabilities() -> list[str]: + return ["age_group", "sex"] + + +def _to_age_label(age): + return "{}-{}".format(age * 10, age * 10 + 9) + + +if __name__ == '__main__': + MODEL_PATH = Path("../../../models/fairface_alldata_4race_20191111.pt") + model = FairFace() + model.load_state_dict(MODEL_PATH) + + DATASET_PATH = Path("../../../data/frll_neutral_front") + dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0]) + + out = model.predict(dataset) diff --git a/src/facebias/estimators/mivolo/__init__.py b/src/facebias/estimators/mivolo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/facebias/estimators/mivolo/create_timm_model.py b/src/facebias/estimators/mivolo/create_timm_model.py new file mode 100644 index 0000000..68cd981 --- /dev/null +++ b/src/facebias/estimators/mivolo/create_timm_model.py @@ -0,0 +1,105 @@ +""" +Code adapted from timm https://github.com/huggingface/pytorch-image-models + +Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich +""" + +import os +from typing import Any, Dict, Optional, Union + +import timm +# register new models +from facebias.estimators.mivolo.mivolo_model import * # noqa: F403, F401 +from timm.layers import set_layer_config +from timm.models._factory import parse_model_name +from timm.models._helpers import load_state_dict +from timm.models._hub import load_model_config_from_hf +from timm.models._pretrained import PretrainedCfg +from timm.models._registry import (is_model, model_entrypoint, + split_model_name_tag) + + +def load_checkpoint( + model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None +): + if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, "load_pretrained"): + timm.models._model_builder.load_pretrained(checkpoint_path) + else: + raise NotImplementedError("Model cannot load numpy checkpoint") + return + state_dict = load_state_dict(checkpoint_path, use_ema) + if filter_keys: + for sd_key in list(state_dict.keys()): + for filter_key in filter_keys: + if filter_key in sd_key: + if sd_key in state_dict: + del state_dict[sd_key] + + rep = [] + if state_dict_map is not None: + # 'patch_embed.conv1.' : 'patch_embed.conv.' + for state_k in list(state_dict.keys()): + for target_k, target_v in state_dict_map.items(): + if target_v in state_k: + target_name = state_k.replace(target_v, target_k) + state_dict[target_name] = state_dict[state_k] + rep.append(state_k) + for r in rep: + if r in state_dict: + del state_dict[r] + + incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: bool = False, + pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, + pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, + checkpoint_path: str = "", + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + filter_keys=None, + state_dict_map=None, + **kwargs, +): + """Create a model + Lookup model's entrypoint function and pass relevant args to create a new model. + """ + # Parameters that aren't supported by all models or are intended to only override model defaults if set + # should default to None in command line args/cfg. Remove them if they are present and not set so that + # non-supporting models don't break and default args remain in effect. + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + model_source, model_name = parse_model_name(model_name) + if model_source == "hf-hub": + assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub." + # For model names specified in the form `hf-hub:path/architecture_name@revision`, + # load model weights + pretrained_cfg from Hugging Face hub. + pretrained_cfg, model_name = load_model_config_from_hf(model_name) + else: + model_name, pretrained_tag = split_model_name_tag(model_name) + if not pretrained_cfg: + # a valid pretrained_cfg argument takes priority over tag in model name + pretrained_cfg = pretrained_tag + + if not is_model(model_name): + raise RuntimeError("Unknown model (%s)" % model_name) + + create_fn = model_entrypoint(model_name) + with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): + model = create_fn( + pretrained=pretrained, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay, + **kwargs, + ) + + if checkpoint_path: + load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map) + + return model diff --git a/src/facebias/estimators/mivolo/cross_bottleneck_attn.py b/src/facebias/estimators/mivolo/cross_bottleneck_attn.py new file mode 100644 index 0000000..44976bf --- /dev/null +++ b/src/facebias/estimators/mivolo/cross_bottleneck_attn.py @@ -0,0 +1,116 @@ +""" +Code based on timm https://github.com/huggingface/pytorch-image-models + +Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich +""" + +import torch +import torch.nn as nn +from timm.layers.bottleneck_attn import PosEmbedRel +from timm.layers.helpers import make_divisible +from timm.layers.mlp import Mlp +from timm.layers.trace_utils import _assert +from timm.layers.weight_init import trunc_normal_ + + +class CrossBottleneckAttn(nn.Module): + def __init__( + self, + dim, + dim_out=None, + feat_size=None, + stride=1, + num_heads=4, + dim_head=None, + qk_ratio=1.0, + qkv_bias=False, + scale_pos_embed=False, + ): + super().__init__() + assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required" + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + + self.num_heads = num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk**-0.5 + self.scale_pos_embed = scale_pos_embed + + self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) + self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) + + # NOTE I'm only supporting relative pos embedding for now + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size]) + mlp_ratio = 4 + self.mlp = Mlp( + in_features=self.dim_out_v * 2, + hidden_features=int(dim * mlp_ratio), + act_layer=nn.GELU, + out_features=dim_out, + drop=0, + use_conv=True, + ) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in + trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def get_qkv(self, x, qvk_conv): + B, C, H, W = x.shape + + x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W + + q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) + + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) + k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k + v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) + + return q, k, v + + def apply_attn(self, q, k, v, B, H, W, dropout=None): + if self.scale_pos_embed: + attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W + else: + attn = (q @ k) * self.scale + self.pos_embed(q) + attn = attn.softmax(dim=-1) + if dropout: + attn = dropout(attn) + + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W + return out + + def forward(self, x): + B, C, H, W = x.shape + + dim = int(C / 2) + x1 = x[:, :dim, :, :] + x2 = x[:, dim:, :, :] + + _assert(H == self.pos_embed.height, "") + _assert(W == self.pos_embed.width, "") + + q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f) + q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p) + + # person to face + out_f = self.apply_attn(q_f, k_p, v_p, B, H, W) + # face to person + out_p = self.apply_attn(q_p, k_f, v_f, B, H, W) + + x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W + x_pf = self.norm(x_pf) + x_pf = self.mlp(x_pf) # B, dim_out, H, W + + out = self.pool(x_pf) + return out diff --git a/src/facebias/estimators/mivolo/mi_volo.py b/src/facebias/estimators/mivolo/mi_volo.py new file mode 100644 index 0000000..3077bea --- /dev/null +++ b/src/facebias/estimators/mivolo/mi_volo.py @@ -0,0 +1,313 @@ +import logging +from typing import List, Optional + +import cv2 +import numpy as np +import torch +import torchvision.transforms.functional as F +from facebias.estimators.mivolo.create_timm_model import create_model +from timm.data import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, + resolve_data_config) + +_logger = logging.getLogger("MiVOLO") +has_compile = hasattr(torch, "compile") + + +def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]: + return im + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + # ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return im + + +def prepare_classification_images( + img_list: List[Optional[np.ndarray]], + target_size: int = 224, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + device=None, +) -> torch.Tensor: + + prepared_images: List[torch.Tensor] = [] + + for img in img_list: + if img is None: + img = torch.zeros((3, target_size, target_size), dtype=torch.float32) + img = F.normalize(img, mean=mean, std=std) + img = img.unsqueeze(0) + prepared_images.append(img) + continue + + img = class_letterbox(img, new_shape=(target_size, target_size)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = img / 255.0 + img = (img - mean) / std + img = img.astype(dtype=np.float32) + + img = img.transpose((2, 0, 1)) + img = np.ascontiguousarray(img) + img = torch.from_numpy(img) + img = img.unsqueeze(0) + + prepared_images.append(img) + + if len(prepared_images) == 0: + return None + + prepared_input = torch.concat(prepared_images) + + if device: + prepared_input = prepared_input.to(device) + + return prepared_input + + +class Meta: + def __init__(self): + self.min_age = None + self.max_age = None + self.avg_age = None + self.num_classes = None + + self.in_chans = 3 + self.with_persons_model = False + self.disable_faces = False + self.use_persons = True + self.only_age = False + + self.num_classes_gender = 2 + self.input_size = 224 + + def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta": + + state = torch.load(ckpt_path, map_location="cpu") + + self.min_age = state["min_age"] + self.max_age = state["max_age"] + self.avg_age = state["avg_age"] + self.only_age = state["no_gender"] + + only_age = state["no_gender"] + + self.disable_faces = disable_faces + if "with_persons_model" in state: + self.with_persons_model = state["with_persons_model"] + else: + self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False + + self.num_classes = 1 if only_age else 3 + self.in_chans = 3 if not self.with_persons_model else 6 + self.use_persons = use_persons and self.with_persons_model + + if not self.with_persons_model and self.disable_faces: + raise ValueError("You can not use disable-faces for faces-only model") + if self.with_persons_model and self.disable_faces and not self.use_persons: + raise ValueError( + "You can not disable faces and persons together. " + "Set --with-persons if you want to run with --disable-faces" + ) + self.input_size = state["state_dict"]["pos_embed"].shape[1] * 16 + return self + + def __str__(self): + attrs = vars(self) + attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops}) + return ", ".join("%s: %s" % item for item in attrs.items()) + + @property + def use_person_crops(self) -> bool: + return self.with_persons_model and self.use_persons + + @property + def use_face_crops(self) -> bool: + return not self.disable_faces or not self.with_persons_model + + +class MiVOLO: + def __init__( + self, + ckpt_path: str, + device: str = "cuda", + half: bool = True, + disable_faces: bool = False, + use_persons: bool = False, + verbose: bool = False, + torchcompile: Optional[str] = None, + ): + self.verbose = verbose + self.device = torch.device(device) + self.half = half and self.device.type != "cpu" + + self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons) + if self.verbose: + _logger.info(f"Model meta:\n{str(self.meta)}") + + model_name = f"mivolo_d1_{self.meta.input_size}" + self.model = create_model( + model_name=model_name, + num_classes=self.meta.num_classes, + in_chans=self.meta.in_chans, + pretrained=False, + checkpoint_path=ckpt_path, + filter_keys=["fds."], + ) + self.param_count = sum([m.numel() for m in self.model.parameters()]) + _logger.info(f"Model {model_name} created, param count: {self.param_count}") + + self.data_config = resolve_data_config( + model=self.model, + verbose=verbose, + use_test_size=True, + ) + + self.data_config["crop_pct"] = 1.0 + c, h, w = self.data_config["input_size"] + assert h == w, "Incorrect data_config" + self.input_size = w + + self.model = self.model.to(self.device) + + if torchcompile: + assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly." + torch._dynamo.reset() + self.model = torch.compile(self.model, backend=torchcompile) + + self.model.eval() + if self.half: + self.model = self.model.half() + + def warmup(self, batch_size: int, steps=10): + if self.meta.with_persons_model: + input_size = (6, self.input_size, self.input_size) + else: + input_size = self.data_config["input_size"] + + input = torch.randn((batch_size,) + tuple(input_size)).to(self.device) + + for _ in range(steps): + out = self.inference(input) # noqa: F841 + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def inference(self, model_input: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + if self.half: + model_input = model_input.half() + output = self.model(model_input) + return output + + # def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): + # if ( + # (detected_bboxes.n_objects == 0) + # or (not self.meta.use_persons and detected_bboxes.n_faces == 0) + # or (self.meta.disable_faces and detected_bboxes.n_persons == 0) + # ): + # # nothing to process + # return + + # faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes) + + # if faces_input is None and person_input is None: + # # nothing to process + # return + + # if self.meta.with_persons_model: + # model_input = torch.cat((faces_input, person_input), dim=1) + # else: + # model_input = faces_input + # output = self.inference(model_input) + + # # write gender and age results into detected_bboxes + # self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds) + + def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds): + if self.meta.only_age: + age_output = output + gender_probs, gender_indx = None, None + else: + age_output = output[:, 2] + gender_output = output[:, :2].softmax(-1) + gender_probs, gender_indx = gender_output.topk(1) + + assert output.shape[0] == len(faces_inds) == len(bodies_inds) + + # per face + for index in range(output.shape[0]): + face_ind = faces_inds[index] + body_ind = bodies_inds[index] + + # get_age + age = age_output[index].item() + age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age + age = round(age, 2) + + detected_bboxes.set_age(face_ind, age) + detected_bboxes.set_age(body_ind, age) + + _logger.info(f"\tage: {age}") + + if gender_probs is not None: + gender = "male" if gender_indx[index].item() == 0 else "female" + gender_score = gender_probs[index].item() + + _logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]") + + detected_bboxes.set_gender(face_ind, gender, gender_score) + detected_bboxes.set_gender(body_ind, gender, gender_score) + + # def prepare_crops_old(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): + + # crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image) + # (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies( + # self.meta.use_person_crops, self.meta.use_face_crops + # ) + + # if not self.meta.use_face_crops: + # assert all(f is None for f in faces_crops) + + # faces_input = prepare_classification_images( + # faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device + # ) + + # if not self.meta.use_person_crops: + # assert all(p is None for p in bodies_crops) + + # person_input = prepare_classification_images( + # bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device + # ) + + # _logger.info( + # f"faces_input: {faces_input.shape if faces_input is not None else None}, " + # f"person_input: {person_input.shape if person_input is not None else None}" + # ) + + # return faces_input, person_input, faces_inds, bodies_inds + + +if __name__ == "__main__": + model = MiVOLO("models/volo-v1_model_imdb_age_gender_4.22.pth.tar", half=True, device="cpu") diff --git a/src/facebias/estimators/mivolo/mivolo_model.py b/src/facebias/estimators/mivolo/mivolo_model.py new file mode 100644 index 0000000..cd92f2c --- /dev/null +++ b/src/facebias/estimators/mivolo/mivolo_model.py @@ -0,0 +1,405 @@ +""" +Code adapted from timm https://github.com/huggingface/pytorch-image-models + +Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich +""" + +import torch +import torch.nn as nn +from facebias.estimators.mivolo.cross_bottleneck_attn import \ + CrossBottleneckAttn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_ +from timm.models._builder import build_model_with_cfg +from timm.models._registry import register_model +from timm.models.volo import VOLO + +__all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": None, + "crop_pct": 0.96, + "interpolation": "bicubic", + "fixed_input_size": True, + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "first_conv": None, + "classifier": ("head", "aux_head"), + **kwargs, + } + + +default_cfgs = { + "mivolo_d1_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96 + ), + "mivolo_d1_384": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar", + crop_pct=1.0, + input_size=(3, 384, 384), + ), + "mivolo_d2_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96 + ), + "mivolo_d2_384": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar", + crop_pct=1.0, + input_size=(3, 384, 384), + ), + "mivolo_d3_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96 + ), + "mivolo_d3_448": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar", + crop_pct=1.0, + input_size=(3, 448, 448), + ), + "mivolo_d4_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96 + ), + "mivolo_d4_448": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar", + crop_pct=1.15, + input_size=(3, 448, 448), + ), + "mivolo_d5_224": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96 + ), + "mivolo_d5_448": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar", + crop_pct=1.15, + input_size=(3, 448, 448), + ), + "mivolo_d5_512": _cfg( + url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar", + crop_pct=1.15, + input_size=(3, 512, 512), + ), +} + + +def get_output_size(input_shape, conv_layer): + padding = conv_layer.padding + dilation = conv_layer.dilation + kernel_size = conv_layer.kernel_size + stride = conv_layer.stride + + output_size = [ + ((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2) + ] + return output_size + + +def get_output_size_module(input_size, stem): + output_size = input_size + + for module in stem: + if isinstance(module, nn.Conv2d): + output_size = [ + ( + (output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1) + // module.stride[i] + ) + + 1 + for i in range(2) + ] + + return output_size + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding.""" + + def __init__( + self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384 + ): + super().__init__() + assert patch_size in [4, 8, 16] + assert in_chans in [3, 6] + self.with_persons_model = in_chans == 6 + self.use_cross_attn = True + + if stem_conv: + if not self.with_persons_model: + self.conv = self.create_stem(stem_stride, in_chans, hidden_dim) + else: + self.conv = True # just to match interface + # split + self.conv1 = self.create_stem(stem_stride, 3, hidden_dim) + self.conv2 = self.create_stem(stem_stride, 3, hidden_dim) + else: + self.conv = None + + if self.with_persons_model: + + self.proj1 = nn.Conv2d( + hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride + ) + self.proj2 = nn.Conv2d( + hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride + ) + + stem_out_shape = get_output_size_module((img_size, img_size), self.conv1) + self.proj_output_size = get_output_size(stem_out_shape, self.proj1) + + self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size) + + else: + self.proj = nn.Conv2d( + hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride + ) + + self.patch_dim = img_size // patch_size + self.num_patches = self.patch_dim**2 + + def create_stem(self, stem_stride, in_chans, hidden_dim): + return nn.Sequential( + nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + if self.conv is not None: + if self.with_persons_model: + x1 = x[:, :3] + x2 = x[:, 3:] + + x1 = self.conv1(x1) + x1 = self.proj1(x1) + + x2 = self.conv2(x2) + x2 = self.proj2(x2) + + x = torch.cat([x1, x2], dim=1) + x = self.map(x) + else: + x = self.conv(x) + x = self.proj(x) # B, C, H, W + + return x + + +class MiVOLOModel(VOLO): + """ + Vision Outlooker, the main class of our model + """ + + def __init__( + self, + layers, + img_size=224, + in_chans=3, + num_classes=1000, + global_pool="token", + patch_size=8, + stem_hidden_dim=64, + embed_dims=None, + num_heads=None, + downsamples=(True, False, False, False), + outlook_attention=(True, False, False, False), + mlp_ratio=3.0, + qkv_bias=False, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + post_layers=("ca", "ca"), + use_aux_head=True, + use_mix_token=False, + pooling_scale=2, + ): + super(MiVOLOModel, self).__init__( + layers=layers, + img_size=img_size, + in_chans=in_chans, + num_classes=num_classes, + global_pool=global_pool, + patch_size=patch_size, + stem_hidden_dim=stem_hidden_dim, + embed_dims=embed_dims, + num_heads=num_heads, + downsamples=downsamples, + outlook_attention=outlook_attention, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_layer=norm_layer, + post_layers=post_layers, + use_aux_head=use_aux_head, + use_mix_token=use_mix_token, + pooling_scale=pooling_scale, + ) + + im_size = img_size[0] if isinstance(img_size, tuple) else img_size + self.patch_embed = PatchEmbed( + img_size=im_size, + stem_conv=True, + stem_stride=2, + patch_size=patch_size, + in_chans=in_chans, + hidden_dim=stem_hidden_dim, + embed_dim=embed_dims[0], + ) + + trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def forward_features(self, x): + x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C + + # step2: tokens learning in the two stages + x = self.forward_tokens(x) + + # step3: post network, apply class attention or not + if self.post_network is not None: + x = self.forward_cls(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None): + if self.global_pool == "avg": + out = x.mean(dim=1) + elif self.global_pool == "token": + out = x[:, 0] + else: + out = x + if pre_logits: + return out + + features = out + fds_enabled = hasattr(self, "_fds_forward") + if fds_enabled: + features = self._fds_forward(features, targets, epoch) + + out = self.head(features) + if self.aux_head is not None: + # generate classes in all feature tokens, see token labeling + aux = self.aux_head(x[:, 1:]) + out = out + 0.5 * aux.max(1)[0] + + return (out, features) if (fds_enabled and self.training) else out + + def forward(self, x, targets=None, epoch=None): + """simplified forward (without mix token training)""" + x = self.forward_features(x) + x = self.forward_head(x, targets=targets, epoch=epoch) + return x + + +def _create_mivolo(variant, pretrained=False, **kwargs): + if kwargs.get("features_only", None): + raise RuntimeError("features_only not implemented for Vision Transformer models.") + return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs) + + +@register_model +def mivolo_d1_224(pretrained=False, **kwargs): + model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) + model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d1_384(pretrained=False, **kwargs): + model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) + model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d2_224(pretrained=False, **kwargs): + model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d2_384(pretrained=False, **kwargs): + model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d3_224(pretrained=False, **kwargs): + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d3_448(pretrained=False, **kwargs): + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d4_224(pretrained=False, **kwargs): + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d4_448(pretrained=False, **kwargs): + """VOLO-D4 model, Params: 193M""" + model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) + model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d5_224(pretrained=False, **kwargs): + model_args = dict( + layers=(12, 12, 20, 4), + embed_dims=(384, 768, 768, 768), + num_heads=(12, 16, 16, 16), + mlp_ratio=4, + stem_hidden_dim=128, + **kwargs + ) + model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d5_448(pretrained=False, **kwargs): + model_args = dict( + layers=(12, 12, 20, 4), + embed_dims=(384, 768, 768, 768), + num_heads=(12, 16, 16, 16), + mlp_ratio=4, + stem_hidden_dim=128, + **kwargs + ) + model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args) + return model + + +@register_model +def mivolo_d5_512(pretrained=False, **kwargs): + model_args = dict( + layers=(12, 12, 20, 4), + embed_dims=(384, 768, 768, 768), + num_heads=(12, 16, 16, 16), + mlp_ratio=4, + stem_hidden_dim=128, + **kwargs + ) + model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args) + return model