Coverage Report

Created: 2026-05-23 21:09

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-util/src/buf_channel.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::pin::Pin;
16
use core::sync::atomic::{AtomicBool, Ordering};
17
use core::task::Poll;
18
use std::collections::VecDeque;
19
use std::sync::Arc;
20
21
use bytes::{Bytes, BytesMut};
22
use futures::task::Context;
23
use futures::{Future, Stream, TryFutureExt};
24
use nativelink_error::{Code, Error, ResultExt, error_if, make_err, make_input_err};
25
use tokio::sync::mpsc;
26
use tracing::warn;
27
28
const ZERO_DATA: Bytes = Bytes::new();
29
30
/// Create a channel pair that can be used to transport buffer objects around to
31
/// different components. This wrapper is used because the streams give some
32
/// utility like managing EOF in a more friendly way, ensure if no EOF is received
33
/// it will send an error to the receiver channel before shutting down and count
34
/// the number of bytes sent.
35
#[must_use]
36
13.3k
pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) {
37
    // We allow up to 2 items in the buffer at any given time. There is no major
38
    // reason behind this magic number other than thinking it will be nice to give
39
    // a little time for another thread to wake up and consume data if another
40
    // thread is pumping large amounts of data into the channel.
41
13.3k
    let (tx, rx) = mpsc::channel(2);
42
13.3k
    let eof_sent = Arc::new(AtomicBool::new(false));
43
13.3k
    (
44
13.3k
        DropCloserWriteHalf {
45
13.3k
            tx: Some(tx),
46
13.3k
            bytes_written: 0,
47
13.3k
            eof_sent: eof_sent.clone(),
48
13.3k
        },
49
13.3k
        DropCloserReadHalf {
50
13.3k
            rx,
51
13.3k
            queued_data: VecDeque::new(),
52
13.3k
            last_err: None,
53
13.3k
            eof_sent,
54
13.3k
            bytes_received: 0,
55
13.3k
            recent_data: Vec::new(),
56
13.3k
            max_recent_data_size: 0,
57
13.3k
        },
58
13.3k
    )
59
13.3k
}
60
61
/// Writer half of the pair.
62
#[derive(Debug)]
63
pub struct DropCloserWriteHalf {
64
    tx: Option<mpsc::Sender<Bytes>>,
65
    bytes_written: u64,
66
    eof_sent: Arc<AtomicBool>,
67
}
68
69
impl DropCloserWriteHalf {
70
    /// Sends data over the channel to the receiver.
71
132k
    pub fn send(&mut self, buf: Bytes) -> impl Future<Output = Result<(), Error>> + '_ {
72
132k
        self.send_get_bytes_on_error(buf).map_err(|err| err.0)
73
132k
    }
74
75
    /// Sends data over the channel to the receiver.
76
    #[inline]
77
132k
    async fn send_get_bytes_on_error(&mut self, buf: Bytes) -> Result<(), (Error, Bytes)> {
78
132k
        let tx = match self
79
132k
            .tx
80
132k
            .as_ref()
81
132k
            .ok_or_else(|| 
make_err!0
(
Code::Internal0
, "Tried to send while stream is closed"))
82
        {
83
132k
            Ok(tx) => tx,
84
0
            Err(e) => return Err((e, buf)),
85
        };
86
132k
        let Ok(buf_len) = u64::try_from(buf.len()) else {
87
0
            return Err((
88
0
                make_err!(Code::Internal, "Could not convert usize to u64"),
89
0
                buf,
90
0
            ));
91
        };
92
132k
        if buf_len == 0 {
93
0
            return Err((
94
0
                make_input_err!("Cannot send EOF in send(). Instead use send_eof()"),
95
0
                buf,
96
0
            ));
97
132k
        }
98
132k
        if let Err(
err0
) = tx.send(buf).await {
99
            // Close our channel.
100
0
            self.tx = None;
101
0
            return Err((
102
0
                make_err!(
103
0
                    Code::Internal,
104
0
                    "Failed to write to data, receiver disconnected"
105
0
                ),
106
0
                err.0,
107
0
            ));
108
132k
        }
109
132k
        self.bytes_written += buf_len;
110
132k
        Ok(())
111
132k
    }
112
113
    /// Binds a reader and a writer together. This will send all the data from the reader
114
    /// to the writer until an EOF is received.
115
    /// This will always read one message ahead to ensure that if an error happens
116
    /// on the EOF message it will not forward on the last payload message and instead
117
    /// forward on the error.
118
5
    pub async fn bind_buffered(&mut self, reader: &mut DropCloserReadHalf) -> Result<(), Error> {
119
        loop {
120
110
            let chunk = reader
121
110
                .recv()
122
110
                .await
123
110
                .err_tip(|| "In DropCloserWriteHalf::bind_buffered::recv")
?0
;
124
110
            if chunk.is_empty() {
125
4
                self.send_eof()
126
4
                    .err_tip(|| "In DropCloserWriteHalf::bind_buffered::send_eof")
?0
;
127
4
                break; // EOF.
128
106
            }
129
            // Always read one message ahead so if we get an error on our EOF
130
            // we forward it on to the reader.
131
106
            if reader.peek().await.is_err() {
132
                // Read our next message for good book keeping.
133
0
                drop(
134
1
                    reader
135
1
                        .recv()
136
1
                        .await
137
1
                        .err_tip(|| "In DropCloserWriteHalf::bind_buffered::peek::eof")?,
138
                );
139
0
                return Err(make_err!(
140
0
                    Code::Internal,
141
0
                    "DropCloserReadHalf::peek() said error, but when data received said Ok. This should never happen."
142
0
                ));
143
105
            }
144
105
            match self.send_get_bytes_on_error(chunk).await {
145
105
                Ok(()) => {}
146
0
                Err(e) => {
147
0
                    reader.queued_data.push_front(e.1);
148
0
                    return Err(e.0).err_tip(|| "In DropCloserWriteHalf::bind_buffered::send");
149
                }
150
            }
151
        }
152
4
        Ok(())
153
5
    }
154
155
    /// Sends an EOF (End of File) message to the receiver which will gracefully let the
156
    /// stream know it has no more data. This will close the stream.
157
13.2k
    pub fn send_eof(&mut self) -> Result<(), Error> {
158
        // Flag that we have sent the EOF.
159
13.2k
        let eof_was_sent = self.eof_sent.swap(true, Ordering::Release);
160
13.2k
        if eof_was_sent {
161
1
            warn!(
162
                "Stream already closed when eof already was sent. This is often ok for retry was triggered, but should not happen on happy path."
163
            );
164
1
            return Ok(());
165
13.2k
        }
166
167
        // Now close our stream.
168
13.2k
        self.tx = None;
169
13.2k
        Ok(())
170
13.2k
    }
171
172
    /// Returns the number of bytes written so far. This does not mean the receiver received
173
    /// all of the bytes written to the stream so far.
174
    #[must_use]
175
174
    pub const fn get_bytes_written(&self) -> u64 {
176
174
        self.bytes_written
177
174
    }
178
179
    /// Returns if the pipe was broken. This is good for determining if the reader broke the
180
    /// pipe or the writer broke the pipe, since this will only return true if the pipe was
181
    /// broken by the writer.
182
    #[must_use]
183
2
    pub const fn is_pipe_broken(&self) -> bool {
184
2
        self.tx.is_none()
185
2
    }
186
}
187
188
/// Reader half of the pair.
189
#[derive(Debug)]
190
pub struct DropCloserReadHalf {
191
    rx: mpsc::Receiver<Bytes>,
192
    /// Number of bytes received over the stream.
193
    bytes_received: u64,
194
    eof_sent: Arc<AtomicBool>,
195
    /// If there was an error in the stream, this will be set to the last error.
196
    last_err: Option<Error>,
197
    /// If not empty, this is the data that needs to be sent out before
198
    /// data from the underlying channel can should be sent.
199
    queued_data: VecDeque<Bytes>,
200
    /// As data is being read from the stream, this buffer will be filled
201
    /// with the most recent data. Once `max_recent_data_size` is reached
202
    /// this buffer will be cleared and no longer be populated.
203
    /// This is useful if the caller wants to reset the the reader to before
204
    /// any of the data was received if possible (eg: something failed and
205
    /// we want to retry).
206
    recent_data: Vec<Bytes>,
207
    /// Amount of data to keep in the `recent_data` buffer before clearing it
208
    /// and no longer populating it.
209
    max_recent_data_size: u64,
210
}
211
212
impl DropCloserReadHalf {
213
    /// Returns if the stream has data ready.
214
0
    pub fn is_empty(&self) -> bool {
215
0
        self.rx.is_empty()
216
0
    }
217
218
145k
    fn recv_inner(&mut self, chunk: Bytes) -> Result<Bytes, Error> {
219
        // `queued_data` is allowed to have empty bytes that represent EOF
220
145k
        if chunk.is_empty() {
221
12.7k
            if !self.eof_sent.load(Ordering::Acquire) {
222
31
                let err = make_err!(Code::Internal, "Sender dropped before sending EOF");
223
31
                self.queued_data.clear();
224
31
                self.recent_data.clear();
225
31
                self.bytes_received = 0;
226
31
                self.last_err = Some(err.clone());
227
31
                return Err(err);
228
12.7k
            }
229
230
12.7k
            self.maybe_populate_recent_data(&ZERO_DATA);
231
12.7k
            return Ok(ZERO_DATA);
232
132k
        }
233
234
132k
        self.bytes_received += chunk.len() as u64;
235
132k
        self.maybe_populate_recent_data(&chunk);
236
132k
        Ok(chunk)
237
145k
    }
238
239
    /// Try to receive a chunk of data, returning `None` if none is available.
240
169k
    pub fn try_recv(&mut self) -> Option<Result<Bytes, Error>> {
241
169k
        if let Some(
err2
) = &self.last_err {
242
2
            return Some(Err(err.clone()));
243
169k
        }
244
169k
        self.queued_data.pop_front().map(Ok)
245
169k
    }
246
247
    /// Receive a chunk of data, waiting asynchronously until some is available.
248
169k
    pub async fn recv(&mut self) -> Result<Bytes, Error> {
249
169k
        if let Some(
result23.7k
) = self.try_recv() {
250
23.7k
            result
251
        } else {
252
            // `None` here indicates EOF, which we represent as Zero data
253
145k
            let 
data145k
= self.rx.recv().await.
unwrap_or145k
(
ZERO_DATA145k
);
254
145k
            self.recv_inner(data)
255
        }
256
169k
    }
257
258
145k
    fn maybe_populate_recent_data(&mut self, chunk: &Bytes) {
259
145k
        if self.max_recent_data_size == 0 {
260
30.0k
            return; // Fast path.
261
115k
        }
262
115k
        if self.bytes_received > self.max_recent_data_size {
263
62.9k
            if !self.recent_data.is_empty() {
264
1
                self.recent_data.clear();
265
62.9k
            }
266
62.9k
            return;
267
52.5k
        }
268
52.5k
        self.recent_data.push(chunk.clone());
269
145k
    }
270
271
    /// Sets the maximum size of the `recent_data` buffer. If the number of bytes
272
    /// received exceeds this size, the `recent_data` buffer will be cleared and
273
    /// no longer populated.
274
7
    pub const fn set_max_recent_data_size(&mut self, size: u64) {
275
7
        self.max_recent_data_size = size;
276
7
    }
277
278
    /// Attempts to reset the stream to before any data was received. This will
279
    /// only work if the number of bytes received is less than `max_recent_data_size`.
280
    ///
281
    /// On error the state of the stream is undefined and the caller should not
282
    /// attempt to use the stream again.
283
1
    pub fn try_reset_stream(&mut self) -> Result<(), Error> {
284
1
        if self.bytes_received > self.max_recent_data_size {
285
0
            return Err(make_err!(
286
0
                Code::Internal,
287
0
                "Cannot reset stream, max_recent_data_size exceeded"
288
0
            ));
289
1
        }
290
1
        let mut data_sum = 0;
291
1
        for 
chunk0
in self.recent_data.drain(..).rev() {
292
0
            data_sum += chunk.len() as u64;
293
0
            self.queued_data.push_front(chunk);
294
0
        }
295
1
        assert!(self.recent_data.is_empty(), "Recent_data should be empty");
296
        // Ensure the sum of the bytes in recent_data is equal to the bytes_received.
297
0
        error_if!(
298
1
            data_sum != self.bytes_received,
299
            "Sum of recent_data bytes does not equal bytes_received"
300
        );
301
1
        self.bytes_received = 0;
302
1
        Ok(())
303
1
    }
304
305
    /// Drains the reader until an EOF is received, but sends data to the void.
306
1.00k
    pub async fn drain(&mut self) -> Result<(), Error> {
307
        loop {
308
1.01k
            if self
309
1.01k
                .recv()
310
1.01k
                .await
311
1.01k
                .err_tip(|| "Failed to drain in buf_channel::drain")
?0
312
1.01k
                .is_empty()
313
            {
314
1.00k
                break; // EOF.
315
4
            }
316
        }
317
1.00k
        Ok(())
318
1.00k
    }
319
320
    /// Peek the next set of bytes in the stream without consuming them.
321
13.1k
    pub async fn peek(&mut self) -> Result<&Bytes, Error> {
322
13.1k
        if self.queued_data.is_empty() {
323
13.1k
            let 
chunk13.1k
= self.recv().await.err_tip(|| "In buf_channel::peek")
?5
;
324
13.1k
            self.queued_data.push_front(chunk);
325
0
        }
326
13.1k
        Ok(self
327
13.1k
            .queued_data
328
13.1k
            .front()
329
13.1k
            .expect("Should have data in the queue"))
330
13.1k
    }
331
332
    /// The number of bytes received over this stream so far.
333
0
    pub const fn get_bytes_received(&self) -> u64 {
334
0
        self.bytes_received
335
0
    }
336
337
    /// Takes exactly `size` number of bytes from the stream and returns them.
338
    /// This means the stream will keep polling until either an EOF is received or
339
    /// `size` bytes are received and concat them all together then return them.
340
    /// This method is optimized to reduce copies when possible.
341
    /// If `size` is None, it will take all the bytes in the stream.
342
34.1k
    pub async fn consume(&mut self, size: Option<usize>) -> Result<Bytes, Error> {
343
34.1k
        let size = size.unwrap_or(usize::MAX);
344
1.65k
        let first_chunk = {
345
34.1k
            let 
mut chunk34.1k
= self
346
34.1k
                .recv()
347
34.1k
                .await
348
34.1k
                .err_tip(|| "During first read of buf_channel::take()")
?18
;
349
34.1k
            if chunk.is_empty() {
350
375
                return Ok(chunk); // EOF.
351
33.7k
            }
352
33.7k
            if chunk.len() > size {
353
20.8k
                let remaining = chunk.split_off(size);
354
20.8k
                self.queued_data.push_front(remaining);
355
                // No need to read EOF if we are a partial chunk.
356
20.8k
                return Ok(chunk);
357
12.9k
            }
358
            // Try to read our EOF to ensure our sender did not error out.
359
12.9k
            match self.peek().await {
360
12.9k
                Ok(peeked_chunk) => {
361
12.9k
                    if peeked_chunk.is_empty() || 
chunk1.84k
.len() == size {
362
11.2k
                        return Ok(chunk);
363
1.65k
                    }
364
                }
365
4
                Err(e) => {
366
4
                    return Err(e).err_tip(|| "Failed to check if next chunk is EOF")?;
367
                }
368
            }
369
1.65k
            chunk
370
        };
371
1.65k
        let mut output = BytesMut::new();
372
1.65k
        output.extend_from_slice(&first_chunk);
373
374
        loop {
375
120k
            let 
mut chunk120k
= self
376
120k
                .recv()
377
120k
                .await
378
120k
                .err_tip(|| "During next read of buf_channel::take()")
?1
;
379
120k
            if chunk.is_empty() {
380
727
                break; // EOF.
381
119k
            }
382
119k
            if output.len() + chunk.len() > size {
383
104
                // Slice off the extra data and put it back into the queue. We are done.
384
104
                let remaining = chunk.split_off(size - output.len());
385
104
                self.queued_data.push_front(remaining);
386
119k
            }
387
119k
            output.extend_from_slice(&chunk);
388
119k
            if output.len() == size {
389
930
                break; // We are done.
390
118k
            }
391
        }
392
1.65k
        Ok(output.freeze())
393
34.1k
    }
394
}
395
396
impl Stream for DropCloserReadHalf {
397
    type Item = Result<Bytes, std::io::Error>;
398
399
    // TODO(palfrey) This is not very efficient as we are creating a new future on every
400
    // poll() call. It might be better to use a waker.
401
133
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
402
133
        Box::pin(self.recv())
403
133
            .as_mut()
404
133
            .poll(cx)
405
133
            .map(|result| match 
result98
{
406
97
                Ok(bytes) => {
407
97
                    if bytes.is_empty() {
408
24
                        return None;
409
73
                    }
410
73
                    Some(Ok(bytes))
411
                }
412
1
                Err(e) => Some(Err(e.to_std_err())),
413
98
            })
414
133
    }
415
}