Added setup and basic module sources.

This commit is contained in:
Guilherme G. Schardong
2026-04-15 22:24:58 +01:00
parent 3e1e491a99
commit 1c11b41329
11 changed files with 1237 additions and 0 deletions

30
pyproject.toml Normal file
View File

@@ -0,0 +1,30 @@
[project]
name = "facebias"
version = "0.0.1"
description = "Face image bias estimation and benchmarking"
license = "MIT"
readme = "README.md"
requires-python = ">= 3.12"
[build-system]
requires = ["setuptools >= 77.0.3"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
package-dir = {"" = "src"}
[tool.setuptools.packages.find]
where = ["src"]
[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}
[project.optional-dependencies]
dev = [
"jedi-language-server>=0.46.0",
"black>=25.12.0",
"flake8>=7.3.0",
"isort>=8.0.1",
"mypy>=1.19.0",
"pytest>=9.0.0"
]

65
src/facebias/__init__.py Normal file
View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
import logging
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Optional
import cv2
import numpy as np
logging.basicConfig(
level=logging.DEBUG, # Set the minimum logging level
# format="%(asctime)s - [%(levelname)s] - %(filename)s:%(lineno)s - %(message)s",
)
logger = logging.getLogger(__name__)
@dataclass
class FaceBox:
x1: int
y1: int
x2: int
y2: int
def load_dataset(
root: Path,
meta_path: Optional[Path] = None,
imname_proc_fn: Optional[Callable]=None
) -> tuple[dict[Path, np.ndarray], dict[str, dict[str, Any]]]:
metadata = dict()
paths = set([p for p in root.iterdir() if not p.is_dir()])
if meta_path is not None and meta_path in paths:
# Just read the metadata and avoid needing to test for it in the main loop.
metadata = load_metadata(meta_path)
paths.remove(meta_path)
ims = OrderedDict()
# TODO(gschardong): Paralellize this?
for p in paths:
try:
im = cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB)
except cv2.error:
logger.info(f'File "{p}" is not an image.')
if meta_path is None:
logger.info(f'Trying to read "{p}" as metadata.')
if p.suffix == ".csv":
metadata = load_metadata(p)
logger.info("Metadata read successfully.")
else:
logger.error(f'Failed to read "{p}" as metadata. Skipping.')
elif not metadata:
logger.critical("Logic error: Metadata should have been read already.")
else:
proc_imname = imname_proc_fn(p.name) if imname_proc_fn is not None else p.name
ims[proc_imname] = im
if not metadata:
logger.warning("No metadata found.")
if meta_path is not None:
logger.error(f'Metadata file not found at "{meta_path}".')
return ims, metadata

View File

@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
import numpy as np
class NoFaceDetectedError(Exception):
def __init__(self, message: str = "No face detected in image."):
self.message = message
super(NoFaceDetectedError, self).__init__(self.message)
class BaseFaceDetector(ABC):
"""Base class for face detectors.
This classe provides the base interface for supported face detection
methods. An example is provided in the `MediapipeDetector` class in
`facebias.detectors.mediapipe`. Its goal is simply to load the resources
and perform face detection, whenever a call to `predict` is performed."""
@abstractmethod
def detect(self, im) -> np.ndarray:
"""Detects faces in image `im`. Throws `NoFaceDetectedError`
if `im` has no valid face(s)."""
pass

View File

@@ -0,0 +1,68 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from facebias import FaceBox
from facebias.detectors import BaseFaceDetector, NoFaceDetectedError
from mediapipe.tasks.python import BaseOptions, vision
class MediapipeDetector(BaseFaceDetector):
"""Uses mediapipe's blazeface model to detect faces in images or video frames.
Parameters
----------
model_path: str, PathLike
mode: Optional[mediapipe.tasks.python.vision.RunningMode]
"""
def __init__(self, model_path: str, mode=vision.RunningMode.IMAGE) -> None:
super(MediapipeDetector, self).__init__()
baseopts = BaseOptions(model_asset_path=model_path)
lmopts = vision.FaceDetectorOptions(
base_options=baseopts,
running_mode=mode,
)
self.mode = mode
self.detector = vision.FaceDetector.create_from_options(lmopts)
def detect(self, im: np.ndarray, timestamp_ms: float = 0.0) -> list[FaceBox]:
"""Returns a list with bounding boxes for each face found in `im`.
`timestamp_ms` is only used for videos.
"""
detect_results = None
mpim = mp.Image(image_format=mp.ImageFormat.SRGB, data=im)
if self.mode == vision.RunningMode.IMAGE:
detect_results = self.detector.detect(mpim)
else:
detect_results = self.detector.detect_for_video(mpim, round(timestamp_ms))
if detect_results is None or not detect_results.detections:
raise NoFaceDetectedError()
faces = [None for _ in detect_results.detections]
for i, d in enumerate(detect_results.detections):
bbox = d.bounding_box
faces[i] = FaceBox(
x1=bbox.origin_y,
y1=bbox.origin_x,
x2=bbox.origin_y + bbox.height,
y2=bbox.origin_x + bbox.width,
)
return faces
if __name__ == '__main__':
import cv2
from pathlib import Path
TEST_IM_PATH = Path("data/frll_neutral_front/001_03.jpg")
DETECTOR_PATH = Path("models/blaze_face_short_range.tflite")
detector = MediapipeDetector(str(DETECTOR_PATH))
im = cv2.cvtColor(cv2.imread(TEST_IM_PATH), cv2.COLOR_BGR2RGB)
crop = detector.detect(im)
print(crop)

View File

@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
from abc import ABC, abstractclassmethod, abstractmethod
from pathlib import Path
from typing import Any
import numpy as np
class BaseEstimator(ABC):
@abstractmethod
def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]:
pass
@abstractclassmethod
def capabilities() -> list[str]:
pass

View File

@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
import logging
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
import torchvision
from facebias import load_dataset
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
from facebias.estimators import BaseEstimator
class FairFace(BaseEstimator):
"""
"""
def __init__(self, checkpoint_path: Optional[Path]=None, device=torch.device("cpu")):
self.T = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
self.device = device
self.model = torchvision.models.resnet34(pretrained=True)
self.model.fc = nn.Linear(self.model.fc.in_features, 18)
self.model = self.model.eval().to(device)
if checkpoint_path is not None:
self.load_state_dict(checkpoint_path)
self.model.to(device)
def load_state_dict(self, checkpoint_path: Path) -> "FairFace":
self.model.load_state_dict(
torch.load(checkpoint_path, map_location=torch.device("cpu"))
)
return self
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]:
preds = OrderedDict()
for p, im in ims.items():
im = self.T(im).view(1, 3, 224, 224).to(self.device)
with torch.no_grad():
y = self.model(im)
y = y.cpu().detach().squeeze().numpy()
# Ethnicity prediction
# y_ethno = y[:4]
# ethno_score = np.exp(y_ethno) / np.sum(np.exp(y_ethno))
# ethno_pred = np.argmax(ethno_score)
# Age prediction
y_age = y[9:18]
age_score = np.exp(y_age) / np.sum(np.exp(y_age))
age_pred = np.argmax(age_score)
# Gender prediction
y_sex = y[7:9]
sex_score = np.exp(y_sex) / np.sum(np.exp(y_sex))
sex_pred = np.argmax(sex_score)
sex_label = "m" if not sex_pred else "f"
preds[p] = {
"age_group": _to_age_label(age_pred),
"sex": sex_label,
}
return preds
def capabilities() -> list[str]:
return ["age_group", "sex"]
def _to_age_label(age):
return "{}-{}".format(age * 10, age * 10 + 9)
if __name__ == '__main__':
MODEL_PATH = Path("../../../models/fairface_alldata_4race_20191111.pt")
model = FairFace()
model.load_state_dict(MODEL_PATH)
DATASET_PATH = Path("../../../data/frll_neutral_front")
dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0])
out = model.predict(dataset)

View File

@@ -0,0 +1,105 @@
"""
Code adapted from timm https://github.com/huggingface/pytorch-image-models
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
"""
import os
from typing import Any, Dict, Optional, Union
import timm
# register new models
from facebias.estimators.mivolo.mivolo_model import * # noqa: F403, F401
from timm.layers import set_layer_config
from timm.models._factory import parse_model_name
from timm.models._helpers import load_state_dict
from timm.models._hub import load_model_config_from_hf
from timm.models._pretrained import PretrainedCfg
from timm.models._registry import (is_model, model_entrypoint,
split_model_name_tag)
def load_checkpoint(
model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None
):
if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, "load_pretrained"):
timm.models._model_builder.load_pretrained(checkpoint_path)
else:
raise NotImplementedError("Model cannot load numpy checkpoint")
return
state_dict = load_state_dict(checkpoint_path, use_ema)
if filter_keys:
for sd_key in list(state_dict.keys()):
for filter_key in filter_keys:
if filter_key in sd_key:
if sd_key in state_dict:
del state_dict[sd_key]
rep = []
if state_dict_map is not None:
# 'patch_embed.conv1.' : 'patch_embed.conv.'
for state_k in list(state_dict.keys()):
for target_k, target_v in state_dict_map.items():
if target_v in state_k:
target_name = state_k.replace(target_v, target_k)
state_dict[target_name] = state_dict[state_k]
rep.append(state_k)
for r in rep:
if r in state_dict:
del state_dict[r]
incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False)
return incompatible_keys
def create_model(
model_name: str,
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: str = "",
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
filter_keys=None,
state_dict_map=None,
**kwargs,
):
"""Create a model
Lookup model's entrypoint function and pass relevant args to create a new model.
"""
# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
model_source, model_name = parse_model_name(model_name)
if model_source == "hf-hub":
assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub."
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
else:
model_name, pretrained_tag = split_model_name_tag(model_name)
if not pretrained_cfg:
# a valid pretrained_cfg argument takes priority over tag in model name
pretrained_cfg = pretrained_tag
if not is_model(model_name):
raise RuntimeError("Unknown model (%s)" % model_name)
create_fn = model_entrypoint(model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
**kwargs,
)
if checkpoint_path:
load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map)
return model

View File

@@ -0,0 +1,116 @@
"""
Code based on timm https://github.com/huggingface/pytorch-image-models
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
"""
import torch
import torch.nn as nn
from timm.layers.bottleneck_attn import PosEmbedRel
from timm.layers.helpers import make_divisible
from timm.layers.mlp import Mlp
from timm.layers.trace_utils import _assert
from timm.layers.weight_init import trunc_normal_
class CrossBottleneckAttn(nn.Module):
def __init__(
self,
dim,
dim_out=None,
feat_size=None,
stride=1,
num_heads=4,
dim_head=None,
qk_ratio=1.0,
qkv_bias=False,
scale_pos_embed=False,
):
super().__init__()
assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
dim_out = dim_out or dim
assert dim_out % num_heads == 0
self.num_heads = num_heads
self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
self.dim_head_v = dim_out // self.num_heads
self.dim_out_qk = num_heads * self.dim_head_qk
self.dim_out_v = num_heads * self.dim_head_v
self.scale = self.dim_head_qk**-0.5
self.scale_pos_embed = scale_pos_embed
self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
# NOTE I'm only supporting relative pos embedding for now
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
mlp_ratio = 4
self.mlp = Mlp(
in_features=self.dim_out_v * 2,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
out_features=dim_out,
drop=0,
use_conv=True,
)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
self.reset_parameters()
def reset_parameters(self):
trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def get_qkv(self, x, qvk_conv):
B, C, H, W = x.shape
x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
return q, k, v
def apply_attn(self, q, k, v, B, H, W, dropout=None):
if self.scale_pos_embed:
attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
else:
attn = (q @ k) * self.scale + self.pos_embed(q)
attn = attn.softmax(dim=-1)
if dropout:
attn = dropout(attn)
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
return out
def forward(self, x):
B, C, H, W = x.shape
dim = int(C / 2)
x1 = x[:, :dim, :, :]
x2 = x[:, dim:, :, :]
_assert(H == self.pos_embed.height, "")
_assert(W == self.pos_embed.width, "")
q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
# person to face
out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
# face to person
out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
x_pf = self.norm(x_pf)
x_pf = self.mlp(x_pf) # B, dim_out, H, W
out = self.pool(x_pf)
return out

View File

@@ -0,0 +1,313 @@
import logging
from typing import List, Optional
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
from facebias.estimators.mivolo.create_timm_model import create_model
from timm.data import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD,
resolve_data_config)
_logger = logging.getLogger("MiVOLO")
has_compile = hasattr(torch, "compile")
def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True):
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]:
return im
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
# ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return im
def prepare_classification_images(
img_list: List[Optional[np.ndarray]],
target_size: int = 224,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
device=None,
) -> torch.Tensor:
prepared_images: List[torch.Tensor] = []
for img in img_list:
if img is None:
img = torch.zeros((3, target_size, target_size), dtype=torch.float32)
img = F.normalize(img, mean=mean, std=std)
img = img.unsqueeze(0)
prepared_images.append(img)
continue
img = class_letterbox(img, new_shape=(target_size, target_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img / 255.0
img = (img - mean) / std
img = img.astype(dtype=np.float32)
img = img.transpose((2, 0, 1))
img = np.ascontiguousarray(img)
img = torch.from_numpy(img)
img = img.unsqueeze(0)
prepared_images.append(img)
if len(prepared_images) == 0:
return None
prepared_input = torch.concat(prepared_images)
if device:
prepared_input = prepared_input.to(device)
return prepared_input
class Meta:
def __init__(self):
self.min_age = None
self.max_age = None
self.avg_age = None
self.num_classes = None
self.in_chans = 3
self.with_persons_model = False
self.disable_faces = False
self.use_persons = True
self.only_age = False
self.num_classes_gender = 2
self.input_size = 224
def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta":
state = torch.load(ckpt_path, map_location="cpu")
self.min_age = state["min_age"]
self.max_age = state["max_age"]
self.avg_age = state["avg_age"]
self.only_age = state["no_gender"]
only_age = state["no_gender"]
self.disable_faces = disable_faces
if "with_persons_model" in state:
self.with_persons_model = state["with_persons_model"]
else:
self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False
self.num_classes = 1 if only_age else 3
self.in_chans = 3 if not self.with_persons_model else 6
self.use_persons = use_persons and self.with_persons_model
if not self.with_persons_model and self.disable_faces:
raise ValueError("You can not use disable-faces for faces-only model")
if self.with_persons_model and self.disable_faces and not self.use_persons:
raise ValueError(
"You can not disable faces and persons together. "
"Set --with-persons if you want to run with --disable-faces"
)
self.input_size = state["state_dict"]["pos_embed"].shape[1] * 16
return self
def __str__(self):
attrs = vars(self)
attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops})
return ", ".join("%s: %s" % item for item in attrs.items())
@property
def use_person_crops(self) -> bool:
return self.with_persons_model and self.use_persons
@property
def use_face_crops(self) -> bool:
return not self.disable_faces or not self.with_persons_model
class MiVOLO:
def __init__(
self,
ckpt_path: str,
device: str = "cuda",
half: bool = True,
disable_faces: bool = False,
use_persons: bool = False,
verbose: bool = False,
torchcompile: Optional[str] = None,
):
self.verbose = verbose
self.device = torch.device(device)
self.half = half and self.device.type != "cpu"
self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons)
if self.verbose:
_logger.info(f"Model meta:\n{str(self.meta)}")
model_name = f"mivolo_d1_{self.meta.input_size}"
self.model = create_model(
model_name=model_name,
num_classes=self.meta.num_classes,
in_chans=self.meta.in_chans,
pretrained=False,
checkpoint_path=ckpt_path,
filter_keys=["fds."],
)
self.param_count = sum([m.numel() for m in self.model.parameters()])
_logger.info(f"Model {model_name} created, param count: {self.param_count}")
self.data_config = resolve_data_config(
model=self.model,
verbose=verbose,
use_test_size=True,
)
self.data_config["crop_pct"] = 1.0
c, h, w = self.data_config["input_size"]
assert h == w, "Incorrect data_config"
self.input_size = w
self.model = self.model.to(self.device)
if torchcompile:
assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly."
torch._dynamo.reset()
self.model = torch.compile(self.model, backend=torchcompile)
self.model.eval()
if self.half:
self.model = self.model.half()
def warmup(self, batch_size: int, steps=10):
if self.meta.with_persons_model:
input_size = (6, self.input_size, self.input_size)
else:
input_size = self.data_config["input_size"]
input = torch.randn((batch_size,) + tuple(input_size)).to(self.device)
for _ in range(steps):
out = self.inference(input) # noqa: F841
if torch.cuda.is_available():
torch.cuda.synchronize()
def inference(self, model_input: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
if self.half:
model_input = model_input.half()
output = self.model(model_input)
return output
# def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
# if (
# (detected_bboxes.n_objects == 0)
# or (not self.meta.use_persons and detected_bboxes.n_faces == 0)
# or (self.meta.disable_faces and detected_bboxes.n_persons == 0)
# ):
# # nothing to process
# return
# faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes)
# if faces_input is None and person_input is None:
# # nothing to process
# return
# if self.meta.with_persons_model:
# model_input = torch.cat((faces_input, person_input), dim=1)
# else:
# model_input = faces_input
# output = self.inference(model_input)
# # write gender and age results into detected_bboxes
# self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
if self.meta.only_age:
age_output = output
gender_probs, gender_indx = None, None
else:
age_output = output[:, 2]
gender_output = output[:, :2].softmax(-1)
gender_probs, gender_indx = gender_output.topk(1)
assert output.shape[0] == len(faces_inds) == len(bodies_inds)
# per face
for index in range(output.shape[0]):
face_ind = faces_inds[index]
body_ind = bodies_inds[index]
# get_age
age = age_output[index].item()
age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
age = round(age, 2)
detected_bboxes.set_age(face_ind, age)
detected_bboxes.set_age(body_ind, age)
_logger.info(f"\tage: {age}")
if gender_probs is not None:
gender = "male" if gender_indx[index].item() == 0 else "female"
gender_score = gender_probs[index].item()
_logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
detected_bboxes.set_gender(face_ind, gender, gender_score)
detected_bboxes.set_gender(body_ind, gender, gender_score)
# def prepare_crops_old(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
# crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
# (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(
# self.meta.use_person_crops, self.meta.use_face_crops
# )
# if not self.meta.use_face_crops:
# assert all(f is None for f in faces_crops)
# faces_input = prepare_classification_images(
# faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
# )
# if not self.meta.use_person_crops:
# assert all(p is None for p in bodies_crops)
# person_input = prepare_classification_images(
# bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
# )
# _logger.info(
# f"faces_input: {faces_input.shape if faces_input is not None else None}, "
# f"person_input: {person_input.shape if person_input is not None else None}"
# )
# return faces_input, person_input, faces_inds, bodies_inds
if __name__ == "__main__":
model = MiVOLO("models/volo-v1_model_imdb_age_gender_4.22.pth.tar", half=True, device="cpu")

View File

@@ -0,0 +1,405 @@
"""
Code adapted from timm https://github.com/huggingface/pytorch-image-models
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
"""
import torch
import torch.nn as nn
from facebias.estimators.mivolo.cross_bottleneck_attn import \
CrossBottleneckAttn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_
from timm.models._builder import build_model_with_cfg
from timm.models._registry import register_model
from timm.models.volo import VOLO
__all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 224, 224),
"pool_size": None,
"crop_pct": 0.96,
"interpolation": "bicubic",
"fixed_input_size": True,
"mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD,
"first_conv": None,
"classifier": ("head", "aux_head"),
**kwargs,
}
default_cfgs = {
"mivolo_d1_224": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96
),
"mivolo_d1_384": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar",
crop_pct=1.0,
input_size=(3, 384, 384),
),
"mivolo_d2_224": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96
),
"mivolo_d2_384": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar",
crop_pct=1.0,
input_size=(3, 384, 384),
),
"mivolo_d3_224": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96
),
"mivolo_d3_448": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar",
crop_pct=1.0,
input_size=(3, 448, 448),
),
"mivolo_d4_224": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96
),
"mivolo_d4_448": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar",
crop_pct=1.15,
input_size=(3, 448, 448),
),
"mivolo_d5_224": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96
),
"mivolo_d5_448": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar",
crop_pct=1.15,
input_size=(3, 448, 448),
),
"mivolo_d5_512": _cfg(
url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar",
crop_pct=1.15,
input_size=(3, 512, 512),
),
}
def get_output_size(input_shape, conv_layer):
padding = conv_layer.padding
dilation = conv_layer.dilation
kernel_size = conv_layer.kernel_size
stride = conv_layer.stride
output_size = [
((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2)
]
return output_size
def get_output_size_module(input_size, stem):
output_size = input_size
for module in stem:
if isinstance(module, nn.Conv2d):
output_size = [
(
(output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1)
// module.stride[i]
)
+ 1
for i in range(2)
]
return output_size
class PatchEmbed(nn.Module):
"""Image to Patch Embedding."""
def __init__(
self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384
):
super().__init__()
assert patch_size in [4, 8, 16]
assert in_chans in [3, 6]
self.with_persons_model = in_chans == 6
self.use_cross_attn = True
if stem_conv:
if not self.with_persons_model:
self.conv = self.create_stem(stem_stride, in_chans, hidden_dim)
else:
self.conv = True # just to match interface
# split
self.conv1 = self.create_stem(stem_stride, 3, hidden_dim)
self.conv2 = self.create_stem(stem_stride, 3, hidden_dim)
else:
self.conv = None
if self.with_persons_model:
self.proj1 = nn.Conv2d(
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
)
self.proj2 = nn.Conv2d(
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
)
stem_out_shape = get_output_size_module((img_size, img_size), self.conv1)
self.proj_output_size = get_output_size(stem_out_shape, self.proj1)
self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size)
else:
self.proj = nn.Conv2d(
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
)
self.patch_dim = img_size // patch_size
self.num_patches = self.patch_dim**2
def create_stem(self, stem_stride, in_chans, hidden_dim):
return nn.Sequential(
nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
)
def forward(self, x):
if self.conv is not None:
if self.with_persons_model:
x1 = x[:, :3]
x2 = x[:, 3:]
x1 = self.conv1(x1)
x1 = self.proj1(x1)
x2 = self.conv2(x2)
x2 = self.proj2(x2)
x = torch.cat([x1, x2], dim=1)
x = self.map(x)
else:
x = self.conv(x)
x = self.proj(x) # B, C, H, W
return x
class MiVOLOModel(VOLO):
"""
Vision Outlooker, the main class of our model
"""
def __init__(
self,
layers,
img_size=224,
in_chans=3,
num_classes=1000,
global_pool="token",
patch_size=8,
stem_hidden_dim=64,
embed_dims=None,
num_heads=None,
downsamples=(True, False, False, False),
outlook_attention=(True, False, False, False),
mlp_ratio=3.0,
qkv_bias=False,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
post_layers=("ca", "ca"),
use_aux_head=True,
use_mix_token=False,
pooling_scale=2,
):
super(MiVOLOModel, self).__init__(
layers=layers,
img_size=img_size,
in_chans=in_chans,
num_classes=num_classes,
global_pool=global_pool,
patch_size=patch_size,
stem_hidden_dim=stem_hidden_dim,
embed_dims=embed_dims,
num_heads=num_heads,
downsamples=downsamples,
outlook_attention=outlook_attention,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
post_layers=post_layers,
use_aux_head=use_aux_head,
use_mix_token=use_mix_token,
pooling_scale=pooling_scale,
)
im_size = img_size[0] if isinstance(img_size, tuple) else img_size
self.patch_embed = PatchEmbed(
img_size=im_size,
stem_conv=True,
stem_stride=2,
patch_size=patch_size,
in_chans=in_chans,
hidden_dim=stem_hidden_dim,
embed_dim=embed_dims[0],
)
trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def forward_features(self, x):
x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
# step2: tokens learning in the two stages
x = self.forward_tokens(x)
# step3: post network, apply class attention or not
if self.post_network is not None:
x = self.forward_cls(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None):
if self.global_pool == "avg":
out = x.mean(dim=1)
elif self.global_pool == "token":
out = x[:, 0]
else:
out = x
if pre_logits:
return out
features = out
fds_enabled = hasattr(self, "_fds_forward")
if fds_enabled:
features = self._fds_forward(features, targets, epoch)
out = self.head(features)
if self.aux_head is not None:
# generate classes in all feature tokens, see token labeling
aux = self.aux_head(x[:, 1:])
out = out + 0.5 * aux.max(1)[0]
return (out, features) if (fds_enabled and self.training) else out
def forward(self, x, targets=None, epoch=None):
"""simplified forward (without mix token training)"""
x = self.forward_features(x)
x = self.forward_head(x, targets=targets, epoch=epoch)
return x
def _create_mivolo(variant, pretrained=False, **kwargs):
if kwargs.get("features_only", None):
raise RuntimeError("features_only not implemented for Vision Transformer models.")
return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs)
@register_model
def mivolo_d1_224(pretrained=False, **kwargs):
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d1_384(pretrained=False, **kwargs):
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d2_224(pretrained=False, **kwargs):
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d2_384(pretrained=False, **kwargs):
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d3_224(pretrained=False, **kwargs):
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d3_448(pretrained=False, **kwargs):
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d4_224(pretrained=False, **kwargs):
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d4_448(pretrained=False, **kwargs):
"""VOLO-D4 model, Params: 193M"""
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d5_224(pretrained=False, **kwargs):
model_args = dict(
layers=(12, 12, 20, 4),
embed_dims=(384, 768, 768, 768),
num_heads=(12, 16, 16, 16),
mlp_ratio=4,
stem_hidden_dim=128,
**kwargs
)
model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d5_448(pretrained=False, **kwargs):
model_args = dict(
layers=(12, 12, 20, 4),
embed_dims=(384, 768, 768, 768),
num_heads=(12, 16, 16, 16),
mlp_ratio=4,
stem_hidden_dim=128,
**kwargs
)
model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args)
return model
@register_model
def mivolo_d5_512(pretrained=False, **kwargs):
model_args = dict(
layers=(12, 12, 20, 4),
embed_dims=(384, 768, 768, 768),
num_heads=(12, 16, 16, 16),
mlp_ratio=4,
stem_hidden_dim=128,
**kwargs
)
model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args)
return model