package xyz.r0r5chach.dermy.models; import android.content.res.AssetFileDescriptor; import android.content.res.AssetManager; import android.graphics.Bitmap; import org.tensorflow.lite.Interpreter; import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.channels.FileChannel.MapMode; public class Model { protected final Interpreter modelInterpreter; protected final String modelPath; public Model(AssetManager assetManager, String modelPath) throws IOException { modelInterpreter = new Interpreter(loadModel(assetManager)); this.modelPath = modelPath; } public float[] runInference(Bitmap image, int classes) { ByteBuffer inputBuffer = preprocessImage(image); float[][] output = new float[1][classes]; modelInterpreter.run(inputBuffer, output); return output[0]; } public MappedByteBuffer loadModel(AssetManager assetManager) throws IOException { AssetFileDescriptor fileDescriptor = assetManager.openFd(modelPath); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); //FIXME: add try-with-resources FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(MapMode.READ_ONLY, startOffset, declaredLength); } public static ByteBuffer preprocessImage(Bitmap image) { ByteBuffer buffer = ByteBuffer.allocateDirect(4 * 224 * 224 * 3); buffer.order(ByteOrder.nativeOrder()); int[] intValues = new int[244 * 244]; image.getPixels(intValues, 0, image.getWidth(), 0, 0, image.getWidth(), image.getHeight()); int pixel = 0; for (int i = 0; i < 224; ++i) { for (int j = 0; j < 224; ++j) { int val = intValues[pixel++]; buffer.putFloat((((val >> 16) & 0xFF) - 127) / 128.0f); buffer.putFloat((((val >> 8) & 0xFF) - 127) / 128.0f); buffer.putFloat(((val & 0xFF) - 127) / 128.0f); } } return buffer; } }