summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2024-09-13 00:26:03 -0400
committerOwen Jacobson <owen@grimoire.ca>2024-09-13 02:42:27 -0400
commit067e3da1900d052a416c56e1c047640aa23441ae (patch)
tree8baad4240d2532216f2530f5c974479e557c675a /src
parent5d76d0712e07040d9aeeebccb189d75636a07c7a (diff)
Transmit messages via `/:chan/send` and `/:chan/events`.
Diffstat (limited to 'src')
-rw-r--r--src/app.rs16
-rw-r--r--src/channel/app.rs141
-rw-r--r--src/channel/repo/channels.rs (renamed from src/channel/repo.rs)11
-rw-r--r--src/channel/repo/messages.rs111
-rw-r--r--src/channel/repo/mod.rs2
-rw-r--r--src/channel/routes.rs56
-rw-r--r--src/cli.rs3
-rw-r--r--src/id.rs2
-rw-r--r--src/index/app.rs2
-rw-r--r--src/index/templates.rs2
-rw-r--r--src/login/repo/logins.rs2
11 files changed, 324 insertions, 24 deletions
diff --git a/src/app.rs b/src/app.rs
index 4195fdc..f349fd4 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -1,15 +1,23 @@
use sqlx::sqlite::SqlitePool;
-use crate::{channel::app::Channels, index::app::Index, login::app::Logins};
+use crate::error::BoxedError;
+
+use crate::{
+ channel::app::{Broadcaster, Channels},
+ index::app::Index,
+ login::app::Logins,
+};
#[derive(Clone)]
pub struct App {
db: SqlitePool,
+ broadcaster: Broadcaster,
}
impl App {
- pub fn from(db: SqlitePool) -> Self {
- Self { db }
+ pub async fn from(db: SqlitePool) -> Result<Self, BoxedError> {
+ let broadcaster = Broadcaster::from_database(&db).await?;
+ Ok(Self { db, broadcaster })
}
}
@@ -23,6 +31,6 @@ impl App {
}
pub fn channels(&self) -> Channels {
- Channels::new(&self.db)
+ Channels::new(&self.db, &self.broadcaster)
}
}
diff --git a/src/channel/app.rs b/src/channel/app.rs
index 84822cb..7b02300 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -1,21 +1,152 @@
+use std::collections::{hash_map::Entry, HashMap};
+use std::sync::{Arc, Mutex};
+
+use futures::{
+ stream::{self, StreamExt as _, TryStreamExt as _},
+ Stream,
+};
use sqlx::sqlite::SqlitePool;
+use tokio::sync::broadcast::{channel, Sender};
+use tokio_stream::wrappers::BroadcastStream;
-use super::repo::Provider as _;
-use crate::error::BoxedError;
+use super::repo::{
+ channels::{Id as ChannelId, Provider as _},
+ messages::{Message, Provider as _},
+};
+use crate::{clock::DateTime, error::BoxedError, login::repo::logins::Login};
pub struct Channels<'a> {
db: &'a SqlitePool,
+ broadcaster: &'a Broadcaster,
}
impl<'a> Channels<'a> {
- pub fn new(db: &'a SqlitePool) -> Self {
- Self { db }
+ pub fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self {
+ Self { db, broadcaster }
}
pub async fn create(&self, name: &str) -> Result<(), BoxedError> {
let mut tx = self.db.begin().await?;
- tx.channels().create(name).await?;
+ let channel = tx.channels().create(name).await?;
+ tx.commit().await?;
+
+ self.broadcaster.register_channel(&channel)?;
+ Ok(())
+ }
+
+ pub async fn send(
+ &self,
+ login: &Login,
+ channel: &ChannelId,
+ body: &str,
+ sent_at: &DateTime,
+ ) -> Result<(), BoxedError> {
+ let mut tx = self.db.begin().await?;
+ let message = tx
+ .messages()
+ .create(&login.id, channel, body, sent_at)
+ .await?;
+ tx.commit().await?;
+
+ self.broadcaster.broadcast(channel, message)?;
+ Ok(())
+ }
+
+ pub async fn events(
+ &self,
+ channel: &ChannelId,
+ ) -> Result<impl Stream<Item = Result<Message, BoxedError>>, BoxedError> {
+ let live_messages = self.broadcaster.listen(channel)?.map_err(BoxedError::from);
+
+ let mut tx = self.db.begin().await?;
+ let stored_messages = tx.messages().all(channel).await?;
+ tx.commit().await?;
+
+ let stored_messages = stream::iter(stored_messages.into_iter().map(Ok));
+
+ Ok(stored_messages.chain(live_messages))
+ }
+}
+
+// Clones will share the same senders collection.
+#[derive(Clone)]
+pub struct Broadcaster {
+ // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
+ // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
+ // lock it must be sync.
+ senders: Arc<Mutex<HashMap<ChannelId, Sender<Message>>>>,
+}
+
+impl Broadcaster {
+ pub async fn from_database(db: &SqlitePool) -> Result<Broadcaster, BoxedError> {
+ let mut tx = db.begin().await?;
+ let channels = tx.channels().all().await?;
tx.commit().await?;
+
+ let channels = channels.iter().map(|c| &c.id);
+ let broadcaster = Broadcaster::new(channels);
+ Ok(broadcaster)
+ }
+
+ fn new<'i>(channels: impl IntoIterator<Item = &'i ChannelId>) -> Self {
+ let senders: HashMap<_, _> = channels
+ .into_iter()
+ .cloned()
+ .map(|id| (id, Self::make_sender()))
+ .collect();
+
+ Self {
+ senders: Arc::new(Mutex::new(senders)),
+ }
+ }
+
+ pub fn register_channel(&self, channel: &ChannelId) -> Result<(), RegisterError> {
+ match self.senders.lock().unwrap().entry(channel.clone()) {
+ Entry::Occupied(_) => Err(RegisterError::Duplicate),
+ vacant => {
+ vacant.or_insert_with(Self::make_sender);
+ Ok(())
+ }
+ }
+ }
+
+ pub fn broadcast(&self, channel: &ChannelId, message: Message) -> Result<(), BroadcastError> {
+ let lock = self.senders.lock().unwrap();
+ let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?;
+
+ // Per the Tokio docs, the returned error is only used to indicate that
+ // there are no receivers. In this use case, that's fine; a lack of
+ // listening consumers (chat clients) when a message is sent isn't an
+ // error.
+ let _ = tx.send(message);
Ok(())
}
+
+ pub fn listen(&self, channel: &ChannelId) -> Result<BroadcastStream<Message>, BroadcastError> {
+ let lock = self.senders.lock().unwrap();
+ let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?;
+ let rx = tx.subscribe();
+ let stream = BroadcastStream::from(rx);
+
+ Ok(stream)
+ }
+
+ fn make_sender() -> Sender<Message> {
+ // Queue depth of 16 chosen entirely arbitrarily. Don't read too much
+ // into it.
+ let (tx, _) = channel(16);
+ tx
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum RegisterError {
+ #[error("duplicate channel registered")]
+ Duplicate,
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum BroadcastError {
+ #[error("requested channel not registered")]
+ Unregistered,
}
diff --git a/src/channel/repo.rs b/src/channel/repo/channels.rs
index a04cac5..6fb0c23 100644
--- a/src/channel/repo.rs
+++ b/src/channel/repo/channels.rs
@@ -25,22 +25,23 @@ pub struct Channel {
impl<'c> Channels<'c> {
/// Create a new channel.
- pub async fn create(&mut self, name: &str) -> Result<(), BoxedError> {
+ pub async fn create(&mut self, name: &str) -> Result<Id, BoxedError> {
let id = Id::generate();
- sqlx::query!(
+ let channel = sqlx::query_scalar!(
r#"
insert
into channel (id, name)
values ($1, $2)
+ returning id as "id: Id"
"#,
id,
name,
)
- .execute(&mut *self.0)
+ .fetch_one(&mut *self.0)
.await?;
- Ok(())
+ Ok(channel)
}
pub async fn all(&mut self) -> Result<Vec<Channel>, BoxedError> {
@@ -62,7 +63,7 @@ impl<'c> Channels<'c> {
}
/// Stable identifier for a [Channel]. Prefixed with `C`.
-#[derive(Debug, sqlx::Type, serde::Deserialize)]
+#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)]
#[sqlx(transparent)]
#[serde(transparent)]
pub struct Id(BaseId);
diff --git a/src/channel/repo/messages.rs b/src/channel/repo/messages.rs
new file mode 100644
index 0000000..bdb0d29
--- /dev/null
+++ b/src/channel/repo/messages.rs
@@ -0,0 +1,111 @@
+use std::fmt;
+
+use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
+
+use super::channels::Id as ChannelId;
+use crate::{
+ clock::DateTime, error::BoxedError, id::Id as BaseId, login::repo::logins::Id as LoginId,
+};
+
+pub trait Provider {
+ fn messages(&mut self) -> Messages;
+}
+
+impl<'c> Provider for Transaction<'c, Sqlite> {
+ fn messages(&mut self) -> Messages {
+ Messages(self)
+ }
+}
+
+pub struct Messages<'t>(&'t mut SqliteConnection);
+
+#[derive(Clone, Debug, serde::Serialize)]
+pub struct Message {
+ pub id: Id,
+ pub sender: LoginId,
+ pub channel: ChannelId,
+ pub body: String,
+ pub sent_at: DateTime,
+}
+
+impl<'c> Messages<'c> {
+ pub async fn create(
+ &mut self,
+ sender: &LoginId,
+ channel: &ChannelId,
+ body: &str,
+ sent_at: &DateTime,
+ ) -> Result<Message, BoxedError> {
+ let id = Id::generate();
+
+ let message = sqlx::query_as!(
+ Message,
+ r#"
+ insert into message
+ (id, sender, channel, body, sent_at)
+ values ($1, $2, $3, $4, $5)
+ returning
+ id as "id: Id",
+ sender as "sender: LoginId",
+ channel as "channel: ChannelId",
+ body,
+ sent_at as "sent_at: DateTime"
+ "#,
+ id,
+ sender,
+ channel,
+ body,
+ sent_at,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(message)
+ }
+
+ pub async fn all(&mut self, channel: &ChannelId) -> Result<Vec<Message>, BoxedError> {
+ let messages = sqlx::query_as!(
+ Message,
+ r#"
+ select
+ id as "id: Id",
+ sender as "sender: LoginId",
+ channel as "channel: ChannelId",
+ body,
+ sent_at as "sent_at: DateTime"
+ from message
+ where channel = $1
+ order by sent_at asc
+ "#,
+ channel,
+ )
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ Ok(messages)
+ }
+}
+
+/// Stable identifier for a [Message]. Prefixed with `M`.
+#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)]
+#[sqlx(transparent)]
+#[serde(transparent)]
+pub struct Id(BaseId);
+
+impl From<BaseId> for Id {
+ fn from(id: BaseId) -> Self {
+ Self(id)
+ }
+}
+
+impl Id {
+ pub fn generate() -> Self {
+ BaseId::generate("M")
+ }
+}
+
+impl fmt::Display for Id {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
diff --git a/src/channel/repo/mod.rs b/src/channel/repo/mod.rs
new file mode 100644
index 0000000..345897d
--- /dev/null
+++ b/src/channel/repo/mod.rs
@@ -0,0 +1,2 @@
+pub mod channels;
+pub mod messages;
diff --git a/src/channel/routes.rs b/src/channel/routes.rs
index 864f1b3..83c733c 100644
--- a/src/channel/routes.rs
+++ b/src/channel/routes.rs
@@ -1,14 +1,23 @@
use axum::{
- extract::{Form, State},
- response::{IntoResponse, Redirect},
- routing::post,
+ extract::{Form, Path, State},
+ http::StatusCode,
+ response::{
+ sse::{self, Sse},
+ IntoResponse, Redirect,
+ },
+ routing::{get, post},
Router,
};
+use futures::stream::{StreamExt as _, TryStreamExt as _};
-use crate::{app::App, error::InternalError, login::repo::logins::Login};
+use super::repo::channels::Id as ChannelId;
+use crate::{app::App, clock::RequestedAt, error::InternalError, login::repo::logins::Login};
pub fn router() -> Router<App> {
- Router::new().route("/create", post(on_create))
+ Router::new()
+ .route("/create", post(on_create))
+ .route("/:channel/send", post(on_send))
+ .route("/:channel/events", get(on_events))
}
#[derive(serde::Deserialize)]
@@ -25,3 +34,40 @@ async fn on_create(
Ok(Redirect::to("/"))
}
+
+#[derive(serde::Deserialize)]
+struct SendRequest {
+ message: String,
+}
+
+async fn on_send(
+ Path(channel): Path<ChannelId>,
+ RequestedAt(sent_at): RequestedAt,
+ State(app): State<App>,
+ login: Login,
+ Form(form): Form<SendRequest>,
+) -> Result<impl IntoResponse, InternalError> {
+ app.channels()
+ .send(&login, &channel, &form.message, &sent_at)
+ .await?;
+
+ Ok(StatusCode::ACCEPTED)
+}
+
+async fn on_events(
+ Path(channel): Path<ChannelId>,
+ State(app): State<App>,
+ _: Login, // requires auth, but doesn't actually care who you are
+) -> Result<impl IntoResponse, InternalError> {
+ let stream = app
+ .channels()
+ .events(&channel)
+ .await?
+ .map(|msg| match msg {
+ Ok(msg) => Ok(serde_json::to_string(&msg)?),
+ Err(err) => Err(err),
+ })
+ .map_ok(|msg| sse::Event::default().data(&msg));
+
+ Ok(Sse::new(stream).keep_alive(sse::KeepAlive::default()))
+}
diff --git a/src/cli.rs b/src/cli.rs
index e374834..fa7c499 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -28,9 +28,10 @@ impl Args {
sqlx::migrate!().run(&pool).await?;
+ let app = App::from(pool).await?;
let app = routers()
.route_layer(middleware::from_fn(clock::middleware))
- .with_state(App::from(pool));
+ .with_state(app);
let listener = self.listener().await?;
let started_msg = started_msg(&listener)?;
diff --git a/src/id.rs b/src/id.rs
index 4e12f2a..c69b341 100644
--- a/src/id.rs
+++ b/src/id.rs
@@ -27,7 +27,7 @@ pub const ID_SIZE: usize = 15;
//
// By convention, the prefix should be UPPERCASE - note that the alphabet for this
// is entirely lowercase.
-#[derive(Debug, Hash, PartialEq, Eq, sqlx::Type, serde::Deserialize)]
+#[derive(Clone, Debug, Hash, PartialEq, Eq, sqlx::Type, serde::Deserialize, serde::Serialize)]
#[sqlx(transparent)]
#[serde(transparent)]
pub struct Id(String);
diff --git a/src/index/app.rs b/src/index/app.rs
index 79f5a9a..d6eef18 100644
--- a/src/index/app.rs
+++ b/src/index/app.rs
@@ -1,7 +1,7 @@
use sqlx::sqlite::SqlitePool;
use crate::{
- channel::repo::{Channel, Provider as _},
+ channel::repo::channels::{Channel, Provider as _},
error::BoxedError,
};
diff --git a/src/index/templates.rs b/src/index/templates.rs
index fdb750b..38cd93f 100644
--- a/src/index/templates.rs
+++ b/src/index/templates.rs
@@ -1,6 +1,6 @@
use maud::{html, Markup, DOCTYPE};
-use crate::{channel::repo::Channel, login::repo::logins::Login};
+use crate::{channel::repo::channels::Channel, login::repo::logins::Login};
pub fn authenticated<'c>(login: Login, channels: impl IntoIterator<Item = &'c Channel>) -> Markup {
html! {
diff --git a/src/login/repo/logins.rs b/src/login/repo/logins.rs
index 26a5b09..142d8fb 100644
--- a/src/login/repo/logins.rs
+++ b/src/login/repo/logins.rs
@@ -90,7 +90,7 @@ impl<'c> Logins<'c> {
}
/// Stable identifier for a [Login]. Prefixed with `L`.
-#[derive(Debug, sqlx::Type)]
+#[derive(Clone, Debug, sqlx::Type, serde::Serialize)]
#[sqlx(transparent)]
pub struct Id(BaseId);