1use std::collections::HashMap;
19use std::sync::Arc;
20
21use axum::extract::ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade};
22use axum::extract::{Query, State};
23use axum::response::IntoResponse;
24use tokio::sync::broadcast::error::RecvError;
25
26use crate::session::{ScrollbackEvent, Session};
27use crate::terminal::Terminal;
28use crate::web::AppState;
29
30const CMD_INPUT: u8 = 0x00;
32const CMD_RESIZE: u8 = 0x01;
34
35const CMD_OUTPUT: u8 = 0x00;
37const CMD_SESSION_ID: u8 = 0x10;
39const CMD_SHELL_EXIT: u8 = 0x12;
41const CMD_WINDOW_SIZE: u8 = 0x13;
43const CMD_REPLAY_END: u8 = 0x14;
45
46const CLOSE_SESSION_NOT_FOUND: u16 = 4404;
48
49async fn send_frame(socket: &mut WebSocket, cmd: u8, payload: &[u8]) -> Result<(), ()> {
51 let mut frame = Vec::with_capacity(1 + payload.len());
52 frame.push(cmd);
53 frame.extend_from_slice(payload);
54 socket
55 .send(Message::Binary(frame.into()))
56 .await
57 .map_err(|_| ())
58}
59
60fn encode_window_size(rows: u16, cols: u16) -> [u8; 4] {
62 let r = rows.to_be_bytes();
63 let c = cols.to_be_bytes();
64 [r[0], r[1], c[0], c[1]]
65}
66
67pub async fn ws_handler(
68 ws: WebSocketUpgrade,
69 Query(params): Query<HashMap<String, String>>,
70 State(state): State<AppState>,
71) -> impl IntoResponse {
72 let sid = params.get("sid").cloned();
73 let readonly = params.contains_key("view");
74 ws.on_upgrade(move |socket| handle_socket(socket, state, sid, readonly))
75}
76
77enum ResolveError {
78 NotFound(String),
79 Io(std::io::Error),
80}
81
82async fn handle_socket(
83 mut socket: WebSocket,
84 state: AppState,
85 sid: Option<String>,
86 readonly: bool,
87) {
88 let session = match resolve_session(&state, sid.as_deref()) {
90 Ok(result) => result,
91 Err(ResolveError::NotFound(id)) => {
92 tracing::warn!("session {id} not found");
93 let _ = socket
94 .send(Message::Close(Some(CloseFrame {
95 code: CLOSE_SESSION_NOT_FOUND,
96 reason: "session not found".into(),
97 })))
98 .await;
99 return;
100 }
101 Err(ResolveError::Io(e)) => {
102 tracing::error!("failed to create session: {e}");
103 return;
104 }
105 };
106
107 handle_session(&mut socket, &session, readonly).await;
108}
109
110pub async fn handle_session(socket: &mut WebSocket, session: &Arc<Session>, readonly: bool) {
120 if send_frame(socket, CMD_SESSION_ID, session.id().as_bytes())
122 .await
123 .is_err()
124 {
125 return;
126 }
127
128 let (events, mut output_rx, mut window_size_rx) = session.attach();
129
130 let (rows, cols) = *window_size_rx.borrow_and_update();
131 if send_frame(socket, CMD_WINDOW_SIZE, &encode_window_size(rows, cols))
132 .await
133 .is_err()
134 {
135 session.detach();
136 return;
137 }
138
139 for event in &events {
141 let ok = match event {
142 ScrollbackEvent::Output(data) => send_frame(socket, CMD_OUTPUT, data).await.is_ok(),
143 ScrollbackEvent::WindowSize(r, c) => {
144 send_frame(socket, CMD_WINDOW_SIZE, &encode_window_size(*r, *c))
145 .await
146 .is_ok()
147 }
148 };
149 if !ok {
150 session.detach();
151 return;
152 }
153 }
154
155 if send_frame(socket, CMD_REPLAY_END, &[]).await.is_err() {
156 session.detach();
157 return;
158 }
159
160 let mut closed_rx = session.terminal.closed();
162 loop {
163 tokio::select! {
164 result = output_rx.recv() => {
165 match result {
166 Ok(data) => {
167 if send_frame(socket, CMD_OUTPUT, &data).await.is_err() {
168 break;
169 }
170 }
171 Err(RecvError::Lagged(n)) => {
172 tracing::warn!("output lagged {n} messages");
173 continue;
174 }
175 Err(RecvError::Closed) => break,
176 }
177 }
178 msg = socket.recv() => {
179 match msg {
180 Some(Ok(Message::Binary(data))) => {
181 if readonly || data.is_empty() {
182 continue;
183 }
184 handle_client_message(session, &data).await;
185 }
186 Some(Ok(Message::Close(_))) | None => break,
187 _ => {}
188 }
189 }
190 Ok(()) = window_size_rx.changed() => {
191 let (rows, cols) = *window_size_rx.borrow_and_update();
192 if send_frame(socket, CMD_WINDOW_SIZE, &encode_window_size(rows, cols)).await.is_err() {
193 break;
194 }
195 }
196 _ = closed_rx.changed() => {
197 while let Ok(data) = output_rx.try_recv() {
199 if send_frame(socket, CMD_OUTPUT, &data).await.is_err() {
200 break;
201 }
202 }
203 let _ = send_frame(socket, CMD_SHELL_EXIT, &[]).await;
204 break;
205 }
206 }
207 }
208 session.detach();
209}
210
211fn resolve_session(state: &AppState, sid: Option<&str>) -> Result<Arc<Session>, ResolveError> {
212 if let Some(sid) = sid {
213 return state
214 .sessions
215 .get(sid)
216 .inspect(|_| tracing::info!("reattaching to session {sid}"))
217 .ok_or_else(|| ResolveError::NotFound(sid.to_owned()));
218 }
219 let (terminal, output_rx) =
220 Terminal::spawn(&state.shell, state.pwd.as_deref()).map_err(ResolveError::Io)?;
221 let session = Session::new(
222 terminal,
223 output_rx,
224 state.scrollback_limit,
225 state.orphan_timeout,
226 );
227 tracing::info!("created new session {}", session.id());
228 state.sessions.insert(session.clone());
229 Ok(session)
230}
231
232#[derive(Debug, PartialEq)]
233enum ClientCommand<'a> {
234 Input(&'a [u8]),
235 Resize { rows: u16, cols: u16 },
236 Unknown(u8),
237}
238
239fn parse_client_message(data: &[u8]) -> Option<ClientCommand<'_>> {
240 let (&cmd, payload) = data.split_first()?;
241 match cmd {
242 CMD_INPUT => Some(ClientCommand::Input(payload)),
243 CMD_RESIZE if payload.len() >= 4 => {
244 let rows = u16::from_be_bytes([payload[0], payload[1]]);
245 let cols = u16::from_be_bytes([payload[2], payload[3]]);
246 Some(ClientCommand::Resize { rows, cols })
247 }
248 CMD_RESIZE => None,
249 other => Some(ClientCommand::Unknown(other)),
250 }
251}
252
253async fn handle_client_message(session: &Session, data: &[u8]) {
254 match parse_client_message(data) {
255 Some(ClientCommand::Input(payload)) => {
256 if let Err(e) = session.terminal.write(payload.to_vec()).await {
257 tracing::error!("write to terminal failed: {e}");
258 }
259 }
260 Some(ClientCommand::Resize { rows, cols }) => {
261 if let Err(e) = session.terminal.resize(rows, cols) {
262 tracing::error!("resize failed: {e}");
263 }
264 session.set_window_size(rows, cols);
265 }
266 Some(ClientCommand::Unknown(cmd)) => {
267 tracing::warn!("unknown command: 0x{cmd:02x}");
268 }
269 None => {}
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_parse_input() {
279 let data = [0x00, b'h', b'i'];
280 assert_eq!(
281 parse_client_message(&data),
282 Some(ClientCommand::Input(b"hi"))
283 );
284 }
285
286 #[test]
287 fn test_parse_resize() {
288 let data = [0x01, 0, 24, 0, 80];
289 assert_eq!(
290 parse_client_message(&data),
291 Some(ClientCommand::Resize { rows: 24, cols: 80 })
292 );
293 }
294
295 #[test]
296 fn test_parse_resize_too_short() {
297 let data = [0x01, 0, 24];
298 assert_eq!(parse_client_message(&data), None);
299 }
300
301 #[test]
302 fn test_parse_empty() {
303 assert_eq!(parse_client_message(&[]), None);
304 }
305
306 #[test]
307 fn test_parse_unknown() {
308 let data = [0xFF, 1];
309 assert_eq!(
310 parse_client_message(&data),
311 Some(ClientCommand::Unknown(0xFF))
312 );
313 }
314}