dermy-app/app/src/main/java/xyz/r0r5chach/dermy/models/Model.java

68 lines
2.3 KiB
Java

package xyz.r0r5chach.dermy.models;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import org.tensorflow.lite.Interpreter;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
public class Model {
protected final Interpreter modelInterpreter;
protected final String modelPath;
public Model(AssetManager assetManager, String modelPath) throws IOException {
modelInterpreter = new Interpreter(loadModel(assetManager));
this.modelPath = modelPath;
}
public float[] runInference(Bitmap image, int classes) {
ByteBuffer inputBuffer = preprocessImage(image);
float[][] output = new float[1][classes];
modelInterpreter.run(inputBuffer, output);
return output[0];
}
public MappedByteBuffer loadModel(AssetManager assetManager) throws IOException {
AssetFileDescriptor fileDescriptor = assetManager.openFd(modelPath);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); //FIXME: add try-with-resources
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(MapMode.READ_ONLY, startOffset, declaredLength);
}
public static ByteBuffer preprocessImage(Bitmap image) {
ByteBuffer buffer = ByteBuffer.allocateDirect(4 * 224 * 224 * 3);
buffer.order(ByteOrder.nativeOrder());
int[] intValues = new int[244 * 244];
image.getPixels(intValues, 0, image.getWidth(), 0, 0, image.getWidth(), image.getHeight());
int pixel = 0;
for (int i = 0; i < 224; ++i) {
for (int j = 0; j < 224; ++j) {
int val = intValues[pixel++];
buffer.putFloat((((val >> 16) & 0xFF) - 127) / 128.0f);
buffer.putFloat((((val >> 8) & 0xFF) - 127) / 128.0f);
buffer.putFloat(((val & 0xFF) - 127) / 128.0f);
}
}
return buffer;
}
}