From cda66aa56bec51e4b31b6e16761e5b69ae95dcc3 Mon Sep 17 00:00:00 2001 From: r0r-5chach Date: Mon, 10 Jun 2024 18:43:24 +0100 Subject: [PATCH] prediction route complete --- Cargo.lock | 1 + Cargo.toml | 1 + src/account.rs | 40 +++++++++++++++++--------- src/account/db.rs | 5 +++- src/lib.rs | 6 ++-- src/model.rs | 71 +++++++++++++++++++++++++++++++++++----------- src/model/model.py | 3 +- 7 files changed, 92 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6ee8503..62a7e21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -368,6 +368,7 @@ dependencies = [ "mongodb", "pyo3", "serde", + "serde_json", "tokio", "vrd 0.0.7", ] diff --git a/Cargo.toml b/Cargo.toml index 4b0ff45..12159fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,5 +13,6 @@ http = "1.1.0" mongodb = { version = "2.8.2", features = ["bson-chrono-0_4", "tokio-runtime"]} pyo3 = { version = "0.21.2", features = ["auto-initialize"]} serde = "1.0.203" +serde_json = "1.0.117" tokio = "1.38.0" vrd = "0.0.7" diff --git a/src/account.rs b/src/account.rs index a1c5215..c95863b 100644 --- a/src/account.rs +++ b/src/account.rs @@ -1,17 +1,14 @@ pub mod db; +use anyhow::Result; use axum::{ - body::Body, - Json, - response::Response, - Router, - routing::{get, post} + body::Body, response::{IntoResponse, Response}, routing::{get, post}, Json, Router }; use crate::ApiResult; use db::{get_users, User}; use http::{header::HeaderMap, StatusCode}; use mongodb::bson::{doc, oid::ObjectId}; - +use serde_json::json; pub fn router() -> Router { Router::new() @@ -31,7 +28,19 @@ pub fn router() -> Router { ) } -pub async fn get_sign_in(Json(body): Json) -> ApiResult { +pub async fn auth(api: &str) -> Result> { + let db = get_users().await?; + + let query = doc! { + "$expr": { "$eq": ["$_auth._api", ObjectId::parse_str(api)?] } + }; + + Ok(db.find_one(query, None).await?) + +} + + +async fn get_sign_in(Json(body): Json) -> ApiResult { let db = get_users().await?; let query = doc! { "$expr": { "$eq": ["$username", body.username] } @@ -41,7 +50,9 @@ pub async fn get_sign_in(Json(body): Json) -> ApiResult { Some(user) => { Ok(Response::builder() .status(StatusCode::CREATED) - .body(Body::from(user.auth.unwrap().salt.unwrap()))?) + .body(Json(json!({ + "_salt": user.auth.unwrap().salt.unwrap() + })).into_response().into_body())?) }, None => { Ok(Response::builder() @@ -51,7 +62,7 @@ pub async fn get_sign_in(Json(body): Json) -> ApiResult { } } -pub async fn post_sign_in(Json(body): Json) -> ApiResult { +async fn post_sign_in(Json(body): Json) -> ApiResult { let db = get_users().await?; let api = ObjectId::new(); let query = doc! { @@ -74,7 +85,9 @@ pub async fn post_sign_in(Json(body): Json) -> ApiResult { Ok(Response::builder() .status(StatusCode::OK) - .body(Body::from(api.to_string()))?) + .body(Json(json!({ + "_api_key": api.to_string() + })).into_response().into_body())?) }, None => { Ok(Response::builder() @@ -84,7 +97,7 @@ pub async fn post_sign_in(Json(body): Json) -> ApiResult { } } -pub async fn post_sign_up(Json(body): Json) -> ApiResult { +async fn post_sign_up(Json(body): Json) -> ApiResult { let db = get_users().await?; let auth = body.clone().auth.unwrap_or_default(); let query = doc! { @@ -114,7 +127,7 @@ pub async fn post_sign_up(Json(body): Json) -> ApiResult { } -pub async fn post_sign_out(headers: HeaderMap) -> ApiResult { +async fn post_sign_out(headers: HeaderMap) -> ApiResult { let db = get_users().await?; let api = headers["api_key"].to_str()?; @@ -127,7 +140,7 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult { } }; - match db.find_one(query.clone(), None).await? { + match auth(api).await? { Some(_user) => { db.update_one(query, update, None).await?; Ok(Response::builder() @@ -141,7 +154,6 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult { }, } } - pub async fn post_backup() {} //TODO: Backup pub async fn get_restore() {} //TODO: restore diff --git a/src/account/db.rs b/src/account/db.rs index 81f6b5a..184646c 100644 --- a/src/account/db.rs +++ b/src/account/db.rs @@ -7,7 +7,10 @@ use mongodb::{ pub async fn get_db_client() -> Result { let username = std::env::var("DB_USERNAME")?; let password = std::env::var("DB_PASSWORD")?; - Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@db:27017")).await?) + let host = std::env::var("DB_HOST")?; + let port = std::env::var("DB_PORT")?; + + Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@{host}:{port}")).await?) } pub async fn get_database() -> Result { diff --git a/src/lib.rs b/src/lib.rs index 07d1814..aaea8d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,8 +13,8 @@ pub type ApiResult = Result; pub async fn run() -> Result<()> { let app = router(); - - let listener = tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap(); + let port = std::env::var("API_PORT")?; + let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}")).await.unwrap(); axum::serve(listener, app).await.unwrap(); @@ -28,7 +28,7 @@ fn router() -> axum::Router { .nest("/predict", model::router()) } -struct AppError(anyhow::Error); +pub struct AppError(anyhow::Error); impl IntoResponse for AppError { fn into_response(self) -> Response { diff --git a/src/model.rs b/src/model.rs index 0a66f00..a0538b1 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,9 +1,18 @@ use axum::{ - body::Body, - response::Response + body::{Body, Bytes}, + response::{IntoResponse, Json, Response} }; -use crate::ApiResult; -use pyo3::prelude::*; +use crate::{ + ApiResult, + account::auth +}; +use http::{HeaderMap, StatusCode}; +use pyo3::{ + prelude::*, + types::PyTuple + +}; +use serde::{Serialize, Deserialize}; const MODULE: &str = include_str!("model/model.py"); const MODEL: &[u8] = include_bytes!("model/model.keras"); @@ -15,16 +24,46 @@ pub fn router() -> axum::Router { ) } -async fn get_predict() -> ApiResult { - //TODO: If api key is correct - // Extract image from body of req - let result = Python::with_gil(|py| -> PyResult<()> { - let predict: Py = PyModule::from_code_bound(py, MODULE, "model.py", "model")? - .getattr("predict")?.into(); - - //TODO: Get result and return it - Ok(()) - }); - - Ok(Response::builder().body(Body::from(""))?) +#[derive(Serialize, Deserialize)] +struct Prediction { + benign_probability: f64, + malignant_probability: f64, +} + +async fn get_predict(headers: HeaderMap, body: Bytes) -> ApiResult { + let api = headers["api_key"].to_str()?; + + match auth(api).await? { + Some(_user) => { + let prediction = predict(body).await?; + + Ok(Response::builder() + .status(StatusCode::OK) + .body(Json(prediction).into_response().into_body())?) + + }, + None => { + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from("Prediction could not be completed"))?) + }, + } + + +} + +async fn predict(image: Bytes) -> PyResult { + Python::with_gil(|py| -> PyResult { + let predict: Py = PyModule::from_code_bound(py, MODULE, "model.py", "model")? + .getattr("predict")?.into(); + + let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?; + let benign_probability: f64 = results.get_item(0)?.extract()?; + let malignant_probability: f64 = results.get_item(1)?.extract()?; + + Ok(Prediction { + benign_probability, + malignant_probability, + }) + }) } diff --git a/src/model/model.py b/src/model/model.py index b6d751d..9caf123 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -26,5 +26,6 @@ def process_model_bytes(bytes): def predict(model_bytes, image_bytes): image = preprocess_image_bytes(image_bytes) model = process_model_bytes(model_bytes) + result = model.predict(image) - return model.predict(image) + return result[0], result[1]