Skip to content

Instantly share code, notes, and snippets.

@vkkhare
Last active March 4, 2020 10:42
Show Gist options
  • Select an option

  • Save vkkhare/2dd7a824c506c4dc64ebe041c7879416 to your computer and use it in GitHub Desktop.

Select an option

Save vkkhare/2dd7a824c506c4dc64ebe041c7879416 to your computer and use it in GitHub Desktop.
testing pytorch mobile for training on android
package org.openmined.KotlinSyft
import android.graphics.Bitmap
import org.pytorch.Tensor
import org.pytorch.Module
import org.pytorch.IValue
import org.pytorch.torchvision.TensorImageUtils
import java.io.File
class Classifier (forwardPath: String,backwardPath:String,x:String,y:String){
var forward_model: Module = Module.load(forwardPath)
var backward_model: Module = Module.load(backwardPath)
val train_input = arrayListOf(FloatArray(10))
val train_output = arrayListOf<Float>()
init {
File(x).forEachLine{ line ->
train_input.add(line.split(",").map{ it.toFloat() }.toFloatArray())
}
File(y).forEachLine{ line ->
train_output.add(line.toFloat())
}
}
var mean = floatArrayOf(0.485f, 0.456f, 0.406f)
var std = floatArrayOf(0.229f, 0.224f, 0.225f)
fun setMeanAndStd(mean: FloatArray, std: FloatArray) {
this.mean = mean
this.std = std
}
private fun preprocess(bitmap: Bitmap, size: Int): Tensor {
var bitmap = bitmap
bitmap = Bitmap.createScaledBitmap(bitmap, size, size, false)
return TensorImageUtils.bitmapToFloat32Tensor(bitmap, this.mean, this.std)
}
private fun argMax(inputs: FloatArray): Int {
var maxIndex = -1
var maxvalue = 0.0f
for (i in inputs.indices) {
if (inputs[i] > maxvalue) {
maxIndex = i
maxvalue = inputs[i]
}
}
return maxIndex
}
fun train() {
var w = IValue.from(Tensor.fromBlob(FloatArray(10){1.0f}, longArrayOf(10)))
var loss :IValue
for (epoch in 1..500) {
var loss_print : Float = 0.0f
for (i in 0..500) {
val x = Tensor.fromBlob(train_input[i], longArrayOf(10))
val y = Tensor.fromBlob(floatArrayOf(train_output[i]), longArrayOf(1))
loss = forward_model.forward(IValue.from(x), IValue.from(y), w)
loss_print = loss.toTensor().dataAsFloatArray[0]
w = backward_model.forward(w, loss, IValue.from(x))
}
println("loss $loss_print")
}
}
}
package org.openmined.KotlinSyft
import android.os.Bundle
import com.google.android.material.snackbar.Snackbar
import androidx.appcompat.app.AppCompatActivity
import android.view.Menu
import android.view.MenuItem
import kotlinx.android.synthetic.main.activity_main.*
import android.view.View
class MainActivity : AppCompatActivity() {
private val cameraRequestCode = 1
private lateinit var classifier : Classifier
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
setSupportActionBar(toolbar)
classifier = Classifier(
Utils.assetFilePath(this, "forward_model.pt")!!,
Utils.assetFilePath(this, "backward_model.pt")!!,
Utils.assetFilePath(this, "train.txt")!!,
Utils.assetFilePath(this, "output.txt")!!
)
val capture:View = findViewById(R.id.capture)
capture.setOnClickListener {
classifier.train()
}
fab.setOnClickListener { view ->
Snackbar.make(view, "Replace with your own action", Snackbar.LENGTH_LONG)
.setAction("Action", null).show()
}
}
override fun onCreateOptionsMenu(menu: Menu): Boolean {
// Inflate the menu; this adds items to the action bar if it is present.
menuInflater.inflate(R.menu.menu_main, menu)
return true
}
override fun onOptionsItemSelected(item: MenuItem): Boolean {
// Handle action bar item clicks here. The action bar will
// automatically handle clicks on the Home/Up button, so long
// as you specify a parent activity in AndroidManifest.xml.
return when (item.itemId) {
R.id.action_settings -> true
else -> super.onOptionsItemSelected(item)
}
}
}
package org.openmined.KotlinSyft
import android.content.Context
import android.util.Log
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
object Utils {
fun assetFilePath(context: Context, assetName: String): String? {
val file = File(context.filesDir, assetName)
try {
context.assets.open(assetName).use { Is ->
FileOutputStream(file).use {Is.copyTo(it,4*1024)}
return file.absolutePath
}
} catch (e: IOException) {
Log.e("pytorch android", "Error process asset $assetName to file path")
}
return null
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment