| import multiprocessing as mp |
| import pathlib |
| from typing import Any |
|
|
| import datasets |
| from PIL import Image |
| import torch |
| from torch.utils.data import DataLoader |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
|
|
| from src import config |
| from src import tokenizer as tk |
|
|
|
|
| class CaptionDatset(Dataset): |
| def __init__(self, dataset: datasets.Dataset, img_path: pathlib.Path) -> None: |
| self.dataset = dataset |
| self.img_path = img_path |
|
|
| def __len__(self) -> int: |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx: int) -> dict[str, Any]: |
| item = self.dataset[idx] |
| image = Image.open(self.img_path / item["url"].rsplit("/", 1)[-1]).convert("RGB") |
| return {"image": image, "caption": item["short_caption"]} |
|
|
|
|
| class CollateFn: |
| def __init__(self, tokenizer: tk.Tokenizer, transform: transforms.Compose): |
| self.tokenizer = tokenizer |
| self.transform = transform |
|
|
| def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]: |
| stacked_images = torch.stack([self.transform(item["image"]) for item in batch]) |
| tokenized_text = self.tokenizer([item["caption"] for item in batch]) |
|
|
| return { |
| "images": stacked_images, |
| **tokenized_text, |
| } |
|
|
|
|
| def _get_dataloaders( |
| train_ds: Dataset, |
| valid_ds: Dataset, |
| training_config: config.TrainerConfig, |
| collate_fn: CollateFn, |
| ) -> tuple[DataLoader, DataLoader]: |
| common_params = { |
| "batch_size": training_config.batch_size, |
| "pin_memory": True, |
| "num_workers": mp.cpu_count() // 3, |
| "collate_fn": collate_fn, |
| } |
| train_loader = DataLoader( |
| train_ds, |
| shuffle=True, |
| drop_last=True, |
| **common_params, |
| ) |
| valid_loader = DataLoader( |
| valid_ds, |
| shuffle=False, |
| drop_last=False, |
| **common_params, |
| ) |
| return train_loader, valid_loader |
|
|
|
|
| def get_dataset( |
| transform: transforms.Compose, |
| tokenizer: tk.Tokenizer, |
| hyper_parameters: config.TrainerConfig, |
| ) -> tuple[DataLoader, DataLoader]: |
| dataset: datasets.Dataset = datasets.load_dataset( |
| hyper_parameters._data_config.dataset, split="train" |
| ) |
| train_test_dataset = dataset.train_test_split(seed=42, test_size=0.1) |
| train_ds = CaptionDatset(train_test_dataset["train"], config.IMAGE_DOWNLOAD_PATH) |
| valid_ds = CaptionDatset(train_test_dataset["test"], config.IMAGE_DOWNLOAD_PATH) |
| collate_fn = CollateFn(tokenizer, transform) |
|
|
| return _get_dataloaders( |
| train_ds=train_ds, |
| valid_ds=valid_ds, |
| training_config=hyper_parameters, |
| collate_fn=collate_fn, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| |
| import os |
|
|
| from tqdm.auto import tqdm |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| hyper_parameters = config.TrainerConfig() |
| transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()]) |
| tokenizer = tk.Tokenizer( |
| hyper_parameters._model_config.text_model, hyper_parameters._model_config.max_len |
| ) |
| train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters) |
|
|
| batch = next(iter(train_dl)) |
| print({k: v.shape for k, v in batch.items()}) |
|
|
| for batch in tqdm(train_dl): |
| continue |
|
|