Last active
March 4, 2020 10:42
-
-
Save vkkhare/2dd7a824c506c4dc64ebe041c7879416 to your computer and use it in GitHub Desktop.
testing pytorch mobile for training on android
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package org.openmined.KotlinSyft | |
| import android.graphics.Bitmap | |
| import org.pytorch.Tensor | |
| import org.pytorch.Module | |
| import org.pytorch.IValue | |
| import org.pytorch.torchvision.TensorImageUtils | |
| import java.io.File | |
| class Classifier (forwardPath: String,backwardPath:String,x:String,y:String){ | |
| var forward_model: Module = Module.load(forwardPath) | |
| var backward_model: Module = Module.load(backwardPath) | |
| val train_input = arrayListOf(FloatArray(10)) | |
| val train_output = arrayListOf<Float>() | |
| init { | |
| File(x).forEachLine{ line -> | |
| train_input.add(line.split(",").map{ it.toFloat() }.toFloatArray()) | |
| } | |
| File(y).forEachLine{ line -> | |
| train_output.add(line.toFloat()) | |
| } | |
| } | |
| var mean = floatArrayOf(0.485f, 0.456f, 0.406f) | |
| var std = floatArrayOf(0.229f, 0.224f, 0.225f) | |
| fun setMeanAndStd(mean: FloatArray, std: FloatArray) { | |
| this.mean = mean | |
| this.std = std | |
| } | |
| private fun preprocess(bitmap: Bitmap, size: Int): Tensor { | |
| var bitmap = bitmap | |
| bitmap = Bitmap.createScaledBitmap(bitmap, size, size, false) | |
| return TensorImageUtils.bitmapToFloat32Tensor(bitmap, this.mean, this.std) | |
| } | |
| private fun argMax(inputs: FloatArray): Int { | |
| var maxIndex = -1 | |
| var maxvalue = 0.0f | |
| for (i in inputs.indices) { | |
| if (inputs[i] > maxvalue) { | |
| maxIndex = i | |
| maxvalue = inputs[i] | |
| } | |
| } | |
| return maxIndex | |
| } | |
| fun train() { | |
| var w = IValue.from(Tensor.fromBlob(FloatArray(10){1.0f}, longArrayOf(10))) | |
| var loss :IValue | |
| for (epoch in 1..500) { | |
| var loss_print : Float = 0.0f | |
| for (i in 0..500) { | |
| val x = Tensor.fromBlob(train_input[i], longArrayOf(10)) | |
| val y = Tensor.fromBlob(floatArrayOf(train_output[i]), longArrayOf(1)) | |
| loss = forward_model.forward(IValue.from(x), IValue.from(y), w) | |
| loss_print = loss.toTensor().dataAsFloatArray[0] | |
| w = backward_model.forward(w, loss, IValue.from(x)) | |
| } | |
| println("loss $loss_print") | |
| } | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package org.openmined.KotlinSyft | |
| import android.os.Bundle | |
| import com.google.android.material.snackbar.Snackbar | |
| import androidx.appcompat.app.AppCompatActivity | |
| import android.view.Menu | |
| import android.view.MenuItem | |
| import kotlinx.android.synthetic.main.activity_main.* | |
| import android.view.View | |
| class MainActivity : AppCompatActivity() { | |
| private val cameraRequestCode = 1 | |
| private lateinit var classifier : Classifier | |
| override fun onCreate(savedInstanceState: Bundle?) { | |
| super.onCreate(savedInstanceState) | |
| setContentView(R.layout.activity_main) | |
| setSupportActionBar(toolbar) | |
| classifier = Classifier( | |
| Utils.assetFilePath(this, "forward_model.pt")!!, | |
| Utils.assetFilePath(this, "backward_model.pt")!!, | |
| Utils.assetFilePath(this, "train.txt")!!, | |
| Utils.assetFilePath(this, "output.txt")!! | |
| ) | |
| val capture:View = findViewById(R.id.capture) | |
| capture.setOnClickListener { | |
| classifier.train() | |
| } | |
| fab.setOnClickListener { view -> | |
| Snackbar.make(view, "Replace with your own action", Snackbar.LENGTH_LONG) | |
| .setAction("Action", null).show() | |
| } | |
| } | |
| override fun onCreateOptionsMenu(menu: Menu): Boolean { | |
| // Inflate the menu; this adds items to the action bar if it is present. | |
| menuInflater.inflate(R.menu.menu_main, menu) | |
| return true | |
| } | |
| override fun onOptionsItemSelected(item: MenuItem): Boolean { | |
| // Handle action bar item clicks here. The action bar will | |
| // automatically handle clicks on the Home/Up button, so long | |
| // as you specify a parent activity in AndroidManifest.xml. | |
| return when (item.itemId) { | |
| R.id.action_settings -> true | |
| else -> super.onOptionsItemSelected(item) | |
| } | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package org.openmined.KotlinSyft | |
| import android.content.Context | |
| import android.util.Log | |
| import java.io.File | |
| import java.io.FileOutputStream | |
| import java.io.IOException | |
| object Utils { | |
| fun assetFilePath(context: Context, assetName: String): String? { | |
| val file = File(context.filesDir, assetName) | |
| try { | |
| context.assets.open(assetName).use { Is -> | |
| FileOutputStream(file).use {Is.copyTo(it,4*1024)} | |
| return file.absolutePath | |
| } | |
| } catch (e: IOException) { | |
| Log.e("pytorch android", "Error process asset $assetName to file path") | |
| } | |
| return null | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment