dermy-api/src/model.rs

69 lines
1.8 KiB
Rust
Raw Normal View History

2024-06-08 15:36:06 +00:00
use axum::{
2024-07-11 10:42:56 +00:00
extract::Path,
2024-06-10 17:43:24 +00:00
body::{Body, Bytes},
response::{IntoResponse, Json, Response}
2024-06-08 15:36:06 +00:00
};
2024-06-10 17:43:24 +00:00
use crate::{
ApiResult,
account::auth
};
2024-07-11 10:42:56 +00:00
use http::StatusCode;
2024-06-10 17:43:24 +00:00
use pyo3::{
prelude::*,
types::PyTuple
};
use serde::{Serialize, Deserialize};
2024-06-08 15:36:06 +00:00
const MODULE: &str = include_str!("model/model.py");
const MODEL: &[u8] = include_bytes!("model/model.keras");
2024-06-06 18:21:12 +00:00
pub fn router() -> axum::Router {
axum::Router::new()
2024-07-11 10:42:56 +00:00
.route("/:api",
2024-06-06 18:21:12 +00:00
axum::routing::post(get_predict)
)
}
2024-06-10 17:43:24 +00:00
#[derive(Serialize, Deserialize)]
struct Prediction {
benign_probability: f64,
malignant_probability: f64,
}
2024-07-11 10:42:56 +00:00
async fn get_predict(Path(api): Path<String>, body: Bytes) -> ApiResult {
match auth(&api).await? {
2024-06-10 17:43:24 +00:00
Some(_user) => {
let prediction = predict(body).await?;
Ok(Response::builder()
.status(StatusCode::OK)
.body(Json(prediction).into_response().into_body())?)
},
None => {
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Prediction could not be completed"))?)
},
}
}
2024-06-09 17:22:13 +00:00
2024-06-10 17:43:24 +00:00
async fn predict(image: Bytes) -> PyResult<Prediction> {
Python::with_gil(|py| -> PyResult<Prediction> {
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
.getattr("predict")?.into();
let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?;
let benign_probability: f64 = results.get_item(0)?.extract()?;
let malignant_probability: f64 = results.get_item(1)?.extract()?;
2024-06-08 15:36:06 +00:00
2024-06-10 17:43:24 +00:00
Ok(Prediction {
benign_probability,
malignant_probability,
})
})
2024-06-08 15:36:06 +00:00
}