use std::collections::{BTreeMap, HashSet}; use axum::{ extract::State, http::StatusCode, response::{ sse::{self, Sse}, IntoResponse, Response, }, routing::get, Router, }; use axum_extra::extract::Query; use futures::{ future, stream::{self, Stream, StreamExt as _, TryStreamExt as _}, }; use super::repo::broadcast; use crate::{ app::App, channel::app::EventsError, clock::RequestedAt, error::InternalError, header::LastEventId, repo::{channel, login::Login}, }; #[cfg(test)] mod test; // For the purposes of event replay, an "event ID" is a vector of per-channel // sequence numbers. Replay will start with messages whose sequence number in // its channel is higher than the sequence in the event ID, or if the channel // is not listed in the event ID, then at the beginning. // // Using a sorted map ensures that there is a canonical representation for // each event ID. type EventId = BTreeMap; pub fn router() -> Router { Router::new().route("/api/events", get(events)) } #[derive(Clone, serde::Deserialize)] struct EventsQuery { #[serde(default, rename = "channel")] channels: HashSet, } async fn events( State(app): State, RequestedAt(now): RequestedAt, _: Login, // requires auth, but doesn't actually care who you are last_event_id: Option>, Query(query): Query, ) -> Result + std::fmt::Debug>, ErrorResponse> { let resume_at = last_event_id .map(LastEventId::into_inner) .unwrap_or_default(); let streams = stream::iter(query.channels) .then(|channel| { let app = app.clone(); let resume_at = resume_at.clone(); async move { let resume_at = resume_at.get(&channel).copied(); let events = app .channels() .events(&channel, &now, resume_at) .await? .map(ChannelEvent::wrap(channel)); Ok::<_, EventsError>(events) } }) .try_collect::>() .await // impl From would take more code; this is used once. .map_err(ErrorResponse)?; // We resume counting from the provided last-event-id mapping, rather than // starting from scratch, so that the events in a resumed stream contain // the full vector of channel IDs for their event IDs right off the bat, // even before any events are actually delivered. let stream = stream::select_all(streams).scan(resume_at, |sequences, event| { let (channel, sequence) = event.event_id(); sequences.insert(channel, sequence); let event = ReplayableEvent(sequences.clone(), event); future::ready(Some(event)) }); Ok(Events(stream)) } #[derive(Debug)] struct Events(S); impl IntoResponse for Events where S: Stream + Send + 'static, { fn into_response(self) -> Response { let Self(stream) = self; let stream = stream.map(sse::Event::try_from); Sse::new(stream) .keep_alive(sse::KeepAlive::default()) .into_response() } } #[derive(Debug)] struct ErrorResponse(EventsError); impl IntoResponse for ErrorResponse { fn into_response(self) -> Response { let Self(error) = self; match error { not_found @ EventsError::ChannelNotFound(_) => { (StatusCode::NOT_FOUND, not_found.to_string()).into_response() } resume_at @ EventsError::ResumeAtError(_) => { (StatusCode::BAD_REQUEST, resume_at.to_string()).into_response() } other => InternalError::from(other).into_response(), } } } #[derive(Debug)] struct ReplayableEvent(EventId, ChannelEvent); #[derive(Debug, serde::Serialize)] struct ChannelEvent { channel: channel::Id, #[serde(flatten)] message: broadcast::Message, } impl ChannelEvent { fn wrap(channel: channel::Id) -> impl Fn(broadcast::Message) -> Self { move |message| Self { channel: channel.clone(), message, } } fn event_id(&self) -> (channel::Id, broadcast::Sequence) { (self.channel.clone(), self.message.sequence) } } impl TryFrom for sse::Event { type Error = serde_json::Error; fn try_from(value: ReplayableEvent) -> Result { let ReplayableEvent(id, data) = value; let id = serde_json::to_string(&id)?; let data = serde_json::to_string_pretty(&data)?; let event = Self::default().id(id).data(data); Ok(event) } }