/build/source/nativelink-util/src/proto_stream_utils.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::borrow::Cow; |
16 | | use std::fmt::Debug; |
17 | | use std::pin::Pin; |
18 | | use std::sync::Arc; |
19 | | use std::task::{Context, Poll}; |
20 | | |
21 | | use futures::{Stream, StreamExt}; |
22 | | use nativelink_error::{error_if, make_input_err, Error, ResultExt}; |
23 | | use nativelink_proto::google::bytestream::{ReadResponse, WriteRequest}; |
24 | | use parking_lot::Mutex; |
25 | | use tonic::{Status, Streaming}; |
26 | | |
27 | | use crate::resource_info::ResourceInfo; |
28 | | |
29 | | pub struct WriteRequestStreamWrapper<T> { |
30 | | pub resource_info: ResourceInfo<'static>, |
31 | | pub bytes_received: usize, |
32 | | stream: T, |
33 | | first_msg: Option<WriteRequest>, |
34 | | write_finished: bool, |
35 | | } |
36 | | |
37 | | impl<T> Debug for WriteRequestStreamWrapper<T> { |
38 | 14 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
39 | 14 | f.debug_struct("WriteRequestStreamWrapper") |
40 | 14 | .field("resource_info", &self.resource_info) |
41 | 14 | .field("bytes_received", &self.bytes_received) |
42 | 14 | .field("first_msg", &self.first_msg) |
43 | 14 | .field("write_finished", &self.write_finished) |
44 | 14 | .finish() |
45 | 14 | } |
46 | | } |
47 | | |
48 | | impl<T, E> WriteRequestStreamWrapper<T> |
49 | | where |
50 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
51 | | E: Into<Error>, |
52 | | { |
53 | 16 | pub async fn from(mut stream: T) -> Result<WriteRequestStreamWrapper<T>, Error> { |
54 | 16 | let first_msg15 = stream |
55 | 16 | .next() |
56 | 16 | .await |
57 | 16 | .err_tip(|| "Error receiving first message in stream"0 )?0 |
58 | 16 | .err_tip(|| "Expected WriteRequest struct in stream"1 )?1 ; |
59 | | |
60 | 15 | let resource_info = ResourceInfo::new(&first_msg.resource_name, true) |
61 | 15 | .err_tip(|| { |
62 | 0 | format!( |
63 | 0 | "Could not extract resource info from first message of stream: {}", |
64 | 0 | first_msg.resource_name |
65 | 0 | ) |
66 | 15 | })?0 |
67 | 15 | .to_owned(); |
68 | 15 | |
69 | 15 | Ok(WriteRequestStreamWrapper { |
70 | 15 | resource_info, |
71 | 15 | bytes_received: 0, |
72 | 15 | stream, |
73 | 15 | first_msg: Some(first_msg), |
74 | 15 | write_finished: false, |
75 | 15 | }) |
76 | 16 | } |
77 | | |
78 | 24 | pub async fn next(&mut self) -> Option<Result<WriteRequest, Error>> { |
79 | 125 | futures::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await24 |
80 | 24 | } |
81 | | |
82 | 0 | pub fn is_first_msg(&self) -> bool { |
83 | 0 | self.first_msg.is_some() |
84 | 0 | } |
85 | | } |
86 | | |
87 | | impl<T, E> Stream for WriteRequestStreamWrapper<T> |
88 | | where |
89 | | E: Into<Error>, |
90 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin, |
91 | | { |
92 | | type Item = Result<WriteRequest, Error>; |
93 | | |
94 | 129 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
95 | 129 | // If the stream said that the previous message was the last one, then |
96 | 129 | // return a stream EOF (i.e. None). |
97 | 129 | if self.write_finished { Branch (97:12): [True: 0, False: 125]
Branch (97:12): [True: 0, False: 0]
Branch (97:12): [Folded - Ignored]
Branch (97:12): [Folded - Ignored]
Branch (97:12): [True: 1, False: 3]
|
98 | 0 | error_if!( |
99 | 1 | self.bytes_received != self.resource_info.expected_size, Branch (99:17): [True: 0, False: 0]
Branch (99:17): [True: 0, False: 0]
Branch (99:17): [Folded - Ignored]
Branch (99:17): [Folded - Ignored]
Branch (99:17): [True: 0, False: 1]
|
100 | | "Did not send enough data. Expected {}, but so far received {}", |
101 | 0 | self.resource_info.expected_size, |
102 | 0 | self.bytes_received |
103 | | ); |
104 | 1 | return Poll::Ready(None); |
105 | 128 | } |
106 | | |
107 | | // Gets the next message, this is either the cached first or a |
108 | | // subsequent message from the wrapped Stream. |
109 | 128 | let maybe_message27 = if let Some(first_msg15 ) = self.first_msg.take() { Branch (109:36): [True: 14, False: 111]
Branch (109:36): [True: 0, False: 0]
Branch (109:36): [Folded - Ignored]
Branch (109:36): [Folded - Ignored]
Branch (109:36): [True: 1, False: 2]
|
110 | 15 | Ok(first_msg) |
111 | | } else { |
112 | 113 | match Pin::new(&mut self.stream).poll_next(cx) { |
113 | 101 | Poll::Pending => return Poll::Pending, |
114 | 9 | Poll::Ready(Some(maybe_message)) => maybe_message |
115 | 9 | .err_tip(|| format!("Stream error at byte {}", self.bytes_received)0 ), |
116 | 3 | Poll::Ready(None) => Err(make_input_err!("Expected WriteRequest struct in stream")), |
117 | | } |
118 | | }; |
119 | | |
120 | | // If we successfully got a message, update our internal state with the |
121 | | // message meta data. |
122 | 27 | Poll::Ready(Some(maybe_message.and_then(|message| { |
123 | 24 | self.write_finished = message.finish_write; |
124 | 24 | self.bytes_received += message.data.len(); |
125 | 24 | |
126 | 24 | // Check that we haven't read past the expected end. |
127 | 24 | if self.bytes_received > self.resource_info.expected_size { Branch (127:16): [True: 1, False: 20]
Branch (127:16): [True: 0, False: 0]
Branch (127:16): [Folded - Ignored]
Branch (127:16): [Folded - Ignored]
Branch (127:16): [True: 0, False: 3]
|
128 | 1 | Err(make_input_err!( |
129 | 1 | "Sent too much data. Expected {}, but so far received {}", |
130 | 1 | self.resource_info.expected_size, |
131 | 1 | self.bytes_received |
132 | 1 | )) |
133 | | } else { |
134 | 23 | Ok(message) |
135 | | } |
136 | 27 | }24 ))) |
137 | 129 | } |
138 | | } |
139 | | |
140 | | /// This provides a buffer for the first response from GrpcStore.read in order |
141 | | /// to allow the first read to occur within the retry loop. That means that if |
142 | | /// the connection establishes fine, but reading the first byte of the file |
143 | | /// fails we have the ability to retry before returning to the caller. |
144 | | pub struct FirstStream { |
145 | | /// Contains the first response from the stream (which could be an EOF, |
146 | | /// hence the nested Option). This should be populated on creation and |
147 | | /// returned as the first result from the stream. Subsequent reads from the |
148 | | /// stream will use the encapsulated stream. |
149 | | first_response: Option<Option<ReadResponse>>, |
150 | | /// The stream to get responses from when `first_response` is None. |
151 | | stream: Streaming<ReadResponse>, |
152 | | } |
153 | | |
154 | | impl FirstStream { |
155 | 0 | pub fn new(first_response: Option<ReadResponse>, stream: Streaming<ReadResponse>) -> Self { |
156 | 0 | Self { |
157 | 0 | first_response: Some(first_response), |
158 | 0 | stream, |
159 | 0 | } |
160 | 0 | } |
161 | | } |
162 | | |
163 | | impl Stream for FirstStream { |
164 | | type Item = Result<ReadResponse, Status>; |
165 | | |
166 | 0 | fn poll_next( |
167 | 0 | mut self: Pin<&mut Self>, |
168 | 0 | cx: &mut std::task::Context<'_>, |
169 | 0 | ) -> std::task::Poll<Option<Self::Item>> { |
170 | 0 | if let Some(first_response) = self.first_response.take() { Branch (170:16): [True: 0, False: 0]
Branch (170:16): [Folded - Ignored]
|
171 | 0 | return std::task::Poll::Ready(first_response.map(Ok)); |
172 | 0 | } |
173 | 0 | Pin::new(&mut self.stream).poll_next(cx) |
174 | 0 | } |
175 | | } |
176 | | |
177 | | /// This structure wraps all of the information required to perform a write |
178 | | /// request on the `GrpcStore`. It stores the last message retrieved which allows |
179 | | /// the write to resume since the UUID allows upload resume at the server. |
180 | | pub struct WriteState<T, E> |
181 | | where |
182 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
183 | | E: Into<Error> + 'static, |
184 | | { |
185 | | instance_name: String, |
186 | | read_stream_error: Option<Error>, |
187 | | read_stream: WriteRequestStreamWrapper<T>, |
188 | | // Tonic doesn't appear to report an error until it has taken two messages, |
189 | | // therefore we are required to buffer the last two messages. |
190 | | cached_messages: [Option<WriteRequest>; 2], |
191 | | // When resuming after an error, the previous messages are cloned into this |
192 | | // queue upfront to allow them to be served back. |
193 | | resume_queue: [Option<WriteRequest>; 2], |
194 | | // An optimisation to avoid having to manage resume_queue when it's empty. |
195 | | is_resumed: bool, |
196 | | } |
197 | | |
198 | | impl<T, E> WriteState<T, E> |
199 | | where |
200 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
201 | | E: Into<Error> + 'static, |
202 | | { |
203 | 1 | pub fn new(instance_name: String, read_stream: WriteRequestStreamWrapper<T>) -> Self { |
204 | 1 | Self { |
205 | 1 | instance_name, |
206 | 1 | read_stream_error: None, |
207 | 1 | read_stream, |
208 | 1 | cached_messages: [None, None], |
209 | 1 | resume_queue: [None, None], |
210 | 1 | is_resumed: false, |
211 | 1 | } |
212 | 1 | } |
213 | | |
214 | 3 | fn push_message(&mut self, message: WriteRequest) { |
215 | 3 | self.cached_messages.swap(0, 1); |
216 | 3 | self.cached_messages[0] = Some(message); |
217 | 3 | } |
218 | | |
219 | 4 | fn resumed_message(&mut self) -> Option<WriteRequest> { |
220 | 4 | if self.is_resumed { Branch (220:12): [True: 0, False: 0]
Branch (220:12): [True: 0, False: 0]
Branch (220:12): [Folded - Ignored]
Branch (220:12): [Folded - Ignored]
Branch (220:12): [True: 0, False: 4]
|
221 | | // The resume_queue is a circular buffer, that we have to shift, |
222 | | // since its only got two elements its a trivial swap. |
223 | 0 | self.resume_queue.swap(0, 1); |
224 | 0 | let message = self.resume_queue[0].take(); |
225 | 0 | if message.is_none() { Branch (225:16): [True: 0, False: 0]
Branch (225:16): [True: 0, False: 0]
Branch (225:16): [Folded - Ignored]
Branch (225:16): [Folded - Ignored]
Branch (225:16): [True: 0, False: 0]
|
226 | 0 | self.is_resumed = false; |
227 | 0 | } |
228 | 0 | message |
229 | | } else { |
230 | 4 | None |
231 | | } |
232 | 4 | } |
233 | | |
234 | 0 | pub fn can_resume(&self) -> bool { |
235 | 0 | self.read_stream_error.is_none() Branch (235:9): [True: 0, False: 0]
Branch (235:9): [True: 0, False: 0]
Branch (235:9): [Folded - Ignored]
Branch (235:9): [Folded - Ignored]
|
236 | 0 | && (self.cached_messages[0].is_some() || self.read_stream.is_first_msg()) Branch (236:17): [True: 0, False: 0]
Branch (236:17): [True: 0, False: 0]
Branch (236:17): [Folded - Ignored]
Branch (236:17): [Folded - Ignored]
|
237 | 0 | } |
238 | | |
239 | 0 | pub fn resume(&mut self) { |
240 | 0 | self.resume_queue.clone_from(&self.cached_messages); |
241 | 0 | self.is_resumed = true; |
242 | 0 | } |
243 | | |
244 | 1 | pub fn take_read_stream_error(&mut self) -> Option<Error> { |
245 | 1 | self.read_stream_error.take() |
246 | 1 | } |
247 | | } |
248 | | |
249 | | /// A wrapper around `WriteState` to allow it to be reclaimed from the underlying |
250 | | /// write call in the case of failure. |
251 | | pub struct WriteStateWrapper<T, E> |
252 | | where |
253 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
254 | | E: Into<Error> + 'static, |
255 | | { |
256 | | shared_state: Arc<Mutex<WriteState<T, E>>>, |
257 | | } |
258 | | |
259 | | impl<T, E> WriteStateWrapper<T, E> |
260 | | where |
261 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
262 | | E: Into<Error> + 'static, |
263 | | { |
264 | 1 | pub fn new(shared_state: Arc<Mutex<WriteState<T, E>>>) -> Self { |
265 | 1 | Self { shared_state } |
266 | 1 | } |
267 | | } |
268 | | |
269 | | impl<T, E> Stream for WriteStateWrapper<T, E> |
270 | | where |
271 | | T: Stream<Item = Result<WriteRequest, E>> + Unpin + Send + 'static, |
272 | | E: Into<Error> + 'static, |
273 | | { |
274 | | type Item = WriteRequest; |
275 | | |
276 | 4 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
277 | | const IS_UPLOAD_TRUE: bool = true; |
278 | | |
279 | | // This should be an uncontended lock since write was called. |
280 | 4 | let mut local_state = self.shared_state.lock(); |
281 | 4 | // If this is the first or second call after a failure and we have |
282 | 4 | // cached messages, then use the cached write requests. |
283 | 4 | let cached_message = local_state.resumed_message(); |
284 | 4 | if cached_message.is_some() { Branch (284:12): [True: 0, False: 0]
Branch (284:12): [True: 0, False: 0]
Branch (284:12): [Folded - Ignored]
Branch (284:12): [Folded - Ignored]
Branch (284:12): [True: 0, False: 4]
|
285 | 0 | return Poll::Ready(cached_message); |
286 | 4 | } |
287 | | // Read a new write request from the downstream. |
288 | 4 | let Poll::Ready(maybe_message) = Pin::new(&mut local_state.read_stream).poll_next(cx) Branch (288:13): [True: 0, False: 0]
Branch (288:13): [True: 0, False: 0]
Branch (288:13): [Folded - Ignored]
Branch (288:13): [Folded - Ignored]
Branch (288:13): [True: 4, False: 0]
|
289 | | else { |
290 | 0 | return Poll::Pending; |
291 | | }; |
292 | | // Update the instance name in the write request and forward it on. |
293 | 4 | let result = match maybe_message3 { |
294 | 3 | Some(Ok(mut message)) => { |
295 | 3 | if !message.resource_name.is_empty() { Branch (295:20): [True: 0, False: 0]
Branch (295:20): [True: 0, False: 0]
Branch (295:20): [Folded - Ignored]
Branch (295:20): [Folded - Ignored]
Branch (295:20): [True: 1, False: 2]
|
296 | | // Replace the instance name in the resource name if it is |
297 | | // different from the instance name in the write state. |
298 | 1 | match ResourceInfo::new(&message.resource_name, IS_UPLOAD_TRUE) { |
299 | 1 | Ok(mut resource_name) => { |
300 | 1 | if resource_name.instance_name != local_state.instance_name { Branch (300:32): [True: 0, False: 0]
Branch (300:32): [True: 0, False: 0]
Branch (300:32): [Folded - Ignored]
Branch (300:32): [Folded - Ignored]
Branch (300:32): [True: 0, False: 1]
|
301 | 0 | resource_name.instance_name = |
302 | 0 | Cow::Borrowed(&local_state.instance_name); |
303 | 0 | message.resource_name = resource_name.to_string(IS_UPLOAD_TRUE); |
304 | 1 | } |
305 | | } |
306 | 0 | Err(err) => { |
307 | 0 | local_state.read_stream_error = Some(err); |
308 | 0 | return Poll::Ready(None); |
309 | | } |
310 | | } |
311 | 2 | } |
312 | | // Cache the last request in case there is an error to allow |
313 | | // the upload to be resumed. |
314 | 3 | local_state.push_message(message.clone()); |
315 | 3 | Some(message) |
316 | | } |
317 | 0 | Some(Err(err)) => { |
318 | 0 | local_state.read_stream_error = Some(err); |
319 | 0 | None |
320 | | } |
321 | 1 | None => None, |
322 | | }; |
323 | 4 | Poll::Ready(result) |
324 | 4 | } |
325 | | } |