namespace torch { namespace data { template struct Example { D data; L label; }; template struct Example { D data; }; namespace datasets { // can this just be an enum class? namespace access_policy { // Allows next_batch(size_t batch_size) struct Stream {}; // Allows next_batch(ArrayRef indices) struct Random : Stream {}; } // namespace access_policy template struct Map; // Trait class template < typename S, typename B = std::vector>, typename A = access_policy::Random> struct Dataset { using Self = S; using BatchType = B; using AccessPolicy = A; template Map map(Args&&... args) &&; }; // Map template struct MapBase : Dataset, typename T::OutputType, typename S::AccessPolicy> { MapBase(S&& dataset, T&& transform) : dataset(std::move(dataset)), transform(std::move(transform)) {} S dataset; T transform; }; template struct MapImpl; template struct MapImpl : MapBase { using MapBase::MapBase; typename T::OutputType next(size_t count) { return this->transform(this->dataset.next(count)); } }; template struct MapImpl : MapBase { using MapBase::MapBase; typename T::OutputType next(std::vector&& indices) { return this->transform.apply(this->dataset.next(std::move(indices))); } }; template struct Map : MapImpl { using MapImpl::MapImpl; }; // End Map template template Map Dataset::map(Args&&... args) && { // static_assert( // std::is_same::value, // "Batch type of dataset does not match input type of transform"); return {std::move(*static_cast(this)), TransformType(std::forward(args)...)}; } class MNIST : public Dataset { public: explicit MNIST(const std::string& root_path, bool train = true) : data_(100) {} std::vector> next(std::vector&& indices) { std::vector> examples; for (const auto& index : indices) { examples.push_back(data_[index]); } return examples; } size_t size() const noexcept { return data_.size(); } private: std::vector> data_; }; struct RowBatch { size_t count; }; class HiveDataset : public Dataset { public: HiveDataset() = default; RowBatch next(size_t count) { return {count}; } size_t size() const noexcept { return 12345; } }; } // namespace datasets namespace transforms { template struct Transform { using InputType = I; using OutputType = O; }; template struct TensorTransform : Transform>, std::vector>> { virtual ~TensorTransform() = default; virtual Tensor apply(const Tensor& tensor) = 0; Example apply(Example&& batch) const { for (const auto& example : batch) { apply(example.data); } return std::move(batch); } }; struct Normalize : TensorTransform<> { Normalize(double mean, double stddev) : mean(mean), stddev(stddev) {} Tensor apply(const Tensor& tensor) override { return (tensor - mean) / stddev; } template std::vector> apply(std::vector>&& batch) { for (auto& example : batch) { example.data = apply(example.data); } return std::move(batch); } double mean{0}; double stddev{0}; }; } // namespace transforms } // namespace data } // namespace torch