Skip to content

Instantly share code, notes, and snippets.

@vkkhare
Created April 1, 2020 07:55
Show Gist options
  • Select an option

  • Save vkkhare/9869be9ee9525fb9574f84d05356e723 to your computer and use it in GitHub Desktop.

Select an option

Save vkkhare/9869be9ee9525fb9574f84d05356e723 to your computer and use it in GitHub Desktop.
pytorch testing module
val script = Module.load("torchscript.pt")
val batchSize = IValue.from(
Tensor.fromBlob(longArrayOf(1), longArrayOf(1, 1))
)
val lr = IValue.from(
Tensor.fromBlob(floatArrayOf(0.01f), longArrayOf(1, 1))
)
val w1 = IValue.from(
Tensor.fromBlob(
FloatArray(392 * 784) { Random.nextFloat() / sqrt(784F) },
longArrayOf(392, 784)
)
)
val b1 = IValue.from(Tensor.fromBlob(FloatArray(392) { 0F }, longArrayOf(1, 392)))
val w2 = IValue.from(
Tensor.fromBlob(
FloatArray(10 * 392) { Random.nextFloat() / sqrt(392F) },
longArrayOf(10, 392)
)
)
val b2 = IValue.from(Tensor.fromBlob(FloatArray(10) { 0F }, longArrayOf(1, 10)))
val x = IValue.from(
Tensor.fromBlob(
FloatArray(784) { Random.nextFloat() / sqrt(784F) },
longArrayOf(1, 784)
)
)
val y = IValue.from(
Tensor.fromBlob(
FloatArray(1) { Random.nextFloat() },
longArrayOf(1, 1)
)
)
val output = script.forward(x, y, batchSize, lr, w1, b1, w2, b2).toTuple()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment