Coverage Report

Created: 2025-03-08 07:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-util/src/proto_stream_utils.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::borrow::Cow;
16
use std::fmt::Debug;
17
use std::pin::Pin;
18
use std::sync::Arc;
19
use std::task::{Context, Poll};
20
21
use futures::{Stream, StreamExt};
22
use nativelink_error::{error_if, make_input_err, Error, ResultExt};
23
use nativelink_proto::google::bytestream::{ReadResponse, WriteRequest};
24
use parking_lot::Mutex;
25
use tonic::{Status, Streaming};
26
27
use crate::resource_info::ResourceInfo;
28
29
pub struct WriteRequestStreamWrapper<T> {
30
    pub resource_info: ResourceInfo<'static>,
31
    pub bytes_received: usize,
32
    stream: T,
33
    first_msg: Option<WriteRequest>,
34
    write_finished: bool,
35
}
36
37
impl<T> Debug for WriteRequestStreamWrapper<T> {
38
14
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39
14
        f.debug_struct("WriteRequestStreamWrapper")
40
14
            .field("resource_info", &self.resource_info)
41
14
            .field("bytes_received", &self.bytes_received)
42
14
            .field("first_msg", &self.first_msg)
43
14
            .field("write_finished", &self.write_finished)
44
14
            .finish()
45
14
    }
46
}
47
48
impl<T, E> WriteRequestStreamWrapper<T>
49
where
50
    T: Stream<Item = Result<WriteRequest, E>> + Unpin,
51
    E: Into<Error>,
52
{
53
16
    pub async fn from(mut stream: T) -> Result<WriteRequestStreamWrapper<T>, Error> {
54
16
        let 
first_msg15
= stream
55
16
            .next()
56
16
            .await
57
16
            .err_tip(|| 
"Error receiving first message in stream"0
)
?0
58
16
            .err_tip(|| 
"Expected WriteRequest struct in stream"1
)
?1
;
59
60
15
        let resource_info = ResourceInfo::new(&first_msg.resource_name, true)
61
15
            .err_tip(|| {
62
0
                format!(
63
0
                    "Could not extract resource info from first message of stream: {}",
64
0
                    first_msg.resource_name
65
0
                )
66
15
            })
?0
67
15
            .to_owned();
68
15
69
15
        Ok(WriteRequestStreamWrapper {
70
15
            resource_info,
71
15
            bytes_received: 0,
72
15
            stream,
73
15
            first_msg: Some(first_msg),
74
15
            write_finished: false,
75
15
        })
76
16
    }
77
78
24
    pub async fn next(&mut self) -> Option<Result<WriteRequest, Error>> {
79
125
        futures::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)
).await24
80
24
    }
81
82
0
    pub const fn is_first_msg(&self) -> bool {
83
0
        self.first_msg.is_some()
84
0
    }
85
}
86
87
impl<T, E> Stream for WriteRequestStreamWrapper<T>
88
where
89
    E: Into<Error>,
90
    T: Stream<Item = Result<WriteRequest, E>> + Unpin,
91
{
92
    type Item = Result<WriteRequest, Error>;
93
94
129
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95
129
        // If the stream said that the previous message was the last one, then
96
129
        // return a stream EOF (i.e. None).
97
129
        if self.write_finished {
  Branch (97:12): [True: 0, False: 125]
  Branch (97:12): [True: 0, False: 0]
  Branch (97:12): [Folded - Ignored]
  Branch (97:12): [Folded - Ignored]
  Branch (97:12): [True: 1, False: 3]
98
0
            error_if!(
99
1
                self.bytes_received != self.resource_info.expected_size,
  Branch (99:17): [True: 0, False: 0]
  Branch (99:17): [True: 0, False: 0]
  Branch (99:17): [Folded - Ignored]
  Branch (99:17): [Folded - Ignored]
  Branch (99:17): [True: 0, False: 1]
100
                "Did not send enough data. Expected {}, but so far received {}",
101
0
                self.resource_info.expected_size,
102
0
                self.bytes_received
103
            );
104
1
            return Poll::Ready(None);
105
128
        }
106
107
        // Gets the next message, this is either the cached first or a
108
        // subsequent message from the wrapped Stream.
109
128
        let 
maybe_message27
= if let Some(
first_msg15
) = self.first_msg.take() {
  Branch (109:36): [True: 14, False: 111]
  Branch (109:36): [True: 0, False: 0]
  Branch (109:36): [Folded - Ignored]
  Branch (109:36): [Folded - Ignored]
  Branch (109:36): [True: 1, False: 2]
110
15
            Ok(first_msg)
111
        } else {
112
113
            match Pin::new(&mut self.stream).poll_next(cx) {
113
101
                Poll::Pending => return Poll::Pending,
114
9
                Poll::Ready(Some(maybe_message)) => maybe_message
115
9
                    .err_tip(|| 
format!("Stream error at byte {}", self.bytes_received)0
),
116
3
                Poll::Ready(None) => Err(make_input_err!("Expected WriteRequest struct in stream")),
117
            }
118
        };
119
120
        // If we successfully got a message, update our internal state with the
121
        // message meta data.
122
27
        Poll::Ready(Some(maybe_message.and_then(|message| {
123
24
            self.write_finished = message.finish_write;
124
24
            self.bytes_received += message.data.len();
125
24
126
24
            // Check that we haven't read past the expected end.
127
24
            if self.bytes_received > self.resource_info.expected_size {
  Branch (127:16): [True: 1, False: 20]
  Branch (127:16): [True: 0, False: 0]
  Branch (127:16): [Folded - Ignored]
  Branch (127:16): [Folded - Ignored]
  Branch (127:16): [True: 0, False: 3]
128
1
                Err(make_input_err!(
129
1
                    "Sent too much data. Expected {}, but so far received {}",
130
1
                    self.resource_info.expected_size,
131
1
                    self.bytes_received
132
1
                ))
133
            } else {
134
23
                Ok(message)
135
            }
136
27
        
}24
)))
137
129
    }
138
}
139
140
/// This provides a buffer for the first response from GrpcStore.read in order
141
/// to allow the first read to occur within the retry loop.  That means that if
142
/// the connection establishes fine, but reading the first byte of the file
143
/// fails we have the ability to retry before returning to the caller.
144
pub struct FirstStream {
145
    /// Contains the first response from the stream (which could be an EOF,
146
    /// hence the nested Option).  This should be populated on creation and
147
    /// returned as the first result from the stream.  Subsequent reads from the
148
    /// stream will use the encapsulated stream.
149
    first_response: Option<Option<ReadResponse>>,
150
    /// The stream to get responses from when `first_response` is None.
151
    stream: Streaming<ReadResponse>,
152
}
153
154
impl FirstStream {
155
0
    pub const fn new(
156
0
        first_response: Option<ReadResponse>,
157
0
        stream: Streaming<ReadResponse>,
158
0
    ) -> Self {
159
0
        Self {
160
0
            first_response: Some(first_response),
161
0
            stream,
162
0
        }
163
0
    }
164
}
165
166
impl Stream for FirstStream {
167
    type Item = Result<ReadResponse, Status>;
168
169
0
    fn poll_next(
170
0
        mut self: Pin<&mut Self>,
171
0
        cx: &mut std::task::Context<'_>,
172
0
    ) -> std::task::Poll<Option<Self::Item>> {
173
0
        if let Some(first_response) = self.first_response.take() {
  Branch (173:16): [True: 0, False: 0]
  Branch (173:16): [Folded - Ignored]
174
0
            return std::task::Poll::Ready(first_response.map(Ok));
175
0
        }
176
0
        Pin::new(&mut self.stream).poll_next(cx)
177
0
    }
178
}
179
180
/// This structure wraps all of the information required to perform a write
181
/// request on the `GrpcStore`.  It stores the last message retrieved which allows
182
/// the write to resume since the UUID allows upload resume at the server.
183
pub struct WriteState<T, E>
184
where
185
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
186
    E: Into<Error> + 'static,
187
{
188
    instance_name: String,
189
    read_stream_error: Option<Error>,
190
    read_stream: WriteRequestStreamWrapper<T>,
191
    // Tonic doesn't appear to report an error until it has taken two messages,
192
    // therefore we are required to buffer the last two messages.
193
    cached_messages: [Option<WriteRequest>; 2],
194
    // When resuming after an error, the previous messages are cloned into this
195
    // queue upfront to allow them to be served back.
196
    resume_queue: [Option<WriteRequest>; 2],
197
    // An optimisation to avoid having to manage resume_queue when it's empty.
198
    is_resumed: bool,
199
}
200
201
impl<T, E> WriteState<T, E>
202
where
203
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
204
    E: Into<Error> + 'static,
205
{
206
1
    pub const fn new(instance_name: String, read_stream: WriteRequestStreamWrapper<T>) -> Self {
207
1
        Self {
208
1
            instance_name,
209
1
            read_stream_error: None,
210
1
            read_stream,
211
1
            cached_messages: [None, None],
212
1
            resume_queue: [None, None],
213
1
            is_resumed: false,
214
1
        }
215
1
    }
216
217
3
    fn push_message(&mut self, message: WriteRequest) {
218
3
        self.cached_messages.swap(0, 1);
219
3
        self.cached_messages[0] = Some(message);
220
3
    }
221
222
4
    fn resumed_message(&mut self) -> Option<WriteRequest> {
223
4
        if self.is_resumed {
  Branch (223:12): [True: 0, False: 0]
  Branch (223:12): [True: 0, False: 0]
  Branch (223:12): [Folded - Ignored]
  Branch (223:12): [Folded - Ignored]
  Branch (223:12): [True: 0, False: 4]
224
            // The resume_queue is a circular buffer, that we have to shift,
225
            // since its only got two elements its a trivial swap.
226
0
            self.resume_queue.swap(0, 1);
227
0
            let message = self.resume_queue[0].take();
228
0
            if message.is_none() {
  Branch (228:16): [True: 0, False: 0]
  Branch (228:16): [True: 0, False: 0]
  Branch (228:16): [Folded - Ignored]
  Branch (228:16): [Folded - Ignored]
  Branch (228:16): [True: 0, False: 0]
229
0
                self.is_resumed = false;
230
0
            }
231
0
            message
232
        } else {
233
4
            None
234
        }
235
4
    }
236
237
0
    pub const fn can_resume(&self) -> bool {
238
0
        self.read_stream_error.is_none()
  Branch (238:9): [True: 0, False: 0]
  Branch (238:9): [True: 0, False: 0]
  Branch (238:9): [Folded - Ignored]
  Branch (238:9): [Folded - Ignored]
239
0
            && (self.cached_messages[0].is_some() || self.read_stream.is_first_msg())
  Branch (239:17): [True: 0, False: 0]
  Branch (239:17): [True: 0, False: 0]
  Branch (239:17): [Folded - Ignored]
  Branch (239:17): [Folded - Ignored]
240
0
    }
241
242
0
    pub fn resume(&mut self) {
243
0
        self.resume_queue.clone_from(&self.cached_messages);
244
0
        self.is_resumed = true;
245
0
    }
246
247
1
    pub fn take_read_stream_error(&mut self) -> Option<Error> {
248
1
        self.read_stream_error.take()
249
1
    }
250
}
251
252
/// A wrapper around `WriteState` to allow it to be reclaimed from the underlying
253
/// write call in the case of failure.
254
pub struct WriteStateWrapper<T, E>
255
where
256
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
257
    E: Into<Error> + 'static,
258
{
259
    shared_state: Arc<Mutex<WriteState<T, E>>>,
260
}
261
262
impl<T, E> WriteStateWrapper<T, E>
263
where
264
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
265
    E: Into<Error> + 'static,
266
{
267
1
    pub const fn new(shared_state: Arc<Mutex<WriteState<T, E>>>) -> Self {
268
1
        Self { shared_state }
269
1
    }
270
}
271
272
impl<T, E> Stream for WriteStateWrapper<T, E>
273
where
274
    T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static,
275
    E: Into<Error> + 'static,
276
{
277
    type Item = WriteRequest;
278
279
4
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
280
        const IS_UPLOAD_TRUE: bool = true;
281
282
        // This should be an uncontended lock since write was called.
283
4
        let mut local_state = self.shared_state.lock();
284
4
        // If this is the first or second call after a failure and we have
285
4
        // cached messages, then use the cached write requests.
286
4
        let cached_message = local_state.resumed_message();
287
4
        if cached_message.is_some() {
  Branch (287:12): [True: 0, False: 0]
  Branch (287:12): [True: 0, False: 0]
  Branch (287:12): [Folded - Ignored]
  Branch (287:12): [Folded - Ignored]
  Branch (287:12): [True: 0, False: 4]
288
0
            return Poll::Ready(cached_message);
289
4
        }
290
        // Read a new write request from the downstream.
291
4
        let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx)
  Branch (291:13): [True: 0, False: 0]
  Branch (291:13): [True: 0, False: 0]
  Branch (291:13): [Folded - Ignored]
  Branch (291:13): [Folded - Ignored]
  Branch (291:13): [True: 4, False: 0]
292
        else {
293
0
            return Poll::Pending;
294
        };
295
        // Update the instance name in the write request and forward it on.
296
4
        let result = match 
maybe_message3
{
297
3
            Some(Ok(mut message)) => {
298
3
                if !message.resource_name.is_empty() {
  Branch (298:20): [True: 0, False: 0]
  Branch (298:20): [True: 0, False: 0]
  Branch (298:20): [Folded - Ignored]
  Branch (298:20): [Folded - Ignored]
  Branch (298:20): [True: 1, False: 2]
299
                    // Replace the instance name in the resource name if it is
300
                    // different from the instance name in the write state.
301
1
                    match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) {
302
1
                        Ok(mut resource_name) => {
303
1
                            if resource_name.instance_name != local_state.instance_name {
  Branch (303:32): [True: 0, False: 0]
  Branch (303:32): [True: 0, False: 0]
  Branch (303:32): [Folded - Ignored]
  Branch (303:32): [Folded - Ignored]
  Branch (303:32): [True: 0, False: 1]
304
0
                                resource_name.instance_name =
305
0
                                    Cow::Borrowed(&local_state.instance_name);
306
0
                                message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE);
307
1
                            }
308
                        }
309
0
                        Err(err) => {
310
0
                            local_state.read_stream_error = Some(err);
311
0
                            return Poll::Ready(None);
312
                        }
313
                    }
314
2
                }
315
                // Cache the last request in case there is an error to allow
316
                // the upload to be resumed.
317
3
                local_state.push_message(message.clone());
318
3
                Some(message)
319
            }
320
0
            Some(Err(err)) => {
321
0
                local_state.read_stream_error = Some(err);
322
0
                None
323
            }
324
1
            None => None,
325
        };
326
4
        Poll::Ready(result)
327
4
    }
328
}