use axum::{ body::{Body, Bytes}, response::{IntoResponse, Json, Response} }; use crate::{ ApiResult, account::auth }; use http::{HeaderMap, StatusCode}; use pyo3::{ prelude::*, types::PyTuple }; use serde::{Serialize, Deserialize}; const MODULE: &str = include_str!("model/model.py"); const MODEL: &[u8] = include_bytes!("model/model.keras"); pub fn router() -> axum::Router { axum::Router::new() .route("/", axum::routing::post(get_predict) ) } #[derive(Serialize, Deserialize)] struct Prediction { benign_probability: f64, malignant_probability: f64, } async fn get_predict(headers: HeaderMap, body: Bytes) -> ApiResult { let api = headers["api_key"].to_str()?; match auth(api).await? { Some(_user) => { let prediction = predict(body).await?; Ok(Response::builder() .status(StatusCode::OK) .body(Json(prediction).into_response().into_body())?) }, None => { Ok(Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from("Prediction could not be completed"))?) }, } } async fn predict(image: Bytes) -> PyResult { Python::with_gil(|py| -> PyResult { let predict: Py = PyModule::from_code_bound(py, MODULE, "model.py", "model")? .getattr("predict")?.into(); let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?; let benign_probability: f64 = results.get_item(0)?.extract()?; let malignant_probability: f64 = results.get_item(1)?.extract()?; Ok(Prediction { benign_probability, malignant_probability, }) }) }