Spaces:
Sleeping
Sleeping
File size: 1,848 Bytes
b8c9192 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | """DeepSeeNet model definition."""
from torch import Tensor, nn
try:
import timm
except ImportError: # pragma: no cover - handled when timm is absent.
timm = None
class DeepSeeNet(nn.Module):
"""DeepSeeNet risk-factor classifier in PyTorch.
Args:
n_classes: Number of output classes.
backbone: Any timm model name that supports ``num_classes=0``. The
default uses InceptionV3.
pretrained: Load ImageNet weights for the backbone.
dropout: Dropout probability used by the classifier head.
freeze_backbone: If true, keep the backbone frozen and train only the
classifier head.
"""
def __init__(
self,
n_classes: int = 2,
backbone: str = "inception_v3",
pretrained: bool = True,
dropout: float = 0.5,
freeze_backbone: bool = False,
) -> None:
super().__init__()
if n_classes < 1:
raise ValueError("n_classes must be positive")
if timm is None:
raise ImportError("timm is required to build DeepSeeNet")
self.backbone_name = backbone
self.backbone = timm.create_model(
backbone,
pretrained=pretrained,
num_classes=0,
global_pool="avg",
)
in_features = self.backbone.num_features
self.classifier = nn.Sequential(
nn.Linear(in_features, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(128, n_classes),
)
if freeze_backbone:
self.backbone.requires_grad_(False)
def forward(self, x: Tensor) -> Tensor:
features = self.backbone(x)
return self.classifier(features)
|