prediction route complete
This commit is contained in:
parent
49ff2be8bc
commit
cda66aa56b
|
|
@ -368,6 +368,7 @@ dependencies = [
|
||||||
"mongodb",
|
"mongodb",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
"vrd 0.0.7",
|
"vrd 0.0.7",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -13,5 +13,6 @@ http = "1.1.0"
|
||||||
mongodb = { version = "2.8.2", features = ["bson-chrono-0_4", "tokio-runtime"]}
|
mongodb = { version = "2.8.2", features = ["bson-chrono-0_4", "tokio-runtime"]}
|
||||||
pyo3 = { version = "0.21.2", features = ["auto-initialize"]}
|
pyo3 = { version = "0.21.2", features = ["auto-initialize"]}
|
||||||
serde = "1.0.203"
|
serde = "1.0.203"
|
||||||
|
serde_json = "1.0.117"
|
||||||
tokio = "1.38.0"
|
tokio = "1.38.0"
|
||||||
vrd = "0.0.7"
|
vrd = "0.0.7"
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,14 @@
|
||||||
pub mod db;
|
pub mod db;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body, response::{IntoResponse, Response}, routing::{get, post}, Json, Router
|
||||||
Json,
|
|
||||||
response::Response,
|
|
||||||
Router,
|
|
||||||
routing::{get, post}
|
|
||||||
};
|
};
|
||||||
use crate::ApiResult;
|
use crate::ApiResult;
|
||||||
use db::{get_users, User};
|
use db::{get_users, User};
|
||||||
use http::{header::HeaderMap, StatusCode};
|
use http::{header::HeaderMap, StatusCode};
|
||||||
use mongodb::bson::{doc, oid::ObjectId};
|
use mongodb::bson::{doc, oid::ObjectId};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
pub fn router() -> Router {
|
pub fn router() -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
|
|
@ -31,7 +28,19 @@ pub fn router() -> Router {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
|
pub async fn auth(api: &str) -> Result<Option<User>> {
|
||||||
|
let db = get_users().await?;
|
||||||
|
|
||||||
|
let query = doc! {
|
||||||
|
"$expr": { "$eq": ["$_auth._api", ObjectId::parse_str(api)?] }
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(db.find_one(query, None).await?)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||||
let db = get_users().await?;
|
let db = get_users().await?;
|
||||||
let query = doc! {
|
let query = doc! {
|
||||||
"$expr": { "$eq": ["$username", body.username] }
|
"$expr": { "$eq": ["$username", body.username] }
|
||||||
|
|
@ -41,7 +50,9 @@ pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||||
Some(user) => {
|
Some(user) => {
|
||||||
Ok(Response::builder()
|
Ok(Response::builder()
|
||||||
.status(StatusCode::CREATED)
|
.status(StatusCode::CREATED)
|
||||||
.body(Body::from(user.auth.unwrap().salt.unwrap()))?)
|
.body(Json(json!({
|
||||||
|
"_salt": user.auth.unwrap().salt.unwrap()
|
||||||
|
})).into_response().into_body())?)
|
||||||
},
|
},
|
||||||
None => {
|
None => {
|
||||||
Ok(Response::builder()
|
Ok(Response::builder()
|
||||||
|
|
@ -51,7 +62,7 @@ pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
|
async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||||
let db = get_users().await?;
|
let db = get_users().await?;
|
||||||
let api = ObjectId::new();
|
let api = ObjectId::new();
|
||||||
let query = doc! {
|
let query = doc! {
|
||||||
|
|
@ -74,7 +85,9 @@ pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||||
|
|
||||||
Ok(Response::builder()
|
Ok(Response::builder()
|
||||||
.status(StatusCode::OK)
|
.status(StatusCode::OK)
|
||||||
.body(Body::from(api.to_string()))?)
|
.body(Json(json!({
|
||||||
|
"_api_key": api.to_string()
|
||||||
|
})).into_response().into_body())?)
|
||||||
},
|
},
|
||||||
None => {
|
None => {
|
||||||
Ok(Response::builder()
|
Ok(Response::builder()
|
||||||
|
|
@ -84,7 +97,7 @@ pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
|
async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
|
||||||
let db = get_users().await?;
|
let db = get_users().await?;
|
||||||
let auth = body.clone().auth.unwrap_or_default();
|
let auth = body.clone().auth.unwrap_or_default();
|
||||||
let query = doc! {
|
let query = doc! {
|
||||||
|
|
@ -114,7 +127,7 @@ pub async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
|
async fn post_sign_out(headers: HeaderMap) -> ApiResult {
|
||||||
let db = get_users().await?;
|
let db = get_users().await?;
|
||||||
let api = headers["api_key"].to_str()?;
|
let api = headers["api_key"].to_str()?;
|
||||||
|
|
||||||
|
|
@ -127,7 +140,7 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match db.find_one(query.clone(), None).await? {
|
match auth(api).await? {
|
||||||
Some(_user) => {
|
Some(_user) => {
|
||||||
db.update_one(query, update, None).await?;
|
db.update_one(query, update, None).await?;
|
||||||
Ok(Response::builder()
|
Ok(Response::builder()
|
||||||
|
|
@ -141,7 +154,6 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn post_backup() {} //TODO: Backup
|
pub async fn post_backup() {} //TODO: Backup
|
||||||
|
|
||||||
pub async fn get_restore() {} //TODO: restore
|
pub async fn get_restore() {} //TODO: restore
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,10 @@ use mongodb::{
|
||||||
pub async fn get_db_client() -> Result<Client> {
|
pub async fn get_db_client() -> Result<Client> {
|
||||||
let username = std::env::var("DB_USERNAME")?;
|
let username = std::env::var("DB_USERNAME")?;
|
||||||
let password = std::env::var("DB_PASSWORD")?;
|
let password = std::env::var("DB_PASSWORD")?;
|
||||||
Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@db:27017")).await?)
|
let host = std::env::var("DB_HOST")?;
|
||||||
|
let port = std::env::var("DB_PORT")?;
|
||||||
|
|
||||||
|
Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@{host}:{port}")).await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_database() -> Result<Database> {
|
pub async fn get_database() -> Result<Database> {
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,8 @@ pub type ApiResult = Result<Response, AppError>;
|
||||||
pub async fn run() -> Result<()> {
|
pub async fn run() -> Result<()> {
|
||||||
|
|
||||||
let app = router();
|
let app = router();
|
||||||
|
let port = std::env::var("API_PORT")?;
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap();
|
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}")).await.unwrap();
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,7 +28,7 @@ fn router() -> axum::Router {
|
||||||
.nest("/predict", model::router())
|
.nest("/predict", model::router())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AppError(anyhow::Error);
|
pub struct AppError(anyhow::Error);
|
||||||
|
|
||||||
impl IntoResponse for AppError {
|
impl IntoResponse for AppError {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
|
|
|
||||||
63
src/model.rs
63
src/model.rs
|
|
@ -1,9 +1,18 @@
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::{Body, Bytes},
|
||||||
response::Response
|
response::{IntoResponse, Json, Response}
|
||||||
};
|
};
|
||||||
use crate::ApiResult;
|
use crate::{
|
||||||
use pyo3::prelude::*;
|
ApiResult,
|
||||||
|
account::auth
|
||||||
|
};
|
||||||
|
use http::{HeaderMap, StatusCode};
|
||||||
|
use pyo3::{
|
||||||
|
prelude::*,
|
||||||
|
types::PyTuple
|
||||||
|
|
||||||
|
};
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
const MODULE: &str = include_str!("model/model.py");
|
const MODULE: &str = include_str!("model/model.py");
|
||||||
const MODEL: &[u8] = include_bytes!("model/model.keras");
|
const MODEL: &[u8] = include_bytes!("model/model.keras");
|
||||||
|
|
@ -15,16 +24,46 @@ pub fn router() -> axum::Router {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_predict() -> ApiResult {
|
#[derive(Serialize, Deserialize)]
|
||||||
//TODO: If api key is correct
|
struct Prediction {
|
||||||
// Extract image from body of req
|
benign_probability: f64,
|
||||||
let result = Python::with_gil(|py| -> PyResult<()> {
|
malignant_probability: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_predict(headers: HeaderMap, body: Bytes) -> ApiResult {
|
||||||
|
let api = headers["api_key"].to_str()?;
|
||||||
|
|
||||||
|
match auth(api).await? {
|
||||||
|
Some(_user) => {
|
||||||
|
let prediction = predict(body).await?;
|
||||||
|
|
||||||
|
Ok(Response::builder()
|
||||||
|
.status(StatusCode::OK)
|
||||||
|
.body(Json(prediction).into_response().into_body())?)
|
||||||
|
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
Ok(Response::builder()
|
||||||
|
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
|
.body(Body::from("Prediction could not be completed"))?)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn predict(image: Bytes) -> PyResult<Prediction> {
|
||||||
|
Python::with_gil(|py| -> PyResult<Prediction> {
|
||||||
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
|
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
|
||||||
.getattr("predict")?.into();
|
.getattr("predict")?.into();
|
||||||
|
|
||||||
//TODO: Get result and return it
|
let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?;
|
||||||
Ok(())
|
let benign_probability: f64 = results.get_item(0)?.extract()?;
|
||||||
});
|
let malignant_probability: f64 = results.get_item(1)?.extract()?;
|
||||||
|
|
||||||
Ok(Response::builder().body(Body::from(""))?)
|
Ok(Prediction {
|
||||||
|
benign_probability,
|
||||||
|
malignant_probability,
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,5 +26,6 @@ def process_model_bytes(bytes):
|
||||||
def predict(model_bytes, image_bytes):
|
def predict(model_bytes, image_bytes):
|
||||||
image = preprocess_image_bytes(image_bytes)
|
image = preprocess_image_bytes(image_bytes)
|
||||||
model = process_model_bytes(model_bytes)
|
model = process_model_bytes(model_bytes)
|
||||||
|
result = model.predict(image)
|
||||||
|
|
||||||
return model.predict(image)
|
return result[0], result[1]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue