/build/source/nativelink-util/src/proto_stream_utils.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::fmt::Debug; |
16 | | use core::mem; |
17 | | use core::pin::Pin; |
18 | | use core::task::{Context, Poll}; |
19 | | use std::borrow::Cow; |
20 | | use std::sync::Arc; |
21 | | |
22 | | use futures::{Stream, StreamExt}; |
23 | | use nativelink_error::{Error, ResultExt, error_if, make_input_err}; |
24 | | use nativelink_proto::google::bytestream::{ReadResponse, WriteRequest}; |
25 | | use parking_lot::Mutex; |
26 | | use tonic::{Status, Streaming}; |
27 | | |
28 | | use crate::resource_info::ResourceInfo; |
29 | | |
30 | | pub struct WriteRequestStreamWrapper<T> { |
31 | | pub resource_info: ResourceInfo<'static>, |
32 | | pub bytes_received: usize, |
33 | | stream: T, |
34 | | first_msg: Option<WriteRequest>, |
35 | | write_finished: bool, |
36 | | } |
37 | | |
38 | | impl<T> Debug for WriteRequestStreamWrapper<T> { |
39 | 15 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
40 | 15 | f.debug_struct("WriteRequestStreamWrapper") |
41 | 15 | .field("resource_info", &self.resource_info) |
42 | 15 | .field("bytes_received", &self.bytes_received) |
43 | 15 | .field("first_msg", &self.first_msg) |
44 | 15 | .field("write_finished", &self.write_finished) |
45 | 15 | .finish() |
46 | 15 | } |
47 | | } |
48 | | |
49 | | impl<T, E> WriteRequestStreamWrapper<T> |
50 | | where |
51 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
52 | | E: Into<Error>, |
53 | | { |
54 | 17 | pub async fn from(mut stream: T) -> Result<Self, Error> { |
55 | 17 | let first_msg16 = stream |
56 | 17 | .next() |
57 | 17 | .await |
58 | 17 | .err_tip(|| "Error receiving first message in stream")?0 |
59 | 17 | .err_tip(|| "Expected WriteRequest struct in stream")?1 ; |
60 | | |
61 | 16 | let resource_info = ResourceInfo::new(&first_msg.resource_name, true) |
62 | 16 | .err_tip(|| {0 |
63 | 0 | format!( |
64 | 0 | "Could not extract resource info from first message of stream: {}", |
65 | | first_msg.resource_name |
66 | | ) |
67 | 0 | })? |
68 | 16 | .to_owned(); |
69 | | |
70 | 16 | Ok(Self { |
71 | 16 | resource_info, |
72 | 16 | bytes_received: 0, |
73 | 16 | stream, |
74 | 16 | first_msg: Some(first_msg), |
75 | 16 | write_finished: false, |
76 | 16 | }) |
77 | 17 | } |
78 | | |
79 | 25 | pub async fn next(&mut self) -> Option<Result<WriteRequest, Error>> { |
80 | 126 | futures::future::poll_fn25 (|cx| Pin::new(&mut *self).poll_next(cx)).await25 |
81 | 25 | } |
82 | | |
83 | 0 | pub const fn is_first_msg(&self) -> bool { |
84 | 0 | self.first_msg.is_some() |
85 | 0 | } |
86 | | } |
87 | | |
88 | | impl<T, E> Stream for WriteRequestStreamWrapper<T> |
89 | | where |
90 | | E: Into<Error>, |
91 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
92 | | { |
93 | | type Item = Result<WriteRequest, Error>; |
94 | | |
95 | 130 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
96 | | // If the stream said that the previous message was the last one, then |
97 | | // return a stream EOF (i.e. None). |
98 | 130 | if self.write_finished { Branch (98:12): [True: 0, False: 126]
Branch (98:12): [True: 0, False: 0]
Branch (98:12): [Folded - Ignored]
Branch (98:12): [Folded - Ignored]
Branch (98:12): [True: 1, False: 3]
|
99 | 0 | error_if!( |
100 | 1 | self.bytes_received != self.resource_info.expected_size, Branch (100:17): [True: 0, False: 0]
Branch (100:17): [True: 0, False: 0]
Branch (100:17): [Folded - Ignored]
Branch (100:17): [Folded - Ignored]
Branch (100:17): [True: 0, False: 1]
|
101 | | "Did not send enough data. Expected {}, but so far received {}", |
102 | 0 | self.resource_info.expected_size, |
103 | 0 | self.bytes_received |
104 | | ); |
105 | 1 | return Poll::Ready(None); |
106 | 129 | } |
107 | | |
108 | | // Gets the next message, this is either the cached first or a |
109 | | // subsequent message from the wrapped Stream. |
110 | 129 | let maybe_message28 = if let Some(first_msg16 ) = self.first_msg.take() { Branch (110:36): [True: 15, False: 111]
Branch (110:36): [True: 0, False: 0]
Branch (110:36): [Folded - Ignored]
Branch (110:36): [Folded - Ignored]
Branch (110:36): [True: 1, False: 2]
|
111 | 16 | Ok(first_msg) |
112 | | } else { |
113 | 113 | match Pin::new(&mut self.stream).poll_next(cx) { |
114 | 101 | Poll::Pending => return Poll::Pending, |
115 | 9 | Poll::Ready(Some(maybe_message)) => maybe_message |
116 | 9 | .err_tip(|| format!("Stream error at byte {}"0 , self.bytes_received0 )), |
117 | 3 | Poll::Ready(None) => Err(make_input_err!("Expected WriteRequest struct in stream")), |
118 | | } |
119 | | }; |
120 | | |
121 | | // If we successfully got a message, update our internal state with the |
122 | | // message meta data. |
123 | 28 | Poll::Ready(Some(maybe_message.and_then(|message| {25 |
124 | 25 | self.write_finished = message.finish_write; |
125 | 25 | self.bytes_received += message.data.len(); |
126 | | |
127 | | // Check that we haven't read past the expected end. |
128 | 25 | if self.bytes_received > self.resource_info.expected_size { Branch (128:16): [True: 2, False: 20]
Branch (128:16): [True: 0, False: 0]
Branch (128:16): [Folded - Ignored]
Branch (128:16): [Folded - Ignored]
Branch (128:16): [True: 0, False: 3]
|
129 | 2 | Err(make_input_err!( |
130 | 2 | "Sent too much data. Expected {}, but so far received {}", |
131 | 2 | self.resource_info.expected_size, |
132 | 2 | self.bytes_received |
133 | 2 | )) |
134 | | } else { |
135 | 23 | Ok(message) |
136 | | } |
137 | 25 | }))) |
138 | 130 | } |
139 | | } |
140 | | |
141 | | /// Represents the state of the first response in a `FirstStream`. |
142 | | #[derive(Debug)] |
143 | | pub enum FirstResponseState { |
144 | | /// Contains an optional first response that hasn't been consumed yet. |
145 | | /// A `None` value indicates the first response was EOF. |
146 | | Unused(Option<ReadResponse>), |
147 | | /// Indicates the first response has been consumed and future reads should |
148 | | /// come from the underlying stream. |
149 | | Used, |
150 | | } |
151 | | |
152 | | /// This provides a buffer for the first response from GrpcStore.read in order |
153 | | /// to allow the first read to occur within the retry loop. That means that if |
154 | | /// the connection establishes fine, but reading the first byte of the file |
155 | | /// fails we have the ability to retry before returning to the caller. |
156 | | #[derive(Debug)] |
157 | | pub struct FirstStream { |
158 | | /// The current state of the first response. When in the `Unused` state, |
159 | | /// contains an optional response which could be `None` or an EOF. |
160 | | /// Once consumed, transitions to the `Used` state. |
161 | | state: FirstResponseState, |
162 | | /// The stream to get responses from after the first response is consumed. |
163 | | stream: Streaming<ReadResponse>, |
164 | | } |
165 | | |
166 | | impl FirstStream { |
167 | | /// Creates a new `FirstStream` with the given first response and underlying |
168 | | /// stream. |
169 | 0 | pub const fn new( |
170 | 0 | first_response: Option<ReadResponse>, |
171 | 0 | stream: Streaming<ReadResponse>, |
172 | 0 | ) -> Self { |
173 | 0 | Self { |
174 | 0 | state: FirstResponseState::Unused(first_response), |
175 | 0 | stream, |
176 | 0 | } |
177 | 0 | } |
178 | | } |
179 | | |
180 | | impl Stream for FirstStream { |
181 | | type Item = Result<ReadResponse, Status>; |
182 | | |
183 | 0 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
184 | 0 | match mem::replace(&mut self.state, FirstResponseState::Used) { |
185 | 0 | FirstResponseState::Unused(first_response) => Poll::Ready(first_response.map(Ok)), |
186 | 0 | FirstResponseState::Used => Pin::new(&mut self.stream).poll_next(cx), |
187 | | } |
188 | 0 | } |
189 | | } |
190 | | |
191 | | /// This structure wraps all of the information required to perform a write |
192 | | /// request on the `GrpcStore`. It stores the last message retrieved which allows |
193 | | /// the write to resume since the UUID allows upload resume at the server. |
194 | | #[derive(Debug)] |
195 | | pub struct WriteState<T, E> |
196 | | where |
197 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
198 | | E: Into<Error> + 'static, |
199 | | { |
200 | | instance_name: String, |
201 | | read_stream_error: Option<Error>, |
202 | | read_stream: WriteRequestStreamWrapper<T>, |
203 | | // Tonic doesn't appear to report an error until it has taken two messages, |
204 | | // therefore we are required to buffer the last two messages. |
205 | | cached_messages: [Option<WriteRequest>; 2], |
206 | | // When resuming after an error, the previous messages are cloned into this |
207 | | // queue upfront to allow them to be served back. |
208 | | resume_queue: [Option<WriteRequest>; 2], |
209 | | // An optimisation to avoid having to manage resume_queue when it's empty. |
210 | | is_resumed: bool, |
211 | | } |
212 | | |
213 | | impl<T, E> WriteState<T, E> |
214 | | where |
215 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
216 | | E: Into<Error> + 'static, |
217 | | { |
218 | 1 | pub const fn new(instance_name: String, read_stream: WriteRequestStreamWrapper<T>) -> Self { |
219 | 1 | Self { |
220 | 1 | instance_name, |
221 | 1 | read_stream_error: None, |
222 | 1 | read_stream, |
223 | 1 | cached_messages: [None, None], |
224 | 1 | resume_queue: [None, None], |
225 | 1 | is_resumed: false, |
226 | 1 | } |
227 | 1 | } |
228 | | |
229 | 3 | fn push_message(&mut self, message: WriteRequest) { |
230 | 3 | self.cached_messages.swap(0, 1); |
231 | 3 | self.cached_messages[0] = Some(message); |
232 | 3 | } |
233 | | |
234 | 4 | const fn resumed_message(&mut self) -> Option<WriteRequest> { |
235 | 4 | if self.is_resumed { Branch (235:12): [True: 0, False: 0]
Branch (235:12): [True: 0, False: 0]
Branch (235:12): [Folded - Ignored]
Branch (235:12): [Folded - Ignored]
Branch (235:12): [True: 0, False: 4]
|
236 | | // The resume_queue is a circular buffer, that we have to shift, |
237 | | // since its only got two elements its a trivial swap. |
238 | 0 | self.resume_queue.swap(0, 1); |
239 | 0 | let message = self.resume_queue[0].take(); |
240 | 0 | if message.is_none() { Branch (240:16): [True: 0, False: 0]
Branch (240:16): [True: 0, False: 0]
Branch (240:16): [Folded - Ignored]
Branch (240:16): [Folded - Ignored]
Branch (240:16): [True: 0, False: 0]
|
241 | 0 | self.is_resumed = false; |
242 | 0 | } |
243 | 0 | message |
244 | | } else { |
245 | 4 | None |
246 | | } |
247 | 4 | } |
248 | | |
249 | 0 | pub const fn can_resume(&self) -> bool { |
250 | 0 | self.read_stream_error.is_none() Branch (250:9): [True: 0, False: 0]
Branch (250:9): [True: 0, False: 0]
Branch (250:9): [Folded - Ignored]
Branch (250:9): [Folded - Ignored]
|
251 | 0 | && (self.cached_messages[0].is_some() || self.read_stream.is_first_msg()) Branch (251:17): [True: 0, False: 0]
Branch (251:17): [True: 0, False: 0]
Branch (251:17): [Folded - Ignored]
Branch (251:17): [Folded - Ignored]
|
252 | 0 | } |
253 | | |
254 | 0 | pub fn resume(&mut self) { |
255 | 0 | self.resume_queue.clone_from(&self.cached_messages); |
256 | 0 | self.is_resumed = true; |
257 | 0 | } |
258 | | |
259 | 1 | pub const fn take_read_stream_error(&mut self) -> Option<Error> { |
260 | 1 | self.read_stream_error.take() |
261 | 1 | } |
262 | | } |
263 | | |
264 | | /// A wrapper around `WriteState` to allow it to be reclaimed from the underlying |
265 | | /// write call in the case of failure. |
266 | | #[derive(Debug)] |
267 | | pub struct WriteStateWrapper<T, E> |
268 | | where |
269 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
270 | | E: Into<Error> + 'static, |
271 | | { |
272 | | shared_state: Arc<Mutex<WriteState<T, E>>>, |
273 | | } |
274 | | |
275 | | impl<T, E> WriteStateWrapper<T, E> |
276 | | where |
277 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
278 | | E: Into<Error> + 'static, |
279 | | { |
280 | 1 | pub const fn new(shared_state: Arc<Mutex<WriteState<T, E>>>) -> Self { |
281 | 1 | Self { shared_state } |
282 | 1 | } |
283 | | } |
284 | | |
285 | | impl<T, E> Stream for WriteStateWrapper<T, E> |
286 | | where |
287 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
288 | | E: Into<Error> + 'static, |
289 | | { |
290 | | type Item = WriteRequest; |
291 | | |
292 | 4 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
293 | | const IS_UPLOAD_TRUE: bool = true; |
294 | | |
295 | | // This should be an uncontended lock since write was called. |
296 | 4 | let mut local_state = self.shared_state.lock(); |
297 | | // If this is the first or second call after a failure and we have |
298 | | // cached messages, then use the cached write requests. |
299 | 4 | let cached_message = local_state.resumed_message(); |
300 | 4 | if cached_message.is_some() { Branch (300:12): [True: 0, False: 0]
Branch (300:12): [True: 0, False: 0]
Branch (300:12): [Folded - Ignored]
Branch (300:12): [Folded - Ignored]
Branch (300:12): [True: 0, False: 4]
|
301 | 0 | return Poll::Ready(cached_message); |
302 | 4 | } |
303 | | // Read a new write request from the downstream. |
304 | 4 | let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx) Branch (304:13): [True: 0, False: 0]
Branch (304:13): [True: 0, False: 0]
Branch (304:13): [Folded - Ignored]
Branch (304:13): [Folded - Ignored]
Branch (304:13): [True: 4, False: 0]
|
305 | | else { |
306 | 0 | return Poll::Pending; |
307 | | }; |
308 | | // Update the instance name in the write request and forward it on. |
309 | 4 | let result = match maybe_message3 { |
310 | 3 | Some(Ok(mut message)) => { |
311 | 3 | if !message.resource_name.is_empty() { Branch (311:20): [True: 0, False: 0]
Branch (311:20): [True: 0, False: 0]
Branch (311:20): [Folded - Ignored]
Branch (311:20): [Folded - Ignored]
Branch (311:20): [True: 1, False: 2]
|
312 | | // Replace the instance name in the resource name if it is |
313 | | // different from the instance name in the write state. |
314 | 1 | match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) { |
315 | 1 | Ok(mut resource_name) => { |
316 | 1 | if resource_name.instance_name != local_state.instance_name { Branch (316:32): [True: 0, False: 0]
Branch (316:32): [True: 0, False: 0]
Branch (316:32): [Folded - Ignored]
Branch (316:32): [Folded - Ignored]
Branch (316:32): [True: 0, False: 1]
|
317 | 0 | resource_name.instance_name = |
318 | 0 | Cow::Borrowed(&local_state.instance_name); |
319 | 0 | message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE); |
320 | 1 | } |
321 | | } |
322 | 0 | Err(err) => { |
323 | 0 | local_state.read_stream_error = Some(err); |
324 | 0 | return Poll::Ready(None); |
325 | | } |
326 | | } |
327 | 2 | } |
328 | | // Cache the last request in case there is an error to allow |
329 | | // the upload to be resumed. |
330 | 3 | local_state.push_message(message.clone()); |
331 | 3 | Some(message) |
332 | | } |
333 | 0 | Some(Err(err)) => { |
334 | 0 | local_state.read_stream_error = Some(err); |
335 | 0 | None |
336 | | } |
337 | 1 | None => None, |
338 | | }; |
339 | 4 | Poll::Ready(result) |
340 | 4 | } |
341 | | } |