summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/events.rs46
-rw-r--r--src/header.rs58
-rw-r--r--src/lib.rs1
3 files changed, 68 insertions, 37 deletions
diff --git a/src/events.rs b/src/events.rs
index 38d53fc..9b5901e 100644
--- a/src/events.rs
+++ b/src/events.rs
@@ -1,6 +1,5 @@
use axum::{
extract::State,
- http::{HeaderName, HeaderValue},
response::{
sse::{self, Sse},
IntoResponse,
@@ -8,7 +7,7 @@ use axum::{
routing::get,
Router,
};
-use axum_extra::{extract::Query, typed_header::TypedHeader};
+use axum_extra::extract::Query;
use chrono::{format::SecondsFormat, DateTime};
use futures::{
future,
@@ -19,6 +18,7 @@ use crate::{
app::App,
channel::repo::broadcast,
error::{BoxedError, InternalError},
+ header::LastEventId,
repo::{channel, login::Login},
};
@@ -35,11 +35,10 @@ struct EventsQuery {
async fn on_events(
State(app): State<App>,
_: Login, // requires auth, but doesn't actually care who you are
- last_event_id: Option<TypedHeader<LastEventId>>,
+ last_event_id: Option<LastEventId>,
Query(query): Query<EventsQuery>,
) -> Result<impl IntoResponse, InternalError> {
let resume_at = last_event_id
- .map(|TypedHeader(header)| header)
.map(|LastEventId(header)| header)
.map(|header| DateTime::parse_from_rfc3339(&header))
.transpose()?
@@ -53,10 +52,7 @@ async fn on_events(
.channels()
.events(&channel, resume_at.as_ref())
.await?
- .map_ok(move |message| ChannelEvent {
- channel: channel.clone(),
- message,
- });
+ .map_ok(ChannelEvent::wrap(channel));
Ok::<_, BoxedError>(events)
}
@@ -89,35 +85,11 @@ struct ChannelEvent<M> {
message: M,
}
-pub struct LastEventId(pub String);
-
-static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id");
-
-impl headers::Header for LastEventId {
- fn name() -> &'static HeaderName {
- &LAST_EVENT_ID
- }
-
- fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
- where
- I: Iterator<Item = &'i HeaderValue>,
- {
- let value = values.next().ok_or_else(headers::Error::invalid)?;
- if let Ok(value) = value.to_str() {
- Ok(Self(value.into()))
- } else {
- Err(headers::Error::invalid())
+impl<M> ChannelEvent<M> {
+ fn wrap(channel: channel::Id) -> impl Fn(M) -> Self {
+ move |message| Self {
+ channel: channel.clone(),
+ message,
}
}
-
- fn encode<E>(&self, values: &mut E)
- where
- E: Extend<HeaderValue>,
- {
- let Self(value) = self;
- // Must panic or suppress; the trait provides no other options.
- let value = HeaderValue::from_str(value).expect("LastEventId is a valid header value");
-
- values.extend(std::iter::once(value));
- }
}
diff --git a/src/header.rs b/src/header.rs
new file mode 100644
index 0000000..904e29d
--- /dev/null
+++ b/src/header.rs
@@ -0,0 +1,58 @@
+use axum::{
+ extract::FromRequestParts,
+ http::{request::Parts, HeaderName, HeaderValue},
+};
+use axum_extra::typed_header::TypedHeader;
+
+/// A typed header. When used as a bare extractor, reads from the
+/// `Last-Event-Id` HTTP header.
+pub struct LastEventId(pub String);
+
+static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id");
+
+impl headers::Header for LastEventId {
+ fn name() -> &'static HeaderName {
+ &LAST_EVENT_ID
+ }
+
+ fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
+ where
+ I: Iterator<Item = &'i HeaderValue>,
+ {
+ let value = values.next().ok_or_else(headers::Error::invalid)?;
+ if let Ok(value) = value.to_str() {
+ Ok(Self(value.into()))
+ } else {
+ Err(headers::Error::invalid())
+ }
+ }
+
+ fn encode<E>(&self, values: &mut E)
+ where
+ E: Extend<HeaderValue>,
+ {
+ let Self(value) = self;
+ // Must panic or suppress; the trait provides no other options.
+ let value = HeaderValue::from_str(value).expect("LastEventId is a valid header value");
+
+ values.extend(std::iter::once(value));
+ }
+}
+
+#[async_trait::async_trait]
+impl<S> FromRequestParts<S> for LastEventId
+where
+ S: Send + Sync,
+{
+ type Rejection = <TypedHeader<Self> as FromRequestParts<S>>::Rejection;
+
+ async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
+ // This is purely for ergonomics: it allows `RequestedAt` to be extracted
+ // without having to wrap it in `Extension<>`. Callers _can_ still do that,
+ // but they aren't forced to.
+ let TypedHeader(requested_at) =
+ TypedHeader::<Self>::from_request_parts(parts, state).await?;
+
+ Ok(requested_at)
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index f71ef95..8b6a78f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -4,6 +4,7 @@ pub mod cli;
mod clock;
mod error;
mod events;
+mod header;
mod id;
mod index;
mod login;