prediction route complete
This commit is contained in:
parent
49ff2be8bc
commit
cda66aa56b
|
|
@ -368,6 +368,7 @@ dependencies = [
|
|||
"mongodb",
|
||||
"pyo3",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"vrd 0.0.7",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -13,5 +13,6 @@ http = "1.1.0"
|
|||
mongodb = { version = "2.8.2", features = ["bson-chrono-0_4", "tokio-runtime"]}
|
||||
pyo3 = { version = "0.21.2", features = ["auto-initialize"]}
|
||||
serde = "1.0.203"
|
||||
serde_json = "1.0.117"
|
||||
tokio = "1.38.0"
|
||||
vrd = "0.0.7"
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
pub mod db;
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::{
|
||||
body::Body,
|
||||
Json,
|
||||
response::Response,
|
||||
Router,
|
||||
routing::{get, post}
|
||||
body::Body, response::{IntoResponse, Response}, routing::{get, post}, Json, Router
|
||||
};
|
||||
use crate::ApiResult;
|
||||
use db::{get_users, User};
|
||||
use http::{header::HeaderMap, StatusCode};
|
||||
use mongodb::bson::{doc, oid::ObjectId};
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
|
|
@ -31,7 +28,19 @@ pub fn router() -> Router {
|
|||
)
|
||||
}
|
||||
|
||||
pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||
pub async fn auth(api: &str) -> Result<Option<User>> {
|
||||
let db = get_users().await?;
|
||||
|
||||
let query = doc! {
|
||||
"$expr": { "$eq": ["$_auth._api", ObjectId::parse_str(api)?] }
|
||||
};
|
||||
|
||||
Ok(db.find_one(query, None).await?)
|
||||
|
||||
}
|
||||
|
||||
|
||||
async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||
let db = get_users().await?;
|
||||
let query = doc! {
|
||||
"$expr": { "$eq": ["$username", body.username] }
|
||||
|
|
@ -41,7 +50,9 @@ pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
|
|||
Some(user) => {
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::CREATED)
|
||||
.body(Body::from(user.auth.unwrap().salt.unwrap()))?)
|
||||
.body(Json(json!({
|
||||
"_salt": user.auth.unwrap().salt.unwrap()
|
||||
})).into_response().into_body())?)
|
||||
},
|
||||
None => {
|
||||
Ok(Response::builder()
|
||||
|
|
@ -51,7 +62,7 @@ pub async fn get_sign_in(Json(body): Json<User>) -> ApiResult {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||
async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
|
||||
let db = get_users().await?;
|
||||
let api = ObjectId::new();
|
||||
let query = doc! {
|
||||
|
|
@ -74,7 +85,9 @@ pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
|
|||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::from(api.to_string()))?)
|
||||
.body(Json(json!({
|
||||
"_api_key": api.to_string()
|
||||
})).into_response().into_body())?)
|
||||
},
|
||||
None => {
|
||||
Ok(Response::builder()
|
||||
|
|
@ -84,7 +97,7 @@ pub async fn post_sign_in(Json(body): Json<User>) -> ApiResult {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
|
||||
async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
|
||||
let db = get_users().await?;
|
||||
let auth = body.clone().auth.unwrap_or_default();
|
||||
let query = doc! {
|
||||
|
|
@ -114,7 +127,7 @@ pub async fn post_sign_up(Json(body): Json<User>) -> ApiResult {
|
|||
}
|
||||
|
||||
|
||||
pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
|
||||
async fn post_sign_out(headers: HeaderMap) -> ApiResult {
|
||||
let db = get_users().await?;
|
||||
let api = headers["api_key"].to_str()?;
|
||||
|
||||
|
|
@ -127,7 +140,7 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
|
|||
}
|
||||
};
|
||||
|
||||
match db.find_one(query.clone(), None).await? {
|
||||
match auth(api).await? {
|
||||
Some(_user) => {
|
||||
db.update_one(query, update, None).await?;
|
||||
Ok(Response::builder()
|
||||
|
|
@ -141,7 +154,6 @@ pub async fn post_sign_out(headers: HeaderMap) -> ApiResult {
|
|||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn post_backup() {} //TODO: Backup
|
||||
|
||||
pub async fn get_restore() {} //TODO: restore
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@ use mongodb::{
|
|||
pub async fn get_db_client() -> Result<Client> {
|
||||
let username = std::env::var("DB_USERNAME")?;
|
||||
let password = std::env::var("DB_PASSWORD")?;
|
||||
Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@db:27017")).await?)
|
||||
let host = std::env::var("DB_HOST")?;
|
||||
let port = std::env::var("DB_PORT")?;
|
||||
|
||||
Ok(Client::with_uri_str(format!("mongodb://{username}:{password}@{host}:{port}")).await?)
|
||||
}
|
||||
|
||||
pub async fn get_database() -> Result<Database> {
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ pub type ApiResult = Result<Response, AppError>;
|
|||
pub async fn run() -> Result<()> {
|
||||
|
||||
let app = router();
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap();
|
||||
let port = std::env::var("API_PORT")?;
|
||||
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}")).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ fn router() -> axum::Router {
|
|||
.nest("/predict", model::router())
|
||||
}
|
||||
|
||||
struct AppError(anyhow::Error);
|
||||
pub struct AppError(anyhow::Error);
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
|
|
|
|||
71
src/model.rs
71
src/model.rs
|
|
@ -1,9 +1,18 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
response::Response
|
||||
body::{Body, Bytes},
|
||||
response::{IntoResponse, Json, Response}
|
||||
};
|
||||
use crate::ApiResult;
|
||||
use pyo3::prelude::*;
|
||||
use crate::{
|
||||
ApiResult,
|
||||
account::auth
|
||||
};
|
||||
use http::{HeaderMap, StatusCode};
|
||||
use pyo3::{
|
||||
prelude::*,
|
||||
types::PyTuple
|
||||
|
||||
};
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
const MODULE: &str = include_str!("model/model.py");
|
||||
const MODEL: &[u8] = include_bytes!("model/model.keras");
|
||||
|
|
@ -15,16 +24,46 @@ pub fn router() -> axum::Router {
|
|||
)
|
||||
}
|
||||
|
||||
async fn get_predict() -> ApiResult {
|
||||
//TODO: If api key is correct
|
||||
// Extract image from body of req
|
||||
let result = Python::with_gil(|py| -> PyResult<()> {
|
||||
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
|
||||
.getattr("predict")?.into();
|
||||
|
||||
//TODO: Get result and return it
|
||||
Ok(())
|
||||
});
|
||||
|
||||
Ok(Response::builder().body(Body::from(""))?)
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Prediction {
|
||||
benign_probability: f64,
|
||||
malignant_probability: f64,
|
||||
}
|
||||
|
||||
async fn get_predict(headers: HeaderMap, body: Bytes) -> ApiResult {
|
||||
let api = headers["api_key"].to_str()?;
|
||||
|
||||
match auth(api).await? {
|
||||
Some(_user) => {
|
||||
let prediction = predict(body).await?;
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Json(prediction).into_response().into_body())?)
|
||||
|
||||
},
|
||||
None => {
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from("Prediction could not be completed"))?)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
async fn predict(image: Bytes) -> PyResult<Prediction> {
|
||||
Python::with_gil(|py| -> PyResult<Prediction> {
|
||||
let predict: Py<PyAny> = PyModule::from_code_bound(py, MODULE, "model.py", "model")?
|
||||
.getattr("predict")?.into();
|
||||
|
||||
let results: &PyTuple = predict.call1(py, (MODEL, image.into_py(py)))?.extract(py)?;
|
||||
let benign_probability: f64 = results.get_item(0)?.extract()?;
|
||||
let malignant_probability: f64 = results.get_item(1)?.extract()?;
|
||||
|
||||
Ok(Prediction {
|
||||
benign_probability,
|
||||
malignant_probability,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,5 +26,6 @@ def process_model_bytes(bytes):
|
|||
def predict(model_bytes, image_bytes):
|
||||
image = preprocess_image_bytes(image_bytes)
|
||||
model = process_model_bytes(model_bytes)
|
||||
result = model.predict(image)
|
||||
|
||||
return model.predict(image)
|
||||
return result[0], result[1]
|
||||
|
|
|
|||
Loading…
Reference in New Issue