prediction route complete

This commit is contained in:
Joshua Perry 2024-06-10 18:43:24 +01:00
parent 49ff2be8bc
commit cda66aa56b
7 changed files with 92 additions and 35 deletions

1
Cargo.lock generated
View File

@ -368,6 +368,7 @@ dependencies = [
"mongodb",
"pyo3",
"serde",
"serde_json",
"tokio",
"vrd 0.0.7",
]

View File

@ -13,5 +13,6 @@ http = "1.1.0"
mongodb = { version = "2.8.2", features = ["bson-chrono-0_4", "tokio-runtime"]}
pyo3 = { version = "0.21.2", features = ["auto-initialize"]}
serde = "1.0.203"
serde_json = "1.0.117"
tokio = "1.38.0"
vrd = "0.0.7"

View File

@ -1,17 +1,14 @@
pub mod db;
use anyhow::Result;
use axum::{
body::Body,
Json,
response::Response,
Router,
routing::{get, post}
body::Body, response::{IntoResponse, Response}, routing::{get, post}, Json, Router
};
use crate::ApiResult;
use db::{get_users, User};
use http::{header::HeaderMap, StatusCode};
use mongodb::bson::{doc, oid::ObjectId};
use serde_json::json;
pub fn router() -> Router {
Router::new()
@ -31,7 +28,19 @@ pub fn router() -> Router {
)
}
pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
pub async fn auth(api: &str) -> Result<Option<User>> {
let db = get_users().await?;
let query = doc! {
"$expr": { "$eq": ["$_auth._api", ObjectId::parse_str(api)?] }
};
Ok(db.find_one(query, None).await?)
}
async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
let db = get_users().await?;
let query = doc! {
"$expr": { "$eq": ["$username", body.username] }
@ -41,7 +50,9 @@ pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
Some(user) => {
Ok(Response::builder()
.status(StatusCode::CREATED)
.body(Body::from(user.auth.unwrap().salt.unwrap()))?)
.body(Json(json!({
"_salt": user.auth.unwrap().salt.unwrap()
})).into_response().into_body())?)
},
None => {
Ok(Response::builder()
@ -51,7 +62,7 @@ pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
}
}
pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
let db = get_users().await?;
let api = ObjectId::new();
let query = doc! {
@ -74,7 +85,9 @@ pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::from(api.to_string()))?)
.body(Json(json!({
"_api_key": api.to_string()
})).into_response().into_body())?)
},
None => {
Ok(Response::builder()
@ -84,7 +97,7 @@ pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
}
}
pub async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
let db = get_users().await?;
let auth = body.clone().auth.unwrap_or_default();
let query = doc! {
@ -114,7 +127,7 @@ pub async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
}
pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
async fn post_sign_out(headers: HeaderMap) -> ApiResult {
let db = get_users().await?;
let api = headers["api_key"].to_str()?;
@ -127,7 +140,7 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
}
};
match db.find_one(query.clone(), None).await? {
match auth(api).await? {
Some(_user) => {
db.update_one(query, update, None).await?;
Ok(Response::builder()
@ -141,7 +154,6 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
},
}
}
pub async fn post_backup() {} //TODO: Backup
pub async fn get_restore() {} //TODO: restore

View File

@ -7,7 +7,10 @@ use mongodb::{
pub async fn get_db_client() -> Result<Client> {
let username = std::env::var("DB_USERNAME")?;
let password = std::env::var("DB_PASSWORD")?;
Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@db:27017")).await?)
let host = std::env::var("DB_HOST")?;
let port = std::env::var("DB_PORT")?;
Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@{host}:{port}")).await?)
}
pub async fn get_database() -> Result<Database> {

View File

@ -13,8 +13,8 @@ pub type ApiResult = Result<Response, AppError>;
pub async fn run() -> Result<()> {
let app = router();
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap();
let port = std::env::var("API_PORT")?;
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}")).await.unwrap();
axum::serve(listener, app).await.unwrap();
@ -28,7 +28,7 @@ fn router() -> axum::Router {
.nest("/predict", model::router())
}
struct AppError(anyhow::Error);
pub struct AppError(anyhow::Error);
impl IntoResponse for AppError {
fn into_response(self) -> Response {

View File

@ -1,9 +1,18 @@
use axum::{
body::Body,
response::Response
body::{Body, Bytes},
response::{IntoResponse, Json, Response}
};
use crate::ApiResult;
use pyo3::prelude::*;
use crate::{
ApiResult,
account::auth
};
use http::{HeaderMap, StatusCode};
use pyo3::{
prelude::*,
types::PyTuple
};
use serde::{Serialize, Deserialize};
const MODULE: &str = include_str!("model/model.py");
const MODEL: &[u8] = include_bytes!("model/model.keras");
@ -15,16 +24,46 @@ pub fn router() -> axum::Router {
)
}
async fn get_predict() -> ApiResult {
//TODO: If api key is correct
// Extract image from body of req
let result = Python::with_gil(|py| -> PyResult<()> {
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
.getattr("predict")?.into();
//TODO: Get result and return it
Ok(())
});
Ok(Response::builder().body(Body::from(""))?)
#[derive(Serialize, Deserialize)]
struct Prediction {
benign_probability: f64,
malignant_probability: f64,
}
async fn get_predict(headers: HeaderMap, body: Bytes) -> ApiResult {
let api = headers["api_key"].to_str()?;
match auth(api).await? {
Some(_user) => {
let prediction = predict(body).await?;
Ok(Response::builder()
.status(StatusCode::OK)
.body(Json(prediction).into_response().into_body())?)
},
None => {
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Prediction could not be completed"))?)
},
}
}
async fn predict(image: Bytes) -> PyResult<Prediction> {
Python::with_gil(|py| -> PyResult<Prediction> {
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
.getattr("predict")?.into();
let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?;
let benign_probability: f64 = results.get_item(0)?.extract()?;
let malignant_probability: f64 = results.get_item(1)?.extract()?;
Ok(Prediction {
benign_probability,
malignant_probability,
})
})
}

View File

@ -26,5 +26,6 @@ def process_model_bytes(bytes):
def predict(model_bytes, image_bytes):
image = preprocess_image_bytes(image_bytes)
model = process_model_bytes(model_bytes)
result = model.predict(image)
return model.predict(image)
return result[0], result[1]