Added setup and basic module sources.
This commit is contained in:
30
pyproject.toml
Normal file
30
pyproject.toml
Normal file
@@ -0,0 +1,30 @@
|
||||
[project]
|
||||
name = "facebias"
|
||||
version = "0.0.1"
|
||||
description = "Face image bias estimation and benchmarking"
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.12"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools >= 77.0.3"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "src"}
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
dependencies = {file = ["requirements.txt"]}
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"jedi-language-server>=0.46.0",
|
||||
"black>=25.12.0",
|
||||
"flake8>=7.3.0",
|
||||
"isort>=8.0.1",
|
||||
"mypy>=1.19.0",
|
||||
"pytest>=9.0.0"
|
||||
]
|
||||
65
src/facebias/__init__.py
Normal file
65
src/facebias/__init__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, # Set the minimum logging level
|
||||
# format="%(asctime)s - [%(levelname)s] - %(filename)s:%(lineno)s - %(message)s",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FaceBox:
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
|
||||
def load_dataset(
|
||||
root: Path,
|
||||
meta_path: Optional[Path] = None,
|
||||
imname_proc_fn: Optional[Callable]=None
|
||||
) -> tuple[dict[Path, np.ndarray], dict[str, dict[str, Any]]]:
|
||||
metadata = dict()
|
||||
paths = set([p for p in root.iterdir() if not p.is_dir()])
|
||||
if meta_path is not None and meta_path in paths:
|
||||
# Just read the metadata and avoid needing to test for it in the main loop.
|
||||
metadata = load_metadata(meta_path)
|
||||
paths.remove(meta_path)
|
||||
|
||||
ims = OrderedDict()
|
||||
# TODO(gschardong): Paralellize this?
|
||||
for p in paths:
|
||||
try:
|
||||
im = cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB)
|
||||
except cv2.error:
|
||||
logger.info(f'File "{p}" is not an image.')
|
||||
if meta_path is None:
|
||||
logger.info(f'Trying to read "{p}" as metadata.')
|
||||
if p.suffix == ".csv":
|
||||
metadata = load_metadata(p)
|
||||
logger.info("Metadata read successfully.")
|
||||
else:
|
||||
logger.error(f'Failed to read "{p}" as metadata. Skipping.')
|
||||
elif not metadata:
|
||||
logger.critical("Logic error: Metadata should have been read already.")
|
||||
else:
|
||||
proc_imname = imname_proc_fn(p.name) if imname_proc_fn is not None else p.name
|
||||
ims[proc_imname] = im
|
||||
|
||||
if not metadata:
|
||||
logger.warning("No metadata found.")
|
||||
if meta_path is not None:
|
||||
logger.error(f'Metadata file not found at "{meta_path}".')
|
||||
|
||||
return ims, metadata
|
||||
26
src/facebias/detectors/__init__.py
Normal file
26
src/facebias/detectors/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NoFaceDetectedError(Exception):
|
||||
def __init__(self, message: str = "No face detected in image."):
|
||||
self.message = message
|
||||
super(NoFaceDetectedError, self).__init__(self.message)
|
||||
|
||||
|
||||
class BaseFaceDetector(ABC):
|
||||
"""Base class for face detectors.
|
||||
|
||||
This classe provides the base interface for supported face detection
|
||||
methods. An example is provided in the `MediapipeDetector` class in
|
||||
`facebias.detectors.mediapipe`. Its goal is simply to load the resources
|
||||
and perform face detection, whenever a call to `predict` is performed."""
|
||||
|
||||
@abstractmethod
|
||||
def detect(self, im) -> np.ndarray:
|
||||
"""Detects faces in image `im`. Throws `NoFaceDetectedError`
|
||||
if `im` has no valid face(s)."""
|
||||
pass
|
||||
68
src/facebias/detectors/mediapipe.py
Executable file
68
src/facebias/detectors/mediapipe.py
Executable file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from facebias import FaceBox
|
||||
from facebias.detectors import BaseFaceDetector, NoFaceDetectedError
|
||||
|
||||
from mediapipe.tasks.python import BaseOptions, vision
|
||||
|
||||
|
||||
class MediapipeDetector(BaseFaceDetector):
|
||||
"""Uses mediapipe's blazeface model to detect faces in images or video frames.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path: str, PathLike
|
||||
mode: Optional[mediapipe.tasks.python.vision.RunningMode]
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, mode=vision.RunningMode.IMAGE) -> None:
|
||||
super(MediapipeDetector, self).__init__()
|
||||
baseopts = BaseOptions(model_asset_path=model_path)
|
||||
lmopts = vision.FaceDetectorOptions(
|
||||
base_options=baseopts,
|
||||
running_mode=mode,
|
||||
)
|
||||
self.mode = mode
|
||||
self.detector = vision.FaceDetector.create_from_options(lmopts)
|
||||
|
||||
def detect(self, im: np.ndarray, timestamp_ms: float = 0.0) -> list[FaceBox]:
|
||||
"""Returns a list with bounding boxes for each face found in `im`.
|
||||
|
||||
`timestamp_ms` is only used for videos.
|
||||
"""
|
||||
detect_results = None
|
||||
mpim = mp.Image(image_format=mp.ImageFormat.SRGB, data=im)
|
||||
if self.mode == vision.RunningMode.IMAGE:
|
||||
detect_results = self.detector.detect(mpim)
|
||||
else:
|
||||
detect_results = self.detector.detect_for_video(mpim, round(timestamp_ms))
|
||||
|
||||
if detect_results is None or not detect_results.detections:
|
||||
raise NoFaceDetectedError()
|
||||
|
||||
faces = [None for _ in detect_results.detections]
|
||||
for i, d in enumerate(detect_results.detections):
|
||||
bbox = d.bounding_box
|
||||
faces[i] = FaceBox(
|
||||
x1=bbox.origin_y,
|
||||
y1=bbox.origin_x,
|
||||
x2=bbox.origin_y + bbox.height,
|
||||
y2=bbox.origin_x + bbox.width,
|
||||
)
|
||||
|
||||
return faces
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
|
||||
TEST_IM_PATH = Path("data/frll_neutral_front/001_03.jpg")
|
||||
DETECTOR_PATH = Path("models/blaze_face_short_range.tflite")
|
||||
|
||||
detector = MediapipeDetector(str(DETECTOR_PATH))
|
||||
|
||||
im = cv2.cvtColor(cv2.imread(TEST_IM_PATH), cv2.COLOR_BGR2RGB)
|
||||
crop = detector.detect(im)
|
||||
print(crop)
|
||||
17
src/facebias/estimators/__init__.py
Normal file
17
src/facebias/estimators/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractclassmethod, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseEstimator(ABC):
|
||||
@abstractmethod
|
||||
def predict(self, images: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def capabilities() -> list[str]:
|
||||
pass
|
||||
92
src/facebias/estimators/fairface.py
Normal file
92
src/facebias/estimators/fairface.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from facebias import load_dataset
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision import transforms
|
||||
|
||||
from facebias.estimators import BaseEstimator
|
||||
|
||||
|
||||
class FairFace(BaseEstimator):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, checkpoint_path: Optional[Path]=None, device=torch.device("cpu")):
|
||||
self.T = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
])
|
||||
|
||||
self.device = device
|
||||
self.model = torchvision.models.resnet34(pretrained=True)
|
||||
self.model.fc = nn.Linear(self.model.fc.in_features, 18)
|
||||
self.model = self.model.eval().to(device)
|
||||
if checkpoint_path is not None:
|
||||
self.load_state_dict(checkpoint_path)
|
||||
self.model.to(device)
|
||||
|
||||
def load_state_dict(self, checkpoint_path: Path) -> "FairFace":
|
||||
self.model.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
)
|
||||
return self
|
||||
|
||||
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[str, Any]]:
|
||||
preds = OrderedDict()
|
||||
for p, im in ims.items():
|
||||
im = self.T(im).view(1, 3, 224, 224).to(self.device)
|
||||
with torch.no_grad():
|
||||
y = self.model(im)
|
||||
|
||||
y = y.cpu().detach().squeeze().numpy()
|
||||
|
||||
# Ethnicity prediction
|
||||
# y_ethno = y[:4]
|
||||
# ethno_score = np.exp(y_ethno) / np.sum(np.exp(y_ethno))
|
||||
# ethno_pred = np.argmax(ethno_score)
|
||||
|
||||
# Age prediction
|
||||
y_age = y[9:18]
|
||||
age_score = np.exp(y_age) / np.sum(np.exp(y_age))
|
||||
age_pred = np.argmax(age_score)
|
||||
|
||||
# Gender prediction
|
||||
y_sex = y[7:9]
|
||||
sex_score = np.exp(y_sex) / np.sum(np.exp(y_sex))
|
||||
sex_pred = np.argmax(sex_score)
|
||||
sex_label = "m" if not sex_pred else "f"
|
||||
|
||||
preds[p] = {
|
||||
"age_group": _to_age_label(age_pred),
|
||||
"sex": sex_label,
|
||||
}
|
||||
|
||||
return preds
|
||||
|
||||
def capabilities() -> list[str]:
|
||||
return ["age_group", "sex"]
|
||||
|
||||
|
||||
def _to_age_label(age):
|
||||
return "{}-{}".format(age * 10, age * 10 + 9)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
MODEL_PATH = Path("../../../models/fairface_alldata_4race_20191111.pt")
|
||||
model = FairFace()
|
||||
model.load_state_dict(MODEL_PATH)
|
||||
|
||||
DATASET_PATH = Path("../../../data/frll_neutral_front")
|
||||
dataset, _ = load_dataset(DATASET_PATH, imname_proc_fn=lambda x: x.split('_')[0])
|
||||
|
||||
out = model.predict(dataset)
|
||||
0
src/facebias/estimators/mivolo/__init__.py
Normal file
0
src/facebias/estimators/mivolo/__init__.py
Normal file
105
src/facebias/estimators/mivolo/create_timm_model.py
Normal file
105
src/facebias/estimators/mivolo/create_timm_model.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Code adapted from timm https://github.com/huggingface/pytorch-image-models
|
||||
|
||||
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import timm
|
||||
# register new models
|
||||
from facebias.estimators.mivolo.mivolo_model import * # noqa: F403, F401
|
||||
from timm.layers import set_layer_config
|
||||
from timm.models._factory import parse_model_name
|
||||
from timm.models._helpers import load_state_dict
|
||||
from timm.models._hub import load_model_config_from_hf
|
||||
from timm.models._pretrained import PretrainedCfg
|
||||
from timm.models._registry import (is_model, model_entrypoint,
|
||||
split_model_name_tag)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None
|
||||
):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, "load_pretrained"):
|
||||
timm.models._model_builder.load_pretrained(checkpoint_path)
|
||||
else:
|
||||
raise NotImplementedError("Model cannot load numpy checkpoint")
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
if filter_keys:
|
||||
for sd_key in list(state_dict.keys()):
|
||||
for filter_key in filter_keys:
|
||||
if filter_key in sd_key:
|
||||
if sd_key in state_dict:
|
||||
del state_dict[sd_key]
|
||||
|
||||
rep = []
|
||||
if state_dict_map is not None:
|
||||
# 'patch_embed.conv1.' : 'patch_embed.conv.'
|
||||
for state_k in list(state_dict.keys()):
|
||||
for target_k, target_v in state_dict_map.items():
|
||||
if target_v in state_k:
|
||||
target_name = state_k.replace(target_v, target_k)
|
||||
state_dict[target_name] = state_dict[state_k]
|
||||
rep.append(state_k)
|
||||
for r in rep:
|
||||
if r in state_dict:
|
||||
del state_dict[r]
|
||||
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False)
|
||||
return incompatible_keys
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
pretrained: bool = False,
|
||||
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
||||
checkpoint_path: str = "",
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
filter_keys=None,
|
||||
state_dict_map=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a model
|
||||
Lookup model's entrypoint function and pass relevant args to create a new model.
|
||||
"""
|
||||
# Parameters that aren't supported by all models or are intended to only override model defaults if set
|
||||
# should default to None in command line args/cfg. Remove them if they are present and not set so that
|
||||
# non-supporting models don't break and default args remain in effect.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
model_source, model_name = parse_model_name(model_name)
|
||||
if model_source == "hf-hub":
|
||||
assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub."
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
|
||||
else:
|
||||
model_name, pretrained_tag = split_model_name_tag(model_name)
|
||||
if not pretrained_cfg:
|
||||
# a valid pretrained_cfg argument takes priority over tag in model name
|
||||
pretrained_cfg = pretrained_tag
|
||||
|
||||
if not is_model(model_name):
|
||||
raise RuntimeError("Unknown model (%s)" % model_name)
|
||||
|
||||
create_fn = model_entrypoint(model_name)
|
||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||
model = create_fn(
|
||||
pretrained=pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map)
|
||||
|
||||
return model
|
||||
116
src/facebias/estimators/mivolo/cross_bottleneck_attn.py
Normal file
116
src/facebias/estimators/mivolo/cross_bottleneck_attn.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Code based on timm https://github.com/huggingface/pytorch-image-models
|
||||
|
||||
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.layers.bottleneck_attn import PosEmbedRel
|
||||
from timm.layers.helpers import make_divisible
|
||||
from timm.layers.mlp import Mlp
|
||||
from timm.layers.trace_utils import _assert
|
||||
from timm.layers.weight_init import trunc_normal_
|
||||
|
||||
|
||||
class CrossBottleneckAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out=None,
|
||||
feat_size=None,
|
||||
stride=1,
|
||||
num_heads=4,
|
||||
dim_head=None,
|
||||
qk_ratio=1.0,
|
||||
qkv_bias=False,
|
||||
scale_pos_embed=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
|
||||
dim_out = dim_out or dim
|
||||
assert dim_out % num_heads == 0
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
|
||||
self.dim_head_v = dim_out // self.num_heads
|
||||
self.dim_out_qk = num_heads * self.dim_head_qk
|
||||
self.dim_out_v = num_heads * self.dim_head_v
|
||||
self.scale = self.dim_head_qk**-0.5
|
||||
self.scale_pos_embed = scale_pos_embed
|
||||
|
||||
self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
|
||||
self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
|
||||
|
||||
# NOTE I'm only supporting relative pos embedding for now
|
||||
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
|
||||
|
||||
self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
|
||||
mlp_ratio = 4
|
||||
self.mlp = Mlp(
|
||||
in_features=self.dim_out_v * 2,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=nn.GELU,
|
||||
out_features=dim_out,
|
||||
drop=0,
|
||||
use_conv=True,
|
||||
)
|
||||
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
|
||||
trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
|
||||
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
||||
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
||||
|
||||
def get_qkv(self, x, qvk_conv):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
|
||||
|
||||
q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
|
||||
|
||||
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
|
||||
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
|
||||
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def apply_attn(self, q, k, v, B, H, W, dropout=None):
|
||||
if self.scale_pos_embed:
|
||||
attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
|
||||
else:
|
||||
attn = (q @ k) * self.scale + self.pos_embed(q)
|
||||
attn = attn.softmax(dim=-1)
|
||||
if dropout:
|
||||
attn = dropout(attn)
|
||||
|
||||
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
|
||||
return out
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
dim = int(C / 2)
|
||||
x1 = x[:, :dim, :, :]
|
||||
x2 = x[:, dim:, :, :]
|
||||
|
||||
_assert(H == self.pos_embed.height, "")
|
||||
_assert(W == self.pos_embed.width, "")
|
||||
|
||||
q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
|
||||
q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
|
||||
|
||||
# person to face
|
||||
out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
|
||||
# face to person
|
||||
out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
|
||||
|
||||
x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
|
||||
x_pf = self.norm(x_pf)
|
||||
x_pf = self.mlp(x_pf) # B, dim_out, H, W
|
||||
|
||||
out = self.pool(x_pf)
|
||||
return out
|
||||
313
src/facebias/estimators/mivolo/mi_volo.py
Normal file
313
src/facebias/estimators/mivolo/mi_volo.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from facebias.estimators.mivolo.create_timm_model import create_model
|
||||
from timm.data import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD,
|
||||
resolve_data_config)
|
||||
|
||||
_logger = logging.getLogger("MiVOLO")
|
||||
has_compile = hasattr(torch, "compile")
|
||||
|
||||
|
||||
def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True):
|
||||
# Resize and pad image while meeting stride-multiple constraints
|
||||
shape = im.shape[:2] # current shape [height, width]
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]:
|
||||
return im
|
||||
|
||||
# Scale ratio (new / old)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
||||
r = min(r, 1.0)
|
||||
|
||||
# Compute padding
|
||||
# ratio = r, r # width, height ratios
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
||||
return im
|
||||
|
||||
|
||||
def prepare_classification_images(
|
||||
img_list: List[Optional[np.ndarray]],
|
||||
target_size: int = 224,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
device=None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
prepared_images: List[torch.Tensor] = []
|
||||
|
||||
for img in img_list:
|
||||
if img is None:
|
||||
img = torch.zeros((3, target_size, target_size), dtype=torch.float32)
|
||||
img = F.normalize(img, mean=mean, std=std)
|
||||
img = img.unsqueeze(0)
|
||||
prepared_images.append(img)
|
||||
continue
|
||||
|
||||
img = class_letterbox(img, new_shape=(target_size, target_size))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img = img / 255.0
|
||||
img = (img - mean) / std
|
||||
img = img.astype(dtype=np.float32)
|
||||
|
||||
img = img.transpose((2, 0, 1))
|
||||
img = np.ascontiguousarray(img)
|
||||
img = torch.from_numpy(img)
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
prepared_images.append(img)
|
||||
|
||||
if len(prepared_images) == 0:
|
||||
return None
|
||||
|
||||
prepared_input = torch.concat(prepared_images)
|
||||
|
||||
if device:
|
||||
prepared_input = prepared_input.to(device)
|
||||
|
||||
return prepared_input
|
||||
|
||||
|
||||
class Meta:
|
||||
def __init__(self):
|
||||
self.min_age = None
|
||||
self.max_age = None
|
||||
self.avg_age = None
|
||||
self.num_classes = None
|
||||
|
||||
self.in_chans = 3
|
||||
self.with_persons_model = False
|
||||
self.disable_faces = False
|
||||
self.use_persons = True
|
||||
self.only_age = False
|
||||
|
||||
self.num_classes_gender = 2
|
||||
self.input_size = 224
|
||||
|
||||
def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta":
|
||||
|
||||
state = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
self.min_age = state["min_age"]
|
||||
self.max_age = state["max_age"]
|
||||
self.avg_age = state["avg_age"]
|
||||
self.only_age = state["no_gender"]
|
||||
|
||||
only_age = state["no_gender"]
|
||||
|
||||
self.disable_faces = disable_faces
|
||||
if "with_persons_model" in state:
|
||||
self.with_persons_model = state["with_persons_model"]
|
||||
else:
|
||||
self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False
|
||||
|
||||
self.num_classes = 1 if only_age else 3
|
||||
self.in_chans = 3 if not self.with_persons_model else 6
|
||||
self.use_persons = use_persons and self.with_persons_model
|
||||
|
||||
if not self.with_persons_model and self.disable_faces:
|
||||
raise ValueError("You can not use disable-faces for faces-only model")
|
||||
if self.with_persons_model and self.disable_faces and not self.use_persons:
|
||||
raise ValueError(
|
||||
"You can not disable faces and persons together. "
|
||||
"Set --with-persons if you want to run with --disable-faces"
|
||||
)
|
||||
self.input_size = state["state_dict"]["pos_embed"].shape[1] * 16
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
attrs = vars(self)
|
||||
attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops})
|
||||
return ", ".join("%s: %s" % item for item in attrs.items())
|
||||
|
||||
@property
|
||||
def use_person_crops(self) -> bool:
|
||||
return self.with_persons_model and self.use_persons
|
||||
|
||||
@property
|
||||
def use_face_crops(self) -> bool:
|
||||
return not self.disable_faces or not self.with_persons_model
|
||||
|
||||
|
||||
class MiVOLO:
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_path: str,
|
||||
device: str = "cuda",
|
||||
half: bool = True,
|
||||
disable_faces: bool = False,
|
||||
use_persons: bool = False,
|
||||
verbose: bool = False,
|
||||
torchcompile: Optional[str] = None,
|
||||
):
|
||||
self.verbose = verbose
|
||||
self.device = torch.device(device)
|
||||
self.half = half and self.device.type != "cpu"
|
||||
|
||||
self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons)
|
||||
if self.verbose:
|
||||
_logger.info(f"Model meta:\n{str(self.meta)}")
|
||||
|
||||
model_name = f"mivolo_d1_{self.meta.input_size}"
|
||||
self.model = create_model(
|
||||
model_name=model_name,
|
||||
num_classes=self.meta.num_classes,
|
||||
in_chans=self.meta.in_chans,
|
||||
pretrained=False,
|
||||
checkpoint_path=ckpt_path,
|
||||
filter_keys=["fds."],
|
||||
)
|
||||
self.param_count = sum([m.numel() for m in self.model.parameters()])
|
||||
_logger.info(f"Model {model_name} created, param count: {self.param_count}")
|
||||
|
||||
self.data_config = resolve_data_config(
|
||||
model=self.model,
|
||||
verbose=verbose,
|
||||
use_test_size=True,
|
||||
)
|
||||
|
||||
self.data_config["crop_pct"] = 1.0
|
||||
c, h, w = self.data_config["input_size"]
|
||||
assert h == w, "Incorrect data_config"
|
||||
self.input_size = w
|
||||
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
if torchcompile:
|
||||
assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly."
|
||||
torch._dynamo.reset()
|
||||
self.model = torch.compile(self.model, backend=torchcompile)
|
||||
|
||||
self.model.eval()
|
||||
if self.half:
|
||||
self.model = self.model.half()
|
||||
|
||||
def warmup(self, batch_size: int, steps=10):
|
||||
if self.meta.with_persons_model:
|
||||
input_size = (6, self.input_size, self.input_size)
|
||||
else:
|
||||
input_size = self.data_config["input_size"]
|
||||
|
||||
input = torch.randn((batch_size,) + tuple(input_size)).to(self.device)
|
||||
|
||||
for _ in range(steps):
|
||||
out = self.inference(input) # noqa: F841
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def inference(self, model_input: torch.Tensor) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
if self.half:
|
||||
model_input = model_input.half()
|
||||
output = self.model(model_input)
|
||||
return output
|
||||
|
||||
# def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
|
||||
# if (
|
||||
# (detected_bboxes.n_objects == 0)
|
||||
# or (not self.meta.use_persons and detected_bboxes.n_faces == 0)
|
||||
# or (self.meta.disable_faces and detected_bboxes.n_persons == 0)
|
||||
# ):
|
||||
# # nothing to process
|
||||
# return
|
||||
|
||||
# faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes)
|
||||
|
||||
# if faces_input is None and person_input is None:
|
||||
# # nothing to process
|
||||
# return
|
||||
|
||||
# if self.meta.with_persons_model:
|
||||
# model_input = torch.cat((faces_input, person_input), dim=1)
|
||||
# else:
|
||||
# model_input = faces_input
|
||||
# output = self.inference(model_input)
|
||||
|
||||
# # write gender and age results into detected_bboxes
|
||||
# self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
|
||||
|
||||
def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
|
||||
if self.meta.only_age:
|
||||
age_output = output
|
||||
gender_probs, gender_indx = None, None
|
||||
else:
|
||||
age_output = output[:, 2]
|
||||
gender_output = output[:, :2].softmax(-1)
|
||||
gender_probs, gender_indx = gender_output.topk(1)
|
||||
|
||||
assert output.shape[0] == len(faces_inds) == len(bodies_inds)
|
||||
|
||||
# per face
|
||||
for index in range(output.shape[0]):
|
||||
face_ind = faces_inds[index]
|
||||
body_ind = bodies_inds[index]
|
||||
|
||||
# get_age
|
||||
age = age_output[index].item()
|
||||
age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
|
||||
age = round(age, 2)
|
||||
|
||||
detected_bboxes.set_age(face_ind, age)
|
||||
detected_bboxes.set_age(body_ind, age)
|
||||
|
||||
_logger.info(f"\tage: {age}")
|
||||
|
||||
if gender_probs is not None:
|
||||
gender = "male" if gender_indx[index].item() == 0 else "female"
|
||||
gender_score = gender_probs[index].item()
|
||||
|
||||
_logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
|
||||
|
||||
detected_bboxes.set_gender(face_ind, gender, gender_score)
|
||||
detected_bboxes.set_gender(body_ind, gender, gender_score)
|
||||
|
||||
# def prepare_crops_old(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
|
||||
|
||||
# crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
|
||||
# (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(
|
||||
# self.meta.use_person_crops, self.meta.use_face_crops
|
||||
# )
|
||||
|
||||
# if not self.meta.use_face_crops:
|
||||
# assert all(f is None for f in faces_crops)
|
||||
|
||||
# faces_input = prepare_classification_images(
|
||||
# faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
|
||||
# )
|
||||
|
||||
# if not self.meta.use_person_crops:
|
||||
# assert all(p is None for p in bodies_crops)
|
||||
|
||||
# person_input = prepare_classification_images(
|
||||
# bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
|
||||
# )
|
||||
|
||||
# _logger.info(
|
||||
# f"faces_input: {faces_input.shape if faces_input is not None else None}, "
|
||||
# f"person_input: {person_input.shape if person_input is not None else None}"
|
||||
# )
|
||||
|
||||
# return faces_input, person_input, faces_inds, bodies_inds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = MiVOLO("models/volo-v1_model_imdb_age_gender_4.22.pth.tar", half=True, device="cpu")
|
||||
405
src/facebias/estimators/mivolo/mivolo_model.py
Normal file
405
src/facebias/estimators/mivolo/mivolo_model.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Code adapted from timm https://github.com/huggingface/pytorch-image-models
|
||||
|
||||
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from facebias.estimators.mivolo.cross_bottleneck_attn import \
|
||||
CrossBottleneckAttn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_
|
||||
from timm.models._builder import build_model_with_cfg
|
||||
from timm.models._registry import register_model
|
||||
from timm.models.volo import VOLO
|
||||
|
||||
__all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
"num_classes": 1000,
|
||||
"input_size": (3, 224, 224),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.96,
|
||||
"interpolation": "bicubic",
|
||||
"fixed_input_size": True,
|
||||
"mean": IMAGENET_DEFAULT_MEAN,
|
||||
"std": IMAGENET_DEFAULT_STD,
|
||||
"first_conv": None,
|
||||
"classifier": ("head", "aux_head"),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
"mivolo_d1_224": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96
|
||||
),
|
||||
"mivolo_d1_384": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar",
|
||||
crop_pct=1.0,
|
||||
input_size=(3, 384, 384),
|
||||
),
|
||||
"mivolo_d2_224": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96
|
||||
),
|
||||
"mivolo_d2_384": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar",
|
||||
crop_pct=1.0,
|
||||
input_size=(3, 384, 384),
|
||||
),
|
||||
"mivolo_d3_224": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96
|
||||
),
|
||||
"mivolo_d3_448": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar",
|
||||
crop_pct=1.0,
|
||||
input_size=(3, 448, 448),
|
||||
),
|
||||
"mivolo_d4_224": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96
|
||||
),
|
||||
"mivolo_d4_448": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar",
|
||||
crop_pct=1.15,
|
||||
input_size=(3, 448, 448),
|
||||
),
|
||||
"mivolo_d5_224": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96
|
||||
),
|
||||
"mivolo_d5_448": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar",
|
||||
crop_pct=1.15,
|
||||
input_size=(3, 448, 448),
|
||||
),
|
||||
"mivolo_d5_512": _cfg(
|
||||
url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar",
|
||||
crop_pct=1.15,
|
||||
input_size=(3, 512, 512),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_output_size(input_shape, conv_layer):
|
||||
padding = conv_layer.padding
|
||||
dilation = conv_layer.dilation
|
||||
kernel_size = conv_layer.kernel_size
|
||||
stride = conv_layer.stride
|
||||
|
||||
output_size = [
|
||||
((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2)
|
||||
]
|
||||
return output_size
|
||||
|
||||
|
||||
def get_output_size_module(input_size, stem):
|
||||
output_size = input_size
|
||||
|
||||
for module in stem:
|
||||
if isinstance(module, nn.Conv2d):
|
||||
output_size = [
|
||||
(
|
||||
(output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1)
|
||||
// module.stride[i]
|
||||
)
|
||||
+ 1
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
return output_size
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding."""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384
|
||||
):
|
||||
super().__init__()
|
||||
assert patch_size in [4, 8, 16]
|
||||
assert in_chans in [3, 6]
|
||||
self.with_persons_model = in_chans == 6
|
||||
self.use_cross_attn = True
|
||||
|
||||
if stem_conv:
|
||||
if not self.with_persons_model:
|
||||
self.conv = self.create_stem(stem_stride, in_chans, hidden_dim)
|
||||
else:
|
||||
self.conv = True # just to match interface
|
||||
# split
|
||||
self.conv1 = self.create_stem(stem_stride, 3, hidden_dim)
|
||||
self.conv2 = self.create_stem(stem_stride, 3, hidden_dim)
|
||||
else:
|
||||
self.conv = None
|
||||
|
||||
if self.with_persons_model:
|
||||
|
||||
self.proj1 = nn.Conv2d(
|
||||
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
|
||||
)
|
||||
self.proj2 = nn.Conv2d(
|
||||
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
|
||||
)
|
||||
|
||||
stem_out_shape = get_output_size_module((img_size, img_size), self.conv1)
|
||||
self.proj_output_size = get_output_size(stem_out_shape, self.proj1)
|
||||
|
||||
self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size)
|
||||
|
||||
else:
|
||||
self.proj = nn.Conv2d(
|
||||
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
|
||||
)
|
||||
|
||||
self.patch_dim = img_size // patch_size
|
||||
self.num_patches = self.patch_dim**2
|
||||
|
||||
def create_stem(self, stem_stride, in_chans, hidden_dim):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv is not None:
|
||||
if self.with_persons_model:
|
||||
x1 = x[:, :3]
|
||||
x2 = x[:, 3:]
|
||||
|
||||
x1 = self.conv1(x1)
|
||||
x1 = self.proj1(x1)
|
||||
|
||||
x2 = self.conv2(x2)
|
||||
x2 = self.proj2(x2)
|
||||
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
x = self.map(x)
|
||||
else:
|
||||
x = self.conv(x)
|
||||
x = self.proj(x) # B, C, H, W
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MiVOLOModel(VOLO):
|
||||
"""
|
||||
Vision Outlooker, the main class of our model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layers,
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool="token",
|
||||
patch_size=8,
|
||||
stem_hidden_dim=64,
|
||||
embed_dims=None,
|
||||
num_heads=None,
|
||||
downsamples=(True, False, False, False),
|
||||
outlook_attention=(True, False, False, False),
|
||||
mlp_ratio=3.0,
|
||||
qkv_bias=False,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
post_layers=("ca", "ca"),
|
||||
use_aux_head=True,
|
||||
use_mix_token=False,
|
||||
pooling_scale=2,
|
||||
):
|
||||
super(MiVOLOModel, self).__init__(
|
||||
layers=layers,
|
||||
img_size=img_size,
|
||||
in_chans=in_chans,
|
||||
num_classes=num_classes,
|
||||
global_pool=global_pool,
|
||||
patch_size=patch_size,
|
||||
stem_hidden_dim=stem_hidden_dim,
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
downsamples=downsamples,
|
||||
outlook_attention=outlook_attention,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_layer=norm_layer,
|
||||
post_layers=post_layers,
|
||||
use_aux_head=use_aux_head,
|
||||
use_mix_token=use_mix_token,
|
||||
pooling_scale=pooling_scale,
|
||||
)
|
||||
|
||||
im_size = img_size[0] if isinstance(img_size, tuple) else img_size
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=im_size,
|
||||
stem_conv=True,
|
||||
stem_stride=2,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
hidden_dim=stem_hidden_dim,
|
||||
embed_dim=embed_dims[0],
|
||||
)
|
||||
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
|
||||
|
||||
# step2: tokens learning in the two stages
|
||||
x = self.forward_tokens(x)
|
||||
|
||||
# step3: post network, apply class attention or not
|
||||
if self.post_network is not None:
|
||||
x = self.forward_cls(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None):
|
||||
if self.global_pool == "avg":
|
||||
out = x.mean(dim=1)
|
||||
elif self.global_pool == "token":
|
||||
out = x[:, 0]
|
||||
else:
|
||||
out = x
|
||||
if pre_logits:
|
||||
return out
|
||||
|
||||
features = out
|
||||
fds_enabled = hasattr(self, "_fds_forward")
|
||||
if fds_enabled:
|
||||
features = self._fds_forward(features, targets, epoch)
|
||||
|
||||
out = self.head(features)
|
||||
if self.aux_head is not None:
|
||||
# generate classes in all feature tokens, see token labeling
|
||||
aux = self.aux_head(x[:, 1:])
|
||||
out = out + 0.5 * aux.max(1)[0]
|
||||
|
||||
return (out, features) if (fds_enabled and self.training) else out
|
||||
|
||||
def forward(self, x, targets=None, epoch=None):
|
||||
"""simplified forward (without mix token training)"""
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x, targets=targets, epoch=epoch)
|
||||
return x
|
||||
|
||||
|
||||
def _create_mivolo(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get("features_only", None):
|
||||
raise RuntimeError("features_only not implemented for Vision Transformer models.")
|
||||
return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d1_224(pretrained=False, **kwargs):
|
||||
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
|
||||
model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d1_384(pretrained=False, **kwargs):
|
||||
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
|
||||
model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d2_224(pretrained=False, **kwargs):
|
||||
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
||||
model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d2_384(pretrained=False, **kwargs):
|
||||
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
||||
model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d3_224(pretrained=False, **kwargs):
|
||||
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
||||
model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d3_448(pretrained=False, **kwargs):
|
||||
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
||||
model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d4_224(pretrained=False, **kwargs):
|
||||
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
|
||||
model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d4_448(pretrained=False, **kwargs):
|
||||
"""VOLO-D4 model, Params: 193M"""
|
||||
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
|
||||
model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d5_224(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
layers=(12, 12, 20, 4),
|
||||
embed_dims=(384, 768, 768, 768),
|
||||
num_heads=(12, 16, 16, 16),
|
||||
mlp_ratio=4,
|
||||
stem_hidden_dim=128,
|
||||
**kwargs
|
||||
)
|
||||
model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d5_448(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
layers=(12, 12, 20, 4),
|
||||
embed_dims=(384, 768, 768, 768),
|
||||
num_heads=(12, 16, 16, 16),
|
||||
mlp_ratio=4,
|
||||
stem_hidden_dim=128,
|
||||
**kwargs
|
||||
)
|
||||
model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mivolo_d5_512(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
layers=(12, 12, 20, 4),
|
||||
embed_dims=(384, 768, 768, 768),
|
||||
num_heads=(12, 16, 16, 16),
|
||||
mlp_ratio=4,
|
||||
stem_hidden_dim=128,
|
||||
**kwargs
|
||||
)
|
||||
model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args)
|
||||
return model
|
||||
Reference in New Issue
Block a user