/build/source/nativelink-service/src/bytestream_server.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::collections::HashMap; |
16 | | use std::collections::hash_map::Entry; |
17 | | use std::convert::Into; |
18 | | use std::fmt::{Debug, Formatter}; |
19 | | use std::pin::Pin; |
20 | | use std::sync::Arc; |
21 | | use std::sync::atomic::{AtomicU64, Ordering}; |
22 | | use std::time::Duration; |
23 | | |
24 | | use futures::future::{BoxFuture, pending}; |
25 | | use futures::stream::unfold; |
26 | | use futures::{Future, Stream, TryFutureExt, try_join}; |
27 | | use nativelink_config::cas_server::ByteStreamConfig; |
28 | | use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err}; |
29 | | use nativelink_proto::google::bytestream::byte_stream_server::{ |
30 | | ByteStream, ByteStreamServer as Server, |
31 | | }; |
32 | | use nativelink_proto::google::bytestream::{ |
33 | | QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest, |
34 | | WriteResponse, |
35 | | }; |
36 | | use nativelink_store::grpc_store::GrpcStore; |
37 | | use nativelink_store::store_manager::StoreManager; |
38 | | use nativelink_util::buf_channel::{ |
39 | | DropCloserReadHalf, DropCloserWriteHalf, make_buf_channel_pair, |
40 | | }; |
41 | | use nativelink_util::common::DigestInfo; |
42 | | use nativelink_util::digest_hasher::{ |
43 | | DigestHasherFunc, default_digest_hasher_func, make_ctx_for_hash_func, |
44 | | }; |
45 | | use nativelink_util::origin_event::OriginEventContext; |
46 | | use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper; |
47 | | use nativelink_util::resource_info::ResourceInfo; |
48 | | use nativelink_util::spawn; |
49 | | use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo}; |
50 | | use nativelink_util::task::JoinHandleDropGuard; |
51 | | use parking_lot::Mutex; |
52 | | use tokio::time::sleep; |
53 | | use tonic::{Request, Response, Status, Streaming}; |
54 | | use tracing::{Instrument, Level, enabled, error_span, event, instrument}; |
55 | | |
56 | | /// If this value changes update the documentation in the config definition. |
57 | | const DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT: Duration = Duration::from_secs(60); |
58 | | |
59 | | /// If this value changes update the documentation in the config definition. |
60 | | const DEFAULT_MAX_BYTES_PER_STREAM: usize = 64 * 1024; |
61 | | |
62 | | /// If this value changes update the documentation in the config definition. |
63 | | const DEFAULT_MAX_DECODING_MESSAGE_SIZE: usize = 4 * 1024 * 1024; |
64 | | |
65 | | type ReadStream = Pin<Box<dyn Stream<Item = Result<ReadResponse, Status>> + Send + 'static>>; |
66 | | type StoreUpdateFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>; |
67 | | |
68 | | struct StreamState { |
69 | | uuid: String, |
70 | | tx: DropCloserWriteHalf, |
71 | | store_update_fut: StoreUpdateFuture, |
72 | | } |
73 | | |
74 | | impl Debug for StreamState { |
75 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
76 | 0 | f.debug_struct("StreamState") |
77 | 0 | .field("uuid", &self.uuid) |
78 | 0 | .finish() |
79 | 0 | } |
80 | | } |
81 | | |
82 | | /// If a stream is in this state, it will automatically be put back into an `IdleStream` and |
83 | | /// placed back into the `active_uploads` map as an `IdleStream` after it is dropped. |
84 | | /// To prevent it from being put back into an `IdleStream` you must call `.graceful_finish()`. |
85 | | struct ActiveStreamGuard<'a> { |
86 | | stream_state: Option<StreamState>, |
87 | | bytes_received: Arc<AtomicU64>, |
88 | | bytestream_server: &'a ByteStreamServer, |
89 | | } |
90 | | |
91 | | impl ActiveStreamGuard<'_> { |
92 | | /// Consumes the guard. The stream will be considered "finished", will |
93 | | /// remove it from the `active_uploads`. |
94 | 8 | fn graceful_finish(mut self) { |
95 | 8 | let stream_state = self.stream_state.take().unwrap(); |
96 | 8 | self.bytestream_server |
97 | 8 | .active_uploads |
98 | 8 | .lock() |
99 | 8 | .remove(&stream_state.uuid); |
100 | 8 | } |
101 | | } |
102 | | |
103 | | impl Drop for ActiveStreamGuard<'_> { |
104 | 14 | fn drop(&mut self) { |
105 | 14 | let Some(stream_state6 ) = self.stream_state.take() else { Branch (105:13): [True: 6, False: 8]
Branch (105:13): [Folded - Ignored]
|
106 | 8 | return; // If None it means we don't want it put back into an IdleStream. |
107 | | }; |
108 | 6 | let weak_active_uploads = Arc::downgrade(&self.bytestream_server.active_uploads); |
109 | 6 | let mut active_uploads = self.bytestream_server.active_uploads.lock(); |
110 | 6 | let uuid = stream_state.uuid.clone(); |
111 | 6 | let Some(active_uploads_slot) = active_uploads.get_mut(&uuid) else { Branch (111:13): [True: 6, False: 0]
Branch (111:13): [Folded - Ignored]
|
112 | 0 | event!( |
113 | 0 | Level::ERROR, |
114 | | err = "Failed to find active upload. This should never happen.", |
115 | | uuid = ?uuid, |
116 | | ); |
117 | 0 | return; |
118 | | }; |
119 | 6 | let sleep_fn = self.bytestream_server.sleep_fn.clone(); |
120 | 6 | active_uploads_slot.1 = Some(IdleStream { |
121 | 6 | stream_state, |
122 | 6 | _timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move { |
123 | 3 | (*sleep_fn)().await; |
124 | 0 | if let Some(active_uploads) = weak_active_uploads.upgrade() { Branch (124:24): [True: 0, False: 0]
Branch (124:24): [Folded - Ignored]
|
125 | 0 | let mut active_uploads = active_uploads.lock(); |
126 | 0 | event!(Level::INFO, msg = "Removing idle stream", uuid = ?uuid); |
127 | 0 | active_uploads.remove(&uuid); |
128 | 0 | } |
129 | 0 | }), |
130 | | }); |
131 | 14 | } |
132 | | } |
133 | | |
134 | | /// Represents a stream that is in the "idle" state. this means it is not currently being used |
135 | | /// by a client. If it is not used within a certain amount of time it will be removed from the |
136 | | /// `active_uploads` map automatically. |
137 | | #[derive(Debug)] |
138 | | struct IdleStream { |
139 | | stream_state: StreamState, |
140 | | _timeout_streaam_drop_guard: JoinHandleDropGuard<()>, |
141 | | } |
142 | | |
143 | | impl IdleStream { |
144 | 3 | fn into_active_stream( |
145 | 3 | self, |
146 | 3 | bytes_received: Arc<AtomicU64>, |
147 | 3 | bytestream_server: &ByteStreamServer, |
148 | 3 | ) -> ActiveStreamGuard<'_> { |
149 | 3 | ActiveStreamGuard { |
150 | 3 | stream_state: Some(self.stream_state), |
151 | 3 | bytes_received, |
152 | 3 | bytestream_server, |
153 | 3 | } |
154 | 3 | } |
155 | | } |
156 | | |
157 | | type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>); |
158 | | type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>; |
159 | | |
160 | | pub struct ByteStreamServer { |
161 | | stores: HashMap<String, Store>, |
162 | | // Max number of bytes to send on each grpc stream chunk. |
163 | | max_bytes_per_stream: usize, |
164 | | max_decoding_message_size: usize, |
165 | | active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>, |
166 | | sleep_fn: SleepFn, |
167 | | } |
168 | | |
169 | | impl Debug for ByteStreamServer { |
170 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
171 | 0 | f.debug_struct("ByteStreamServer") |
172 | 0 | .field("stores", &self.stores) |
173 | 0 | .field("max_bytes_per_stream", &self.max_bytes_per_stream) |
174 | 0 | .field("max_decoding_message_size", &self.max_decoding_message_size) |
175 | 0 | .field("active_uploads", &self.active_uploads) |
176 | 0 | .finish_non_exhaustive() |
177 | 0 | } |
178 | | } |
179 | | |
180 | | impl ByteStreamServer { |
181 | 14 | pub fn new(config: &ByteStreamConfig, store_manager: &StoreManager) -> Result<Self, Error> { |
182 | 14 | let mut persist_stream_on_disconnect_timeout = |
183 | 14 | Duration::from_secs(config.persist_stream_on_disconnect_timeout as u64); |
184 | 14 | if config.persist_stream_on_disconnect_timeout == 0 { Branch (184:12): [True: 14, False: 0]
Branch (184:12): [Folded - Ignored]
|
185 | 14 | persist_stream_on_disconnect_timeout = DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT; |
186 | 14 | }0 |
187 | 14 | Self::new_with_sleep_fn( |
188 | 14 | config, |
189 | 14 | store_manager, |
190 | 14 | Arc::new(move || Box::pin(sleep(persist_stream_on_disconnect_timeout))3 ), |
191 | | ) |
192 | 14 | } |
193 | | |
194 | 14 | pub fn new_with_sleep_fn( |
195 | 14 | config: &ByteStreamConfig, |
196 | 14 | store_manager: &StoreManager, |
197 | 14 | sleep_fn: SleepFn, |
198 | 14 | ) -> Result<Self, Error> { |
199 | 14 | let mut stores = HashMap::with_capacity(config.cas_stores.len()); |
200 | 28 | for (instance_name, store_name14 ) in &config.cas_stores { |
201 | 14 | let store = store_manager |
202 | 14 | .get_store(store_name) |
203 | 14 | .ok_or_else(|| make_input_err!("'cas_store': '{}' does not exist", store_name)0 )?0 ; |
204 | 14 | stores.insert(instance_name.to_string(), store); |
205 | | } |
206 | 14 | let max_bytes_per_stream = if config.max_bytes_per_stream == 0 { Branch (206:39): [True: 1, False: 13]
Branch (206:39): [Folded - Ignored]
|
207 | 1 | DEFAULT_MAX_BYTES_PER_STREAM |
208 | | } else { |
209 | 13 | config.max_bytes_per_stream |
210 | | }; |
211 | 14 | let max_decoding_message_size = if config.max_decoding_message_size == 0 { Branch (211:44): [True: 13, False: 1]
Branch (211:44): [Folded - Ignored]
|
212 | 13 | DEFAULT_MAX_DECODING_MESSAGE_SIZE |
213 | | } else { |
214 | 1 | config.max_decoding_message_size |
215 | | }; |
216 | 14 | Ok(ByteStreamServer { |
217 | 14 | stores, |
218 | 14 | max_bytes_per_stream, |
219 | 14 | max_decoding_message_size, |
220 | 14 | active_uploads: Arc::new(Mutex::new(HashMap::new())), |
221 | 14 | sleep_fn, |
222 | 14 | }) |
223 | 14 | } |
224 | | |
225 | 1 | pub fn into_service(self) -> Server<Self> { |
226 | 1 | let max_decoding_message_size = self.max_decoding_message_size; |
227 | 1 | Server::new(self).max_decoding_message_size(max_decoding_message_size) |
228 | 1 | } |
229 | | |
230 | 14 | fn create_or_join_upload_stream( |
231 | 14 | &self, |
232 | 14 | uuid: String, |
233 | 14 | store: Store, |
234 | 14 | digest: DigestInfo, |
235 | 14 | ) -> Result<ActiveStreamGuard<'_>, Error> { |
236 | 14 | let (uuid, bytes_received11 ) = match self.active_uploads.lock().entry(uuid) { |
237 | 3 | Entry::Occupied(mut entry) => { |
238 | 3 | let maybe_idle_stream = entry.get_mut(); |
239 | 3 | let Some(idle_stream) = maybe_idle_stream.1.take() else { Branch (239:21): [True: 3, False: 0]
Branch (239:21): [Folded - Ignored]
|
240 | 0 | return Err(make_input_err!("Cannot upload same UUID simultaneously")); |
241 | | }; |
242 | 3 | let bytes_received = maybe_idle_stream.0.clone(); |
243 | 3 | event!(Level::INFO, msg = "Joining existing stream", entry = ?entry.key0 ()); |
244 | 3 | return Ok(idle_stream.into_active_stream(bytes_received, self)); |
245 | | } |
246 | 11 | Entry::Vacant(entry) => { |
247 | 11 | let bytes_received = Arc::new(AtomicU64::new(0)); |
248 | 11 | let uuid = entry.key().clone(); |
249 | 11 | // Our stream is "in use" if the key is in the map, but the value is None. |
250 | 11 | entry.insert((bytes_received.clone(), None)); |
251 | 11 | (uuid, bytes_received) |
252 | 11 | } |
253 | 11 | }; |
254 | 11 | |
255 | 11 | // Important: Do not return an error from this point onwards without |
256 | 11 | // removing the entry from the map, otherwise that UUID becomes |
257 | 11 | // unusable. |
258 | 11 | |
259 | 11 | let (tx, rx) = make_buf_channel_pair(); |
260 | 11 | let store_update_fut = Box::pin(async move { |
261 | 8 | // We need to wrap `Store::update()` in a another future because we need to capture |
262 | 8 | // `store` to ensure its lifetime follows the future and not the caller. |
263 | 8 | store |
264 | 8 | // Bytestream always uses digest size as the actual byte size. |
265 | 8 | .update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes())) |
266 | 8 | .await |
267 | 8 | }); |
268 | 11 | Ok(ActiveStreamGuard { |
269 | 11 | stream_state: Some(StreamState { |
270 | 11 | uuid, |
271 | 11 | tx, |
272 | 11 | store_update_fut, |
273 | 11 | }), |
274 | 11 | bytes_received, |
275 | 11 | bytestream_server: self, |
276 | 11 | }) |
277 | 14 | } |
278 | | |
279 | 3 | async fn inner_read( |
280 | 3 | &self, |
281 | 3 | store: Store, |
282 | 3 | digest: DigestInfo, |
283 | 3 | read_request: ReadRequest, |
284 | 3 | ) -> Result<impl Stream<Item = Result<ReadResponse, Status>> + Send + use<>, Error> { |
285 | | struct ReaderState { |
286 | | max_bytes_per_stream: usize, |
287 | | rx: DropCloserReadHalf, |
288 | | maybe_get_part_result: Option<Result<(), Error>>, |
289 | | get_part_fut: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>, |
290 | | } |
291 | | |
292 | 3 | let read_limit = u64::try_from(read_request.read_limit) |
293 | 3 | .err_tip(|| "Could not convert read_limit to u64"0 )?0 ; |
294 | | |
295 | 3 | let (tx, rx) = make_buf_channel_pair(); |
296 | | |
297 | 3 | let read_limit = if read_limit != 0 { Branch (297:29): [True: 3, False: 0]
Branch (297:29): [Folded - Ignored]
|
298 | 3 | Some(read_limit) |
299 | | } else { |
300 | 0 | None |
301 | | }; |
302 | | |
303 | | // This allows us to call a destructor when the the object is dropped. |
304 | 3 | let state = Some(ReaderState { |
305 | 3 | rx, |
306 | 3 | max_bytes_per_stream: self.max_bytes_per_stream, |
307 | 3 | maybe_get_part_result: None, |
308 | 3 | get_part_fut: Box::pin(async move { |
309 | 3 | store |
310 | 3 | .get_part( |
311 | 3 | digest, |
312 | 3 | tx, |
313 | 3 | u64::try_from(read_request.read_offset) |
314 | 3 | .err_tip(|| "Could not convert read_offset to u64"0 )?0 , |
315 | 3 | read_limit, |
316 | 3 | ) |
317 | 3 | .await |
318 | 3 | }), |
319 | | }); |
320 | | |
321 | 3 | let read_stream_span = error_span!("read_stream"); |
322 | | |
323 | 9.77k | Ok(Box::pin(unfold(state3 , move |state| { |
324 | 9.77k | async { |
325 | 9.77k | let mut state = state?0 ; // If None our stream is done. |
326 | 9.77k | let mut response = ReadResponse::default(); |
327 | 9.77k | { |
328 | 9.77k | let consume_fut = state.rx.consume(Some(state.max_bytes_per_stream)); |
329 | 9.77k | tokio::pin!(consume_fut); |
330 | | loop { |
331 | 9.77k | tokio::select! { |
332 | 9.77k | read_result9.77k = &mut consume_fut => { |
333 | 9.77k | match read_result { |
334 | 9.76k | Ok(bytes) => { |
335 | 9.76k | if bytes.is_empty() { Branch (335:40): [True: 2, False: 9.76k]
Branch (335:40): [Folded - Ignored]
|
336 | | // EOF. |
337 | 2 | return None; |
338 | 9.76k | } |
339 | 9.76k | if bytes.len() > state.max_bytes_per_stream { Branch (339:40): [True: 0, False: 9.76k]
Branch (339:40): [Folded - Ignored]
|
340 | 0 | let err = make_err!(Code::Internal, "Returned store size was larger than read size"); |
341 | 0 | return Some((Err(err.into()), None)); |
342 | 9.76k | } |
343 | 9.76k | response.data = bytes; |
344 | 9.76k | if enabled!(Level::DEBUG) { |
345 | 0 | event!(Level::INFO, response = ?response); |
346 | | } else { |
347 | 9.76k | event!(Level::INFO, response.data = format!("<redacted len({})>", response.data.len0 ())); |
348 | | } |
349 | 9.76k | break; |
350 | | } |
351 | 1 | Err(mut e) => { |
352 | | // We may need to propagate the error from reading the data through first. |
353 | | // For example, the NotFound error will come through `get_part_fut`, and |
354 | | // will not be present in `e`, but we need to ensure we pass NotFound error |
355 | | // code or the client won't know why it failed. |
356 | 1 | let get_part_result = if let Some(result) = state.maybe_get_part_result { Branch (356:66): [True: 1, False: 0]
Branch (356:66): [Folded - Ignored]
|
357 | 1 | result |
358 | | } else { |
359 | | // This should never be `future::pending()` if maybe_get_part_result is |
360 | | // not set. |
361 | 0 | state.get_part_fut.await |
362 | | }; |
363 | 1 | if let Err(err) = get_part_result { Branch (363:44): [True: 1, False: 0]
Branch (363:44): [Folded - Ignored]
|
364 | 1 | e = err.merge(e); |
365 | 1 | }0 |
366 | 1 | if e.code == Code::NotFound { Branch (366:40): [True: 1, False: 0]
Branch (366:40): [Folded - Ignored]
|
367 | 1 | // Trim the error code. Not Found is quite common and we don't want to send a large |
368 | 1 | // error (debug) message for something that is common. We resize to just the last |
369 | 1 | // message as it will be the most relevant. |
370 | 1 | e.messages.truncate(1); |
371 | 1 | }0 |
372 | 1 | event!(Level::ERROR, response = ?e); |
373 | 1 | return Some((Err(e.into()), None)) |
374 | | } |
375 | | } |
376 | | }, |
377 | 9.77k | result3 = &mut state.get_part_fut => { |
378 | 3 | state.maybe_get_part_result = Some(result); |
379 | 3 | // It is non-deterministic on which future will finish in what order. |
380 | 3 | // It is also possible that the `state.rx.consume()` call above may not be able to |
381 | 3 | // respond even though the publishing future is done. |
382 | 3 | // Because of this we set the writing future to pending so it never finishes. |
383 | 3 | // The `state.rx.consume()` future will eventually finish and return either the |
384 | 3 | // data or an error. |
385 | 3 | // An EOF will terminate the `state.rx.consume()` future, but we are also protected |
386 | 3 | // because we are dropping the writing future, it will drop the `tx` channel |
387 | 3 | // which will eventually propagate an error to the `state.rx.consume()` future if |
388 | 3 | // the EOF was not sent due to some other error. |
389 | 3 | state.get_part_fut = Box::pin(pending()); |
390 | 3 | }, |
391 | | } |
392 | | } |
393 | | } |
394 | 9.76k | Some((Ok(response), Some(state))) |
395 | 9.77k | }.instrument(read_stream_span.clone()) |
396 | 9.77k | }))) |
397 | 3 | } |
398 | | |
399 | | // We instrument tracing here as well as below because `stream` has a hash on it |
400 | | // that is extracted from the first stream message. If we only implemented it below |
401 | | // we would not have the hash available to us. |
402 | | #[instrument( |
403 | | ret(level = Level::INFO), |
404 | 14 | level = Level::ERROR, |
405 | | skip(self, store), |
406 | | fields(stream.first_msg = "<redacted>") |
407 | | )] |
408 | | async fn inner_write( |
409 | | &self, |
410 | | store: Store, |
411 | | digest: DigestInfo, |
412 | | stream: WriteRequestStreamWrapper<impl Stream<Item = Result<WriteRequest, Status>> + Unpin>, |
413 | | ) -> Result<Response<WriteResponse>, Error> { |
414 | 14 | async fn process_client_stream( |
415 | 14 | mut stream: WriteRequestStreamWrapper< |
416 | 14 | impl Stream<Item = Result<WriteRequest, Status>> + Unpin, |
417 | 14 | >, |
418 | 14 | tx: &mut DropCloserWriteHalf, |
419 | 14 | outer_bytes_received: &Arc<AtomicU64>, |
420 | 14 | expected_size: u64, |
421 | 14 | ) -> Result<(), Error> { |
422 | | loop { |
423 | 24 | let write_request20 = match stream.next().await { |
424 | | // Code path for when client tries to gracefully close the stream. |
425 | | // If this happens it means there's a problem with the data sent, |
426 | | // because we always close the stream from our end before this point |
427 | | // by counting the number of bytes sent from the client. If they send |
428 | | // less than the amount they said they were going to send and then |
429 | | // close the stream, we know there's a problem. |
430 | | None => { |
431 | 0 | return Err(make_input_err!( |
432 | 0 | "Client closed stream before sending all data" |
433 | 0 | )); |
434 | | } |
435 | | // Code path for client stream error. Probably client disconnect. |
436 | 4 | Some(Err(err)) => return Err(err), |
437 | | // Code path for received chunk of data. |
438 | 20 | Some(Ok(write_request)) => write_request, |
439 | 20 | }; |
440 | 20 | |
441 | 20 | if write_request.write_offset < 0 { Branch (441:20): [True: 1, False: 19]
Branch (441:20): [Folded - Ignored]
|
442 | 1 | return Err(make_input_err!( |
443 | 1 | "Invalid negative write offset in write request: {}", |
444 | 1 | write_request.write_offset |
445 | 1 | )); |
446 | 19 | } |
447 | 19 | let write_offset = write_request.write_offset as u64; |
448 | | |
449 | | // If we get duplicate data because a client didn't know where |
450 | | // it left off from, then we can simply skip it. |
451 | 19 | let data18 = if write_offset < tx.get_bytes_written() { Branch (451:31): [True: 2, False: 17]
Branch (451:31): [Folded - Ignored]
|
452 | 2 | if (write_offset + write_request.data.len() as u64) < tx.get_bytes_written() { Branch (452:24): [True: 0, False: 2]
Branch (452:24): [Folded - Ignored]
|
453 | 0 | if write_request.finish_write { Branch (453:28): [True: 0, False: 0]
Branch (453:28): [Folded - Ignored]
|
454 | 0 | return Err(make_input_err!( |
455 | 0 | "Resumed stream finished at {} bytes when we already received {} bytes.", |
456 | 0 | write_offset + write_request.data.len() as u64, |
457 | 0 | tx.get_bytes_written() |
458 | 0 | )); |
459 | 0 | } |
460 | 0 | continue; |
461 | 2 | } |
462 | 2 | write_request |
463 | 2 | .data |
464 | 2 | .slice((tx.get_bytes_written() - write_offset) as usize..) |
465 | | } else { |
466 | 17 | if write_offset != tx.get_bytes_written() { Branch (466:24): [True: 1, False: 16]
Branch (466:24): [Folded - Ignored]
|
467 | 1 | return Err(make_input_err!( |
468 | 1 | "Received out of order data. Got {}, expected {}", |
469 | 1 | write_offset, |
470 | 1 | tx.get_bytes_written() |
471 | 1 | )); |
472 | 16 | } |
473 | 16 | write_request.data |
474 | | }; |
475 | | |
476 | | // Do not process EOF or weird stuff will happen. |
477 | 18 | if !data.is_empty() { Branch (477:20): [True: 13, False: 5]
Branch (477:20): [Folded - Ignored]
|
478 | | // We also need to process the possible EOF branch, so we can't early return. |
479 | 13 | if let Err(mut err0 ) = tx.send(data).await { Branch (479:28): [True: 0, False: 13]
Branch (479:28): [Folded - Ignored]
|
480 | 0 | err.code = Code::Internal; |
481 | 0 | return Err(err); |
482 | 13 | } |
483 | 13 | outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release); |
484 | 5 | } |
485 | | |
486 | 18 | if expected_size < tx.get_bytes_written() { Branch (486:20): [True: 0, False: 18]
Branch (486:20): [Folded - Ignored]
|
487 | 0 | return Err(make_input_err!("Received more bytes than expected")); |
488 | 18 | } |
489 | 18 | if write_request.finish_write { Branch (489:20): [True: 8, False: 10]
Branch (489:20): [Folded - Ignored]
|
490 | | // Gracefully close our stream. |
491 | 8 | tx.send_eof() |
492 | 8 | .err_tip(|| "Failed to send EOF in ByteStream::write"0 )?0 ; |
493 | 8 | return Ok(()); |
494 | 10 | } |
495 | | // Continue. |
496 | | } |
497 | | // Unreachable. |
498 | 14 | } |
499 | | |
500 | | let uuid = stream |
501 | | .resource_info |
502 | | .uuid |
503 | | .as_ref() |
504 | 0 | .ok_or_else(|| make_input_err!("UUID must be set if writing data"))? |
505 | | .to_string(); |
506 | | let mut active_stream_guard = self.create_or_join_upload_stream(uuid, store, digest)?; |
507 | | let expected_size = stream.resource_info.expected_size as u64; |
508 | | |
509 | | let active_stream = active_stream_guard.stream_state.as_mut().unwrap(); |
510 | | try_join!( |
511 | | process_client_stream( |
512 | | stream, |
513 | | &mut active_stream.tx, |
514 | | &active_stream_guard.bytes_received, |
515 | | expected_size |
516 | | ), |
517 | | (&mut active_stream.store_update_fut) |
518 | 0 | .map_err(|err| { err.append("Error updating inner store") }) |
519 | | )?; |
520 | | |
521 | | // Close our guard and consider the stream no longer active. |
522 | | active_stream_guard.graceful_finish(); |
523 | | |
524 | | Ok(Response::new(WriteResponse { |
525 | | committed_size: expected_size as i64, |
526 | | })) |
527 | | } |
528 | | |
529 | 3 | async fn inner_query_write_status( |
530 | 3 | &self, |
531 | 3 | query_request: &QueryWriteStatusRequest, |
532 | 3 | ) -> Result<Response<QueryWriteStatusResponse>, Error> { |
533 | 3 | let mut resource_info = ResourceInfo::new(&query_request.resource_name, true)?0 ; |
534 | | |
535 | 3 | let store_clone = self |
536 | 3 | .stores |
537 | 3 | .get(resource_info.instance_name.as_ref()) |
538 | 3 | .err_tip(|| { |
539 | 0 | format!( |
540 | 0 | "'instance_name' not configured for '{}'", |
541 | 0 | &resource_info.instance_name |
542 | 0 | ) |
543 | 0 | })? |
544 | 3 | .clone(); |
545 | | |
546 | 3 | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?0 ; |
547 | | |
548 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
549 | 3 | if let Some(grpc_store0 ) = store_clone.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (549:16): [True: 0, False: 3]
Branch (549:16): [Folded - Ignored]
|
550 | 0 | return grpc_store |
551 | 0 | .query_write_status(Request::new(query_request.clone())) |
552 | 0 | .await; |
553 | 3 | } |
554 | | |
555 | 3 | let uuid = resource_info |
556 | 3 | .uuid |
557 | 3 | .take() |
558 | 3 | .ok_or_else(|| make_input_err!("UUID must be set if querying write status")0 )?0 ; |
559 | | |
560 | | { |
561 | 3 | let active_uploads = self.active_uploads.lock(); |
562 | 3 | if let Some((received_bytes, _maybe_idle_stream1 )) = active_uploads.get(uuid.as_ref()) { Branch (562:20): [True: 1, False: 2]
Branch (562:20): [Folded - Ignored]
|
563 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
564 | 1 | committed_size: received_bytes.load(Ordering::Acquire) as i64, |
565 | 1 | // If we are in the active_uploads map, but the value is None, |
566 | 1 | // it means the stream is not complete. |
567 | 1 | complete: false, |
568 | 1 | })); |
569 | 2 | } |
570 | 2 | } |
571 | 2 | |
572 | 2 | let has_fut = store_clone.has(digest); |
573 | 2 | let Some(item_size1 ) = has_fut.await.err_tip(|| "Failed to call .has() on store"0 )?0 else { Branch (573:13): [True: 1, False: 1]
Branch (573:13): [Folded - Ignored]
|
574 | | // We lie here and say that the stream needs to start over, even though |
575 | | // it was never started. This can happen when the client disconnects |
576 | | // before sending the first payload, but the client thinks it did send |
577 | | // the payload. |
578 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
579 | 1 | committed_size: 0, |
580 | 1 | complete: false, |
581 | 1 | })); |
582 | | }; |
583 | 1 | Ok(Response::new(QueryWriteStatusResponse { |
584 | 1 | committed_size: item_size as i64, |
585 | 1 | complete: true, |
586 | 1 | })) |
587 | 3 | } |
588 | | } |
589 | | |
590 | | #[tonic::async_trait] |
591 | | impl ByteStream for ByteStreamServer { |
592 | | type ReadStream = ReadStream; |
593 | | |
594 | | #[instrument( |
595 | | err, |
596 | | level = Level::ERROR, |
597 | | skip_all, |
598 | | fields(request = ?grpc_request.get_ref()) |
599 | | )] |
600 | | async fn read( |
601 | | &self, |
602 | | grpc_request: Request<ReadRequest>, |
603 | 6 | ) -> Result<Response<Self::ReadStream>, Status> { |
604 | 3 | let read_request = grpc_request.into_inner(); |
605 | 3 | let ctx = OriginEventContext::new(|| &read_request0 ).await; |
606 | | |
607 | 3 | let resource_info = ResourceInfo::new(&read_request.resource_name, false)?0 ; |
608 | 3 | let instance_name = resource_info.instance_name.as_ref(); |
609 | 3 | let store = self |
610 | 3 | .stores |
611 | 3 | .get(instance_name) |
612 | 3 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'")0 )?0 |
613 | 3 | .clone(); |
614 | | |
615 | 3 | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?0 ; |
616 | | |
617 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
618 | 3 | if let Some(grpc_store0 ) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (618:16): [True: 0, False: 3]
Branch (618:16): [Folded - Ignored]
|
619 | 0 | let stream = grpc_store.read(Request::new(read_request)).await?; |
620 | 0 | let resp = Ok(Response::new(ctx.wrap_stream(stream))); |
621 | 0 | ctx.emit(|| &resp).await; |
622 | 0 | return resp; |
623 | 3 | } |
624 | | |
625 | 3 | let digest_function = resource_info.digest_function.as_deref().map_or_else( |
626 | 3 | || Ok(default_digest_hasher_func()), |
627 | | DigestHasherFunc::try_from, |
628 | 0 | )?; |
629 | | |
630 | 3 | let resp = make_ctx_for_hash_func(digest_function) |
631 | 3 | .err_tip(|| "In BytestreamServer::read"0 )?0 |
632 | 3 | .wrap_async( |
633 | 3 | error_span!("bytestream_read"), |
634 | 3 | self.inner_read(store, digest, read_request), |
635 | 3 | ) |
636 | 3 | .await |
637 | 3 | .err_tip(|| "In ByteStreamServer::read"0 ) |
638 | 3 | .map(|stream| -> Response<Self::ReadStream> { |
639 | 3 | Response::new(Box::pin(ctx.wrap_stream(stream))) |
640 | 3 | }) |
641 | 3 | .map_err(Into::into); |
642 | 3 | |
643 | 3 | if resp.is_ok() { Branch (643:12): [True: 3, False: 0]
Branch (643:12): [Folded - Ignored]
|
644 | 3 | event!(Level::DEBUG, return = "Ok(<stream>)"); |
645 | 0 | } |
646 | 3 | ctx.emit(|| &resp0 ).await; |
647 | 3 | resp |
648 | 6 | } |
649 | | |
650 | | #[instrument( |
651 | | err, |
652 | | level = Level::ERROR, |
653 | | skip_all, |
654 | | fields(request = ?grpc_request.get_ref()) |
655 | | )] |
656 | | async fn write( |
657 | | &self, |
658 | | grpc_request: Request<Streaming<WriteRequest>>, |
659 | 30 | ) -> Result<Response<WriteResponse>, Status> { |
660 | 15 | let request = grpc_request.into_inner(); |
661 | 15 | let ctx = OriginEventContext::new(|| &request0 ).await; |
662 | 15 | let stream14 = WriteRequestStreamWrapper::from(ctx.wrap_stream(request)) |
663 | 15 | .await |
664 | 15 | .err_tip(|| "Could not unwrap first stream message"1 ) |
665 | 15 | .map_err(Into::<Status>::into)?1 ; |
666 | | |
667 | 14 | let instance_name = stream.resource_info.instance_name.as_ref(); |
668 | 14 | let store = self |
669 | 14 | .stores |
670 | 14 | .get(instance_name) |
671 | 14 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'")0 )?0 |
672 | 14 | .clone(); |
673 | | |
674 | 14 | let digest = DigestInfo::try_new( |
675 | 14 | &stream.resource_info.hash, |
676 | 14 | stream.resource_info.expected_size, |
677 | 14 | ) |
678 | 14 | .err_tip(|| "Invalid digest input in ByteStream::write"0 )?0 ; |
679 | | |
680 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
681 | 14 | if let Some(grpc_store0 ) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (681:16): [True: 0, False: 14]
Branch (681:16): [Folded - Ignored]
|
682 | 0 | let resp = grpc_store.write(stream).await.map_err(Into::into); |
683 | 0 | ctx.emit(|| &resp).await; |
684 | 0 | return resp; |
685 | 14 | } |
686 | | |
687 | 14 | let digest_function = stream |
688 | 14 | .resource_info |
689 | 14 | .digest_function |
690 | 14 | .as_deref() |
691 | 14 | .map_or_else( |
692 | 14 | || Ok(default_digest_hasher_func()), |
693 | | DigestHasherFunc::try_from, |
694 | 0 | )?; |
695 | | |
696 | 14 | let resp = make_ctx_for_hash_func(digest_function) |
697 | 14 | .err_tip(|| "In BytestreamServer::write"0 )?0 |
698 | 14 | .wrap_async( |
699 | 14 | error_span!("bytestream_write"), |
700 | 14 | self.inner_write(store, digest, stream), |
701 | 14 | ) |
702 | 14 | .await |
703 | 14 | .err_tip(|| "In ByteStreamServer::write"6 ) |
704 | 14 | .map_err(Into::into); |
705 | 14 | ctx.emit(|| &resp0 ).await; |
706 | 14 | resp |
707 | 30 | } |
708 | | |
709 | | #[instrument( |
710 | | err, |
711 | | ret(level = Level::INFO), |
712 | | level = Level::ERROR, |
713 | | skip_all, |
714 | | fields(request = ?grpc_request.get_ref()) |
715 | | )] |
716 | | async fn query_write_status( |
717 | | &self, |
718 | | grpc_request: Request<QueryWriteStatusRequest>, |
719 | 6 | ) -> Result<Response<QueryWriteStatusResponse>, Status> { |
720 | 3 | let request = grpc_request.into_inner(); |
721 | 3 | let ctx = OriginEventContext::new(|| &request0 ).await; |
722 | 3 | let resp = self |
723 | 3 | .inner_query_write_status(&request) |
724 | 3 | .await |
725 | 3 | .err_tip(|| "Failed on query_write_status() command"0 ) |
726 | 3 | .map_err(Into::into); |
727 | 3 | ctx.emit(|| &resp0 ).await; |
728 | 3 | resp |
729 | 6 | } |
730 | | } |