dermy-api/src/model.rs

69 lines
1.8 KiB
Rust

use axum::{
extract::Path,
body::{Body, Bytes},
response::{IntoResponse, Json, Response}
};
use crate::{
ApiResult,
account::auth
};
use http::StatusCode;
use pyo3::{
prelude::*,
types::PyTuple
};
use serde::{Serialize, Deserialize};
const MODULE: &str = include_str!("model/model.py");
const MODEL: &[u8] = include_bytes!("model/model.keras");
pub fn router() -> axum::Router {
axum::Router::new()
.route("/:api",
axum::routing::post(get_predict)
)
}
#[derive(Serialize, Deserialize)]
struct Prediction {
benign_probability: f64,
malignant_probability: f64,
}
async fn get_predict(Path(api): Path<String>, body: Bytes) -> ApiResult {
match auth(&api).await? {
Some(_user) => {
let prediction = predict(body).await?;
Ok(Response::builder()
.status(StatusCode::OK)
.body(Json(prediction).into_response().into_body())?)
},
None => {
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Prediction could not be completed"))?)
},
}
}
async fn predict(image: Bytes) -> PyResult<Prediction> {
Python::with_gil(|py| -> PyResult<Prediction> {
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
.getattr("predict")?.into();
let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?;
let benign_probability: f64 = results.get_item(0)?.extract()?;
let malignant_probability: f64 = results.get_item(1)?.extract()?;
Ok(Prediction {
benign_probability,
malignant_probability,
})
})
}