1use std::path::Path;
8use std::process::Child;
9use std::sync::{Arc, Mutex};
10
11use nix::sys::signal::{self, Signal};
12use nix::unistd::Pid;
13use tokio::io::Interest;
14use tokio::io::unix::AsyncFd;
15use tokio::sync::{broadcast, mpsc, watch};
16
17use crate::pty::PtyMaster;
18
19const OUTPUT_CHANNEL_SIZE: usize = 64;
20const INPUT_CHANNEL_SIZE: usize = 256;
21const READ_BUF_SIZE: usize = 4096;
22
23pub struct Terminal {
28 input_tx: mpsc::Sender<Vec<u8>>,
29 output_tx: broadcast::Sender<Vec<u8>>,
30 fd: Arc<AsyncFd<std::os::fd::OwnedFd>>,
31 child: Mutex<Option<Child>>,
32 closed_rx: watch::Receiver<bool>,
33}
34
35impl Terminal {
36 pub fn spawn(
41 shell: &str,
42 pwd: Option<&Path>,
43 ) -> std::io::Result<(Self, broadcast::Receiver<Vec<u8>>)> {
44 let PtyMaster { master, mut child } = PtyMaster::spawn(shell, pwd)?;
45
46 let async_fd = match AsyncFd::with_interest(master, Interest::READABLE | Interest::WRITABLE)
47 {
48 Ok(fd) => fd,
49 Err(e) => {
50 let _ = child.kill();
51 let _ = child.wait();
52 return Err(e);
53 }
54 };
55 let fd = Arc::new(async_fd);
56
57 let (input_tx, input_rx) = mpsc::channel(INPUT_CHANNEL_SIZE);
58 let (output_tx, output_rx) = broadcast::channel(OUTPUT_CHANNEL_SIZE);
59 let (closed_tx, closed_rx) = watch::channel(false);
60
61 let read_fd = fd.clone();
62 let read_tx = output_tx.clone();
63 tokio::spawn(async move {
64 read_loop(read_fd, read_tx).await;
65 let _ = closed_tx.send(true);
66 });
67
68 let write_fd = fd.clone();
69 tokio::spawn(async move {
70 write_loop(write_fd, input_rx).await;
71 });
72
73 let terminal = Terminal {
74 input_tx,
75 output_tx,
76 fd,
77 child: Mutex::new(Some(child)),
78 closed_rx,
79 };
80 Ok((terminal, output_rx))
81 }
82
83 pub fn subscribe(&self) -> broadcast::Receiver<Vec<u8>> {
85 self.output_tx.subscribe()
86 }
87
88 pub fn closed(&self) -> watch::Receiver<bool> {
91 self.closed_rx.clone()
92 }
93
94 pub async fn write(&self, data: Vec<u8>) -> Result<(), String> {
96 self.input_tx.send(data).await.map_err(|e| e.to_string())
97 }
98
99 pub fn resize(&self, rows: u16, cols: u16) -> std::io::Result<()> {
101 crate::pty::set_window_size(&*self.fd, rows, cols)
102 }
103}
104
105impl Drop for Terminal {
106 fn drop(&mut self) {
107 if let Some(mut child) = self.child.get_mut().unwrap().take() {
108 let pid = Pid::from_raw(child.id() as i32);
109 let _ = signal::kill(pid, Signal::SIGHUP);
110 std::thread::spawn(move || {
114 let _ = child.wait();
115 });
116 }
117 }
118}
119
120async fn read_loop(fd: Arc<AsyncFd<std::os::fd::OwnedFd>>, tx: broadcast::Sender<Vec<u8>>) {
121 let mut buf = [0u8; READ_BUF_SIZE];
122 loop {
123 let mut ready = match fd.readable().await {
124 Ok(r) => r,
125 Err(e) => {
126 tracing::debug!("pty read await error: {}", e);
127 break;
128 }
129 };
130
131 match nix::unistd::read(&*fd, &mut buf) {
132 Ok(0) => break,
133 Ok(n) => {
134 if tx.send(buf[..n].to_vec()).is_err() {
135 break;
136 }
137 ready.retain_ready();
138 }
139 Err(nix::errno::Errno::EAGAIN) => {
140 ready.clear_ready();
141 }
142 Err(e) => {
143 tracing::debug!("pty read error: {}", e);
144 break;
145 }
146 }
147 }
148}
149
150async fn write_loop(fd: Arc<AsyncFd<std::os::fd::OwnedFd>>, mut rx: mpsc::Receiver<Vec<u8>>) {
151 while let Some(data) = rx.recv().await {
152 let mut written = 0;
153 while written < data.len() {
154 let mut ready = match fd.writable().await {
155 Ok(r) => r,
156 Err(e) => {
157 tracing::debug!("pty write await error: {}", e);
158 return;
159 }
160 };
161 match nix::unistd::write(&*fd, &data[written..]) {
162 Ok(n) => {
163 written += n;
164 ready.retain_ready();
165 }
166 Err(nix::errno::Errno::EAGAIN) => {
167 ready.clear_ready();
168 }
169 Err(e) => {
170 tracing::debug!("pty write error: {}", e);
171 return;
172 }
173 }
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use tokio::time::{Duration, timeout};
182
183 #[tokio::test]
184 async fn test_write_and_read_output() {
185 let (terminal, mut rx) = Terminal::spawn("/bin/sh", None).expect("spawn /bin/sh");
186
187 terminal
188 .write(b"echo hello_test_marker\n".to_vec())
189 .await
190 .unwrap();
191
192 let mut collected = String::new();
193 let deadline = Duration::from_secs(3);
194 let _ = timeout(deadline, async {
195 while let Ok(data) = rx.recv().await {
196 collected.push_str(&String::from_utf8_lossy(&data));
197 if collected.contains("hello_test_marker") {
198 break;
199 }
200 }
201 })
202 .await;
203
204 assert!(
205 collected.contains("hello_test_marker"),
206 "expected output to contain marker, got: {collected}"
207 );
208 }
209
210 #[tokio::test]
211 async fn test_resize() {
212 let (terminal, _rx) = Terminal::spawn("/bin/sh", None).expect("spawn /bin/sh");
213 terminal.resize(50, 132).expect("resize should succeed");
214 }
215
216 #[tokio::test]
217 async fn test_closed_on_exit() {
218 let (terminal, _rx) = Terminal::spawn("/bin/sh", None).expect("spawn /bin/sh");
219 let mut closed = terminal.closed();
220
221 terminal.write(b"exit\n".to_vec()).await.unwrap();
222
223 let deadline = Duration::from_secs(10);
224 let result = timeout(deadline, closed.wait_for(|&v| v)).await;
225 assert!(
226 result.is_ok(),
227 "closed signal should be received after exit"
228 );
229 }
230}