Removed commented code, added docstrings and references.
This commit is contained in:
@@ -10,10 +10,7 @@ from typing import Any, Callable, Optional
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, # Set the minimum logging level
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -83,15 +80,6 @@ def load_dataset(
|
||||
im = cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB)
|
||||
except cv2.error:
|
||||
logger.info(f'File "{p}" is not an image. Skipping.')
|
||||
# if meta_path is None:
|
||||
# logger.info(f'Trying to read "{p}" as metadata.')
|
||||
# if p.suffix == ".csv":
|
||||
# metadata = load_metadata(p)
|
||||
# logger.info("Metadata read successfully.")
|
||||
# else:
|
||||
# logger.error(f'Failed to read "{p}" as metadata. Skipping.')
|
||||
# elif not metadata:
|
||||
# logger.critical("Logic error: Metadata should have been read already.")
|
||||
else:
|
||||
proc_imname = imname_proc_fn(p.name) if imname_proc_fn is not None else p.name
|
||||
ims[proc_imname] = im
|
||||
|
||||
@@ -14,9 +14,31 @@ from torchvision import transforms
|
||||
|
||||
|
||||
class FairFace(BaseEstimator):
|
||||
"""FairFace trained model for age-group, sex and ethnicity estimation.
|
||||
|
||||
Note that the ethnicity estimation is disabled, since each model estimates
|
||||
a completely different set of classes. Additionally, there is an argument
|
||||
that ethnicity is extremelly difficult to estimate given a single picture
|
||||
of a person, thus it is better to estimate skin color, which this model
|
||||
does not do.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
checkpoint_path: Optional[Path]
|
||||
Path to the model weights stored as a PyTorch file.
|
||||
|
||||
device: Optional[torch.device]
|
||||
The device to load the model into. By default is `"cpu"`
|
||||
|
||||
Reference
|
||||
---------
|
||||
FairFace repository: https://github.com/dchen236/FairFace
|
||||
"""
|
||||
"""
|
||||
def __init__(self, checkpoint_path: Optional[Path]=None, device=torch.device("cpu")):
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: Optional[Path]=None,
|
||||
device: Optional[torch.device]=torch.device("cpu")
|
||||
):
|
||||
self.T = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((224, 224)),
|
||||
@@ -30,12 +52,12 @@ class FairFace(BaseEstimator):
|
||||
self.model = self.model.eval().to(device)
|
||||
if checkpoint_path is not None:
|
||||
self.load_state_dict(checkpoint_path)
|
||||
self.model.to(device)
|
||||
|
||||
def load_state_dict(self, checkpoint_path: Path) -> "FairFace":
|
||||
self.model.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
)
|
||||
self.model.to(self.device)
|
||||
return self
|
||||
|
||||
def predict(self, ims: dict[Path, np.ndarray]) -> dict[Path, dict[Capability, Any]]:
|
||||
|
||||
@@ -26,6 +26,10 @@ class MiVOLOv1(BaseEstimator):
|
||||
|
||||
device: Optional[torch.device]
|
||||
Device to load the checkpoints into. By default is "cpu".
|
||||
|
||||
Reference
|
||||
---------
|
||||
MiVOLO repository: https://github.com/wildchlamydia/mivolo
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user