prediction route complete

This commit is contained in:
Joshua Perry 2024-06-10 18:43:24 +01:00
parent 49ff2be8bc
commit cda66aa56b
7 changed files with 92 additions and 35 deletions

1
Cargo.lock generated
View File

@ -368,6 +368,7 @@ dependencies = [
"mongodb", "mongodb",
"pyo3", "pyo3",
"serde", "serde",
"serde_json",
"tokio", "tokio",
"vrd 0.0.7", "vrd 0.0.7",
] ]

View File

@ -13,5 +13,6 @@ http = "1.1.0"
mongodb = { version = "2.8.2", features = ["bson-chrono-0_4", "tokio-runtime"]} mongodb = { version = "2.8.2", features = ["bson-chrono-0_4", "tokio-runtime"]}
pyo3 = { version = "0.21.2", features = ["auto-initialize"]} pyo3 = { version = "0.21.2", features = ["auto-initialize"]}
serde = "1.0.203" serde = "1.0.203"
serde_json = "1.0.117"
tokio = "1.38.0" tokio = "1.38.0"
vrd = "0.0.7" vrd = "0.0.7"

View File

@ -1,17 +1,14 @@
pub mod db; pub mod db;
use anyhow::Result;
use axum::{ use axum::{
body::Body, body::Body, response::{IntoResponse, Response}, routing::{get, post}, Json, Router
Json,
response::Response,
Router,
routing::{get, post}
}; };
use crate::ApiResult; use crate::ApiResult;
use db::{get_users, User}; use db::{get_users, User};
use http::{header::HeaderMap, StatusCode}; use http::{header::HeaderMap, StatusCode};
use mongodb::bson::{doc, oid::ObjectId}; use mongodb::bson::{doc, oid::ObjectId};
use serde_json::json;
pub fn router() -> Router { pub fn router() -> Router {
Router::new() Router::new()
@ -31,7 +28,19 @@ pub fn router() -> Router {
) )
} }
pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult { pub async fn auth(api: &str) -> Result<Option<User>> {
let db = get_users().await?;
let query = doc! {
"$expr": { "$eq": ["$_auth._api", ObjectId::parse_str(api)?] }
};
Ok(db.find_one(query, None).await?)
}
async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
let db = get_users().await?; let db = get_users().await?;
let query = doc! { let query = doc! {
"$expr": { "$eq": ["$username", body.username] } "$expr": { "$eq": ["$username", body.username] }
@ -41,7 +50,9 @@ pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
Some(user) => { Some(user) => {
Ok(Response::builder() Ok(Response::builder()
.status(StatusCode::CREATED) .status(StatusCode::CREATED)
.body(Body::from(user.auth.unwrap().salt.unwrap()))?) .body(Json(json!({
"_salt": user.auth.unwrap().salt.unwrap()
})).into_response().into_body())?)
}, },
None => { None => {
Ok(Response::builder() Ok(Response::builder()
@ -51,7 +62,7 @@ pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
} }
} }
pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult { async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
let db = get_users().await?; let db = get_users().await?;
let api = ObjectId::new(); let api = ObjectId::new();
let query = doc! { let query = doc! {
@ -74,7 +85,9 @@ pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
Ok(Response::builder() Ok(Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)
.body(Body::from(api.to_string()))?) .body(Json(json!({
"_api_key": api.to_string()
})).into_response().into_body())?)
}, },
None => { None => {
Ok(Response::builder() Ok(Response::builder()
@ -84,7 +97,7 @@ pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
} }
} }
pub async fn post_sign_up(Json(body): Json<User>) -> ApiResult { async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
let db = get_users().await?; let db = get_users().await?;
let auth = body.clone().auth.unwrap_or_default(); let auth = body.clone().auth.unwrap_or_default();
let query = doc! { let query = doc! {
@ -114,7 +127,7 @@ pub async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
} }
pub async fn post_sign_out(headers: HeaderMap) -> ApiResult { async fn post_sign_out(headers: HeaderMap) -> ApiResult {
let db = get_users().await?; let db = get_users().await?;
let api = headers["api_key"].to_str()?; let api = headers["api_key"].to_str()?;
@ -127,7 +140,7 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
} }
}; };
match db.find_one(query.clone(), None).await? { match auth(api).await? {
Some(_user) => { Some(_user) => {
db.update_one(query, update, None).await?; db.update_one(query, update, None).await?;
Ok(Response::builder() Ok(Response::builder()
@ -141,7 +154,6 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
}, },
} }
} }
pub async fn post_backup() {} //TODO: Backup pub async fn post_backup() {} //TODO: Backup
pub async fn get_restore() {} //TODO: restore pub async fn get_restore() {} //TODO: restore

View File

@ -7,7 +7,10 @@ use mongodb::{
pub async fn get_db_client() -> Result<Client> { pub async fn get_db_client() -> Result<Client> {
let username = std::env::var("DB_USERNAME")?; let username = std::env::var("DB_USERNAME")?;
let password = std::env::var("DB_PASSWORD")?; let password = std::env::var("DB_PASSWORD")?;
Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@db:27017")).await?) let host = std::env::var("DB_HOST")?;
let port = std::env::var("DB_PORT")?;
Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@{host}:{port}")).await?)
} }
pub async fn get_database() -> Result<Database> { pub async fn get_database() -> Result<Database> {

View File

@ -13,8 +13,8 @@ pub type ApiResult = Result<Response, AppError>;
pub async fn run() -> Result<()> { pub async fn run() -> Result<()> {
let app = router(); let app = router();
let port = std::env::var("API_PORT")?;
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap(); let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}")).await.unwrap();
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app).await.unwrap();
@ -28,7 +28,7 @@ fn router() -> axum::Router {
.nest("/predict", model::router()) .nest("/predict", model::router())
} }
struct AppError(anyhow::Error); pub struct AppError(anyhow::Error);
impl IntoResponse for AppError { impl IntoResponse for AppError {
fn into_response(self) -> Response { fn into_response(self) -> Response {

View File

@ -1,9 +1,18 @@
use axum::{ use axum::{
body::Body, body::{Body, Bytes},
response::Response response::{IntoResponse, Json, Response}
}; };
use crate::ApiResult; use crate::{
use pyo3::prelude::*; ApiResult,
account::auth
};
use http::{HeaderMap, StatusCode};
use pyo3::{
prelude::*,
types::PyTuple
};
use serde::{Serialize, Deserialize};
const MODULE: &str = include_str!("model/model.py"); const MODULE: &str = include_str!("model/model.py");
const MODEL: &[u8] = include_bytes!("model/model.keras"); const MODEL: &[u8] = include_bytes!("model/model.keras");
@ -15,16 +24,46 @@ pub fn router() -> axum::Router {
) )
} }
async fn get_predict() -> ApiResult { #[derive(Serialize, Deserialize)]
//TODO: If api key is correct struct Prediction {
// Extract image from body of req benign_probability: f64,
let result = Python::with_gil(|py| -> PyResult<()> { malignant_probability: f64,
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")? }
.getattr("predict")?.into();
async fn get_predict(headers: HeaderMap, body: Bytes) -> ApiResult {
//TODO: Get result and return it let api = headers["api_key"].to_str()?;
Ok(())
}); match auth(api).await? {
Some(_user) => {
Ok(Response::builder().body(Body::from(""))?) let prediction = predict(body).await?;
Ok(Response::builder()
.status(StatusCode::OK)
.body(Json(prediction).into_response().into_body())?)
},
None => {
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Prediction could not be completed"))?)
},
}
}
async fn predict(image: Bytes) -> PyResult<Prediction> {
Python::with_gil(|py| -> PyResult<Prediction> {
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
.getattr("predict")?.into();
let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?;
let benign_probability: f64 = results.get_item(0)?.extract()?;
let malignant_probability: f64 = results.get_item(1)?.extract()?;
Ok(Prediction {
benign_probability,
malignant_probability,
})
})
} }

View File

@ -26,5 +26,6 @@ def process_model_bytes(bytes):
def predict(model_bytes, image_bytes): def predict(model_bytes, image_bytes):
image = preprocess_image_bytes(image_bytes) image = preprocess_image_bytes(image_bytes)
model = process_model_bytes(model_bytes) model = process_model_bytes(model_bytes)
result = model.predict(image)
return model.predict(image) return result[0], result[1]