/build/source/nativelink-util/src/proto_stream_utils.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::fmt::Debug; |
16 | | use core::mem; |
17 | | use core::pin::Pin; |
18 | | use core::task::{Context, Poll}; |
19 | | use std::borrow::Cow; |
20 | | use std::sync::Arc; |
21 | | |
22 | | use futures::{Stream, StreamExt}; |
23 | | use nativelink_error::{Error, ResultExt, error_if, make_input_err}; |
24 | | use nativelink_proto::google::bytestream::{ReadResponse, WriteRequest}; |
25 | | use parking_lot::Mutex; |
26 | | use tonic::{Status, Streaming}; |
27 | | |
28 | | use crate::resource_info::ResourceInfo; |
29 | | |
30 | | pub struct WriteRequestStreamWrapper<T> { |
31 | | pub resource_info: ResourceInfo<'static>, |
32 | | pub bytes_received: usize, |
33 | | stream: T, |
34 | | first_msg: Option<WriteRequest>, |
35 | | pub write_finished: bool, |
36 | | } |
37 | | |
38 | | impl<T> Debug for WriteRequestStreamWrapper<T> { |
39 | 11 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
40 | 11 | f.debug_struct("WriteRequestStreamWrapper") |
41 | 11 | .field("resource_info", &self.resource_info) |
42 | 11 | .field("bytes_received", &self.bytes_received) |
43 | 11 | .field("first_msg", &self.first_msg) |
44 | 11 | .field("write_finished", &self.write_finished) |
45 | 11 | .finish() |
46 | 11 | } |
47 | | } |
48 | | |
49 | | impl<T, E> WriteRequestStreamWrapper<T> |
50 | | where |
51 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
52 | | E: Into<Error>, |
53 | | { |
54 | 17 | pub async fn from(mut stream: T) -> Result<Self, Error> { |
55 | 17 | let first_msg16 = stream |
56 | 17 | .next() |
57 | 17 | .await |
58 | 17 | .err_tip(|| "Error receiving first message in stream")?0 |
59 | 17 | .err_tip(|| "Expected WriteRequest struct in stream (from)")?1 ; |
60 | | |
61 | 16 | let resource_info = ResourceInfo::new(&first_msg.resource_name, true) |
62 | 16 | .err_tip(|| {0 |
63 | 0 | format!( |
64 | 0 | "Could not extract resource info from first message of stream: {}", |
65 | | first_msg.resource_name |
66 | | ) |
67 | 0 | })? |
68 | 16 | .to_owned(); |
69 | | |
70 | 16 | Ok(Self { |
71 | 16 | resource_info, |
72 | 16 | bytes_received: 0, |
73 | 16 | stream, |
74 | 16 | first_msg: Some(first_msg), |
75 | 16 | write_finished: false, |
76 | 16 | }) |
77 | 17 | } |
78 | | |
79 | 25 | pub async fn next(&mut self) -> Option<Result<WriteRequest, Error>> { |
80 | 126 | futures::future::poll_fn25 (|cx| Pin::new(&mut *self).poll_next(cx)).await25 |
81 | 25 | } |
82 | | |
83 | 0 | pub const fn is_first_msg(&self) -> bool { |
84 | 0 | self.first_msg.is_some() |
85 | 0 | } |
86 | | |
87 | | /// Returns whether the first message has `finish_write` set to true. |
88 | | /// This indicates a single-shot upload where all data is in one message. |
89 | 15 | pub fn is_first_msg_complete(&self) -> bool { |
90 | 15 | self.first_msg.as_ref().is_some_and(|msg| msg.finish_write) |
91 | 15 | } |
92 | | } |
93 | | |
94 | | impl<T, E> Stream for WriteRequestStreamWrapper<T> |
95 | | where |
96 | | E: Into<Error>, |
97 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
98 | | { |
99 | | type Item = Result<WriteRequest, Error>; |
100 | | |
101 | 130 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
102 | | // If the stream said that the previous message was the last one, then |
103 | | // return a stream EOF (i.e. None). |
104 | 130 | if self.write_finished { Branch (104:12): [True: 0, False: 126]
Branch (104:12): [True: 0, False: 0]
Branch (104:12): [Folded - Ignored]
Branch (104:12): [Folded - Ignored]
Branch (104:12): [True: 1, False: 3]
|
105 | 0 | error_if!( |
106 | 1 | self.bytes_received != self.resource_info.expected_size, Branch (106:17): [True: 0, False: 0]
Branch (106:17): [True: 0, False: 0]
Branch (106:17): [Folded - Ignored]
Branch (106:17): [Folded - Ignored]
Branch (106:17): [True: 0, False: 1]
|
107 | | "Did not send enough data. Expected {}, but so far received {}", |
108 | 0 | self.resource_info.expected_size, |
109 | 0 | self.bytes_received |
110 | | ); |
111 | 1 | return Poll::Ready(None); |
112 | 129 | } |
113 | | |
114 | | // Gets the next message, this is either the cached first or a |
115 | | // subsequent message from the wrapped Stream. |
116 | 129 | let maybe_message28 = if let Some(first_msg16 ) = self.first_msg.take() { Branch (116:36): [True: 15, False: 111]
Branch (116:36): [True: 0, False: 0]
Branch (116:36): [Folded - Ignored]
Branch (116:36): [Folded - Ignored]
Branch (116:36): [True: 1, False: 2]
|
117 | 16 | Ok(first_msg) |
118 | | } else { |
119 | 113 | match Pin::new(&mut self.stream).poll_next(cx) { |
120 | 101 | Poll::Pending => return Poll::Pending, |
121 | 9 | Poll::Ready(Some(maybe_message)) => maybe_message |
122 | 9 | .err_tip(|| format!("Stream error at byte {}"0 , self.bytes_received0 )), |
123 | 3 | Poll::Ready(None) => Err(make_input_err!( |
124 | 3 | "Expected WriteRequest struct in stream (got None)" |
125 | 3 | )), |
126 | | } |
127 | | }; |
128 | | |
129 | | // If we successfully got a message, update our internal state with the |
130 | | // message meta data. |
131 | 28 | Poll::Ready(Some(maybe_message.and_then(|message| {25 |
132 | 25 | self.write_finished = message.finish_write; |
133 | 25 | self.bytes_received += message.data.len(); |
134 | | |
135 | | // Check that we haven't read past the expected end. |
136 | 25 | if self.bytes_received > self.resource_info.expected_size { Branch (136:16): [True: 2, False: 20]
Branch (136:16): [True: 0, False: 0]
Branch (136:16): [Folded - Ignored]
Branch (136:16): [Folded - Ignored]
Branch (136:16): [True: 0, False: 3]
|
137 | 2 | Err(make_input_err!( |
138 | 2 | "Sent too much data. Expected {}, but so far received {}", |
139 | 2 | self.resource_info.expected_size, |
140 | 2 | self.bytes_received |
141 | 2 | )) |
142 | | } else { |
143 | 23 | Ok(message) |
144 | | } |
145 | 25 | }))) |
146 | 130 | } |
147 | | } |
148 | | |
149 | | /// Represents the state of the first response in a `FirstStream`. |
150 | | #[derive(Debug)] |
151 | | pub enum FirstResponseState { |
152 | | /// Contains an optional first response that hasn't been consumed yet. |
153 | | /// A `None` value indicates the first response was EOF. |
154 | | Unused(Option<ReadResponse>), |
155 | | /// Indicates the first response has been consumed and future reads should |
156 | | /// come from the underlying stream. |
157 | | Used, |
158 | | } |
159 | | |
160 | | /// This provides a buffer for the first response from GrpcStore.read in order |
161 | | /// to allow the first read to occur within the retry loop. That means that if |
162 | | /// the connection establishes fine, but reading the first byte of the file |
163 | | /// fails we have the ability to retry before returning to the caller. |
164 | | #[derive(Debug)] |
165 | | pub struct FirstStream { |
166 | | /// The current state of the first response. When in the `Unused` state, |
167 | | /// contains an optional response which could be `None` or an EOF. |
168 | | /// Once consumed, transitions to the `Used` state. |
169 | | state: FirstResponseState, |
170 | | /// The stream to get responses from after the first response is consumed. |
171 | | stream: Streaming<ReadResponse>, |
172 | | } |
173 | | |
174 | | impl FirstStream { |
175 | | /// Creates a new `FirstStream` with the given first response and underlying |
176 | | /// stream. |
177 | 0 | pub const fn new( |
178 | 0 | first_response: Option<ReadResponse>, |
179 | 0 | stream: Streaming<ReadResponse>, |
180 | 0 | ) -> Self { |
181 | 0 | Self { |
182 | 0 | state: FirstResponseState::Unused(first_response), |
183 | 0 | stream, |
184 | 0 | } |
185 | 0 | } |
186 | | } |
187 | | |
188 | | impl Stream for FirstStream { |
189 | | type Item = Result<ReadResponse, Status>; |
190 | | |
191 | 0 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
192 | 0 | match mem::replace(&mut self.state, FirstResponseState::Used) { |
193 | 0 | FirstResponseState::Unused(first_response) => Poll::Ready(first_response.map(Ok)), |
194 | 0 | FirstResponseState::Used => Pin::new(&mut self.stream).poll_next(cx), |
195 | | } |
196 | 0 | } |
197 | | } |
198 | | |
199 | | /// This structure wraps all of the information required to perform a write |
200 | | /// request on the `GrpcStore`. It stores the last message retrieved which allows |
201 | | /// the write to resume since the UUID allows upload resume at the server. |
202 | | #[derive(Debug)] |
203 | | pub struct WriteState<T, E> |
204 | | where |
205 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
206 | | E: Into<Error> + 'static, |
207 | | { |
208 | | instance_name: String, |
209 | | read_stream_error: Option<Error>, |
210 | | read_stream: WriteRequestStreamWrapper<T>, |
211 | | // Tonic doesn't appear to report an error until it has taken two messages, |
212 | | // therefore we are required to buffer the last two messages. |
213 | | cached_messages: [Option<WriteRequest>; 2], |
214 | | // When resuming after an error, the previous messages are cloned into this |
215 | | // queue upfront to allow them to be served back. |
216 | | resume_queue: [Option<WriteRequest>; 2], |
217 | | // An optimisation to avoid having to manage resume_queue when it's empty. |
218 | | is_resumed: bool, |
219 | | } |
220 | | |
221 | | impl<T, E> WriteState<T, E> |
222 | | where |
223 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
224 | | E: Into<Error> + 'static, |
225 | | { |
226 | 1 | pub const fn new(instance_name: String, read_stream: WriteRequestStreamWrapper<T>) -> Self { |
227 | 1 | Self { |
228 | 1 | instance_name, |
229 | 1 | read_stream_error: None, |
230 | 1 | read_stream, |
231 | 1 | cached_messages: [None, None], |
232 | 1 | resume_queue: [None, None], |
233 | 1 | is_resumed: false, |
234 | 1 | } |
235 | 1 | } |
236 | | |
237 | 3 | fn push_message(&mut self, message: WriteRequest) { |
238 | 3 | self.cached_messages.swap(0, 1); |
239 | 3 | self.cached_messages[0] = Some(message); |
240 | 3 | } |
241 | | |
242 | 4 | const fn resumed_message(&mut self) -> Option<WriteRequest> { |
243 | 4 | if self.is_resumed { Branch (243:12): [True: 0, False: 0]
Branch (243:12): [True: 0, False: 0]
Branch (243:12): [Folded - Ignored]
Branch (243:12): [Folded - Ignored]
Branch (243:12): [True: 0, False: 4]
|
244 | | // The resume_queue is a circular buffer, that we have to shift, |
245 | | // since its only got two elements its a trivial swap. |
246 | 0 | self.resume_queue.swap(0, 1); |
247 | 0 | let message = self.resume_queue[0].take(); |
248 | 0 | if message.is_none() { Branch (248:16): [True: 0, False: 0]
Branch (248:16): [True: 0, False: 0]
Branch (248:16): [Folded - Ignored]
Branch (248:16): [Folded - Ignored]
Branch (248:16): [True: 0, False: 0]
|
249 | 0 | self.is_resumed = false; |
250 | 0 | } |
251 | 0 | message |
252 | | } else { |
253 | 4 | None |
254 | | } |
255 | 4 | } |
256 | | |
257 | 0 | pub const fn can_resume(&self) -> bool { |
258 | 0 | self.read_stream_error.is_none() Branch (258:9): [True: 0, False: 0]
Branch (258:9): [True: 0, False: 0]
Branch (258:9): [Folded - Ignored]
Branch (258:9): [Folded - Ignored]
|
259 | 0 | && (self.cached_messages[0].is_some() || self.read_stream.is_first_msg()) Branch (259:17): [True: 0, False: 0]
Branch (259:17): [True: 0, False: 0]
Branch (259:17): [Folded - Ignored]
Branch (259:17): [Folded - Ignored]
|
260 | 0 | } |
261 | | |
262 | 0 | pub fn resume(&mut self) { |
263 | 0 | self.resume_queue.clone_from(&self.cached_messages); |
264 | 0 | self.is_resumed = true; |
265 | 0 | } |
266 | | |
267 | 1 | pub const fn take_read_stream_error(&mut self) -> Option<Error> { |
268 | 1 | self.read_stream_error.take() |
269 | 1 | } |
270 | | } |
271 | | |
272 | | /// A wrapper around `WriteState` to allow it to be reclaimed from the underlying |
273 | | /// write call in the case of failure. |
274 | | #[derive(Debug)] |
275 | | pub struct WriteStateWrapper<T, E> |
276 | | where |
277 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
278 | | E: Into<Error> + 'static, |
279 | | { |
280 | | shared_state: Arc<Mutex<WriteState<T, E>>>, |
281 | | } |
282 | | |
283 | | impl<T, E> WriteStateWrapper<T, E> |
284 | | where |
285 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
286 | | E: Into<Error> + 'static, |
287 | | { |
288 | 1 | pub const fn new(shared_state: Arc<Mutex<WriteState<T, E>>>) -> Self { |
289 | 1 | Self { shared_state } |
290 | 1 | } |
291 | | } |
292 | | |
293 | | impl<T, E> Stream for WriteStateWrapper<T, E> |
294 | | where |
295 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
296 | | E: Into<Error> + 'static, |
297 | | { |
298 | | type Item = WriteRequest; |
299 | | |
300 | 4 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
301 | | const IS_UPLOAD_TRUE: bool = true; |
302 | | |
303 | | // This should be an uncontended lock since write was called. |
304 | 4 | let mut local_state = self.shared_state.lock(); |
305 | | // If this is the first or second call after a failure and we have |
306 | | // cached messages, then use the cached write requests. |
307 | 4 | let cached_message = local_state.resumed_message(); |
308 | 4 | if cached_message.is_some() { Branch (308:12): [True: 0, False: 0]
Branch (308:12): [True: 0, False: 0]
Branch (308:12): [Folded - Ignored]
Branch (308:12): [Folded - Ignored]
Branch (308:12): [True: 0, False: 4]
|
309 | 0 | return Poll::Ready(cached_message); |
310 | 4 | } |
311 | | // Read a new write request from the downstream. |
312 | 4 | let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx) Branch (312:13): [True: 0, False: 0]
Branch (312:13): [True: 0, False: 0]
Branch (312:13): [Folded - Ignored]
Branch (312:13): [Folded - Ignored]
Branch (312:13): [True: 4, False: 0]
|
313 | | else { |
314 | 0 | return Poll::Pending; |
315 | | }; |
316 | | // Update the instance name in the write request and forward it on. |
317 | 4 | let result = match maybe_message3 { |
318 | 3 | Some(Ok(mut message)) => { |
319 | 3 | if !message.resource_name.is_empty() { Branch (319:20): [True: 0, False: 0]
Branch (319:20): [True: 0, False: 0]
Branch (319:20): [Folded - Ignored]
Branch (319:20): [Folded - Ignored]
Branch (319:20): [True: 1, False: 2]
|
320 | | // Replace the instance name in the resource name if it is |
321 | | // different from the instance name in the write state. |
322 | 1 | match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) { |
323 | 1 | Ok(mut resource_name) => { |
324 | 1 | if resource_name.instance_name != local_state.instance_name { Branch (324:32): [True: 0, False: 0]
Branch (324:32): [True: 0, False: 0]
Branch (324:32): [Folded - Ignored]
Branch (324:32): [Folded - Ignored]
Branch (324:32): [True: 0, False: 1]
|
325 | 0 | resource_name.instance_name = |
326 | 0 | Cow::Borrowed(&local_state.instance_name); |
327 | 0 | message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE); |
328 | 1 | } |
329 | | } |
330 | 0 | Err(err) => { |
331 | 0 | local_state.read_stream_error = Some(err); |
332 | 0 | return Poll::Ready(None); |
333 | | } |
334 | | } |
335 | 2 | } |
336 | | // Cache the last request in case there is an error to allow |
337 | | // the upload to be resumed. |
338 | 3 | local_state.push_message(message.clone()); |
339 | 3 | Some(message) |
340 | | } |
341 | 0 | Some(Err(err)) => { |
342 | 0 | local_state.read_stream_error = Some(err); |
343 | 0 | None |
344 | | } |
345 | 1 | None => None, |
346 | | }; |
347 | 4 | Poll::Ready(result) |
348 | 4 | } |
349 | | } |