Skip to content

Instantly share code, notes, and snippets.

@pengyuan-zhou
Created August 23, 2020 19:51
Show Gist options
  • Select an option

  • Save pengyuan-zhou/79a12d67c5578a82c942f90c8de99c77 to your computer and use it in GitHub Desktop.

Select an option

Save pengyuan-zhou/79a12d67c5578a82c942f90c8de99c77 to your computer and use it in GitHub Desktop.

Revisions

  1. pengyuan-zhou created this gist Aug 23, 2020.
    74 changes: 74 additions & 0 deletions LocalAllinoneDataSource.kt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,74 @@
    package org.openmined.syft.demo.federated.datasource

    import android.content.res.Resources
    import org.openmined.syft.demo.R
    import org.openmined.syft.demo.federated.domain.Batch
    import java.io.BufferedReader
    import java.io.InputStreamReader

    private var DATASIZE: Int? = null
    private var LABELSIZE: Int? = null

    class LocalAllinoneDataSource constructor(
    private val resources: Resources
    ) {
    private var dataReader = returnDataReader()
    var numLabel = 1

    fun loadDataBatch(batchSize: Int): Pair<Batch, Batch> {
    val trainInput = arrayListOf<List<Float>>()
    val labels = arrayListOf<List<Float>>()
    for (idx in 0..batchSize)
    readSample(trainInput, labels)

    DATASIZE = trainInput[0].size
    LABELSIZE = labels[0].size

    val trainingData = Batch(
    trainInput.flatten().toFloatArray(),
    longArrayOf(trainInput.size.toLong(), DATASIZE!!.toLong())
    )
    val trainingLabel = Batch(
    labels.flatten().toFloatArray(),
    longArrayOf(labels.size.toLong(), LABELSIZE!!.toLong())
    )
    return Pair(trainingData, trainingLabel)
    }

    private fun readSample(
    trainInput: ArrayList<List<Float>>,
    labels: ArrayList<List<Float>>
    ) {
    val sample = readLine()

    trainInput.add(
    sample.drop(numLabel).map { it.trim().toFloat() }
    )
    labels.add(
    sample.take(numLabel).map { it.trim().toFloat() }
    )
    }

    private fun readLine(): List<String> {
    var x = dataReader.readLine()?.split(",")
    if (x == null) {
    restartReader()
    x = dataReader.readLine()?.split(",")
    }
    if (x == null)
    throw Exception("cannot read from dataset file")
    return x
    }

    private fun restartReader() {
    dataReader.close()
    dataReader = returnDataReader()
    }


    private fun returnDataReader() = BufferedReader(
    InputStreamReader(
    resources.openRawResource(R.raw.allinone)
    )
    )
    }