tty_web/
terminal.rs

1//! High-level terminal abstraction over a PTY.
2//!
3//! [`Terminal`] owns a [`PtyMaster`] and drives async
4//! read/write loops via tokio. Output is fanned out through a broadcast channel
5//! so multiple subscribers (WebSocket clients) can receive the same stream.
6
7use std::path::Path;
8use std::process::Child;
9use std::sync::{Arc, Mutex};
10
11use nix::sys::signal::{self, Signal};
12use nix::unistd::Pid;
13use tokio::io::Interest;
14use tokio::io::unix::AsyncFd;
15use tokio::sync::{broadcast, mpsc, watch};
16
17use crate::pty::PtyMaster;
18
19const OUTPUT_CHANNEL_SIZE: usize = 64;
20const INPUT_CHANNEL_SIZE: usize = 256;
21const READ_BUF_SIZE: usize = 4096;
22
23/// Async terminal backed by a real PTY.
24///
25/// Spawns two background tasks (read loop and write loop) that bridge the PTY
26/// fd with tokio channels. Sends `SIGHUP` to the child process on drop.
27pub struct Terminal {
28    input_tx: mpsc::Sender<Vec<u8>>,
29    output_tx: broadcast::Sender<Vec<u8>>,
30    fd: Arc<AsyncFd<std::os::fd::OwnedFd>>,
31    child: Mutex<Option<Child>>,
32    closed_rx: watch::Receiver<bool>,
33}
34
35impl Terminal {
36    /// Spawn a shell process and return the terminal plus an initial output
37    /// receiver.
38    ///
39    /// If `pwd` is provided, the shell starts in that directory.
40    pub fn spawn(
41        shell: &str,
42        pwd: Option<&Path>,
43    ) -> std::io::Result<(Self, broadcast::Receiver<Vec<u8>>)> {
44        let PtyMaster { master, mut child } = PtyMaster::spawn(shell, pwd)?;
45
46        let async_fd = match AsyncFd::with_interest(master, Interest::READABLE | Interest::WRITABLE)
47        {
48            Ok(fd) => fd,
49            Err(e) => {
50                let _ = child.kill();
51                let _ = child.wait();
52                return Err(e);
53            }
54        };
55        let fd = Arc::new(async_fd);
56
57        let (input_tx, input_rx) = mpsc::channel(INPUT_CHANNEL_SIZE);
58        let (output_tx, output_rx) = broadcast::channel(OUTPUT_CHANNEL_SIZE);
59        let (closed_tx, closed_rx) = watch::channel(false);
60
61        let read_fd = fd.clone();
62        let read_tx = output_tx.clone();
63        tokio::spawn(async move {
64            read_loop(read_fd, read_tx).await;
65            let _ = closed_tx.send(true);
66        });
67
68        let write_fd = fd.clone();
69        tokio::spawn(async move {
70            write_loop(write_fd, input_rx).await;
71        });
72
73        let terminal = Terminal {
74            input_tx,
75            output_tx,
76            fd,
77            child: Mutex::new(Some(child)),
78            closed_rx,
79        };
80        Ok((terminal, output_rx))
81    }
82
83    /// Subscribe to the terminal output broadcast channel.
84    pub fn subscribe(&self) -> broadcast::Receiver<Vec<u8>> {
85        self.output_tx.subscribe()
86    }
87
88    /// Returns a watch receiver that becomes `true` when the PTY read
89    /// loop exits (shell died / PTY closed).
90    pub fn closed(&self) -> watch::Receiver<bool> {
91        self.closed_rx.clone()
92    }
93
94    /// Queue bytes to be written to the PTY.
95    pub async fn write(&self, data: Vec<u8>) -> Result<(), String> {
96        self.input_tx.send(data).await.map_err(|e| e.to_string())
97    }
98
99    /// Set the PTY window size (rows x cols).
100    pub fn resize(&self, rows: u16, cols: u16) -> std::io::Result<()> {
101        crate::pty::set_window_size(&*self.fd, rows, cols)
102    }
103}
104
105impl Drop for Terminal {
106    fn drop(&mut self) {
107        if let Some(mut child) = self.child.get_mut().unwrap().take() {
108            let pid = Pid::from_raw(child.id() as i32);
109            let _ = signal::kill(pid, Signal::SIGHUP);
110            // Reap the child on a dedicated OS thread so we never block
111            // the tokio runtime (which would deadlock current_thread tests
112            // and stall multi_thread ones).
113            std::thread::spawn(move || {
114                let _ = child.wait();
115            });
116        }
117    }
118}
119
120async fn read_loop(fd: Arc<AsyncFd<std::os::fd::OwnedFd>>, tx: broadcast::Sender<Vec<u8>>) {
121    let mut buf = [0u8; READ_BUF_SIZE];
122    loop {
123        let mut ready = match fd.readable().await {
124            Ok(r) => r,
125            Err(e) => {
126                tracing::debug!("pty read await error: {}", e);
127                break;
128            }
129        };
130
131        match nix::unistd::read(&*fd, &mut buf) {
132            Ok(0) => break,
133            Ok(n) => {
134                if tx.send(buf[..n].to_vec()).is_err() {
135                    break;
136                }
137                ready.retain_ready();
138            }
139            Err(nix::errno::Errno::EAGAIN) => {
140                ready.clear_ready();
141            }
142            Err(e) => {
143                tracing::debug!("pty read error: {}", e);
144                break;
145            }
146        }
147    }
148}
149
150async fn write_loop(fd: Arc<AsyncFd<std::os::fd::OwnedFd>>, mut rx: mpsc::Receiver<Vec<u8>>) {
151    while let Some(data) = rx.recv().await {
152        let mut written = 0;
153        while written < data.len() {
154            let mut ready = match fd.writable().await {
155                Ok(r) => r,
156                Err(e) => {
157                    tracing::debug!("pty write await error: {}", e);
158                    return;
159                }
160            };
161            match nix::unistd::write(&*fd, &data[written..]) {
162                Ok(n) => {
163                    written += n;
164                    ready.retain_ready();
165                }
166                Err(nix::errno::Errno::EAGAIN) => {
167                    ready.clear_ready();
168                }
169                Err(e) => {
170                    tracing::debug!("pty write error: {}", e);
171                    return;
172                }
173            }
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use tokio::time::{Duration, timeout};
182
183    #[tokio::test]
184    async fn test_write_and_read_output() {
185        let (terminal, mut rx) = Terminal::spawn("/bin/sh", None).expect("spawn /bin/sh");
186
187        terminal
188            .write(b"echo hello_test_marker\n".to_vec())
189            .await
190            .unwrap();
191
192        let mut collected = String::new();
193        let deadline = Duration::from_secs(3);
194        let _ = timeout(deadline, async {
195            while let Ok(data) = rx.recv().await {
196                collected.push_str(&String::from_utf8_lossy(&data));
197                if collected.contains("hello_test_marker") {
198                    break;
199                }
200            }
201        })
202        .await;
203
204        assert!(
205            collected.contains("hello_test_marker"),
206            "expected output to contain marker, got: {collected}"
207        );
208    }
209
210    #[tokio::test]
211    async fn test_resize() {
212        let (terminal, _rx) = Terminal::spawn("/bin/sh", None).expect("spawn /bin/sh");
213        terminal.resize(50, 132).expect("resize should succeed");
214    }
215
216    #[tokio::test]
217    async fn test_closed_on_exit() {
218        let (terminal, _rx) = Terminal::spawn("/bin/sh", None).expect("spawn /bin/sh");
219        let mut closed = terminal.closed();
220
221        terminal.write(b"exit\n".to_vec()).await.unwrap();
222
223        let deadline = Duration::from_secs(10);
224        let result = timeout(deadline, closed.wait_for(|&v| v)).await;
225        assert!(
226            result.is_ok(),
227            "closed signal should be received after exit"
228        );
229    }
230}