/build/source/nativelink-util/src/proto_stream_utils.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::borrow::Cow; |
16 | | use std::pin::Pin; |
17 | | use std::sync::Arc; |
18 | | use std::task::{Context, Poll}; |
19 | | |
20 | | use futures::{Stream, StreamExt}; |
21 | | use nativelink_error::{error_if, make_input_err, Error, ResultExt}; |
22 | | use nativelink_proto::google::bytestream::{ReadResponse, WriteRequest}; |
23 | | use parking_lot::Mutex; |
24 | | use tonic::{Status, Streaming}; |
25 | | |
26 | | use crate::resource_info::ResourceInfo; |
27 | | |
28 | | #[derive(Debug)] |
29 | | pub struct WriteRequestStreamWrapper<T, E> |
30 | | where |
31 | | E: Into<Error>, |
32 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
33 | | { |
34 | | pub resource_info: ResourceInfo<'static>, |
35 | | pub bytes_received: usize, |
36 | | stream: T, |
37 | | first_msg: Option<WriteRequest>, |
38 | | write_finished: bool, |
39 | | } |
40 | | |
41 | | impl<T, E> WriteRequestStreamWrapper<T, E> |
42 | | where |
43 | | E: Into<Error>, |
44 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
45 | | { |
46 | 16 | pub async fn from(mut stream: T) -> Result<WriteRequestStreamWrapper<T, E>, Error> { |
47 | 16 | let first_msg15 = stream |
48 | 16 | .next() |
49 | 65 | .await |
50 | 16 | .err_tip(|| "Error receiving first message in stream"0 )?0 |
51 | 16 | .err_tip(|| "Expected WriteRequest struct in stream"1 )?1 ; |
52 | | |
53 | 15 | let resource_info = ResourceInfo::new(&first_msg.resource_name, true) |
54 | 15 | .err_tip(|| { |
55 | 0 | format!( |
56 | 0 | "Could not extract resource info from first message of stream: {}", |
57 | 0 | first_msg.resource_name |
58 | 0 | ) |
59 | 15 | })?0 |
60 | 15 | .to_owned(); |
61 | 15 | |
62 | 15 | Ok(WriteRequestStreamWrapper { |
63 | 15 | resource_info, |
64 | 15 | bytes_received: 0, |
65 | 15 | stream, |
66 | 15 | first_msg: Some(first_msg), |
67 | 15 | write_finished: false, |
68 | 15 | }) |
69 | 16 | } |
70 | | |
71 | 24 | pub async fn next(&mut self) -> Option<Result<WriteRequest, Error>> { |
72 | 125 | futures::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx))24 .await101 |
73 | 24 | } |
74 | | |
75 | 0 | pub fn is_first_msg(&self) -> bool { |
76 | 0 | self.first_msg.is_some() |
77 | 0 | } |
78 | | } |
79 | | |
80 | | impl<T, E> Stream for WriteRequestStreamWrapper<T, E> |
81 | | where |
82 | | E: Into<Error>, |
83 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
84 | | { |
85 | | type Item = Result<WriteRequest, Error>; |
86 | | |
87 | 129 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
88 | 129 | // If the stream said that the previous message was the last one, then |
89 | 129 | // return a stream EOF (i.e. None). |
90 | 129 | if self.write_finished { Branch (90:12): [True: 0, False: 125]
Branch (90:12): [True: 0, False: 0]
Branch (90:12): [Folded - Ignored]
Branch (90:12): [Folded - Ignored]
Branch (90:12): [True: 1, False: 3]
|
91 | 0 | error_if!( |
92 | 1 | self.bytes_received != self.resource_info.expected_size, Branch (92:17): [True: 0, False: 0]
Branch (92:17): [True: 0, False: 0]
Branch (92:17): [Folded - Ignored]
Branch (92:17): [Folded - Ignored]
Branch (92:17): [True: 0, False: 1]
|
93 | | "Did not send enough data. Expected {}, but so far received {}", |
94 | 0 | self.resource_info.expected_size, |
95 | 0 | self.bytes_received |
96 | | ); |
97 | 1 | return Poll::Ready(None); |
98 | 128 | } |
99 | | |
100 | | // Gets the next message, this is either the cached first or a |
101 | | // subsequent message from the wrapped Stream. |
102 | 128 | let maybe_message27 = if let Some(first_msg15 ) = self.first_msg.take() { Branch (102:36): [True: 14, False: 111]
Branch (102:36): [True: 0, False: 0]
Branch (102:36): [Folded - Ignored]
Branch (102:36): [Folded - Ignored]
Branch (102:36): [True: 1, False: 2]
|
103 | 15 | Ok(first_msg) |
104 | | } else { |
105 | 113 | match Pin::new(&mut self.stream).poll_next(cx) { |
106 | 101 | Poll::Pending => return Poll::Pending, |
107 | 9 | Poll::Ready(Some(maybe_message)) => maybe_message |
108 | 9 | .err_tip(|| format!("Stream error at byte {}", self.bytes_received)0 ), |
109 | 3 | Poll::Ready(None) => Err(make_input_err!("Expected WriteRequest struct in stream")), |
110 | | } |
111 | | }; |
112 | | |
113 | | // If we successfully got a message, update our internal state with the |
114 | | // message meta data. |
115 | 27 | Poll::Ready(Some(maybe_message.and_then(|message| { |
116 | 24 | self.write_finished = message.finish_write; |
117 | 24 | self.bytes_received += message.data.len(); |
118 | 24 | |
119 | 24 | // Check that we haven't read past the expected end. |
120 | 24 | if self.bytes_received > self.resource_info.expected_size { Branch (120:16): [True: 1, False: 20]
Branch (120:16): [True: 0, False: 0]
Branch (120:16): [Folded - Ignored]
Branch (120:16): [Folded - Ignored]
Branch (120:16): [True: 0, False: 3]
|
121 | 1 | Err(make_input_err!( |
122 | 1 | "Sent too much data. Expected {}, but so far received {}", |
123 | 1 | self.resource_info.expected_size, |
124 | 1 | self.bytes_received |
125 | 1 | )) |
126 | | } else { |
127 | 23 | Ok(message) |
128 | | } |
129 | 27 | }24 ))) |
130 | 129 | } |
131 | | } |
132 | | |
133 | | /// This provides a buffer for the first response from GrpcStore.read in order |
134 | | /// to allow the first read to occur within the retry loop. That means that if |
135 | | /// the connection establishes fine, but reading the first byte of the file |
136 | | /// fails we have the ability to retry before returning to the caller. |
137 | | pub struct FirstStream { |
138 | | /// Contains the first response from the stream (which could be an EOF, |
139 | | /// hence the nested Option). This should be populated on creation and |
140 | | /// returned as the first result from the stream. Subsequent reads from the |
141 | | /// stream will use the encapsulated stream. |
142 | | first_response: Option<Option<ReadResponse>>, |
143 | | /// The stream to get responses from when first_response is None. |
144 | | stream: Streaming<ReadResponse>, |
145 | | } |
146 | | |
147 | | impl FirstStream { |
148 | 0 | pub fn new(first_response: Option<ReadResponse>, stream: Streaming<ReadResponse>) -> Self { |
149 | 0 | Self { |
150 | 0 | first_response: Some(first_response), |
151 | 0 | stream, |
152 | 0 | } |
153 | 0 | } |
154 | | } |
155 | | |
156 | | impl Stream for FirstStream { |
157 | | type Item = Result<ReadResponse, Status>; |
158 | | |
159 | 0 | fn poll_next( |
160 | 0 | mut self: Pin<&mut Self>, |
161 | 0 | cx: &mut std::task::Context<'_>, |
162 | 0 | ) -> std::task::Poll<Option<Self::Item>> { |
163 | 0 | if let Some(first_response) = self.first_response.take() { Branch (163:16): [True: 0, False: 0]
Branch (163:16): [Folded - Ignored]
|
164 | 0 | return std::task::Poll::Ready(first_response.map(Ok)); |
165 | 0 | } |
166 | 0 | Pin::new(&mut self.stream).poll_next(cx) |
167 | 0 | } |
168 | | } |
169 | | |
170 | | /// This structure wraps all of the information required to perform a write |
171 | | /// request on the GrpcStore. It stores the last message retrieved which allows |
172 | | /// the write to resume since the UUID allows upload resume at the server. |
173 | | pub struct WriteState<T, E> |
174 | | where |
175 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
176 | | E: Into<Error> + 'static, |
177 | | { |
178 | | instance_name: String, |
179 | | read_stream_error: Option<Error>, |
180 | | read_stream: WriteRequestStreamWrapper<T, E>, |
181 | | // Tonic doesn't appear to report an error until it has taken two messages, |
182 | | // therefore we are required to buffer the last two messages. |
183 | | cached_messages: [Option<WriteRequest>; 2], |
184 | | // When resuming after an error, the previous messages are cloned into this |
185 | | // queue upfront to allow them to be served back. |
186 | | resume_queue: [Option<WriteRequest>; 2], |
187 | | // An optimisation to avoid having to manage resume_queue when it's empty. |
188 | | is_resumed: bool, |
189 | | } |
190 | | |
191 | | impl<T, E> WriteState<T, E> |
192 | | where |
193 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
194 | | E: Into<Error> + 'static, |
195 | | { |
196 | 1 | pub fn new(instance_name: String, read_stream: WriteRequestStreamWrapper<T, E>) -> Self { |
197 | 1 | Self { |
198 | 1 | instance_name, |
199 | 1 | read_stream_error: None, |
200 | 1 | read_stream, |
201 | 1 | cached_messages: [None, None], |
202 | 1 | resume_queue: [None, None], |
203 | 1 | is_resumed: false, |
204 | 1 | } |
205 | 1 | } |
206 | | |
207 | 3 | fn push_message(&mut self, message: WriteRequest) { |
208 | 3 | self.cached_messages.swap(0, 1); |
209 | 3 | self.cached_messages[0] = Some(message); |
210 | 3 | } |
211 | | |
212 | 4 | fn resumed_message(&mut self) -> Option<WriteRequest> { |
213 | 4 | if self.is_resumed { Branch (213:12): [True: 0, False: 0]
Branch (213:12): [True: 0, False: 0]
Branch (213:12): [Folded - Ignored]
Branch (213:12): [Folded - Ignored]
Branch (213:12): [True: 0, False: 4]
|
214 | | // The resume_queue is a circular buffer, that we have to shift, |
215 | | // since its only got two elements its a trivial swap. |
216 | 0 | self.resume_queue.swap(0, 1); |
217 | 0 | let message = self.resume_queue[0].take(); |
218 | 0 | if message.is_none() { Branch (218:16): [True: 0, False: 0]
Branch (218:16): [True: 0, False: 0]
Branch (218:16): [Folded - Ignored]
Branch (218:16): [Folded - Ignored]
Branch (218:16): [True: 0, False: 0]
|
219 | 0 | self.is_resumed = false; |
220 | 0 | } |
221 | 0 | message |
222 | | } else { |
223 | 4 | None |
224 | | } |
225 | 4 | } |
226 | | |
227 | 0 | pub fn can_resume(&self) -> bool { |
228 | 0 | self.read_stream_error.is_none() Branch (228:9): [True: 0, False: 0]
Branch (228:9): [True: 0, False: 0]
Branch (228:9): [Folded - Ignored]
Branch (228:9): [Folded - Ignored]
|
229 | 0 | && (self.cached_messages[0].is_some() || self.read_stream.is_first_msg()) Branch (229:17): [True: 0, False: 0]
Branch (229:17): [True: 0, False: 0]
Branch (229:17): [Folded - Ignored]
Branch (229:17): [Folded - Ignored]
|
230 | 0 | } |
231 | | |
232 | 0 | pub fn resume(&mut self) { |
233 | 0 | self.resume_queue.clone_from(&self.cached_messages); |
234 | 0 | self.is_resumed = true; |
235 | 0 | } |
236 | | |
237 | 1 | pub fn take_read_stream_error(&mut self) -> Option<Error> { |
238 | 1 | self.read_stream_error.take() |
239 | 1 | } |
240 | | } |
241 | | |
242 | | /// A wrapper around WriteState to allow it to be reclaimed from the underlying |
243 | | /// write call in the case of failure. |
244 | | pub struct WriteStateWrapper<T, E> |
245 | | where |
246 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
247 | | E: Into<Error> + 'static, |
248 | | { |
249 | | shared_state: Arc<Mutex<WriteState<T, E>>>, |
250 | | } |
251 | | |
252 | | impl<T, E> WriteStateWrapper<T, E> |
253 | | where |
254 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
255 | | E: Into<Error> + 'static, |
256 | | { |
257 | 1 | pub fn new(shared_state: Arc<Mutex<WriteState<T, E>>>) -> Self { |
258 | 1 | Self { shared_state } |
259 | 1 | } |
260 | | } |
261 | | |
262 | | impl<T, E> Stream for WriteStateWrapper<T, E> |
263 | | where |
264 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
265 | | E: Into<Error> + 'static, |
266 | | { |
267 | | type Item = WriteRequest; |
268 | | |
269 | 4 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
270 | | const IS_UPLOAD_TRUE: bool = true; |
271 | | |
272 | | // This should be an uncontended lock since write was called. |
273 | 4 | let mut local_state = self.shared_state.lock(); |
274 | 4 | // If this is the first or second call after a failure and we have |
275 | 4 | // cached messages, then use the cached write requests. |
276 | 4 | let cached_message = local_state.resumed_message(); |
277 | 4 | if cached_message.is_some() { Branch (277:12): [True: 0, False: 0]
Branch (277:12): [True: 0, False: 0]
Branch (277:12): [Folded - Ignored]
Branch (277:12): [Folded - Ignored]
Branch (277:12): [True: 0, False: 4]
|
278 | 0 | return Poll::Ready(cached_message); |
279 | 4 | } |
280 | | // Read a new write request from the downstream. |
281 | 4 | let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx) Branch (281:13): [True: 0, False: 0]
Branch (281:13): [True: 0, False: 0]
Branch (281:13): [Folded - Ignored]
Branch (281:13): [Folded - Ignored]
Branch (281:13): [True: 4, False: 0]
|
282 | | else { |
283 | 0 | return Poll::Pending; |
284 | | }; |
285 | | // Update the instance name in the write request and forward it on. |
286 | 4 | let result = match maybe_message3 { |
287 | 3 | Some(Ok(mut message)) => { |
288 | 3 | if !message.resource_name.is_empty() { Branch (288:20): [True: 0, False: 0]
Branch (288:20): [True: 0, False: 0]
Branch (288:20): [Folded - Ignored]
Branch (288:20): [Folded - Ignored]
Branch (288:20): [True: 1, False: 2]
|
289 | | // Replace the instance name in the resource name if it is |
290 | | // different from the instance name in the write state. |
291 | 1 | match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) { |
292 | 1 | Ok(mut resource_name) => { |
293 | 1 | if resource_name.instance_name != local_state.instance_name { Branch (293:32): [True: 0, False: 0]
Branch (293:32): [True: 0, False: 0]
Branch (293:32): [Folded - Ignored]
Branch (293:32): [Folded - Ignored]
Branch (293:32): [True: 0, False: 1]
|
294 | 0 | resource_name.instance_name = |
295 | 0 | Cow::Borrowed(&local_state.instance_name); |
296 | 0 | message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE); |
297 | 1 | } |
298 | | } |
299 | 0 | Err(err) => { |
300 | 0 | local_state.read_stream_error = Some(err); |
301 | 0 | return Poll::Ready(None); |
302 | | } |
303 | | } |
304 | 2 | } |
305 | | // Cache the last request in case there is an error to allow |
306 | | // the upload to be resumed. |
307 | 3 | local_state.push_message(message.clone()); |
308 | 3 | Some(message) |
309 | | } |
310 | 0 | Some(Err(err)) => { |
311 | 0 | local_state.read_stream_error = Some(err); |
312 | 0 | None |
313 | | } |
314 | 1 | None => None, |
315 | | }; |
316 | 4 | Poll::Ready(result) |
317 | 4 | } |
318 | | } |