2024-06-08 15:36:06 +00:00
|
|
|
use axum::{
|
2024-06-10 17:43:24 +00:00
|
|
|
body::{Body, Bytes},
|
|
|
|
|
response::{IntoResponse, Json, Response}
|
2024-06-08 15:36:06 +00:00
|
|
|
};
|
2024-06-10 17:43:24 +00:00
|
|
|
use crate::{
|
|
|
|
|
ApiResult,
|
|
|
|
|
account::auth
|
|
|
|
|
};
|
|
|
|
|
use http::{HeaderMap, StatusCode};
|
|
|
|
|
use pyo3::{
|
|
|
|
|
prelude::*,
|
|
|
|
|
types::PyTuple
|
|
|
|
|
|
|
|
|
|
};
|
|
|
|
|
use serde::{Serialize, Deserialize};
|
2024-06-08 15:36:06 +00:00
|
|
|
|
|
|
|
|
const MODULE: &str = include_str!("model/model.py");
|
|
|
|
|
const MODEL: &[u8] = include_bytes!("model/model.keras");
|
|
|
|
|
|
2024-06-06 18:21:12 +00:00
|
|
|
pub fn router() -> axum::Router {
|
|
|
|
|
axum::Router::new()
|
|
|
|
|
.route("/",
|
|
|
|
|
axum::routing::post(get_predict)
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
|
2024-06-10 17:43:24 +00:00
|
|
|
#[derive(Serialize, Deserialize)]
|
|
|
|
|
struct Prediction {
|
|
|
|
|
benign_probability: f64,
|
|
|
|
|
malignant_probability: f64,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn get_predict(headers: HeaderMap, body: Bytes) -> ApiResult {
|
|
|
|
|
let api = headers["api_key"].to_str()?;
|
|
|
|
|
|
|
|
|
|
match auth(api).await? {
|
|
|
|
|
Some(_user) => {
|
|
|
|
|
let prediction = predict(body).await?;
|
|
|
|
|
|
|
|
|
|
Ok(Response::builder()
|
|
|
|
|
.status(StatusCode::OK)
|
|
|
|
|
.body(Json(prediction).into_response().into_body())?)
|
|
|
|
|
|
|
|
|
|
},
|
|
|
|
|
None => {
|
|
|
|
|
Ok(Response::builder()
|
|
|
|
|
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
|
|
|
|
.body(Body::from("Prediction could not be completed"))?)
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|
2024-06-09 17:22:13 +00:00
|
|
|
|
2024-06-10 17:43:24 +00:00
|
|
|
async fn predict(image: Bytes) -> PyResult<Prediction> {
|
|
|
|
|
Python::with_gil(|py| -> PyResult<Prediction> {
|
|
|
|
|
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
|
|
|
|
|
.getattr("predict")?.into();
|
|
|
|
|
|
|
|
|
|
let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?;
|
|
|
|
|
let benign_probability: f64 = results.get_item(0)?.extract()?;
|
|
|
|
|
let malignant_probability: f64 = results.get_item(1)?.extract()?;
|
2024-06-08 15:36:06 +00:00
|
|
|
|
2024-06-10 17:43:24 +00:00
|
|
|
Ok(Prediction {
|
|
|
|
|
benign_probability,
|
|
|
|
|
malignant_probability,
|
|
|
|
|
})
|
|
|
|
|
})
|
2024-06-08 15:36:06 +00:00
|
|
|
}
|