use axum::{ extract::Path, body::{Body, Bytes}, response::{IntoResponse, Json, Response} }; use crate::{ ApiResult, account::auth }; use http::StatusCode; use pyo3::{ prelude::*, types::PyTuple }; use serde::{Serialize, Deserialize}; const MODULE: &str = include_str!("model/model.py"); const MODEL: &[u8] = include_bytes!("model/model.keras"); pub fn router() -> axum::Router { axum::Router::new() .route("/:api", axum::routing::post(get_predict) ) } #[derive(Serialize, Deserialize)] struct Prediction { benign_probability: f64, malignant_probability: f64, } async fn get_predict(Path(api): Path, body: Bytes) -> ApiResult { match auth(&api).await? { Some(_user) => { let prediction = predict(body).await?; Ok(Response::builder() .status(StatusCode::OK) .body(Json(prediction).into_response().into_body())?) }, None => { Ok(Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from("Prediction could not be completed"))?) }, } } async fn predict(image: Bytes) -> PyResult { Python::with_gil(|py| -> PyResult { let predict: Py = PyModule::from_code_bound(py, MODULE, "model.py", "model")? .getattr("predict")?.into(); let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?; let benign_probability: f64 = results.get_item(0)?.extract()?; let malignant_probability: f64 = results.get_item(1)?.extract()?; Ok(Prediction { benign_probability, malignant_probability, }) }) }