updated inference code

This commit is contained in:
Joshua Perry 2024-06-09 18:22:13 +01:00
parent 1ee1f001f8
commit 49ff2be8bc
2 changed files with 25 additions and 7 deletions

View File

@ -21,10 +21,10 @@ async fn get_predict() -> ApiResult {
let result = Python::with_gil(|py| -> PyResult<()> { let result = Python::with_gil(|py| -> PyResult<()> {
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")? let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
.getattr("predict")?.into(); .getattr("predict")?.into();
//TODO: Return result
//TODO: Get result and return it
Ok(()) Ok(())
}); });
//TODO: Return results
Ok(Response::builder().body(Body::from(""))?) Ok(Response::builder().body(Body::from(""))?)
} }

View File

@ -1,12 +1,30 @@
import io import io
import tensorflow as tf import numpy as np
import os
from PIL import Image from PIL import Image
import tensorflow as tf
MODEL_NAME = "model.keras"
def preprocess_image_bytes(bytes):
image = Image.open(io.BytesIO(bytes))
image = np.array(image)
image = np.expand_dims(image, axis=0)
return image
def process_model_bytes(bytes):
open(MODEL_NAME, "wb").write(bytes)
model = tf.keras.models.load_model(MODEL_NAME)
os.remove(MODEL_NAME)
return model
def predict(model_bytes, image_bytes): def predict(model_bytes, image_bytes):
model_file = io.BytesIO(model_bytes) image = preprocess_image_bytes(image_bytes)
model = tf.keras.models.load(model_file) model = process_model_bytes(model_bytes)
image = Image.open(io.BytesIO(image_bytes))
return model.predict(image) return model.predict(image)