Skip to content

Instantly share code, notes, and snippets.

@ottobricks
Created May 6, 2026 11:56
Show Gist options
  • Select an option

  • Save ottobricks/158ce18e90eef164e1dc80cfd1caecea to your computer and use it in GitHub Desktop.

Select an option

Save ottobricks/158ce18e90eef164e1dc80cfd1caecea to your computer and use it in GitHub Desktop.
pytorch_binary_classification.py
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
def main() -> None:
mock_data = get_mock_data()
train_df, test_df = split_train_test(mock_data)
train_loader = DataLoader(train_df, batch_size=10, shuffle=False)
test_loader = DataLoader(test_df, batch_size=10, shuffle=False)
def transform():
pass
def split_train_test(dataset) -> tuple[DataLoader, DataLoader]:
train_df, test_df = random_split(dataset, [round(0.8*sample_size), round(0.2*sample_size)])
return train_df, test_df
def get_mock_data():
sample_size = 100
mock_data = datasets.FakeData(
size=sample_size,
image_size=(3, 224, 224),
num_classes=2,
transform=transform()
)
return mock_data
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment