diff --git a/src/model.rs b/src/model.rs index a0538b1..26a2147 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,4 +1,5 @@ use axum::{ + extract::Path, body::{Body, Bytes}, response::{IntoResponse, Json, Response} }; @@ -6,7 +7,7 @@ use crate::{ ApiResult, account::auth }; -use http::{HeaderMap, StatusCode}; +use http::StatusCode; use pyo3::{ prelude::*, types::PyTuple @@ -19,7 +20,7 @@ const MODEL: &[u8] = include_bytes!("model/model.keras"); pub fn router() -> axum::Router { axum::Router::new() - .route("/", + .route("/:api", axum::routing::post(get_predict) ) } @@ -30,10 +31,8 @@ struct Prediction { malignant_probability: f64, } -async fn get_predict(headers: HeaderMap, body: Bytes) -> ApiResult { - let api = headers["api_key"].to_str()?; - - match auth(api).await? { +async fn get_predict(Path(api): Path, body: Bytes) -> ApiResult { + match auth(&api).await? { Some(_user) => { let prediction = predict(body).await?;