/build/source/nativelink-service/src/bytestream_server.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::collections::hash_map::Entry; |
16 | | use std::collections::HashMap; |
17 | | use std::convert::Into; |
18 | | use std::fmt::{Debug, Formatter}; |
19 | | use std::pin::Pin; |
20 | | use std::sync::atomic::{AtomicU64, Ordering}; |
21 | | use std::sync::Arc; |
22 | | use std::time::Duration; |
23 | | |
24 | | use futures::future::{pending, BoxFuture}; |
25 | | use futures::stream::unfold; |
26 | | use futures::{try_join, Future, Stream, TryFutureExt}; |
27 | | use nativelink_config::cas_server::ByteStreamConfig; |
28 | | use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; |
29 | | use nativelink_proto::google::bytestream::byte_stream_server::{ |
30 | | ByteStream, ByteStreamServer as Server, |
31 | | }; |
32 | | use nativelink_proto::google::bytestream::{ |
33 | | QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest, |
34 | | WriteResponse, |
35 | | }; |
36 | | use nativelink_store::grpc_store::GrpcStore; |
37 | | use nativelink_store::store_manager::StoreManager; |
38 | | use nativelink_util::buf_channel::{ |
39 | | make_buf_channel_pair, DropCloserReadHalf, DropCloserWriteHalf, |
40 | | }; |
41 | | use nativelink_util::common::DigestInfo; |
42 | | use nativelink_util::digest_hasher::{ |
43 | | default_digest_hasher_func, make_ctx_for_hash_func, DigestHasherFunc, |
44 | | }; |
45 | | use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper; |
46 | | use nativelink_util::resource_info::ResourceInfo; |
47 | | use nativelink_util::spawn; |
48 | | use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo}; |
49 | | use nativelink_util::task::JoinHandleDropGuard; |
50 | | use parking_lot::Mutex; |
51 | | use tokio::time::sleep; |
52 | | use tonic::{Request, Response, Status, Streaming}; |
53 | | use tracing::{enabled, error_span, event, instrument, Instrument, Level}; |
54 | | |
55 | | /// If this value changes update the documentation in the config definition. |
56 | | const DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT: Duration = Duration::from_secs(60); |
57 | | |
58 | | /// If this value changes update the documentation in the config definition. |
59 | | const DEFAULT_MAX_BYTES_PER_STREAM: usize = 64 * 1024; |
60 | | |
61 | | /// If this value changes update the documentation in the config definition. |
62 | | const DEFAULT_MAX_DECODING_MESSAGE_SIZE: usize = 4 * 1024 * 1024; |
63 | | |
64 | | type ReadStream = Pin<Box<dyn Stream<Item = Result<ReadResponse, Status>> + Send + 'static>>; |
65 | | type StoreUpdateFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>; |
66 | | |
67 | | struct StreamState { |
68 | | uuid: String, |
69 | | tx: DropCloserWriteHalf, |
70 | | store_update_fut: StoreUpdateFuture, |
71 | | } |
72 | | |
73 | | impl Debug for StreamState { |
74 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
75 | 0 | f.debug_struct("StreamState") |
76 | 0 | .field("uuid", &self.uuid) |
77 | 0 | .finish() |
78 | 0 | } |
79 | | } |
80 | | |
81 | | /// If a stream is in this state, it will automatically be put back into an `IdleStream` and |
82 | | /// placed back into the `active_uploads` map as an `IdleStream` after it is dropped. |
83 | | /// To prevent it from being put back into an `IdleStream` you must call `.graceful_finish()`. |
84 | | struct ActiveStreamGuard<'a> { |
85 | | stream_state: Option<StreamState>, |
86 | | bytes_received: Arc<AtomicU64>, |
87 | | bytestream_server: &'a ByteStreamServer, |
88 | | } |
89 | | |
90 | | impl<'a> ActiveStreamGuard<'a> { |
91 | | /// Consumes the guard. The stream will be considered "finished", will |
92 | | /// remove it from the active_uploads. |
93 | 8 | fn graceful_finish(mut self) { |
94 | 8 | let stream_state = self.stream_state.take().unwrap(); |
95 | 8 | self.bytestream_server |
96 | 8 | .active_uploads |
97 | 8 | .lock() |
98 | 8 | .remove(&stream_state.uuid); |
99 | 8 | } |
100 | | } |
101 | | |
102 | | impl<'a> Drop for ActiveStreamGuard<'a> { |
103 | 14 | fn drop(&mut self) { |
104 | 14 | let Some(stream_state6 ) = self.stream_state.take() else { Branch (104:13): [True: 6, False: 8]
Branch (104:13): [Folded - Ignored]
|
105 | 8 | return; // If None it means we don't want it put back into an IdleStream. |
106 | | }; |
107 | 6 | let weak_active_uploads = Arc::downgrade(&self.bytestream_server.active_uploads); |
108 | 6 | let mut active_uploads = self.bytestream_server.active_uploads.lock(); |
109 | 6 | let uuid = stream_state.uuid.clone(); |
110 | 6 | let Some(active_uploads_slot) = active_uploads.get_mut(&uuid) else { Branch (110:13): [True: 6, False: 0]
Branch (110:13): [Folded - Ignored]
|
111 | 0 | event!( |
112 | 0 | Level::ERROR, |
113 | | err = "Failed to find active upload. This should never happen.", |
114 | | uuid = ?uuid, |
115 | | ); |
116 | 0 | return; |
117 | | }; |
118 | 6 | let sleep_fn = self.bytestream_server.sleep_fn.clone(); |
119 | 6 | active_uploads_slot.1 = Some(IdleStream { |
120 | 6 | stream_state, |
121 | 6 | _timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move { |
122 | 3 | (*sleep_fn)().await0 ; |
123 | 0 | if let Some(active_uploads) = weak_active_uploads.upgrade() { Branch (123:24): [True: 0, False: 0]
Branch (123:24): [Folded - Ignored]
|
124 | 0 | let mut active_uploads = active_uploads.lock(); |
125 | 0 | event!(Level::INFO, msg = "Removing idle stream", uuid = ?uuid); |
126 | 0 | active_uploads.remove(&uuid); |
127 | 0 | } |
128 | 6 | }0 ), |
129 | | }); |
130 | 14 | } |
131 | | } |
132 | | |
133 | | /// Represents a stream that is in the "idle" state. this means it is not currently being used |
134 | | /// by a client. If it is not used within a certain amount of time it will be removed from the |
135 | | /// `active_uploads` map automatically. |
136 | | #[derive(Debug)] |
137 | | struct IdleStream { |
138 | | stream_state: StreamState, |
139 | | _timeout_streaam_drop_guard: JoinHandleDropGuard<()>, |
140 | | } |
141 | | |
142 | | impl IdleStream { |
143 | 3 | fn into_active_stream( |
144 | 3 | self, |
145 | 3 | bytes_received: Arc<AtomicU64>, |
146 | 3 | bytestream_server: &ByteStreamServer, |
147 | 3 | ) -> ActiveStreamGuard<'_> { |
148 | 3 | ActiveStreamGuard { |
149 | 3 | stream_state: Some(self.stream_state), |
150 | 3 | bytes_received, |
151 | 3 | bytestream_server, |
152 | 3 | } |
153 | 3 | } |
154 | | } |
155 | | |
156 | | type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>); |
157 | | type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>; |
158 | | |
159 | | pub struct ByteStreamServer { |
160 | | stores: HashMap<String, Store>, |
161 | | // Max number of bytes to send on each grpc stream chunk. |
162 | | max_bytes_per_stream: usize, |
163 | | max_decoding_message_size: usize, |
164 | | active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>, |
165 | | sleep_fn: SleepFn, |
166 | | } |
167 | | |
168 | | impl ByteStreamServer { |
169 | 14 | pub fn new(config: &ByteStreamConfig, store_manager: &StoreManager) -> Result<Self, Error> { |
170 | 14 | let mut persist_stream_on_disconnect_timeout = |
171 | 14 | Duration::from_secs(config.persist_stream_on_disconnect_timeout as u64); |
172 | 14 | if config.persist_stream_on_disconnect_timeout == 0 { Branch (172:12): [True: 14, False: 0]
Branch (172:12): [Folded - Ignored]
|
173 | 14 | persist_stream_on_disconnect_timeout = DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT; |
174 | 14 | }0 |
175 | 14 | Self::new_with_sleep_fn( |
176 | 14 | config, |
177 | 14 | store_manager, |
178 | 14 | Arc::new(move || Box::pin(sleep(persist_stream_on_disconnect_timeout))3 ), |
179 | 14 | ) |
180 | 14 | } |
181 | | |
182 | 14 | pub fn new_with_sleep_fn( |
183 | 14 | config: &ByteStreamConfig, |
184 | 14 | store_manager: &StoreManager, |
185 | 14 | sleep_fn: SleepFn, |
186 | 14 | ) -> Result<Self, Error> { |
187 | 14 | let mut stores = HashMap::with_capacity(config.cas_stores.len()); |
188 | 28 | for (instance_name, store_name14 ) in &config.cas_stores { |
189 | 14 | let store = store_manager |
190 | 14 | .get_store(store_name) |
191 | 14 | .ok_or_else(|| make_input_err!("'cas_store': '{}' does not exist", store_name)0 )?0 ; |
192 | 14 | stores.insert(instance_name.to_string(), store); |
193 | | } |
194 | 14 | let max_bytes_per_stream = if config.max_bytes_per_stream == 0 { Branch (194:39): [True: 1, False: 13]
Branch (194:39): [Folded - Ignored]
|
195 | 1 | DEFAULT_MAX_BYTES_PER_STREAM |
196 | | } else { |
197 | 13 | config.max_bytes_per_stream |
198 | | }; |
199 | 14 | let max_decoding_message_size = if config.max_decoding_message_size == 0 { Branch (199:44): [True: 13, False: 1]
Branch (199:44): [Folded - Ignored]
|
200 | 13 | DEFAULT_MAX_DECODING_MESSAGE_SIZE |
201 | | } else { |
202 | 1 | config.max_decoding_message_size |
203 | | }; |
204 | 14 | Ok(ByteStreamServer { |
205 | 14 | stores, |
206 | 14 | max_bytes_per_stream, |
207 | 14 | max_decoding_message_size, |
208 | 14 | active_uploads: Arc::new(Mutex::new(HashMap::new())), |
209 | 14 | sleep_fn, |
210 | 14 | }) |
211 | 14 | } |
212 | | |
213 | 1 | pub fn into_service(self) -> Server<Self> { |
214 | 1 | let max_decoding_message_size = self.max_decoding_message_size; |
215 | 1 | Server::new(self).max_decoding_message_size(max_decoding_message_size) |
216 | 1 | } |
217 | | |
218 | 14 | fn create_or_join_upload_stream( |
219 | 14 | &self, |
220 | 14 | uuid: String, |
221 | 14 | store: Store, |
222 | 14 | digest: DigestInfo, |
223 | 14 | ) -> Result<ActiveStreamGuard<'_>, Error> { |
224 | 14 | let (uuid, bytes_received11 ) = match self.active_uploads.lock().entry(uuid) { |
225 | 3 | Entry::Occupied(mut entry) => { |
226 | 3 | let maybe_idle_stream = entry.get_mut(); |
227 | 3 | let Some(idle_stream) = maybe_idle_stream.1.take() else { Branch (227:21): [True: 3, False: 0]
Branch (227:21): [Folded - Ignored]
|
228 | 0 | return Err(make_input_err!("Cannot upload same UUID simultaneously")); |
229 | | }; |
230 | 3 | let bytes_received = maybe_idle_stream.0.clone(); |
231 | 3 | event!(Level::INFO, msg = "Joining existing stream", entry = ?entry.key()0 ); |
232 | 3 | return Ok(idle_stream.into_active_stream(bytes_received, self)); |
233 | | } |
234 | 11 | Entry::Vacant(entry) => { |
235 | 11 | let bytes_received = Arc::new(AtomicU64::new(0)); |
236 | 11 | let uuid = entry.key().clone(); |
237 | 11 | // Our stream is "in use" if the key is in the map, but the value is None. |
238 | 11 | entry.insert((bytes_received.clone(), None)); |
239 | 11 | (uuid, bytes_received) |
240 | 11 | } |
241 | 11 | }; |
242 | 11 | |
243 | 11 | // Important: Do not return an error from this point onwards without |
244 | 11 | // removing the entry from the map, otherwise that UUID becomes |
245 | 11 | // unusable. |
246 | 11 | |
247 | 11 | let (tx, rx) = make_buf_channel_pair(); |
248 | 11 | let store_update_fut = Box::pin(async move { |
249 | 8 | // We need to wrap `Store::update()` in a another future because we need to capture |
250 | 8 | // `store` to ensure its lifetime follows the future and not the caller. |
251 | 8 | store |
252 | 8 | // Bytestream always uses digest size as the actual byte size. |
253 | 8 | .update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes())) |
254 | 101 | .await |
255 | 11 | }8 ); |
256 | 11 | Ok(ActiveStreamGuard { |
257 | 11 | stream_state: Some(StreamState { |
258 | 11 | uuid, |
259 | 11 | tx, |
260 | 11 | store_update_fut, |
261 | 11 | }), |
262 | 11 | bytes_received, |
263 | 11 | bytestream_server: self, |
264 | 11 | }) |
265 | 14 | } |
266 | | |
267 | 3 | async fn inner_read( |
268 | 3 | &self, |
269 | 3 | store: Store, |
270 | 3 | digest: DigestInfo, |
271 | 3 | read_request: ReadRequest, |
272 | 3 | ) -> Result<Response<ReadStream>, Error> { |
273 | | struct ReaderState { |
274 | | max_bytes_per_stream: usize, |
275 | | rx: DropCloserReadHalf, |
276 | | maybe_get_part_result: Option<Result<(), Error>>, |
277 | | get_part_fut: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>, |
278 | | } |
279 | | |
280 | 3 | let read_limit = u64::try_from(read_request.read_limit) |
281 | 3 | .err_tip(|| "Could not convert read_limit to u64"0 )?0 ; |
282 | | |
283 | 3 | let (tx, rx) = make_buf_channel_pair(); |
284 | | |
285 | 3 | let read_limit = if read_limit != 0 { Branch (285:29): [True: 3, False: 0]
Branch (285:29): [Folded - Ignored]
|
286 | 3 | Some(read_limit) |
287 | | } else { |
288 | 0 | None |
289 | | }; |
290 | | |
291 | | // This allows us to call a destructor when the the object is dropped. |
292 | 3 | let state = Some(ReaderState { |
293 | 3 | rx, |
294 | 3 | max_bytes_per_stream: self.max_bytes_per_stream, |
295 | 3 | maybe_get_part_result: None, |
296 | 3 | get_part_fut: Box::pin(async move { |
297 | 3 | store |
298 | 3 | .get_part( |
299 | 3 | digest, |
300 | 3 | tx, |
301 | 3 | u64::try_from(read_request.read_offset) |
302 | 3 | .err_tip(|| "Could not convert read_offset to u64"0 )?0 , |
303 | 3 | read_limit, |
304 | | ) |
305 | 0 | .await |
306 | 3 | }), |
307 | 3 | }); |
308 | | |
309 | 3 | let read_stream_span = error_span!("read_stream"); |
310 | | |
311 | 9.77k | Ok(Response::new(Box::pin(unfold(state, 3 move |state| { |
312 | 9.77k | async { |
313 | 9.77k | let mut state9.77k = state?2 ; // If None our stream is done. |
314 | 9.77k | let mut response = ReadResponse::default(); |
315 | 9.77k | { |
316 | 9.77k | let consume_fut = state.rx.consume(Some(state.max_bytes_per_stream)); |
317 | 9.77k | tokio::pin!(consume_fut); |
318 | | loop { |
319 | 9.77k | tokio::select! { |
320 | 9.77k | read_result9.77k = &mut consume_fut => { |
321 | 9.77k | match read_result { |
322 | 9.76k | Ok(bytes) => { |
323 | 9.76k | if bytes.is_empty() { Branch (323:40): [True: 2, False: 9.76k]
Branch (323:40): [Folded - Ignored]
|
324 | | // EOF. |
325 | 2 | return Some((Ok(response), None)); |
326 | 9.76k | } |
327 | 9.76k | if bytes.len() > state.max_bytes_per_stream { Branch (327:40): [True: 0, False: 9.76k]
Branch (327:40): [Folded - Ignored]
|
328 | 0 | let err = make_err!(Code::Internal, "Returned store size was larger than read size"); |
329 | 0 | return Some((Err(err.into()), None)); |
330 | 9.76k | } |
331 | 9.76k | response.data = bytes; |
332 | 9.76k | if enabled!(Level::DEBUG) { |
333 | 0 | event!(Level::INFO, response = ?response); |
334 | | } else { |
335 | 9.76k | event!(Level::INFO, response.data = format!("<redacted len({})>", response.data.len())0 ); |
336 | | } |
337 | 9.76k | break; |
338 | | } |
339 | 1 | Err(mut e) => { |
340 | | // We may need to propagate the error from reading the data through first. |
341 | | // For example, the NotFound error will come through `get_part_fut`, and |
342 | | // will not be present in `e`, but we need to ensure we pass NotFound error |
343 | | // code or the client won't know why it failed. |
344 | 1 | let get_part_result = if let Some(result) = state.maybe_get_part_result { Branch (344:66): [True: 1, False: 0]
Branch (344:66): [Folded - Ignored]
|
345 | 1 | result |
346 | | } else { |
347 | | // This should never be `future::pending()` if maybe_get_part_result is |
348 | | // not set. |
349 | 0 | state.get_part_fut.await |
350 | | }; |
351 | 1 | if let Err(err) = get_part_result { Branch (351:44): [True: 1, False: 0]
Branch (351:44): [Folded - Ignored]
|
352 | 1 | e = err.merge(e); |
353 | 1 | }0 |
354 | 1 | if e.code == Code::NotFound { Branch (354:40): [True: 1, False: 0]
Branch (354:40): [Folded - Ignored]
|
355 | 1 | // Trim the error code. Not Found is quite common and we don't want to send a large |
356 | 1 | // error (debug) message for something that is common. We resize to just the last |
357 | 1 | // message as it will be the most relevant. |
358 | 1 | e.messages.truncate(1); |
359 | 1 | }0 |
360 | 1 | event!(Level::ERROR, response = ?e); |
361 | 1 | return Some((Err(e.into()), None)) |
362 | | } |
363 | | } |
364 | | }, |
365 | 9.77k | result3 = &mut state.get_part_fut => { |
366 | 3 | state.maybe_get_part_result = Some(result); |
367 | 3 | // It is non-deterministic on which future will finish in what order. |
368 | 3 | // It is also possible that the `state.rx.consume()` call above may not be able to |
369 | 3 | // respond even though the publishing future is done. |
370 | 3 | // Because of this we set the writing future to pending so it never finishes. |
371 | 3 | // The `state.rx.consume()` future will eventually finish and return either the |
372 | 3 | // data or an error. |
373 | 3 | // An EOF will terminate the `state.rx.consume()` future, but we are also protected |
374 | 3 | // because we are dropping the writing future, it will drop the `tx` channel |
375 | 3 | // which will eventually propagate an error to the `state.rx.consume()` future if |
376 | 3 | // the EOF was not sent due to some other error. |
377 | 3 | state.get_part_fut = Box::pin(pending()); |
378 | 3 | }, |
379 | | } |
380 | | } |
381 | | } |
382 | 9.76k | Some((Ok(response), Some(state))) |
383 | 9.77k | }.instrument(read_stream_span.clone()) |
384 | 9.77k | }))))3 |
385 | 3 | } |
386 | | |
387 | | // We instrument tracing here as well as below because `stream` has a hash on it |
388 | | // that is extracted from the first stream message. If we only implemented it below |
389 | | // we would not have the hash available to us. |
390 | 28 | #[instrument( |
391 | | ret(level = Level::INFO), |
392 | 14 | level = Level::ERROR, |
393 | | skip(self, store), |
394 | | fields(stream.first_msg = "<redacted>") |
395 | 28 | )] |
396 | | async fn inner_write( |
397 | | &self, |
398 | | store: Store, |
399 | | digest: DigestInfo, |
400 | | stream: WriteRequestStreamWrapper<Streaming<WriteRequest>, Status>, |
401 | | ) -> Result<Response<WriteResponse>, Error> { |
402 | 14 | async fn process_client_stream( |
403 | 14 | mut stream: WriteRequestStreamWrapper<Streaming<WriteRequest>, Status>, |
404 | 14 | tx: &mut DropCloserWriteHalf, |
405 | 14 | outer_bytes_received: &Arc<AtomicU64>, |
406 | 14 | expected_size: u64, |
407 | 14 | ) -> Result<(), Error> { |
408 | | loop { |
409 | 101 | let write_request20 = match stream.next()24 .await { |
410 | | // Code path for when client tries to gracefully close the stream. |
411 | | // If this happens it means there's a problem with the data sent, |
412 | | // because we always close the stream from our end before this point |
413 | | // by counting the number of bytes sent from the client. If they send |
414 | | // less than the amount they said they were going to send and then |
415 | | // close the stream, we know there's a problem. |
416 | | None => { |
417 | 0 | return Err(make_input_err!( |
418 | 0 | "Client closed stream before sending all data" |
419 | 0 | )) |
420 | | } |
421 | | // Code path for client stream error. Probably client disconnect. |
422 | 4 | Some(Err(err)) => return Err(err), |
423 | | // Code path for received chunk of data. |
424 | 20 | Some(Ok(write_request)) => write_request, |
425 | 20 | }; |
426 | 20 | |
427 | 20 | if write_request.write_offset < 0 { Branch (427:20): [True: 1, False: 19]
Branch (427:20): [Folded - Ignored]
|
428 | 1 | return Err(make_input_err!( |
429 | 1 | "Invalid negative write offset in write request: {}", |
430 | 1 | write_request.write_offset |
431 | 1 | )); |
432 | 19 | } |
433 | 19 | let write_offset = write_request.write_offset as u64; |
434 | | |
435 | | // If we get duplicate data because a client didn't know where |
436 | | // it left off from, then we can simply skip it. |
437 | 19 | let data18 = if write_offset < tx.get_bytes_written() { Branch (437:31): [True: 2, False: 17]
Branch (437:31): [Folded - Ignored]
|
438 | 2 | if (write_offset + write_request.data.len() as u64) < tx.get_bytes_written() { Branch (438:24): [True: 0, False: 2]
Branch (438:24): [Folded - Ignored]
|
439 | 0 | if write_request.finish_write { Branch (439:28): [True: 0, False: 0]
Branch (439:28): [Folded - Ignored]
|
440 | 0 | return Err(make_input_err!( |
441 | 0 | "Resumed stream finished at {} bytes when we already received {} bytes.", |
442 | 0 | write_offset + write_request.data.len() as u64, |
443 | 0 | tx.get_bytes_written() |
444 | 0 | )); |
445 | 0 | } |
446 | 0 | continue; |
447 | 2 | } |
448 | 2 | write_request |
449 | 2 | .data |
450 | 2 | .slice((tx.get_bytes_written() - write_offset) as usize..) |
451 | | } else { |
452 | 17 | if write_offset != tx.get_bytes_written() { Branch (452:24): [True: 1, False: 16]
Branch (452:24): [Folded - Ignored]
|
453 | 1 | return Err(make_input_err!( |
454 | 1 | "Received out of order data. Got {}, expected {}", |
455 | 1 | write_offset, |
456 | 1 | tx.get_bytes_written() |
457 | 1 | )); |
458 | 16 | } |
459 | 16 | write_request.data |
460 | | }; |
461 | | |
462 | | // Do not process EOF or weird stuff will happen. |
463 | 18 | if !data.is_empty() { Branch (463:20): [True: 13, False: 5]
Branch (463:20): [Folded - Ignored]
|
464 | | // We also need to process the possible EOF branch, so we can't early return. |
465 | 13 | if let Err(mut err0 ) = tx.send(data).await0 { Branch (465:28): [True: 0, False: 13]
Branch (465:28): [Folded - Ignored]
|
466 | 0 | err.code = Code::Internal; |
467 | 0 | return Err(err); |
468 | 13 | } |
469 | 13 | outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release); |
470 | 5 | } |
471 | | |
472 | 18 | if expected_size < tx.get_bytes_written() { Branch (472:20): [True: 0, False: 18]
Branch (472:20): [Folded - Ignored]
|
473 | 0 | return Err(make_input_err!("Received more bytes than expected")); |
474 | 18 | } |
475 | 18 | if write_request.finish_write { Branch (475:20): [True: 8, False: 10]
Branch (475:20): [Folded - Ignored]
|
476 | | // Gracefully close our stream. |
477 | 8 | tx.send_eof() |
478 | 8 | .err_tip(|| "Failed to send EOF in ByteStream::write"0 )?0 ; |
479 | 8 | return Ok(()); |
480 | 10 | } |
481 | | // Continue. |
482 | | } |
483 | | // Unreachable. |
484 | 14 | } |
485 | | |
486 | | let uuid = stream |
487 | | .resource_info |
488 | | .uuid |
489 | | .as_ref() |
490 | 0 | .ok_or_else(|| make_input_err!("UUID must be set if writing data"))? |
491 | | .to_string(); |
492 | | let mut active_stream_guard = self.create_or_join_upload_stream(uuid, store, digest)?; |
493 | | let expected_size = stream.resource_info.expected_size as u64; |
494 | | |
495 | | let active_stream = active_stream_guard.stream_state.as_mut().unwrap(); |
496 | | try_join!( |
497 | | process_client_stream( |
498 | | stream, |
499 | | &mut active_stream.tx, |
500 | | &active_stream_guard.bytes_received, |
501 | | expected_size |
502 | | ), |
503 | | (&mut active_stream.store_update_fut) |
504 | 0 | .map_err(|err| { err.append("Error updating inner store") }) |
505 | | )?; |
506 | | |
507 | | // Close our guard and consider the stream no longer active. |
508 | | active_stream_guard.graceful_finish(); |
509 | | |
510 | | Ok(Response::new(WriteResponse { |
511 | | committed_size: expected_size as i64, |
512 | | })) |
513 | | } |
514 | | |
515 | 3 | async fn inner_query_write_status( |
516 | 3 | &self, |
517 | 3 | query_request: &QueryWriteStatusRequest, |
518 | 3 | ) -> Result<Response<QueryWriteStatusResponse>, Error> { |
519 | 3 | let mut resource_info = ResourceInfo::new(&query_request.resource_name, true)?0 ; |
520 | | |
521 | 3 | let store_clone = self |
522 | 3 | .stores |
523 | 3 | .get(resource_info.instance_name.as_ref()) |
524 | 3 | .err_tip(|| { |
525 | 0 | format!( |
526 | 0 | "'instance_name' not configured for '{}'", |
527 | 0 | &resource_info.instance_name |
528 | 0 | ) |
529 | 3 | })?0 |
530 | 3 | .clone(); |
531 | | |
532 | 3 | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?0 ; |
533 | | |
534 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
535 | 3 | if let Some(grpc_store0 ) = store_clone.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (535:16): [True: 0, False: 3]
Branch (535:16): [Folded - Ignored]
|
536 | 0 | return grpc_store |
537 | 0 | .query_write_status(Request::new(query_request.clone())) |
538 | 0 | .await; |
539 | 3 | } |
540 | | |
541 | 3 | let uuid = resource_info |
542 | 3 | .uuid |
543 | 3 | .take() |
544 | 3 | .ok_or_else(|| make_input_err!("UUID must be set if querying write status")0 )?0 ; |
545 | | |
546 | | { |
547 | 3 | let active_uploads = self.active_uploads.lock(); |
548 | 3 | if let Some((received_bytes, _maybe_idle_stream1 )) = active_uploads.get(uuid.as_ref()) { Branch (548:20): [True: 1, False: 2]
Branch (548:20): [Folded - Ignored]
|
549 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
550 | 1 | committed_size: received_bytes.load(Ordering::Acquire) as i64, |
551 | 1 | // If we are in the active_uploads map, but the value is None, |
552 | 1 | // it means the stream is not complete. |
553 | 1 | complete: false, |
554 | 1 | })); |
555 | 2 | } |
556 | 2 | } |
557 | 2 | |
558 | 2 | let has_fut = store_clone.has(digest); |
559 | 2 | let Some(item_size1 ) = has_fut.await0 .err_tip(|| "Failed to call .has() on store"0 )?0 else { Branch (559:13): [True: 1, False: 1]
Branch (559:13): [Folded - Ignored]
|
560 | | // We lie here and say that the stream needs to start over, even though |
561 | | // it was never started. This can happen when the client disconnects |
562 | | // before sending the first payload, but the client thinks it did send |
563 | | // the payload. |
564 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
565 | 1 | committed_size: 0, |
566 | 1 | complete: false, |
567 | 1 | })); |
568 | | }; |
569 | 1 | Ok(Response::new(QueryWriteStatusResponse { |
570 | 1 | committed_size: item_size as i64, |
571 | 1 | complete: true, |
572 | 1 | })) |
573 | 3 | } |
574 | | } |
575 | | |
576 | | #[tonic::async_trait] |
577 | | impl ByteStream for ByteStreamServer { |
578 | | type ReadStream = ReadStream; |
579 | | |
580 | | #[allow(clippy::blocks_in_conditions)] |
581 | 3 | #[instrument( |
582 | | err, |
583 | | level = Level::ERROR, |
584 | | skip_all, |
585 | | fields(request = ?grpc_request.get_ref()) |
586 | 6 | )] |
587 | | async fn read( |
588 | | &self, |
589 | | grpc_request: Request<ReadRequest>, |
590 | 3 | ) -> Result<Response<Self::ReadStream>, Status> { |
591 | 3 | let read_request = grpc_request.into_inner(); |
592 | | |
593 | 3 | let resource_info = ResourceInfo::new(&read_request.resource_name, false)?0 ; |
594 | 3 | let instance_name = resource_info.instance_name.as_ref(); |
595 | 3 | let store = self |
596 | 3 | .stores |
597 | 3 | .get(instance_name) |
598 | 3 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'")0 )?0 |
599 | 3 | .clone(); |
600 | | |
601 | 3 | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?0 ; |
602 | | |
603 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
604 | 3 | if let Some(grpc_store0 ) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (604:16): [True: 0, False: 3]
Branch (604:16): [Folded - Ignored]
|
605 | 0 | let stream = grpc_store.read(Request::new(read_request)).await?; |
606 | 0 | return Ok(Response::new(Box::pin(stream))); |
607 | 3 | } |
608 | | |
609 | 3 | let digest_function = resource_info.digest_function.as_deref().map_or_else( |
610 | 3 | || Ok(default_digest_hasher_func()), |
611 | 3 | DigestHasherFunc::try_from, |
612 | 3 | )?0 ; |
613 | | |
614 | 3 | let resp = make_ctx_for_hash_func(digest_function) |
615 | 3 | .err_tip(|| "In BytestreamServer::read"0 )?0 |
616 | | .wrap_async( |
617 | 3 | error_span!("bytestream_read"), |
618 | 3 | self.inner_read(store, digest, read_request), |
619 | | ) |
620 | 0 | .await |
621 | 3 | .err_tip(|| "In ByteStreamServer::read"0 ) |
622 | 3 | .map_err(Into::into); |
623 | 3 | |
624 | 3 | if resp.is_ok() { Branch (624:12): [True: 3, False: 0]
Branch (624:12): [Folded - Ignored]
|
625 | 3 | event!(Level::DEBUG, return = "Ok(<stream>)"); |
626 | 0 | } |
627 | 3 | resp |
628 | 6 | } |
629 | | |
630 | | #[allow(clippy::blocks_in_conditions)] |
631 | 15 | #[instrument( |
632 | | err, |
633 | | level = Level::ERROR, |
634 | | skip_all, |
635 | | fields(request = ?grpc_request.get_ref()) |
636 | 30 | )] |
637 | | async fn write( |
638 | | &self, |
639 | | grpc_request: Request<Streaming<WriteRequest>>, |
640 | 15 | ) -> Result<Response<WriteResponse>, Status> { |
641 | 15 | let stream14 = WriteRequestStreamWrapper::from(grpc_request.into_inner()) |
642 | 65 | .await |
643 | 15 | .err_tip(|| "Could not unwrap first stream message"1 ) |
644 | 15 | .map_err(Into::<Status>::into)?1 ; |
645 | | |
646 | 14 | let instance_name = stream.resource_info.instance_name.as_ref(); |
647 | 14 | let store = self |
648 | 14 | .stores |
649 | 14 | .get(instance_name) |
650 | 14 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'")0 )?0 |
651 | 14 | .clone(); |
652 | | |
653 | 14 | let digest = DigestInfo::try_new( |
654 | 14 | &stream.resource_info.hash, |
655 | 14 | stream.resource_info.expected_size, |
656 | 14 | ) |
657 | 14 | .err_tip(|| "Invalid digest input in ByteStream::write"0 )?0 ; |
658 | | |
659 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
660 | 14 | if let Some(grpc_store0 ) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (660:16): [True: 0, False: 14]
Branch (660:16): [Folded - Ignored]
|
661 | 0 | return grpc_store.write(stream).await.map_err(Into::into); |
662 | 14 | } |
663 | | |
664 | 14 | let digest_function = stream |
665 | 14 | .resource_info |
666 | 14 | .digest_function |
667 | 14 | .as_deref() |
668 | 14 | .map_or_else( |
669 | 14 | || Ok(default_digest_hasher_func()), |
670 | 14 | DigestHasherFunc::try_from, |
671 | 14 | )?0 ; |
672 | | |
673 | 14 | make_ctx_for_hash_func(digest_function) |
674 | 14 | .err_tip(|| "In BytestreamServer::write"0 )?0 |
675 | | .wrap_async( |
676 | 14 | error_span!("bytestream_write"), |
677 | 14 | self.inner_write(store, digest, stream), |
678 | | ) |
679 | 101 | .await |
680 | 14 | .err_tip(|| "In ByteStreamServer::write"6 ) |
681 | 14 | .map_err(Into::into) |
682 | 30 | } |
683 | | |
684 | | #[allow(clippy::blocks_in_conditions)] |
685 | 3 | #[instrument( |
686 | | err, |
687 | | ret(level = Level::INFO), |
688 | | level = Level::ERROR, |
689 | | skip_all, |
690 | | fields(request = ?grpc_request.get_ref()) |
691 | 6 | )] |
692 | | async fn query_write_status( |
693 | | &self, |
694 | | grpc_request: Request<QueryWriteStatusRequest>, |
695 | 3 | ) -> Result<Response<QueryWriteStatusResponse>, Status> { |
696 | 3 | self.inner_query_write_status(&grpc_request.into_inner()) |
697 | 0 | .await |
698 | 3 | .err_tip(|| "Failed on query_write_status() command"0 ) |
699 | 3 | .map_err(Into::into) |
700 | 6 | } |
701 | | } |