Removed commented code, added docstrings and references.

This commit is contained in:
2026-04-16 17:35:18 +01:00
parent b828450a09
commit 8b241e09cd
3 changed files with 30 additions and 16 deletions

View File

@@ -10,10 +10,7 @@ from typing import Any, Callable, Optional
import cv2
import numpy as np
logging.basicConfig(
level=logging.DEBUG, # Set the minimum logging level
)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@@ -83,15 +80,6 @@ def load_dataset(
im = cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB)
except cv2.error:
logger.info(f'File "{p}" is not an image. Skipping.')
# if meta_path is None:
# logger.info(f'Trying to read "{p}" as metadata.')
# if p.suffix == ".csv":
# metadata = load_metadata(p)
# logger.info("Metadata read successfully.")
# else:
# logger.error(f'Failed to read "{p}" as metadata. Skipping.')
# elif not metadata:
# logger.critical("Logic error: Metadata should have been read already.")
else:
proc_imname = imname_proc_fn(p.name) if imname_proc_fn is not None else p.name
ims[proc_imname] = im

View File

@@ -14,9 +14,31 @@ from torchvision import transforms
class FairFace(BaseEstimator):
"""FairFace trained model for age-group, sex and ethnicity estimation.
Note that the ethnicity estimation is disabled, since each model estimates
a completely different set of classes. Additionally, there is an argument
that ethnicity is extremelly difficult to estimate given a single picture
of a person, thus it is better to estimate skin color, which this model
does not do.
Parameters
----------
checkpoint_path: Optional[Path]
Path to the model weights stored as a PyTorch file.
device: Optional[torch.device]
The device to load the model into. By default is `"cpu"`
Reference
---------
FairFace repository: https://github.com/dchen236/FairFace
"""
"""
def __init__(self, checkpoint_path: Optional[Path]=None, device=torch.device("cpu")):
def __init__(
self,
checkpoint_path: Optional[Path]=None,
device: Optional[torch.device]=torch.device("cpu")
):
self.T = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
@@ -30,12 +52,12 @@ class FairFace(BaseEstimator):
self.model = self.model.eval().to(device)
if checkpoint_path is not None:
self.load_state_dict(checkpoint_path)
self.model.to(device)
def load_state_dict(self, checkpoint_path: Path) -> "FairFace":
self.model.load_state_dict(
torch.load(checkpoint_path, map_location=torch.device("cpu"))
)
self.model.to(self.device)
return self
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:

View File

@@ -26,6 +26,10 @@ class MiVOLOv1(BaseEstimator):
device: Optional[torch.device]
Device to load the checkpoints into. By default is "cpu".
Reference
---------
MiVOLO repository: https://github.com/wildchlamydia/mivolo
"""
def __init__(
self,