1use std::collections::HashMap;
12use std::collections::VecDeque;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::{Arc, Mutex, RwLock, Weak};
15use std::time::Instant;
16
17use tokio::sync::{broadcast, watch};
18
19use crate::terminal::Terminal;
20
21const ORPHAN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
23
24pub type AttachResult = (
27 Vec<ScrollbackEvent>,
28 broadcast::Receiver<Vec<u8>>,
29 watch::Receiver<(u16, u16)>,
30);
31
32#[derive(Clone, Debug, PartialEq)]
37pub enum ScrollbackEvent {
38 Output(Vec<u8>),
40 WindowSize(u16, u16),
42}
43
44impl ScrollbackEvent {
45 fn byte_cost(&self) -> usize {
47 match self {
48 Self::Output(data) => data.len(),
49 Self::WindowSize(_, _) => 4,
50 }
51 }
52}
53
54pub struct Session {
59 pub terminal: Terminal,
60 scrollback: Mutex<VecDeque<ScrollbackEvent>>,
61 scrollback_bytes: Mutex<usize>,
62 scrollback_limit: usize,
63 clients: AtomicUsize,
64 detached_at: Mutex<Option<Instant>>,
65 window_size: watch::Sender<(u16, u16)>,
66}
67
68impl Session {
69 pub fn new(
71 terminal: Terminal,
72 output_rx: broadcast::Receiver<Vec<u8>>,
73 scrollback_limit: usize,
74 ) -> Arc<Self> {
75 let (ws_tx, _) = watch::channel((24, 80));
76 let session = Arc::new(Self {
77 terminal,
78 scrollback: Mutex::new(VecDeque::new()),
79 scrollback_bytes: Mutex::new(0),
80 scrollback_limit,
81 clients: AtomicUsize::new(0),
82 detached_at: Mutex::new(None),
83 window_size: ws_tx,
84 });
85
86 let weak: Weak<Session> = Arc::downgrade(&session);
88 let mut rx = output_rx;
89 tokio::spawn(async move {
90 loop {
91 match rx.recv().await {
92 Ok(data) => {
93 let Some(s) = weak.upgrade() else {
94 break;
95 };
96 s.push_scrollback(ScrollbackEvent::Output(data));
97 }
98 Err(broadcast::error::RecvError::Lagged(_)) => {
99 continue;
100 }
101 Err(broadcast::error::RecvError::Closed) => break,
102 }
103 }
104 });
105
106 session
107 }
108
109 fn push_scrollback(&self, event: ScrollbackEvent) {
112 let cost = event.byte_cost();
113 let mut sb = self.scrollback.lock().unwrap();
114 let mut bytes = self.scrollback_bytes.lock().unwrap();
115 *bytes += cost;
116 sb.push_back(event);
117 while *bytes > self.scrollback_limit {
118 if let Some(old) = sb.pop_front() {
119 *bytes -= old.byte_cost();
120 } else {
121 break;
122 }
123 }
124 }
125
126 pub fn attach(&self) -> AttachResult {
130 self.clients.fetch_add(1, Ordering::Relaxed);
131 *self.detached_at.lock().unwrap() = None;
132 let sb = self.scrollback.lock().unwrap();
133 let rx = self.terminal.subscribe();
134 let ws_rx = self.window_size.subscribe();
135 let events: Vec<ScrollbackEvent> = sb.iter().cloned().collect();
136 (events, rx, ws_rx)
137 }
138
139 pub fn set_window_size(&self, rows: u16, cols: u16) {
142 let _ = self.window_size.send((rows, cols));
143 self.push_scrollback(ScrollbackEvent::WindowSize(rows, cols));
144 }
145
146 pub fn detach(&self) {
148 if self.clients.fetch_sub(1, Ordering::Relaxed) == 1 {
149 *self.detached_at.lock().unwrap() = Some(Instant::now());
150 }
151 }
152
153 fn is_orphaned(&self) -> bool {
154 self.clients.load(Ordering::Relaxed) == 0
155 && self
156 .detached_at
157 .lock()
158 .unwrap()
159 .is_some_and(|t| t.elapsed() >= ORPHAN_TIMEOUT)
160 }
161}
162
163pub struct SessionStore {
165 sessions: RwLock<HashMap<String, Arc<Session>>>,
166}
167
168impl SessionStore {
169 pub fn new() -> Arc<Self> {
171 Arc::new(Self {
172 sessions: RwLock::new(HashMap::new()),
173 })
174 }
175
176 pub fn insert(self: &Arc<Self>, session: Arc<Session>) -> String {
180 let id = uuid::Uuid::new_v4().to_string();
181 self.sessions
182 .write()
183 .unwrap()
184 .insert(id.clone(), session.clone());
185
186 let store = Arc::downgrade(self);
188 let sid = id.clone();
189 let closed_rx = session.terminal.closed();
190 tokio::spawn(async move {
191 loop {
192 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
193 let Some(store) = store.upgrade() else { return };
194 let should_remove = {
195 let sessions = store.sessions.read().unwrap();
196 match sessions.get(&sid) {
197 Some(s) => {
198 s.is_orphaned()
199 || (*closed_rx.borrow() && s.clients.load(Ordering::Relaxed) == 0)
200 }
201 None => return,
202 }
203 };
204 if should_remove {
205 store.sessions.write().unwrap().remove(&sid);
206 tracing::info!("removed session {sid}");
207 return;
208 }
209 }
210 });
211
212 id
213 }
214
215 pub fn get(&self, id: &str) -> Option<Arc<Session>> {
217 self.sessions.read().unwrap().get(id).cloned()
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 const TEST_SCROLLBACK_LIMIT: usize = 256 * 1024;
226
227 fn spawn_session() -> Arc<Session> {
228 let (terminal, output_rx) = Terminal::spawn("/bin/sh", None).expect("spawn /bin/sh");
229 Session::new(terminal, output_rx, TEST_SCROLLBACK_LIMIT)
230 }
231
232 #[tokio::test]
233 async fn test_attach_detach_clients() {
234 let session = spawn_session();
235
236 let (_sb1, _rx1, _ws1) = session.attach();
237 assert_eq!(session.clients.load(Ordering::Relaxed), 1);
238
239 let (_sb2, _rx2, _ws2) = session.attach();
240 assert_eq!(session.clients.load(Ordering::Relaxed), 2);
241
242 session.detach();
243 assert_eq!(session.clients.load(Ordering::Relaxed), 1);
244 }
245
246 #[tokio::test]
247 async fn test_not_orphaned_with_clients() {
248 let session = spawn_session();
249 let (_sb, _rx, _ws) = session.attach();
250 assert!(!session.is_orphaned());
251 }
252
253 #[tokio::test]
254 async fn test_not_orphaned_immediately_after_detach() {
255 let session = spawn_session();
256 let (_sb, _rx, _ws) = session.attach();
257 session.detach();
258 assert!(!session.is_orphaned());
259 }
260
261 #[tokio::test]
262 async fn test_orphaned_after_timeout() {
263 let session = spawn_session();
264 let (_sb, _rx, _ws) = session.attach();
265 session.detach();
266 *session.detached_at.lock().unwrap() =
267 Some(Instant::now() - ORPHAN_TIMEOUT - std::time::Duration::from_secs(1));
268 assert!(session.is_orphaned());
269 }
270
271 #[tokio::test]
272 async fn test_scrollback_captures_output() {
273 let session = spawn_session();
274
275 session
276 .terminal
277 .write(b"echo scrollback_test_marker\n".to_vec())
278 .await
279 .unwrap();
280
281 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
282
283 let (events, _rx, _ws) = session.attach();
284 let has_marker = events.iter().any(|e| match e {
285 ScrollbackEvent::Output(data) => {
286 String::from_utf8_lossy(data).contains("scrollback_test_marker")
287 }
288 _ => false,
289 });
290 assert!(has_marker, "scrollback should contain Output with marker");
291 }
292
293 #[tokio::test]
294 async fn test_session_store_insert_and_get() {
295 let store = SessionStore::new();
296 let session = spawn_session();
297 let id = store.insert(session);
298
299 assert!(store.get(&id).is_some());
300 assert!(store.get("nonexistent").is_none());
301 }
302
303 #[tokio::test]
304 async fn test_scrollback_eviction_removes_whole_events() {
305 let (terminal, output_rx) = Terminal::spawn("/bin/sh", None).expect("spawn");
306 let session = Session::new(terminal, output_rx, 10);
307
308 session.push_scrollback(ScrollbackEvent::Output(b"aaaaa".to_vec())); session.push_scrollback(ScrollbackEvent::Output(b"bbbbb".to_vec())); session.push_scrollback(ScrollbackEvent::Output(b"ccc".to_vec())); let sb = session.scrollback.lock().unwrap();
313 let bytes = *session.scrollback_bytes.lock().unwrap();
314 assert!(bytes <= 10, "bytes {bytes} should be within limit");
315 assert!(
316 sb.iter().all(|e| matches!(e, ScrollbackEvent::Output(_))),
317 "all events should be Output"
318 );
319 assert_ne!(
320 sb.front(),
321 Some(&ScrollbackEvent::Output(b"aaaaa".to_vec())),
322 "oldest event should have been evicted"
323 );
324 }
325
326 #[tokio::test]
327 async fn test_set_window_size_records_event() {
328 let (terminal, output_rx) = Terminal::spawn("/bin/sh", None).expect("spawn");
329 let session = Session::new(terminal, output_rx, TEST_SCROLLBACK_LIMIT);
330
331 session.set_window_size(40, 120);
332
333 let sb = session.scrollback.lock().unwrap();
334 let has_ws = sb
335 .iter()
336 .any(|e| matches!(e, ScrollbackEvent::WindowSize(40, 120)));
337 assert!(has_ws, "scrollback should contain WindowSize(40, 120)");
338 }
339}