Created
August 19, 2020 08:53
-
-
Save BloodAxe/70265dcab73c9a078a0928223c244c39 to your computer and use it in GitHub Desktop.
Revisions
-
BloodAxe revised this gist
Aug 19, 2020 . 1 changed file with 1 addition and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,4 @@ # https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/dataset.py#L373 class TrainingValidationDataset(Dataset): def __init__( self, -
BloodAxe created this gist
Aug 19, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,63 @@ class TrainingValidationDataset(Dataset): def __init__( self, images: Union[List, np.ndarray], targets: Optional[Union[List, np.ndarray]], quality: Union[List, np.ndarray], bits: Optional[Union[List, np.ndarray]], transform: Union[A.Compose, A.BasicTransform], features: List[str], ): """ :param obliterate - Augmentation that destroys embedding. """ if targets is not None: if len(images) != len(targets): raise ValueError(f"Size of images and targets does not match: {len(images)} {len(targets)}") self.images = images self.targets = targets self.transform = transform self.features = features self.quality = quality self.bits = bits def __len__(self): return len(self.images) def __repr__(self): return f"TrainingValidationDataset(len={len(self)}, targets_hist={np.bincount(self.targets)}, qf={np.bincount(self.quality)}, features={self.features})" def __getitem__(self, index): image_fname = self.images[index] try: image = cv2.imread(image_fname) if image is None: raise FileNotFoundError(image_fname) except Exception as e: print("Cannot read image ", image_fname, "at index", index) print(e) qf = self.quality[index] data = {} data["image"] = image data.update(compute_features(image, image_fname, self.features)) data = self.transform(**data) sample = {INPUT_IMAGE_ID_KEY: os.path.basename(self.images[index]), INPUT_IMAGE_QF_KEY: int(qf)} if self.bits is not None: # OK sample[INPUT_TRUE_PAYLOAD_BITS] = torch.tensor(self.bits[index], dtype=torch.float32) if self.targets is not None: target = int(self.targets[index]) sample[INPUT_TRUE_MODIFICATION_TYPE] = target sample[INPUT_TRUE_MODIFICATION_FLAG] = torch.tensor([target > 0]).float() for key, value in data.items(): if key in self.features: sample[key] = tensor_from_rgb_image(value) return sample