tty_web/web/
ws.rs

1//! WebSocket handler implementing the tty-web binary protocol.
2//!
3//! # Wire protocol
4//!
5//! All WebSocket messages are **binary frames**. The first byte is the command,
6//! the rest is the payload.
7//!
8//! | Direction | Cmd | Payload | Description |
9//! |-----------|------|---------|-------------|
10//! | client → server | `0x00` | raw bytes | Terminal input |
11//! | client → server | `0x01` | rows(u16 BE) + cols(u16 BE) | Resize |
12//! | server → client | `0x00` | raw bytes | Terminal output |
13//! | server → client | `0x10` | UUID string | Session ID |
14//! | server → client | `0x12` | — | Shell exited |
15//! | server → client | `0x13` | rows(u16 BE) + cols(u16 BE) | Window size |
16//! | server → client | `0x14` | — | Replay end |
17
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use axum::extract::ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade};
22use axum::extract::{Query, State};
23use axum::response::IntoResponse;
24use tokio::sync::broadcast::error::RecvError;
25
26use crate::session::{ScrollbackEvent, Session};
27use crate::terminal::Terminal;
28use crate::web::AppState;
29
30/// Client → Server: terminal input.
31const CMD_INPUT: u8 = 0x00;
32/// Client → Server: resize (4-byte payload: rows u16 BE, cols u16 BE).
33const CMD_RESIZE: u8 = 0x01;
34
35/// Server → Client: terminal output.
36const CMD_OUTPUT: u8 = 0x00;
37/// Server → Client: session UUID string.
38const CMD_SESSION_ID: u8 = 0x10;
39/// Server → Client: shell process exited.
40const CMD_SHELL_EXIT: u8 = 0x12;
41/// Server → Client: current PTY window size (4-byte payload: rows u16 BE, cols u16 BE).
42const CMD_WINDOW_SIZE: u8 = 0x13;
43/// Server → Client: end of scrollback replay.
44const CMD_REPLAY_END: u8 = 0x14;
45
46/// WebSocket close code: requested session not found.
47const CLOSE_SESSION_NOT_FOUND: u16 = 4404;
48
49/// Send a protocol frame (command byte + payload) over the WebSocket.
50async fn send_frame(socket: &mut WebSocket, cmd: u8, payload: &[u8]) -> Result<(), ()> {
51    let mut frame = Vec::with_capacity(1 + payload.len());
52    frame.push(cmd);
53    frame.extend_from_slice(payload);
54    socket
55        .send(Message::Binary(frame.into()))
56        .await
57        .map_err(|_| ())
58}
59
60/// Encode a window size as 4 big-endian bytes (rows, cols).
61fn encode_window_size(rows: u16, cols: u16) -> [u8; 4] {
62    let r = rows.to_be_bytes();
63    let c = cols.to_be_bytes();
64    [r[0], r[1], c[0], c[1]]
65}
66
67pub async fn ws_handler(
68    ws: WebSocketUpgrade,
69    Query(params): Query<HashMap<String, String>>,
70    State(state): State<AppState>,
71) -> impl IntoResponse {
72    let sid = params.get("sid").cloned();
73    let readonly = params.contains_key("view");
74    ws.on_upgrade(move |socket| handle_socket(socket, state, sid, readonly))
75}
76
77enum ResolveError {
78    NotFound(String),
79    Io(std::io::Error),
80}
81
82async fn handle_socket(
83    mut socket: WebSocket,
84    state: AppState,
85    sid: Option<String>,
86    readonly: bool,
87) {
88    // Resolve or create session
89    let (session, session_id) = match resolve_session(&state, sid.as_deref()) {
90        Ok(result) => result,
91        Err(ResolveError::NotFound(id)) => {
92            tracing::warn!("session {id} not found");
93            let _ = socket
94                .send(Message::Close(Some(CloseFrame {
95                    code: CLOSE_SESSION_NOT_FOUND,
96                    reason: "session not found".into(),
97                })))
98                .await;
99            return;
100        }
101        Err(ResolveError::Io(e)) => {
102            tracing::error!("failed to create session: {e}");
103            return;
104        }
105    };
106
107    // Handshake: session ID → window size → replay events → replay end
108    if send_frame(&mut socket, CMD_SESSION_ID, session_id.as_bytes())
109        .await
110        .is_err()
111    {
112        return;
113    }
114
115    let (events, mut output_rx, mut window_size_rx) = session.attach();
116
117    let (rows, cols) = *window_size_rx.borrow_and_update();
118    if send_frame(
119        &mut socket,
120        CMD_WINDOW_SIZE,
121        &encode_window_size(rows, cols),
122    )
123    .await
124    .is_err()
125    {
126        session.detach();
127        return;
128    }
129
130    // Replay scrollback events
131    for event in &events {
132        let ok = match event {
133            ScrollbackEvent::Output(data) => {
134                send_frame(&mut socket, CMD_OUTPUT, data).await.is_ok()
135            }
136            ScrollbackEvent::WindowSize(r, c) => {
137                send_frame(&mut socket, CMD_WINDOW_SIZE, &encode_window_size(*r, *c))
138                    .await
139                    .is_ok()
140            }
141        };
142        if !ok {
143            session.detach();
144            return;
145        }
146    }
147
148    if send_frame(&mut socket, CMD_REPLAY_END, &[]).await.is_err() {
149        session.detach();
150        return;
151    }
152
153    // Main loop: bridge WebSocket ↔ session
154    let mut closed_rx = session.terminal.closed();
155    loop {
156        tokio::select! {
157            result = output_rx.recv() => {
158                match result {
159                    Ok(data) => {
160                        if send_frame(&mut socket, CMD_OUTPUT, &data).await.is_err() {
161                            break;
162                        }
163                    }
164                    Err(RecvError::Lagged(n)) => {
165                        tracing::warn!("output lagged {n} messages");
166                        continue;
167                    }
168                    Err(RecvError::Closed) => break,
169                }
170            }
171            msg = socket.recv() => {
172                match msg {
173                    Some(Ok(Message::Binary(data))) => {
174                        if readonly || data.is_empty() {
175                            continue;
176                        }
177                        handle_client_message(&session, &data).await;
178                    }
179                    Some(Ok(Message::Close(_))) | None => break,
180                    _ => {}
181                }
182            }
183            Ok(()) = window_size_rx.changed() => {
184                let (rows, cols) = *window_size_rx.borrow_and_update();
185                if send_frame(&mut socket, CMD_WINDOW_SIZE, &encode_window_size(rows, cols)).await.is_err() {
186                    break;
187                }
188            }
189            _ = closed_rx.changed() => {
190                // Drain buffered output before sending exit
191                while let Ok(data) = output_rx.try_recv() {
192                    if send_frame(&mut socket, CMD_OUTPUT, &data).await.is_err() {
193                        break;
194                    }
195                }
196                let _ = send_frame(&mut socket, CMD_SHELL_EXIT, &[]).await;
197                break;
198            }
199        }
200    }
201    session.detach();
202}
203
204fn resolve_session(
205    state: &AppState,
206    sid: Option<&str>,
207) -> Result<(Arc<Session>, String), ResolveError> {
208    if let Some(sid) = sid {
209        return state
210            .sessions
211            .get(sid)
212            .map(|session| {
213                tracing::info!("reattaching to session {sid}");
214                (session, sid.to_owned())
215            })
216            .ok_or_else(|| ResolveError::NotFound(sid.to_owned()));
217    }
218    let (terminal, output_rx) =
219        Terminal::spawn(&state.shell, state.pwd.as_deref()).map_err(ResolveError::Io)?;
220    let session = Session::new(terminal, output_rx, state.scrollback_limit);
221    let id = state.sessions.insert(session.clone());
222    tracing::info!("created new session {id}");
223    Ok((session, id))
224}
225
226#[derive(Debug, PartialEq)]
227enum ClientCommand<'a> {
228    Input(&'a [u8]),
229    Resize { rows: u16, cols: u16 },
230    Unknown(u8),
231}
232
233fn parse_client_message(data: &[u8]) -> Option<ClientCommand<'_>> {
234    let (&cmd, payload) = data.split_first()?;
235    match cmd {
236        CMD_INPUT => Some(ClientCommand::Input(payload)),
237        CMD_RESIZE if payload.len() >= 4 => {
238            let rows = u16::from_be_bytes([payload[0], payload[1]]);
239            let cols = u16::from_be_bytes([payload[2], payload[3]]);
240            Some(ClientCommand::Resize { rows, cols })
241        }
242        CMD_RESIZE => None,
243        other => Some(ClientCommand::Unknown(other)),
244    }
245}
246
247async fn handle_client_message(session: &Session, data: &[u8]) {
248    match parse_client_message(data) {
249        Some(ClientCommand::Input(payload)) => {
250            if let Err(e) = session.terminal.write(payload.to_vec()).await {
251                tracing::error!("write to terminal failed: {e}");
252            }
253        }
254        Some(ClientCommand::Resize { rows, cols }) => {
255            if let Err(e) = session.terminal.resize(rows, cols) {
256                tracing::error!("resize failed: {e}");
257            }
258            session.set_window_size(rows, cols);
259        }
260        Some(ClientCommand::Unknown(cmd)) => {
261            tracing::warn!("unknown command: 0x{cmd:02x}");
262        }
263        None => {}
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_parse_input() {
273        let data = [0x00, b'h', b'i'];
274        assert_eq!(
275            parse_client_message(&data),
276            Some(ClientCommand::Input(b"hi"))
277        );
278    }
279
280    #[test]
281    fn test_parse_resize() {
282        let data = [0x01, 0, 24, 0, 80];
283        assert_eq!(
284            parse_client_message(&data),
285            Some(ClientCommand::Resize { rows: 24, cols: 80 })
286        );
287    }
288
289    #[test]
290    fn test_parse_resize_too_short() {
291        let data = [0x01, 0, 24];
292        assert_eq!(parse_client_message(&data), None);
293    }
294
295    #[test]
296    fn test_parse_empty() {
297        assert_eq!(parse_client_message(&[]), None);
298    }
299
300    #[test]
301    fn test_parse_unknown() {
302        let data = [0xFF, 1];
303        assert_eq!(
304            parse_client_message(&data),
305            Some(ClientCommand::Unknown(0xFF))
306        );
307    }
308}