Skip to main content

tty_web/web/
ws.rs

1//! WebSocket handler implementing the tty-web binary protocol.
2//!
3//! # Wire protocol
4//!
5//! All WebSocket messages are **binary frames**. The first byte is the command,
6//! the rest is the payload.
7//!
8//! | Direction | Cmd | Payload | Description |
9//! |-----------|------|---------|-------------|
10//! | client → server | `0x00` | raw bytes | Terminal input |
11//! | client → server | `0x01` | rows(u16 BE) + cols(u16 BE) | Resize |
12//! | server → client | `0x00` | raw bytes | Terminal output |
13//! | server → client | `0x10` | UUID string | Session ID |
14//! | server → client | `0x12` | — | Shell exited |
15//! | server → client | `0x13` | rows(u16 BE) + cols(u16 BE) | Window size |
16//! | server → client | `0x14` | — | Replay end |
17
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use axum::extract::ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade};
22use axum::extract::{Query, State};
23use axum::response::IntoResponse;
24use tokio::sync::broadcast::error::RecvError;
25
26use crate::session::{ScrollbackEvent, Session};
27use crate::terminal::Terminal;
28use crate::web::AppState;
29
30/// Client → Server: terminal input.
31const CMD_INPUT: u8 = 0x00;
32/// Client → Server: resize (4-byte payload: rows u16 BE, cols u16 BE).
33const CMD_RESIZE: u8 = 0x01;
34
35/// Server → Client: terminal output.
36const CMD_OUTPUT: u8 = 0x00;
37/// Server → Client: session UUID string.
38const CMD_SESSION_ID: u8 = 0x10;
39/// Server → Client: shell process exited.
40const CMD_SHELL_EXIT: u8 = 0x12;
41/// Server → Client: current PTY window size (4-byte payload: rows u16 BE, cols u16 BE).
42const CMD_WINDOW_SIZE: u8 = 0x13;
43/// Server → Client: end of scrollback replay.
44const CMD_REPLAY_END: u8 = 0x14;
45
46/// WebSocket close code: requested session not found.
47const CLOSE_SESSION_NOT_FOUND: u16 = 4404;
48
49/// Send a protocol frame (command byte + payload) over the WebSocket.
50async fn send_frame(socket: &mut WebSocket, cmd: u8, payload: &[u8]) -> Result<(), ()> {
51    let mut frame = Vec::with_capacity(1 + payload.len());
52    frame.push(cmd);
53    frame.extend_from_slice(payload);
54    socket
55        .send(Message::Binary(frame.into()))
56        .await
57        .map_err(|_| ())
58}
59
60/// Encode a window size as 4 big-endian bytes (rows, cols).
61fn encode_window_size(rows: u16, cols: u16) -> [u8; 4] {
62    let r = rows.to_be_bytes();
63    let c = cols.to_be_bytes();
64    [r[0], r[1], c[0], c[1]]
65}
66
67pub async fn ws_handler(
68    ws: WebSocketUpgrade,
69    Query(params): Query<HashMap<String, String>>,
70    State(state): State<AppState>,
71) -> impl IntoResponse {
72    let sid = params.get("sid").cloned();
73    let readonly = params.contains_key("view");
74    ws.on_upgrade(move |socket| handle_socket(socket, state, sid, readonly))
75}
76
77enum ResolveError {
78    NotFound(String),
79    Io(std::io::Error),
80}
81
82async fn handle_socket(
83    mut socket: WebSocket,
84    state: AppState,
85    sid: Option<String>,
86    readonly: bool,
87) {
88    // Resolve or create session
89    let session = match resolve_session(&state, sid.as_deref()) {
90        Ok(result) => result,
91        Err(ResolveError::NotFound(id)) => {
92            tracing::warn!("session {id} not found");
93            let _ = socket
94                .send(Message::Close(Some(CloseFrame {
95                    code: CLOSE_SESSION_NOT_FOUND,
96                    reason: "session not found".into(),
97                })))
98                .await;
99            return;
100        }
101        Err(ResolveError::Io(e)) => {
102            tracing::error!("failed to create session: {e}");
103            return;
104        }
105    };
106
107    handle_session(&mut socket, &session, readonly).await;
108}
109
110/// Drive the tty-web binary protocol on an already-resolved session.
111///
112/// Performs the full handshake (session ID → window size → scrollback replay →
113/// replay-end marker), then bridges WebSocket I/O with the terminal until the
114/// client disconnects or the shell exits. Calls [`Session::attach`] /
115/// [`Session::detach`] automatically.
116///
117/// This is the main building block for embedding tty-web in other applications
118/// that manage session creation themselves.
119pub async fn handle_session(socket: &mut WebSocket, session: &Arc<Session>, readonly: bool) {
120    // Handshake: session ID → window size → replay events → replay end
121    if send_frame(socket, CMD_SESSION_ID, session.id().as_bytes())
122        .await
123        .is_err()
124    {
125        return;
126    }
127
128    let (events, mut output_rx, mut window_size_rx) = session.attach();
129
130    let (rows, cols) = *window_size_rx.borrow_and_update();
131    if send_frame(socket, CMD_WINDOW_SIZE, &encode_window_size(rows, cols))
132        .await
133        .is_err()
134    {
135        session.detach();
136        return;
137    }
138
139    // Replay scrollback events
140    for event in &events {
141        let ok = match event {
142            ScrollbackEvent::Output(data) => send_frame(socket, CMD_OUTPUT, data).await.is_ok(),
143            ScrollbackEvent::WindowSize(r, c) => {
144                send_frame(socket, CMD_WINDOW_SIZE, &encode_window_size(*r, *c))
145                    .await
146                    .is_ok()
147            }
148        };
149        if !ok {
150            session.detach();
151            return;
152        }
153    }
154
155    if send_frame(socket, CMD_REPLAY_END, &[]).await.is_err() {
156        session.detach();
157        return;
158    }
159
160    // Main loop: bridge WebSocket ↔ session
161    let mut closed_rx = session.terminal.closed();
162    loop {
163        tokio::select! {
164            result = output_rx.recv() => {
165                match result {
166                    Ok(data) => {
167                        if send_frame(socket, CMD_OUTPUT, &data).await.is_err() {
168                            break;
169                        }
170                    }
171                    Err(RecvError::Lagged(n)) => {
172                        tracing::warn!("output lagged {n} messages");
173                        continue;
174                    }
175                    Err(RecvError::Closed) => break,
176                }
177            }
178            msg = socket.recv() => {
179                match msg {
180                    Some(Ok(Message::Binary(data))) => {
181                        if readonly || data.is_empty() {
182                            continue;
183                        }
184                        handle_client_message(session, &data).await;
185                    }
186                    Some(Ok(Message::Close(_))) | None => break,
187                    _ => {}
188                }
189            }
190            Ok(()) = window_size_rx.changed() => {
191                let (rows, cols) = *window_size_rx.borrow_and_update();
192                if send_frame(socket, CMD_WINDOW_SIZE, &encode_window_size(rows, cols)).await.is_err() {
193                    break;
194                }
195            }
196            _ = closed_rx.changed() => {
197                // Drain buffered output before sending exit
198                while let Ok(data) = output_rx.try_recv() {
199                    if send_frame(socket, CMD_OUTPUT, &data).await.is_err() {
200                        break;
201                    }
202                }
203                let _ = send_frame(socket, CMD_SHELL_EXIT, &[]).await;
204                break;
205            }
206        }
207    }
208    session.detach();
209}
210
211fn resolve_session(state: &AppState, sid: Option<&str>) -> Result<Arc<Session>, ResolveError> {
212    if let Some(sid) = sid {
213        return state
214            .sessions
215            .get(sid)
216            .inspect(|_| tracing::info!("reattaching to session {sid}"))
217            .ok_or_else(|| ResolveError::NotFound(sid.to_owned()));
218    }
219    let (terminal, output_rx) =
220        Terminal::spawn(&state.shell, state.pwd.as_deref()).map_err(ResolveError::Io)?;
221    let session = Session::new(
222        terminal,
223        output_rx,
224        state.scrollback_limit,
225        state.orphan_timeout,
226    );
227    tracing::info!("created new session {}", session.id());
228    state.sessions.insert(session.clone());
229    Ok(session)
230}
231
232#[derive(Debug, PartialEq)]
233enum ClientCommand<'a> {
234    Input(&'a [u8]),
235    Resize { rows: u16, cols: u16 },
236    Unknown(u8),
237}
238
239fn parse_client_message(data: &[u8]) -> Option<ClientCommand<'_>> {
240    let (&cmd, payload) = data.split_first()?;
241    match cmd {
242        CMD_INPUT => Some(ClientCommand::Input(payload)),
243        CMD_RESIZE if payload.len() >= 4 => {
244            let rows = u16::from_be_bytes([payload[0], payload[1]]);
245            let cols = u16::from_be_bytes([payload[2], payload[3]]);
246            Some(ClientCommand::Resize { rows, cols })
247        }
248        CMD_RESIZE => None,
249        other => Some(ClientCommand::Unknown(other)),
250    }
251}
252
253async fn handle_client_message(session: &Session, data: &[u8]) {
254    match parse_client_message(data) {
255        Some(ClientCommand::Input(payload)) => {
256            if let Err(e) = session.terminal.write(payload.to_vec()).await {
257                tracing::error!("write to terminal failed: {e}");
258            }
259        }
260        Some(ClientCommand::Resize { rows, cols }) => {
261            if let Err(e) = session.terminal.resize(rows, cols) {
262                tracing::error!("resize failed: {e}");
263            }
264            session.set_window_size(rows, cols);
265        }
266        Some(ClientCommand::Unknown(cmd)) => {
267            tracing::warn!("unknown command: 0x{cmd:02x}");
268        }
269        None => {}
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_parse_input() {
279        let data = [0x00, b'h', b'i'];
280        assert_eq!(
281            parse_client_message(&data),
282            Some(ClientCommand::Input(b"hi"))
283        );
284    }
285
286    #[test]
287    fn test_parse_resize() {
288        let data = [0x01, 0, 24, 0, 80];
289        assert_eq!(
290            parse_client_message(&data),
291            Some(ClientCommand::Resize { rows: 24, cols: 80 })
292        );
293    }
294
295    #[test]
296    fn test_parse_resize_too_short() {
297        let data = [0x01, 0, 24];
298        assert_eq!(parse_client_message(&data), None);
299    }
300
301    #[test]
302    fn test_parse_empty() {
303        assert_eq!(parse_client_message(&[]), None);
304    }
305
306    #[test]
307    fn test_parse_unknown() {
308        let data = [0xFF, 1];
309        assert_eq!(
310            parse_client_message(&data),
311            Some(ClientCommand::Unknown(0xFF))
312        );
313    }
314}