# https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/dataset.py#L373 class TrainingValidationDataset(Dataset): def __init__( self, images: Union[List, np.ndarray], targets: Optional[Union[List, np.ndarray]], quality: Union[List, np.ndarray], bits: Optional[Union[List, np.ndarray]], transform: Union[A.Compose, A.BasicTransform], features: List[str], ): """ :param obliterate - Augmentation that destroys embedding. """ if targets is not None: if len(images) != len(targets): raise ValueError(f"Size of images and targets does not match: {len(images)} {len(targets)}") self.images = images self.targets = targets self.transform = transform self.features = features self.quality = quality self.bits = bits def __len__(self): return len(self.images) def __repr__(self): return f"TrainingValidationDataset(len={len(self)}, targets_hist={np.bincount(self.targets)}, qf={np.bincount(self.quality)}, features={self.features})" def __getitem__(self, index): image_fname = self.images[index] try: image = cv2.imread(image_fname) if image is None: raise FileNotFoundError(image_fname) except Exception as e: print("Cannot read image ", image_fname, "at index", index) print(e) qf = self.quality[index] data = {} data["image"] = image data.update(compute_features(image, image_fname, self.features)) data = self.transform(**data) sample = {INPUT_IMAGE_ID_KEY: os.path.basename(self.images[index]), INPUT_IMAGE_QF_KEY: int(qf)} if self.bits is not None: # OK sample[INPUT_TRUE_PAYLOAD_BITS] = torch.tensor(self.bits[index], dtype=torch.float32) if self.targets is not None: target = int(self.targets[index]) sample[INPUT_TRUE_MODIFICATION_TYPE] = target sample[INPUT_TRUE_MODIFICATION_FLAG] = torch.tensor([target > 0]).float() for key, value in data.items(): if key in self.features: sample[key] = tensor_from_rgb_image(value) return sample