/build/source/nativelink-util/src/proto_stream_utils.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::fmt::Debug; |
16 | | use core::mem; |
17 | | use core::pin::Pin; |
18 | | use core::task::{Context, Poll}; |
19 | | use std::borrow::Cow; |
20 | | use std::sync::Arc; |
21 | | |
22 | | use futures::{Stream, StreamExt}; |
23 | | use nativelink_error::{Error, ResultExt, error_if, make_input_err}; |
24 | | use nativelink_proto::google::bytestream::{ReadResponse, WriteRequest}; |
25 | | use parking_lot::Mutex; |
26 | | use tonic::{Status, Streaming}; |
27 | | |
28 | | use crate::resource_info::ResourceInfo; |
29 | | |
30 | | pub struct WriteRequestStreamWrapper<T> { |
31 | | pub resource_info: ResourceInfo<'static>, |
32 | | pub bytes_received: usize, |
33 | | stream: T, |
34 | | first_msg: Option<WriteRequest>, |
35 | | pub write_finished: bool, |
36 | | } |
37 | | |
38 | | impl<T> Debug for WriteRequestStreamWrapper<T> { |
39 | 11 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
40 | 11 | f.debug_struct("WriteRequestStreamWrapper") |
41 | 11 | .field("resource_info", &self.resource_info) |
42 | 11 | .field("bytes_received", &self.bytes_received) |
43 | 11 | .field("first_msg", &self.first_msg) |
44 | 11 | .field("write_finished", &self.write_finished) |
45 | 11 | .finish() |
46 | 11 | } |
47 | | } |
48 | | |
49 | | impl<T, E> WriteRequestStreamWrapper<T> |
50 | | where |
51 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
52 | | E: Into<Error>, |
53 | | { |
54 | 17 | pub async fn from(mut stream: T) -> Result<Self, Error> { |
55 | 17 | let first_msg16 = stream |
56 | 17 | .next() |
57 | 17 | .await |
58 | 17 | .err_tip(|| "Error receiving first message in stream")?0 |
59 | 17 | .err_tip(|| "Expected WriteRequest struct in stream")?1 ; |
60 | | |
61 | 16 | let resource_info = ResourceInfo::new(&first_msg.resource_name, true) |
62 | 16 | .err_tip(|| {0 |
63 | 0 | format!( |
64 | 0 | "Could not extract resource info from first message of stream: {}", |
65 | | first_msg.resource_name |
66 | | ) |
67 | 0 | })? |
68 | 16 | .to_owned(); |
69 | | |
70 | 16 | Ok(Self { |
71 | 16 | resource_info, |
72 | 16 | bytes_received: 0, |
73 | 16 | stream, |
74 | 16 | first_msg: Some(first_msg), |
75 | 16 | write_finished: false, |
76 | 16 | }) |
77 | 17 | } |
78 | | |
79 | 25 | pub async fn next(&mut self) -> Option<Result<WriteRequest, Error>> { |
80 | 126 | futures::future::poll_fn25 (|cx| Pin::new(&mut *self).poll_next(cx)).await25 |
81 | 25 | } |
82 | | |
83 | 0 | pub const fn is_first_msg(&self) -> bool { |
84 | 0 | self.first_msg.is_some() |
85 | 0 | } |
86 | | |
87 | | /// Returns whether the first message has `finish_write` set to true. |
88 | | /// This indicates a single-shot upload where all data is in one message. |
89 | 15 | pub fn is_first_msg_complete(&self) -> bool { |
90 | 15 | self.first_msg.as_ref().is_some_and(|msg| msg.finish_write) |
91 | 15 | } |
92 | | } |
93 | | |
94 | | impl<T, E> Stream for WriteRequestStreamWrapper<T> |
95 | | where |
96 | | E: Into<Error>, |
97 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
98 | | { |
99 | | type Item = Result<WriteRequest, Error>; |
100 | | |
101 | 130 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
102 | | // If the stream said that the previous message was the last one, then |
103 | | // return a stream EOF (i.e. None). |
104 | 130 | if self.write_finished { Branch (104:12): [True: 0, False: 126]
Branch (104:12): [True: 0, False: 0]
Branch (104:12): [Folded - Ignored]
Branch (104:12): [Folded - Ignored]
Branch (104:12): [True: 1, False: 3]
|
105 | 0 | error_if!( |
106 | 1 | self.bytes_received != self.resource_info.expected_size, Branch (106:17): [True: 0, False: 0]
Branch (106:17): [True: 0, False: 0]
Branch (106:17): [Folded - Ignored]
Branch (106:17): [Folded - Ignored]
Branch (106:17): [True: 0, False: 1]
|
107 | | "Did not send enough data. Expected {}, but so far received {}", |
108 | 0 | self.resource_info.expected_size, |
109 | 0 | self.bytes_received |
110 | | ); |
111 | 1 | return Poll::Ready(None); |
112 | 129 | } |
113 | | |
114 | | // Gets the next message, this is either the cached first or a |
115 | | // subsequent message from the wrapped Stream. |
116 | 129 | let maybe_message28 = if let Some(first_msg16 ) = self.first_msg.take() { Branch (116:36): [True: 15, False: 111]
Branch (116:36): [True: 0, False: 0]
Branch (116:36): [Folded - Ignored]
Branch (116:36): [Folded - Ignored]
Branch (116:36): [True: 1, False: 2]
|
117 | 16 | Ok(first_msg) |
118 | | } else { |
119 | 113 | match Pin::new(&mut self.stream).poll_next(cx) { |
120 | 101 | Poll::Pending => return Poll::Pending, |
121 | 9 | Poll::Ready(Some(maybe_message)) => maybe_message |
122 | 9 | .err_tip(|| format!("Stream error at byte {}"0 , self.bytes_received0 )), |
123 | 3 | Poll::Ready(None) => Err(make_input_err!("Expected WriteRequest struct in stream")), |
124 | | } |
125 | | }; |
126 | | |
127 | | // If we successfully got a message, update our internal state with the |
128 | | // message meta data. |
129 | 28 | Poll::Ready(Some(maybe_message.and_then(|message| {25 |
130 | 25 | self.write_finished = message.finish_write; |
131 | 25 | self.bytes_received += message.data.len(); |
132 | | |
133 | | // Check that we haven't read past the expected end. |
134 | 25 | if self.bytes_received > self.resource_info.expected_size { Branch (134:16): [True: 2, False: 20]
Branch (134:16): [True: 0, False: 0]
Branch (134:16): [Folded - Ignored]
Branch (134:16): [Folded - Ignored]
Branch (134:16): [True: 0, False: 3]
|
135 | 2 | Err(make_input_err!( |
136 | 2 | "Sent too much data. Expected {}, but so far received {}", |
137 | 2 | self.resource_info.expected_size, |
138 | 2 | self.bytes_received |
139 | 2 | )) |
140 | | } else { |
141 | 23 | Ok(message) |
142 | | } |
143 | 25 | }))) |
144 | 130 | } |
145 | | } |
146 | | |
147 | | /// Represents the state of the first response in a `FirstStream`. |
148 | | #[derive(Debug)] |
149 | | pub enum FirstResponseState { |
150 | | /// Contains an optional first response that hasn't been consumed yet. |
151 | | /// A `None` value indicates the first response was EOF. |
152 | | Unused(Option<ReadResponse>), |
153 | | /// Indicates the first response has been consumed and future reads should |
154 | | /// come from the underlying stream. |
155 | | Used, |
156 | | } |
157 | | |
158 | | /// This provides a buffer for the first response from GrpcStore.read in order |
159 | | /// to allow the first read to occur within the retry loop. That means that if |
160 | | /// the connection establishes fine, but reading the first byte of the file |
161 | | /// fails we have the ability to retry before returning to the caller. |
162 | | #[derive(Debug)] |
163 | | pub struct FirstStream { |
164 | | /// The current state of the first response. When in the `Unused` state, |
165 | | /// contains an optional response which could be `None` or an EOF. |
166 | | /// Once consumed, transitions to the `Used` state. |
167 | | state: FirstResponseState, |
168 | | /// The stream to get responses from after the first response is consumed. |
169 | | stream: Streaming<ReadResponse>, |
170 | | } |
171 | | |
172 | | impl FirstStream { |
173 | | /// Creates a new `FirstStream` with the given first response and underlying |
174 | | /// stream. |
175 | 0 | pub const fn new( |
176 | 0 | first_response: Option<ReadResponse>, |
177 | 0 | stream: Streaming<ReadResponse>, |
178 | 0 | ) -> Self { |
179 | 0 | Self { |
180 | 0 | state: FirstResponseState::Unused(first_response), |
181 | 0 | stream, |
182 | 0 | } |
183 | 0 | } |
184 | | } |
185 | | |
186 | | impl Stream for FirstStream { |
187 | | type Item = Result<ReadResponse, Status>; |
188 | | |
189 | 0 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
190 | 0 | match mem::replace(&mut self.state, FirstResponseState::Used) { |
191 | 0 | FirstResponseState::Unused(first_response) => Poll::Ready(first_response.map(Ok)), |
192 | 0 | FirstResponseState::Used => Pin::new(&mut self.stream).poll_next(cx), |
193 | | } |
194 | 0 | } |
195 | | } |
196 | | |
197 | | /// This structure wraps all of the information required to perform a write |
198 | | /// request on the `GrpcStore`. It stores the last message retrieved which allows |
199 | | /// the write to resume since the UUID allows upload resume at the server. |
200 | | #[derive(Debug)] |
201 | | pub struct WriteState<T, E> |
202 | | where |
203 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
204 | | E: Into<Error> + 'static, |
205 | | { |
206 | | instance_name: String, |
207 | | read_stream_error: Option<Error>, |
208 | | read_stream: WriteRequestStreamWrapper<T>, |
209 | | // Tonic doesn't appear to report an error until it has taken two messages, |
210 | | // therefore we are required to buffer the last two messages. |
211 | | cached_messages: [Option<WriteRequest>; 2], |
212 | | // When resuming after an error, the previous messages are cloned into this |
213 | | // queue upfront to allow them to be served back. |
214 | | resume_queue: [Option<WriteRequest>; 2], |
215 | | // An optimisation to avoid having to manage resume_queue when it's empty. |
216 | | is_resumed: bool, |
217 | | } |
218 | | |
219 | | impl<T, E> WriteState<T, E> |
220 | | where |
221 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
222 | | E: Into<Error> + 'static, |
223 | | { |
224 | 1 | pub const fn new(instance_name: String, read_stream: WriteRequestStreamWrapper<T>) -> Self { |
225 | 1 | Self { |
226 | 1 | instance_name, |
227 | 1 | read_stream_error: None, |
228 | 1 | read_stream, |
229 | 1 | cached_messages: [None, None], |
230 | 1 | resume_queue: [None, None], |
231 | 1 | is_resumed: false, |
232 | 1 | } |
233 | 1 | } |
234 | | |
235 | 3 | fn push_message(&mut self, message: WriteRequest) { |
236 | 3 | self.cached_messages.swap(0, 1); |
237 | 3 | self.cached_messages[0] = Some(message); |
238 | 3 | } |
239 | | |
240 | 4 | const fn resumed_message(&mut self) -> Option<WriteRequest> { |
241 | 4 | if self.is_resumed { Branch (241:12): [True: 0, False: 0]
Branch (241:12): [True: 0, False: 0]
Branch (241:12): [Folded - Ignored]
Branch (241:12): [Folded - Ignored]
Branch (241:12): [True: 0, False: 4]
|
242 | | // The resume_queue is a circular buffer, that we have to shift, |
243 | | // since its only got two elements its a trivial swap. |
244 | 0 | self.resume_queue.swap(0, 1); |
245 | 0 | let message = self.resume_queue[0].take(); |
246 | 0 | if message.is_none() { Branch (246:16): [True: 0, False: 0]
Branch (246:16): [True: 0, False: 0]
Branch (246:16): [Folded - Ignored]
Branch (246:16): [Folded - Ignored]
Branch (246:16): [True: 0, False: 0]
|
247 | 0 | self.is_resumed = false; |
248 | 0 | } |
249 | 0 | message |
250 | | } else { |
251 | 4 | None |
252 | | } |
253 | 4 | } |
254 | | |
255 | 0 | pub const fn can_resume(&self) -> bool { |
256 | 0 | self.read_stream_error.is_none() Branch (256:9): [True: 0, False: 0]
Branch (256:9): [True: 0, False: 0]
Branch (256:9): [Folded - Ignored]
Branch (256:9): [Folded - Ignored]
|
257 | 0 | && (self.cached_messages[0].is_some() || self.read_stream.is_first_msg()) Branch (257:17): [True: 0, False: 0]
Branch (257:17): [True: 0, False: 0]
Branch (257:17): [Folded - Ignored]
Branch (257:17): [Folded - Ignored]
|
258 | 0 | } |
259 | | |
260 | 0 | pub fn resume(&mut self) { |
261 | 0 | self.resume_queue.clone_from(&self.cached_messages); |
262 | 0 | self.is_resumed = true; |
263 | 0 | } |
264 | | |
265 | 1 | pub const fn take_read_stream_error(&mut self) -> Option<Error> { |
266 | 1 | self.read_stream_error.take() |
267 | 1 | } |
268 | | } |
269 | | |
270 | | /// A wrapper around `WriteState` to allow it to be reclaimed from the underlying |
271 | | /// write call in the case of failure. |
272 | | #[derive(Debug)] |
273 | | pub struct WriteStateWrapper<T, E> |
274 | | where |
275 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
276 | | E: Into<Error> + 'static, |
277 | | { |
278 | | shared_state: Arc<Mutex<WriteState<T, E>>>, |
279 | | } |
280 | | |
281 | | impl<T, E> WriteStateWrapper<T, E> |
282 | | where |
283 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
284 | | E: Into<Error> + 'static, |
285 | | { |
286 | 1 | pub const fn new(shared_state: Arc<Mutex<WriteState<T, E>>>) -> Self { |
287 | 1 | Self { shared_state } |
288 | 1 | } |
289 | | } |
290 | | |
291 | | impl<T, E> Stream for WriteStateWrapper<T, E> |
292 | | where |
293 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
294 | | E: Into<Error> + 'static, |
295 | | { |
296 | | type Item = WriteRequest; |
297 | | |
298 | 4 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
299 | | const IS_UPLOAD_TRUE: bool = true; |
300 | | |
301 | | // This should be an uncontended lock since write was called. |
302 | 4 | let mut local_state = self.shared_state.lock(); |
303 | | // If this is the first or second call after a failure and we have |
304 | | // cached messages, then use the cached write requests. |
305 | 4 | let cached_message = local_state.resumed_message(); |
306 | 4 | if cached_message.is_some() { Branch (306:12): [True: 0, False: 0]
Branch (306:12): [True: 0, False: 0]
Branch (306:12): [Folded - Ignored]
Branch (306:12): [Folded - Ignored]
Branch (306:12): [True: 0, False: 4]
|
307 | 0 | return Poll::Ready(cached_message); |
308 | 4 | } |
309 | | // Read a new write request from the downstream. |
310 | 4 | let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx) Branch (310:13): [True: 0, False: 0]
Branch (310:13): [True: 0, False: 0]
Branch (310:13): [Folded - Ignored]
Branch (310:13): [Folded - Ignored]
Branch (310:13): [True: 4, False: 0]
|
311 | | else { |
312 | 0 | return Poll::Pending; |
313 | | }; |
314 | | // Update the instance name in the write request and forward it on. |
315 | 4 | let result = match maybe_message3 { |
316 | 3 | Some(Ok(mut message)) => { |
317 | 3 | if !message.resource_name.is_empty() { Branch (317:20): [True: 0, False: 0]
Branch (317:20): [True: 0, False: 0]
Branch (317:20): [Folded - Ignored]
Branch (317:20): [Folded - Ignored]
Branch (317:20): [True: 1, False: 2]
|
318 | | // Replace the instance name in the resource name if it is |
319 | | // different from the instance name in the write state. |
320 | 1 | match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) { |
321 | 1 | Ok(mut resource_name) => { |
322 | 1 | if resource_name.instance_name != local_state.instance_name { Branch (322:32): [True: 0, False: 0]
Branch (322:32): [True: 0, False: 0]
Branch (322:32): [Folded - Ignored]
Branch (322:32): [Folded - Ignored]
Branch (322:32): [True: 0, False: 1]
|
323 | 0 | resource_name.instance_name = |
324 | 0 | Cow::Borrowed(&local_state.instance_name); |
325 | 0 | message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE); |
326 | 1 | } |
327 | | } |
328 | 0 | Err(err) => { |
329 | 0 | local_state.read_stream_error = Some(err); |
330 | 0 | return Poll::Ready(None); |
331 | | } |
332 | | } |
333 | 2 | } |
334 | | // Cache the last request in case there is an error to allow |
335 | | // the upload to be resumed. |
336 | 3 | local_state.push_message(message.clone()); |
337 | 3 | Some(message) |
338 | | } |
339 | 0 | Some(Err(err)) => { |
340 | 0 | local_state.read_stream_error = Some(err); |
341 | 0 | None |
342 | | } |
343 | 1 | None => None, |
344 | | }; |
345 | 4 | Poll::Ready(result) |
346 | 4 | } |
347 | | } |