Coverage Report

Created: 2026-02-23 10:49

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-util/src/proto_stream_utils.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::fmt::Debug;
16
use core::mem;
17
use core::pin::Pin;
18
use core::task::{Context, Poll};
19
use std::borrow::Cow;
20
use std::sync::Arc;
21
22
use futures::{Stream, StreamExt};
23
use nativelink_error::{Error, ResultExt, error_if, make_input_err};
24
use nativelink_proto::google::bytestream::{ReadResponse, WriteRequest};
25
use parking_lot::Mutex;
26
use tonic::{Status, Streaming};
27
28
use crate::resource_info::ResourceInfo;
29
30
pub struct WriteRequestStreamWrapper<T> {
31
    pub resource_info: ResourceInfo<'static>,
32
    pub bytes_received: usize,
33
    stream: T,
34
    first_msg: Option<WriteRequest>,
35
    pub write_finished: bool,
36
}
37
38
impl<T> Debug for WriteRequestStreamWrapper<T> {
39
11
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
40
11
        f.debug_struct("WriteRequestStreamWrapper")
41
11
            .field("resource_info", &self.resource_info)
42
11
            .field("bytes_received", &self.bytes_received)
43
11
            .field("first_msg", &self.first_msg)
44
11
            .field("write_finished", &self.write_finished)
45
11
            .finish()
46
11
    }
47
}
48
49
impl<T, E> WriteRequestStreamWrapper<T>
50
where
51
    T: Stream<Item = Result<WriteRequest, E>> + Unpin,
52
    E: Into<Error>,
53
{
54
17
    pub async fn from(mut stream: T) -> Result<Self, Error> {
55
17
        let 
first_msg16
= stream
56
17
            .next()
57
17
            .await
58
17
            .err_tip(|| "Error receiving first message in stream")
?0
59
17
            .err_tip(|| "Expected WriteRequest struct in stream (from)")
?1
;
60
61
16
        let resource_info = ResourceInfo::new(&first_msg.resource_name, true)
62
16
            .err_tip(|| 
{0
63
0
                format!(
64
0
                    "Could not extract resource info from first message of stream: {}",
65
                    first_msg.resource_name
66
                )
67
0
            })?
68
16
            .to_owned();
69
70
16
        Ok(Self {
71
16
            resource_info,
72
16
            bytes_received: 0,
73
16
            stream,
74
16
            first_msg: Some(first_msg),
75
16
            write_finished: false,
76
16
        })
77
17
    }
78
79
25
    pub async fn next(&mut self) -> Option<Result<WriteRequest, Error>> {
80
126
        
futures::future::poll_fn25
(|cx| Pin::new(&mut *self).poll_next(cx)).
await25
81
25
    }
82
83
0
    pub const fn is_first_msg(&self) -> bool {
84
0
        self.first_msg.is_some()
85
0
    }
86
87
    /// Returns whether the first message has `finish_write` set to true.
88
    /// This indicates a single-shot upload where all data is in one message.
89
15
    pub fn is_first_msg_complete(&self) -> bool {
90
15
        self.first_msg.as_ref().is_some_and(|msg| msg.finish_write)
91
15
    }
92
}
93
94
impl<T, E> Stream for WriteRequestStreamWrapper<T>
95
where
96
    E: Into<Error>,
97
    T: Stream<Item = Result<WriteRequest, E>> + Unpin,
98
{
99
    type Item = Result<WriteRequest, Error>;
100
101
130
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
102
        // If the stream said that the previous message was the last one, then
103
        // return a stream EOF (i.e. None).
104
130
        if self.write_finished {
  Branch (104:12): [True: 0, False: 126]
  Branch (104:12): [True: 0, False: 0]
  Branch (104:12): [Folded - Ignored]
  Branch (104:12): [Folded - Ignored]
  Branch (104:12): [True: 1, False: 3]
105
0
            error_if!(
106
1
                self.bytes_received != self.resource_info.expected_size,
  Branch (106:17): [True: 0, False: 0]
  Branch (106:17): [True: 0, False: 0]
  Branch (106:17): [Folded - Ignored]
  Branch (106:17): [Folded - Ignored]
  Branch (106:17): [True: 0, False: 1]
107
                "Did not send enough data. Expected {}, but so far received {}",
108
0
                self.resource_info.expected_size,
109
0
                self.bytes_received
110
            );
111
1
            return Poll::Ready(None);
112
129
        }
113
114
        // Gets the next message, this is either the cached first or a
115
        // subsequent message from the wrapped Stream.
116
129
        let 
maybe_message28
= if let Some(
first_msg16
) = self.first_msg.take() {
  Branch (116:36): [True: 15, False: 111]
  Branch (116:36): [True: 0, False: 0]
  Branch (116:36): [Folded - Ignored]
  Branch (116:36): [Folded - Ignored]
  Branch (116:36): [True: 1, False: 2]
117
16
            Ok(first_msg)
118
        } else {
119
113
            match Pin::new(&mut self.stream).poll_next(cx) {
120
101
                Poll::Pending => return Poll::Pending,
121
9
                Poll::Ready(Some(maybe_message)) => maybe_message
122
9
                    .err_tip(|| format!(
"Stream error at byte {}"0
,
self.bytes_received0
)),
123
3
                Poll::Ready(None) => Err(make_input_err!(
124
3
                    "Expected WriteRequest struct in stream (got None)"
125
3
                )),
126
            }
127
        };
128
129
        // If we successfully got a message, update our internal state with the
130
        // message meta data.
131
28
        Poll::Ready(Some(maybe_message.and_then(|message| 
{25
132
25
            self.write_finished = message.finish_write;
133
25
            self.bytes_received += message.data.len();
134
135
            // Check that we haven't read past the expected end.
136
25
            if self.bytes_received > self.resource_info.expected_size {
  Branch (136:16): [True: 2, False: 20]
  Branch (136:16): [True: 0, False: 0]
  Branch (136:16): [Folded - Ignored]
  Branch (136:16): [Folded - Ignored]
  Branch (136:16): [True: 0, False: 3]
137
2
                Err(make_input_err!(
138
2
                    "Sent too much data. Expected {}, but so far received {}",
139
2
                    self.resource_info.expected_size,
140
2
                    self.bytes_received
141
2
                ))
142
            } else {
143
23
                Ok(message)
144
            }
145
25
        })))
146
130
    }
147
}
148
149
/// Represents the state of the first response in a `FirstStream`.
150
#[derive(Debug)]
151
pub enum FirstResponseState {
152
    /// Contains an optional first response that hasn't been consumed yet.
153
    /// A `None` value indicates the first response was EOF.
154
    Unused(Option<ReadResponse>),
155
    /// Indicates the first response has been consumed and future reads should
156
    /// come from the underlying stream.
157
    Used,
158
}
159
160
/// This provides a buffer for the first response from GrpcStore.read in order
161
/// to allow the first read to occur within the retry loop.  That means that if
162
/// the connection establishes fine, but reading the first byte of the file
163
/// fails we have the ability to retry before returning to the caller.
164
#[derive(Debug)]
165
pub struct FirstStream {
166
    /// The current state of the first response. When in the `Unused` state,
167
    /// contains an optional response which could be `None` or an EOF.
168
    /// Once consumed, transitions to the `Used` state.
169
    state: FirstResponseState,
170
    /// The stream to get responses from after the first response is consumed.
171
    stream: Streaming<ReadResponse>,
172
}
173
174
impl FirstStream {
175
    /// Creates a new `FirstStream` with the given first response and underlying
176
    /// stream.
177
0
    pub const fn new(
178
0
        first_response: Option<ReadResponse>,
179
0
        stream: Streaming<ReadResponse>,
180
0
    ) -> Self {
181
0
        Self {
182
0
            state: FirstResponseState::Unused(first_response),
183
0
            stream,
184
0
        }
185
0
    }
186
}
187
188
impl Stream for FirstStream {
189
    type Item = Result<ReadResponse, Status>;
190
191
0
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192
0
        match mem::replace(&mut self.state, FirstResponseState::Used) {
193
0
            FirstResponseState::Unused(first_response) => Poll::Ready(first_response.map(Ok)),
194
0
            FirstResponseState::Used => Pin::new(&mut self.stream).poll_next(cx),
195
        }
196
0
    }
197
}
198
199
/// This structure wraps all of the information required to perform a write
200
/// request on the `GrpcStore`.  It stores the last message retrieved which allows
201
/// the write to resume since the UUID allows upload resume at the server.
202
#[derive(Debug)]
203
pub struct WriteState<T, E>
204
where
205
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
206
    E: Into<Error> + 'static,
207
{
208
    instance_name: String,
209
    read_stream_error: Option<Error>,
210
    read_stream: WriteRequestStreamWrapper<T>,
211
    // Tonic doesn't appear to report an error until it has taken two messages,
212
    // therefore we are required to buffer the last two messages.
213
    cached_messages: [Option<WriteRequest>; 2],
214
    // When resuming after an error, the previous messages are cloned into this
215
    // queue upfront to allow them to be served back.
216
    resume_queue: [Option<WriteRequest>; 2],
217
    // An optimisation to avoid having to manage resume_queue when it's empty.
218
    is_resumed: bool,
219
}
220
221
impl<T, E> WriteState<T, E>
222
where
223
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
224
    E: Into<Error> + 'static,
225
{
226
1
    pub const fn new(instance_name: String, read_stream: WriteRequestStreamWrapper<T>) -> Self {
227
1
        Self {
228
1
            instance_name,
229
1
            read_stream_error: None,
230
1
            read_stream,
231
1
            cached_messages: [None, None],
232
1
            resume_queue: [None, None],
233
1
            is_resumed: false,
234
1
        }
235
1
    }
236
237
3
    fn push_message(&mut self, message: WriteRequest) {
238
3
        self.cached_messages.swap(0, 1);
239
3
        self.cached_messages[0] = Some(message);
240
3
    }
241
242
4
    const fn resumed_message(&mut self) -> Option<WriteRequest> {
243
4
        if self.is_resumed {
  Branch (243:12): [True: 0, False: 0]
  Branch (243:12): [True: 0, False: 0]
  Branch (243:12): [Folded - Ignored]
  Branch (243:12): [Folded - Ignored]
  Branch (243:12): [True: 0, False: 4]
244
            // The resume_queue is a circular buffer, that we have to shift,
245
            // since its only got two elements its a trivial swap.
246
0
            self.resume_queue.swap(0, 1);
247
0
            let message = self.resume_queue[0].take();
248
0
            if message.is_none() {
  Branch (248:16): [True: 0, False: 0]
  Branch (248:16): [True: 0, False: 0]
  Branch (248:16): [Folded - Ignored]
  Branch (248:16): [Folded - Ignored]
  Branch (248:16): [True: 0, False: 0]
249
0
                self.is_resumed = false;
250
0
            }
251
0
            message
252
        } else {
253
4
            None
254
        }
255
4
    }
256
257
0
    pub const fn can_resume(&self) -> bool {
258
0
        self.read_stream_error.is_none()
  Branch (258:9): [True: 0, False: 0]
  Branch (258:9): [True: 0, False: 0]
  Branch (258:9): [Folded - Ignored]
  Branch (258:9): [Folded - Ignored]
259
0
            && (self.cached_messages[0].is_some() || self.read_stream.is_first_msg())
  Branch (259:17): [True: 0, False: 0]
  Branch (259:17): [True: 0, False: 0]
  Branch (259:17): [Folded - Ignored]
  Branch (259:17): [Folded - Ignored]
260
0
    }
261
262
0
    pub fn resume(&mut self) {
263
0
        self.resume_queue.clone_from(&self.cached_messages);
264
0
        self.is_resumed = true;
265
0
    }
266
267
1
    pub const fn take_read_stream_error(&mut self) -> Option<Error> {
268
1
        self.read_stream_error.take()
269
1
    }
270
}
271
272
/// A wrapper around `WriteState` to allow it to be reclaimed from the underlying
273
/// write call in the case of failure.
274
#[derive(Debug)]
275
pub struct WriteStateWrapper<T, E>
276
where
277
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
278
    E: Into<Error> + 'static,
279
{
280
    shared_state: Arc<Mutex<WriteState<T, E>>>,
281
}
282
283
impl<T, E> WriteStateWrapper<T, E>
284
where
285
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
286
    E: Into<Error> + 'static,
287
{
288
1
    pub const fn new(shared_state: Arc<Mutex<WriteState<T, E>>>) -> Self {
289
1
        Self { shared_state }
290
1
    }
291
}
292
293
impl<T, E> Stream for WriteStateWrapper<T, E>
294
where
295
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
296
    E: Into<Error> + 'static,
297
{
298
    type Item = WriteRequest;
299
300
4
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301
        const IS_UPLOAD_TRUE: bool = true;
302
303
        // This should be an uncontended lock since write was called.
304
4
        let mut local_state = self.shared_state.lock();
305
        // If this is the first or second call after a failure and we have
306
        // cached messages, then use the cached write requests.
307
4
        let cached_message = local_state.resumed_message();
308
4
        if cached_message.is_some() {
  Branch (308:12): [True: 0, False: 0]
  Branch (308:12): [True: 0, False: 0]
  Branch (308:12): [Folded - Ignored]
  Branch (308:12): [Folded - Ignored]
  Branch (308:12): [True: 0, False: 4]
309
0
            return Poll::Ready(cached_message);
310
4
        }
311
        // Read a new write request from the downstream.
312
4
        let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx)
  Branch (312:13): [True: 0, False: 0]
  Branch (312:13): [True: 0, False: 0]
  Branch (312:13): [Folded - Ignored]
  Branch (312:13): [Folded - Ignored]
  Branch (312:13): [True: 4, False: 0]
313
        else {
314
0
            return Poll::Pending;
315
        };
316
        // Update the instance name in the write request and forward it on.
317
4
        let result = match 
maybe_message3
{
318
3
            Some(Ok(mut message)) => {
319
3
                if !message.resource_name.is_empty() {
  Branch (319:20): [True: 0, False: 0]
  Branch (319:20): [True: 0, False: 0]
  Branch (319:20): [Folded - Ignored]
  Branch (319:20): [Folded - Ignored]
  Branch (319:20): [True: 1, False: 2]
320
                    // Replace the instance name in the resource name if it is
321
                    // different from the instance name in the write state.
322
1
                    match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) {
323
1
                        Ok(mut resource_name) => {
324
1
                            if resource_name.instance_name != local_state.instance_name {
  Branch (324:32): [True: 0, False: 0]
  Branch (324:32): [True: 0, False: 0]
  Branch (324:32): [Folded - Ignored]
  Branch (324:32): [Folded - Ignored]
  Branch (324:32): [True: 0, False: 1]
325
0
                                resource_name.instance_name =
326
0
                                    Cow::Borrowed(&local_state.instance_name);
327
0
                                message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE);
328
1
                            }
329
                        }
330
0
                        Err(err) => {
331
0
                            local_state.read_stream_error = Some(err);
332
0
                            return Poll::Ready(None);
333
                        }
334
                    }
335
2
                }
336
                // Cache the last request in case there is an error to allow
337
                // the upload to be resumed.
338
3
                local_state.push_message(message.clone());
339
3
                Some(message)
340
            }
341
0
            Some(Err(err)) => {
342
0
                local_state.read_stream_error = Some(err);
343
0
                None
344
            }
345
1
            None => None,
346
        };
347
4
        Poll::Ready(result)
348
4
    }
349
}