| from typing import Optional |
| from omegaconf import DictConfig |
| import pytorch_lightning as L |
| import torch.utils.data as torchdata |
| from .torch import collate, worker_init_fn |
|
|
|
|
| def get_dataset(name): |
| if name == "mapillary": |
| from .mapillary.data_module import MapillaryDataModule |
| return MapillaryDataModule |
| elif name == "nuscenes": |
| from .nuscenes.data_module import NuScenesData |
| return NuScenesData |
| elif name == "kitti": |
| from .kitti.data_module import BEVKitti360Data |
| return BEVKitti360Data |
| else: |
| raise NotImplementedError(f"Dataset {name} not implemented.") |
|
|
|
|
| class GenericDataModule(L.LightningDataModule): |
| def __init__(self, cfg: DictConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.data_module = get_dataset(cfg.name)(cfg) |
|
|
| def prepare_data(self) -> None: |
| self.data_module.prepare_data() |
| |
| def setup(self, stage: Optional[str] = None): |
| self.data_module.setup(stage) |
|
|
| def dataloader( |
| self, |
| stage: str, |
| shuffle: bool = False, |
| num_workers: int = None, |
| sampler: Optional[torchdata.Sampler] = None, |
| ): |
| dataset = self.data_module.dataset(stage) |
| cfg = self.cfg["loading"][stage] |
| num_workers = cfg["num_workers"] if num_workers is None else num_workers |
| loader = torchdata.DataLoader( |
| dataset, |
| batch_size=cfg["batch_size"], |
| num_workers=num_workers, |
| shuffle=shuffle or (stage == "train"), |
| pin_memory=True, |
| persistent_workers=num_workers > 0, |
| worker_init_fn=worker_init_fn, |
| collate_fn=collate, |
| sampler=sampler, |
| ) |
| return loader |
| |
| def train_dataloader(self, **kwargs): |
| return self.dataloader("train", **kwargs) |
|
|
| def val_dataloader(self, **kwargs): |
| return self.dataloader("val", **kwargs) |
|
|
| def test_dataloader(self, **kwargs): |
| return self.dataloader("test", **kwargs) |