1use std::collections::HashMap;
19use std::sync::Arc;
20
21use axum::extract::ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade};
22use axum::extract::{Query, State};
23use axum::response::IntoResponse;
24use tokio::sync::broadcast::error::RecvError;
25
26use crate::session::{ScrollbackEvent, Session};
27use crate::terminal::Terminal;
28use crate::web::AppState;
29
30const CMD_INPUT: u8 = 0x00;
32const CMD_RESIZE: u8 = 0x01;
34
35const CMD_OUTPUT: u8 = 0x00;
37const CMD_SESSION_ID: u8 = 0x10;
39const CMD_SHELL_EXIT: u8 = 0x12;
41const CMD_WINDOW_SIZE: u8 = 0x13;
43const CMD_REPLAY_END: u8 = 0x14;
45
46const CLOSE_SESSION_NOT_FOUND: u16 = 4404;
48
49async fn send_frame(socket: &mut WebSocket, cmd: u8, payload: &[u8]) -> Result<(), ()> {
51 let mut frame = Vec::with_capacity(1 + payload.len());
52 frame.push(cmd);
53 frame.extend_from_slice(payload);
54 socket
55 .send(Message::Binary(frame.into()))
56 .await
57 .map_err(|_| ())
58}
59
60fn encode_window_size(rows: u16, cols: u16) -> [u8; 4] {
62 let r = rows.to_be_bytes();
63 let c = cols.to_be_bytes();
64 [r[0], r[1], c[0], c[1]]
65}
66
67pub async fn ws_handler(
68 ws: WebSocketUpgrade,
69 Query(params): Query<HashMap<String, String>>,
70 State(state): State<AppState>,
71) -> impl IntoResponse {
72 let sid = params.get("sid").cloned();
73 let readonly = params.contains_key("view");
74 ws.on_upgrade(move |socket| handle_socket(socket, state, sid, readonly))
75}
76
77enum ResolveError {
78 NotFound(String),
79 Io(std::io::Error),
80}
81
82async fn handle_socket(
83 mut socket: WebSocket,
84 state: AppState,
85 sid: Option<String>,
86 readonly: bool,
87) {
88 let (session, session_id) = match resolve_session(&state, sid.as_deref()) {
90 Ok(result) => result,
91 Err(ResolveError::NotFound(id)) => {
92 tracing::warn!("session {id} not found");
93 let _ = socket
94 .send(Message::Close(Some(CloseFrame {
95 code: CLOSE_SESSION_NOT_FOUND,
96 reason: "session not found".into(),
97 })))
98 .await;
99 return;
100 }
101 Err(ResolveError::Io(e)) => {
102 tracing::error!("failed to create session: {e}");
103 return;
104 }
105 };
106
107 if send_frame(&mut socket, CMD_SESSION_ID, session_id.as_bytes())
109 .await
110 .is_err()
111 {
112 return;
113 }
114
115 let (events, mut output_rx, mut window_size_rx) = session.attach();
116
117 let (rows, cols) = *window_size_rx.borrow_and_update();
118 if send_frame(
119 &mut socket,
120 CMD_WINDOW_SIZE,
121 &encode_window_size(rows, cols),
122 )
123 .await
124 .is_err()
125 {
126 session.detach();
127 return;
128 }
129
130 for event in &events {
132 let ok = match event {
133 ScrollbackEvent::Output(data) => {
134 send_frame(&mut socket, CMD_OUTPUT, data).await.is_ok()
135 }
136 ScrollbackEvent::WindowSize(r, c) => {
137 send_frame(&mut socket, CMD_WINDOW_SIZE, &encode_window_size(*r, *c))
138 .await
139 .is_ok()
140 }
141 };
142 if !ok {
143 session.detach();
144 return;
145 }
146 }
147
148 if send_frame(&mut socket, CMD_REPLAY_END, &[]).await.is_err() {
149 session.detach();
150 return;
151 }
152
153 let mut closed_rx = session.terminal.closed();
155 loop {
156 tokio::select! {
157 result = output_rx.recv() => {
158 match result {
159 Ok(data) => {
160 if send_frame(&mut socket, CMD_OUTPUT, &data).await.is_err() {
161 break;
162 }
163 }
164 Err(RecvError::Lagged(n)) => {
165 tracing::warn!("output lagged {n} messages");
166 continue;
167 }
168 Err(RecvError::Closed) => break,
169 }
170 }
171 msg = socket.recv() => {
172 match msg {
173 Some(Ok(Message::Binary(data))) => {
174 if readonly || data.is_empty() {
175 continue;
176 }
177 handle_client_message(&session, &data).await;
178 }
179 Some(Ok(Message::Close(_))) | None => break,
180 _ => {}
181 }
182 }
183 Ok(()) = window_size_rx.changed() => {
184 let (rows, cols) = *window_size_rx.borrow_and_update();
185 if send_frame(&mut socket, CMD_WINDOW_SIZE, &encode_window_size(rows, cols)).await.is_err() {
186 break;
187 }
188 }
189 _ = closed_rx.changed() => {
190 while let Ok(data) = output_rx.try_recv() {
192 if send_frame(&mut socket, CMD_OUTPUT, &data).await.is_err() {
193 break;
194 }
195 }
196 let _ = send_frame(&mut socket, CMD_SHELL_EXIT, &[]).await;
197 break;
198 }
199 }
200 }
201 session.detach();
202}
203
204fn resolve_session(
205 state: &AppState,
206 sid: Option<&str>,
207) -> Result<(Arc<Session>, String), ResolveError> {
208 if let Some(sid) = sid {
209 return state
210 .sessions
211 .get(sid)
212 .map(|session| {
213 tracing::info!("reattaching to session {sid}");
214 (session, sid.to_owned())
215 })
216 .ok_or_else(|| ResolveError::NotFound(sid.to_owned()));
217 }
218 let (terminal, output_rx) =
219 Terminal::spawn(&state.shell, state.pwd.as_deref()).map_err(ResolveError::Io)?;
220 let session = Session::new(terminal, output_rx, state.scrollback_limit);
221 let id = state.sessions.insert(session.clone());
222 tracing::info!("created new session {id}");
223 Ok((session, id))
224}
225
226#[derive(Debug, PartialEq)]
227enum ClientCommand<'a> {
228 Input(&'a [u8]),
229 Resize { rows: u16, cols: u16 },
230 Unknown(u8),
231}
232
233fn parse_client_message(data: &[u8]) -> Option<ClientCommand<'_>> {
234 let (&cmd, payload) = data.split_first()?;
235 match cmd {
236 CMD_INPUT => Some(ClientCommand::Input(payload)),
237 CMD_RESIZE if payload.len() >= 4 => {
238 let rows = u16::from_be_bytes([payload[0], payload[1]]);
239 let cols = u16::from_be_bytes([payload[2], payload[3]]);
240 Some(ClientCommand::Resize { rows, cols })
241 }
242 CMD_RESIZE => None,
243 other => Some(ClientCommand::Unknown(other)),
244 }
245}
246
247async fn handle_client_message(session: &Session, data: &[u8]) {
248 match parse_client_message(data) {
249 Some(ClientCommand::Input(payload)) => {
250 if let Err(e) = session.terminal.write(payload.to_vec()).await {
251 tracing::error!("write to terminal failed: {e}");
252 }
253 }
254 Some(ClientCommand::Resize { rows, cols }) => {
255 if let Err(e) = session.terminal.resize(rows, cols) {
256 tracing::error!("resize failed: {e}");
257 }
258 session.set_window_size(rows, cols);
259 }
260 Some(ClientCommand::Unknown(cmd)) => {
261 tracing::warn!("unknown command: 0x{cmd:02x}");
262 }
263 None => {}
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_parse_input() {
273 let data = [0x00, b'h', b'i'];
274 assert_eq!(
275 parse_client_message(&data),
276 Some(ClientCommand::Input(b"hi"))
277 );
278 }
279
280 #[test]
281 fn test_parse_resize() {
282 let data = [0x01, 0, 24, 0, 80];
283 assert_eq!(
284 parse_client_message(&data),
285 Some(ClientCommand::Resize { rows: 24, cols: 80 })
286 );
287 }
288
289 #[test]
290 fn test_parse_resize_too_short() {
291 let data = [0x01, 0, 24];
292 assert_eq!(parse_client_message(&data), None);
293 }
294
295 #[test]
296 fn test_parse_empty() {
297 assert_eq!(parse_client_message(&[]), None);
298 }
299
300 #[test]
301 fn test_parse_unknown() {
302 let data = [0xFF, 1];
303 assert_eq!(
304 parse_client_message(&data),
305 Some(ClientCommand::Unknown(0xFF))
306 );
307 }
308}