Skip to content

Instantly share code, notes, and snippets.

@ProAek11
Forked from witchapong/pytorch_data.py
Created March 10, 2020 18:18
Show Gist options
  • Select an option

  • Save ProAek11/2f53f0ca0fd393e1a002eb6ae438cf83 to your computer and use it in GitHub Desktop.

Select an option

Save ProAek11/2f53f0ca0fd393e1a002eb6ae438cf83 to your computer and use it in GitHub Desktop.
class StockDataset(Dataset):
def __init__(self, df_seq, feat_num, seq_len, target_len, df_cat):
# SEQUENTIAL PART
self.df_seq = df_seq.iloc[:,:-target_len]
self.df_cat = df_cat
self.target = df_seq.iloc[:,-target_len:]
def __getitem__(self, index):
return(torch.tensor(self.df_seq.iloc[index].values.reshape(seq_len,feat_num), dtype=torch.float, device=device),
torch.tensor(self.df_cat.iloc[index], dtype=torch.long, device=device),
torch.tensor(self.target.iloc[index], dtype=torch.float, device=device))
def __len__(self):
return(self.df_seq.shape[0])
feat_num = 6
seq_len = 30
target_len = 5
bs = 32
train_ds = StockDataset(df_seq_train, feat_num, seq_len, target_len, df_dt_feat_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=False)
test_ds = StockDataset(df_seq_test, feat_num, seq_len, target_len, df_dt_feat_test)
test_dl = DataLoader(test_ds, batch_size=len(test_ds), shuffle=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment