use std::ops::Deref; use axum::{ extract::FromRequestParts, http::{request::Parts, HeaderName, HeaderValue}, }; use axum_extra::typed_header::TypedHeader; use serde::{de::DeserializeOwned, Serialize}; /// A typed header. When used as a bare extractor, reads from the /// `Last-Event-Id` HTTP header. pub struct LastEventId(pub T); static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id"); impl headers::Header for LastEventId where T: Serialize + DeserializeOwned, { fn name() -> &'static HeaderName { &LAST_EVENT_ID } fn decode<'i, I>(values: &mut I) -> Result where I: Iterator, { let value = values.next().ok_or_else(headers::Error::invalid)?; let value = value.to_str().map_err(|_| headers::Error::invalid())?; let value = serde_json::from_str(value).map_err(|_| headers::Error::invalid())?; Ok(Self(value)) } fn encode(&self, values: &mut E) where E: Extend, { let Self(value) = self; // Must panic or suppress; the trait provides no other options. let value = serde_json::to_string(value).expect("value can be encoded as JSON"); let value = HeaderValue::from_str(&value).expect("LastEventId is a valid header value"); values.extend(std::iter::once(value)); } } #[async_trait::async_trait] impl FromRequestParts for LastEventId where S: Send + Sync, T: Serialize + DeserializeOwned, { type Rejection = as FromRequestParts>::Rejection; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { // This is purely for ergonomics: it allows `RequestedAt` to be extracted // without having to wrap it in `Extension<>`. Callers _can_ still do that, // but they aren't forced to. let TypedHeader(requested_at) = TypedHeader::::from_request_parts(parts, state).await?; Ok(requested_at) } } impl Deref for LastEventId { type Target = T; fn deref(&self) -> &Self::Target { let Self(header) = self; header } } impl From for LastEventId { fn from(value: T) -> Self { Self(value) } } impl LastEventId { pub fn into_inner(self) -> T { let Self(value) = self; value } }