68 lines
2.3 KiB
Java
68 lines
2.3 KiB
Java
package xyz.r0r5chach.dermy.models;
|
|
|
|
import android.content.res.AssetFileDescriptor;
|
|
import android.content.res.AssetManager;
|
|
import android.graphics.Bitmap;
|
|
|
|
import org.tensorflow.lite.Interpreter;
|
|
|
|
|
|
import java.io.FileInputStream;
|
|
import java.io.IOException;
|
|
import java.nio.ByteBuffer;
|
|
import java.nio.ByteOrder;
|
|
import java.nio.MappedByteBuffer;
|
|
import java.nio.channels.FileChannel;
|
|
import java.nio.channels.FileChannel.MapMode;
|
|
|
|
public class Model {
|
|
protected final Interpreter modelInterpreter;
|
|
protected final String modelPath;
|
|
|
|
public Model(AssetManager assetManager, String modelPath) throws IOException {
|
|
modelInterpreter = new Interpreter(loadModel(assetManager));
|
|
this.modelPath = modelPath;
|
|
}
|
|
|
|
public float[] runInference(Bitmap image, int classes) {
|
|
ByteBuffer inputBuffer = preprocessImage(image);
|
|
float[][] output = new float[1][classes];
|
|
modelInterpreter.run(inputBuffer, output);
|
|
|
|
return output[0];
|
|
}
|
|
|
|
public MappedByteBuffer loadModel(AssetManager assetManager) throws IOException {
|
|
AssetFileDescriptor fileDescriptor = assetManager.openFd(modelPath);
|
|
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); //FIXME: add try-with-resources
|
|
|
|
FileChannel fileChannel = inputStream.getChannel();
|
|
long startOffset = fileDescriptor.getStartOffset();
|
|
long declaredLength = fileDescriptor.getDeclaredLength();
|
|
|
|
return fileChannel.map(MapMode.READ_ONLY, startOffset, declaredLength);
|
|
}
|
|
|
|
public static ByteBuffer preprocessImage(Bitmap image) {
|
|
ByteBuffer buffer = ByteBuffer.allocateDirect(4 * 224 * 224 * 3);
|
|
buffer.order(ByteOrder.nativeOrder());
|
|
|
|
int[] intValues = new int[244 * 244];
|
|
|
|
image.getPixels(intValues, 0, image.getWidth(), 0, 0, image.getWidth(), image.getHeight());
|
|
|
|
int pixel = 0;
|
|
for (int i = 0; i < 224; ++i) {
|
|
for (int j = 0; j < 224; ++j) {
|
|
int val = intValues[pixel++];
|
|
buffer.putFloat((((val >> 16) & 0xFF) - 127) / 128.0f);
|
|
buffer.putFloat((((val >> 8) & 0xFF) - 127) / 128.0f);
|
|
buffer.putFloat(((val & 0xFF) - 127) / 128.0f);
|
|
}
|
|
}
|
|
|
|
return buffer;
|
|
}
|
|
|
|
}
|