diff --git a/src/web/routes.rs b/src/web/routes.rs index 059c2ee..655f0b1 100644 --- a/src/web/routes.rs +++ b/src/web/routes.rs @@ -2,7 +2,8 @@ use std::sync::Arc; use axum::{ extract::State, - response::{ErrorResponse, Html, Redirect}, + http::{HeaderMap, HeaderValue}, + response::{ErrorResponse, Html, IntoResponse, Redirect}, routing::{get, post, IntoMakeService}, Json, Router, }; @@ -10,27 +11,64 @@ use serde::Serialize; use tokio::sync::RwLock; use crate::{ - web::templates::load_templates, + task::{collect_local_address, MY_ADDRESS}, types::{AppState, PageContext}, - task::{MY_ADDRESS, collect_local_address}, + web::templates::load_templates, }; -async fn root(State(state): State>>) -> axum::response::Result> { - let state = state.read().await; +enum Response { + Html(Html), + Json(Json), +} +impl IntoResponse for Response +where + T: Serialize, +{ + fn into_response(self) -> axum::response::Response { + match self { + Response::Html(html) => html.into_response(), + Response::Json(json) => json.into_response(), + } + } +} + +#[derive(Serialize, Default)] +struct AddressResult { + address: String, +} + +async fn root( + State(state): State>>, + headers: HeaderMap, +) -> axum::response::Result> { let address_guard = MY_ADDRESS.lock().await; let address = address_guard.clone(); drop(address_guard); - let ctx = match tera::Context::from_serialize(&PageContext { - address, - }) { - Ok(ctx) => ctx, - Err(err) => return Err(ErrorResponse::from(format!("{err}"))), - }; - match state.templates.render("index.html", &ctx) { - Ok(result) => Ok(Html::from(result)), - Err(err) => Err(ErrorResponse::from(format!("{err}"))), + let accept: Vec<&str> = headers + .get("Accept") + .map(|h| h.to_str().unwrap()) + .unwrap_or("text/html") + .split(",") + .collect(); + + tracing::info!(?accept, "accept"); + if accept.contains(&"text/html") { + let state = state.read().await; + + let ctx = match tera::Context::from_serialize(&PageContext { address }) { + Ok(ctx) => ctx, + Err(err) => return Err(ErrorResponse::from(format!("{err}"))), + }; + match state.templates.render("index.html", &ctx) { + Ok(result) => Ok(Response::Html(Html::from(result))), + Err(err) => Err(ErrorResponse::from(format!("{err}"))), + } + } else if accept.contains(&"application/json") { + Ok(Response::Json(Json(AddressResult { address }))) + } else { + unreachable!() } } @@ -40,7 +78,9 @@ struct ReloadResult { error: Option, } -async fn reload(State(state): State>>) -> Result> { +async fn reload( + State(state): State>>, +) -> Result> { let mut state = state.write_owned().await; if let Err(err) = state.templates.full_reload() { return Err(Json(ReloadResult { @@ -59,7 +99,7 @@ async fn reload(State(state): State>>) -> Result Result>, Box> { let templates = load_templates()?; - + let router = Router::new(); #[cfg(feature = "serve-static")]