diff --git a/Cargo.toml b/Cargo.toml index 14bc6dd..4464963 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,13 +8,10 @@ edition = "2021" [dependencies] anyhow = "1.0.86" axum = "0.7.5" -axum_session = "0.14.0" -axum_session_auth = "0.14.0" -axum_session_mongo = "0.1.0" -candle-nn = "0.5.1" chrono = "0.4.38" http = "1.1.0" mongodb = { version = "2.8.2", features = ["bson-chrono-0_4", "tokio-runtime"]} +pyo3 = { version = "0.21.2", features = ["auto-initialize"]} serde = "1.0.203" tokio = "1.38.0" vrd = "0.0.7" diff --git a/src/account.rs b/src/account.rs index 1c00f57..a1c5215 100644 --- a/src/account.rs +++ b/src/account.rs @@ -7,12 +7,11 @@ use axum::{ Router, routing::{get, post} }; -use crate::AppError; +use crate::ApiResult; use db::{get_users, User}; use http::{header::HeaderMap, StatusCode}; use mongodb::bson::{doc, oid::ObjectId}; -type ApiResult = Result; pub fn router() -> Router { Router::new() diff --git a/src/account/db.rs b/src/account/db.rs index f3d2b37..26f7fbc 100644 --- a/src/account/db.rs +++ b/src/account/db.rs @@ -1,6 +1,4 @@ use anyhow::Result; -use axum::async_trait; -use axum_session_auth::Authentication; use serde::{Deserialize, Serialize}; use mongodb::{ bson::{doc, oid::ObjectId, DateTime}, diff --git a/src/lib.rs b/src/lib.rs index fe9e1d3..07d1814 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,11 +2,13 @@ mod account; mod model; use anyhow::Result; -use axum::response::{IntoResponse, Response}; -use axum_session::{SessionConfig, SessionStore}; -use axum_session_mongo::SessionMongoPool; -use http::StatusCode; -use mongodb::Client; +use axum::{ + body::Body, + response::{IntoResponse, Response} +}; +use http::{StatusCode, Uri}; + +pub type ApiResult = Result; pub async fn run() -> Result<()> { @@ -22,7 +24,8 @@ pub async fn run() -> Result<()> { fn router() -> axum::Router { axum::Router::new() .nest("/account", account::router()) - //.nest("/predict", model::router()) + .fallback(not_found) + .nest("/predict", model::router()) } struct AppError(anyhow::Error); @@ -45,3 +48,9 @@ where Self(err.into()) } } + +async fn not_found(uri: Uri) -> ApiResult { + Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::from(format!("The route {uri} does not exist")))?) +} diff --git a/src/model.rs b/src/model.rs index 7351299..c6a45ac 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,3 +1,13 @@ +use axum::{ + body::Body, + response::Response +}; +use crate::ApiResult; +use pyo3::prelude::*; + +const MODULE: &str = include_str!("model/model.py"); +const MODEL: &[u8] = include_bytes!("model/model.keras"); + pub fn router() -> axum::Router { axum::Router::new() .route("/", @@ -5,4 +15,16 @@ pub fn router() -> axum::Router { ) } -async fn get_predict() {} +async fn get_predict() -> ApiResult { + //TODO: If api key is correct + // Extract image from body of req + let result = Python::with_gil(|py| -> PyResult<()> { + let predict: Py = PyModule::from_code_bound(py, MODULE, "model.py", "model")? + .getattr("predict")?.into(); + //TODO: Return result + Ok(()) + }); + + //TODO: Return results + Ok(Response::builder().body(Body::from(""))?) +}