Coverage Report

Created: 2025-05-30 16:37

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-service/src/bytestream_server.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::convert::Into;
16
use core::fmt::{Debug, Formatter};
17
use core::pin::Pin;
18
use core::sync::atomic::{AtomicU64, Ordering};
19
use core::time::Duration;
20
use std::collections::HashMap;
21
use std::collections::hash_map::Entry;
22
use std::sync::Arc;
23
24
use futures::future::{BoxFuture, pending};
25
use futures::stream::unfold;
26
use futures::{Future, Stream, TryFutureExt, try_join};
27
use nativelink_config::cas_server::ByteStreamConfig;
28
use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err};
29
use nativelink_proto::google::bytestream::byte_stream_server::{
30
    ByteStream, ByteStreamServer as Server,
31
};
32
use nativelink_proto::google::bytestream::{
33
    QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest,
34
    WriteResponse,
35
};
36
use nativelink_store::grpc_store::GrpcStore;
37
use nativelink_store::store_manager::StoreManager;
38
use nativelink_util::buf_channel::{
39
    DropCloserReadHalf, DropCloserWriteHalf, make_buf_channel_pair,
40
};
41
use nativelink_util::common::DigestInfo;
42
use nativelink_util::digest_hasher::{
43
    DigestHasherFunc, default_digest_hasher_func, make_ctx_for_hash_func,
44
};
45
use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper;
46
use nativelink_util::resource_info::ResourceInfo;
47
use nativelink_util::spawn;
48
use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo};
49
use nativelink_util::task::JoinHandleDropGuard;
50
use opentelemetry::context::FutureExt;
51
use parking_lot::Mutex;
52
use tokio::time::sleep;
53
use tonic::{Request, Response, Status, Streaming};
54
use tracing::{Instrument, Level, debug, error, error_span, info, instrument, trace};
55
56
/// If this value changes update the documentation in the config definition.
57
const DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT: Duration = Duration::from_secs(60);
58
59
/// If this value changes update the documentation in the config definition.
60
const DEFAULT_MAX_BYTES_PER_STREAM: usize = 64 * 1024;
61
62
/// If this value changes update the documentation in the config definition.
63
const DEFAULT_MAX_DECODING_MESSAGE_SIZE: usize = 4 * 1024 * 1024;
64
65
type ReadStream = Pin<Box<dyn Stream<Item = Result<ReadResponse, Status>> + Send + 'static>>;
66
type StoreUpdateFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>;
67
68
struct StreamState {
69
    uuid: String,
70
    tx: DropCloserWriteHalf,
71
    store_update_fut: StoreUpdateFuture,
72
}
73
74
impl Debug for StreamState {
75
0
    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
76
0
        f.debug_struct("StreamState")
77
0
            .field("uuid", &self.uuid)
78
0
            .finish()
79
0
    }
80
}
81
82
/// If a stream is in this state, it will automatically be put back into an `IdleStream` and
83
/// placed back into the `active_uploads` map as an `IdleStream` after it is dropped.
84
/// To prevent it from being put back into an `IdleStream` you must call `.graceful_finish()`.
85
struct ActiveStreamGuard<'a> {
86
    stream_state: Option<StreamState>,
87
    bytes_received: Arc<AtomicU64>,
88
    bytestream_server: &'a ByteStreamServer,
89
}
90
91
impl ActiveStreamGuard<'_> {
92
    /// Consumes the guard. The stream will be considered "finished", will
93
    /// remove it from the `active_uploads`.
94
8
    fn graceful_finish(mut self) {
95
8
        let stream_state = self.stream_state.take().unwrap();
96
8
        self.bytestream_server
97
8
            .active_uploads
98
8
            .lock()
99
8
            .remove(&stream_state.uuid);
100
8
    }
101
}
102
103
impl Drop for ActiveStreamGuard<'_> {
104
15
    fn drop(&mut self) {
105
15
        let Some(
stream_state7
) = self.stream_state.take() else {
  Branch (105:13): [True: 7, False: 8]
  Branch (105:13): [Folded - Ignored]
106
8
            return; // If None it means we don't want it put back into an IdleStream.
107
        };
108
7
        let weak_active_uploads = Arc::downgrade(&self.bytestream_server.active_uploads);
109
7
        let mut active_uploads = self.bytestream_server.active_uploads.lock();
110
7
        let uuid = stream_state.uuid.clone();
111
7
        let Some(active_uploads_slot) = active_uploads.get_mut(&uuid) else {
  Branch (111:13): [True: 7, False: 0]
  Branch (111:13): [Folded - Ignored]
112
0
            error!(
113
                err = "Failed to find active upload. This should never happen.",
114
                uuid = ?uuid,
115
            );
116
0
            return;
117
        };
118
7
        let sleep_fn = self.bytestream_server.sleep_fn.clone();
119
7
        active_uploads_slot.1 = Some(IdleStream {
120
7
            stream_state,
121
7
            _timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move 
{3
122
3
                (*sleep_fn)().await;
123
0
                if let Some(active_uploads) = weak_active_uploads.upgrade() {
  Branch (123:24): [True: 0, False: 0]
  Branch (123:24): [Folded - Ignored]
124
0
                    let mut active_uploads = active_uploads.lock();
125
0
                    info!(msg = "Removing idle stream", uuid = ?uuid);
126
0
                    active_uploads.remove(&uuid);
127
0
                }
128
0
            }),
129
        });
130
15
    }
131
}
132
133
/// Represents a stream that is in the "idle" state. this means it is not currently being used
134
/// by a client. If it is not used within a certain amount of time it will be removed from the
135
/// `active_uploads` map automatically.
136
#[derive(Debug)]
137
struct IdleStream {
138
    stream_state: StreamState,
139
    _timeout_streaam_drop_guard: JoinHandleDropGuard<()>,
140
}
141
142
impl IdleStream {
143
3
    fn into_active_stream(
144
3
        self,
145
3
        bytes_received: Arc<AtomicU64>,
146
3
        bytestream_server: &ByteStreamServer,
147
3
    ) -> ActiveStreamGuard<'_> {
148
3
        ActiveStreamGuard {
149
3
            stream_state: Some(self.stream_state),
150
3
            bytes_received,
151
3
            bytestream_server,
152
3
        }
153
3
    }
154
}
155
156
type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>);
157
type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
158
159
pub struct ByteStreamServer {
160
    stores: HashMap<String, Store>,
161
    // Max number of bytes to send on each grpc stream chunk.
162
    max_bytes_per_stream: usize,
163
    max_decoding_message_size: usize,
164
    active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>,
165
    sleep_fn: SleepFn,
166
}
167
168
impl Debug for ByteStreamServer {
169
0
    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
170
0
        f.debug_struct("ByteStreamServer")
171
0
            .field("stores", &self.stores)
172
0
            .field("max_bytes_per_stream", &self.max_bytes_per_stream)
173
0
            .field("max_decoding_message_size", &self.max_decoding_message_size)
174
0
            .field("active_uploads", &self.active_uploads)
175
0
            .finish_non_exhaustive()
176
0
    }
177
}
178
179
impl ByteStreamServer {
180
15
    pub fn new(config: &ByteStreamConfig, store_manager: &StoreManager) -> Result<Self, Error> {
181
15
        let persist_stream_on_disconnect_timeout =
182
15
            if config.persist_stream_on_disconnect_timeout == 0 {
  Branch (182:16): [True: 15, False: 0]
  Branch (182:16): [Folded - Ignored]
183
15
                DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT
184
            } else {
185
0
                Duration::from_secs(config.persist_stream_on_disconnect_timeout as u64)
186
            };
187
15
        Self::new_with_sleep_fn(
188
15
            config,
189
15
            store_manager,
190
15
            Arc::new(move || 
Box::pin3
(
sleep3
(
persist_stream_on_disconnect_timeout3
))),
191
        )
192
15
    }
193
194
15
    pub fn new_with_sleep_fn(
195
15
        config: &ByteStreamConfig,
196
15
        store_manager: &StoreManager,
197
15
        sleep_fn: SleepFn,
198
15
    ) -> Result<Self, Error> {
199
15
        let mut stores = HashMap::with_capacity(config.cas_stores.len());
200
30
        for (
instance_name15
,
store_name15
) in &config.cas_stores {
201
15
            let store = store_manager
202
15
                .get_store(store_name)
203
15
                .ok_or_else(|| make_input_err!("'cas_store': '{}' does not exist", store_name))
?0
;
204
15
            stores.insert(instance_name.to_string(), store);
205
        }
206
15
        let max_bytes_per_stream = if config.max_bytes_per_stream == 0 {
  Branch (206:39): [True: 1, False: 14]
  Branch (206:39): [Folded - Ignored]
207
1
            DEFAULT_MAX_BYTES_PER_STREAM
208
        } else {
209
14
            config.max_bytes_per_stream
210
        };
211
15
        let max_decoding_message_size = if config.max_decoding_message_size == 0 {
  Branch (211:44): [True: 14, False: 1]
  Branch (211:44): [Folded - Ignored]
212
14
            DEFAULT_MAX_DECODING_MESSAGE_SIZE
213
        } else {
214
1
            config.max_decoding_message_size
215
        };
216
15
        Ok(Self {
217
15
            stores,
218
15
            max_bytes_per_stream,
219
15
            max_decoding_message_size,
220
15
            active_uploads: Arc::new(Mutex::new(HashMap::new())),
221
15
            sleep_fn,
222
15
        })
223
15
    }
224
225
1
    pub fn into_service(self) -> Server<Self> {
226
1
        let max_decoding_message_size = self.max_decoding_message_size;
227
1
        Server::new(self).max_decoding_message_size(max_decoding_message_size)
228
1
    }
229
230
15
    fn create_or_join_upload_stream(
231
15
        &self,
232
15
        uuid: String,
233
15
        store: Store,
234
15
        digest: DigestInfo,
235
15
    ) -> Result<ActiveStreamGuard<'_>, Error> {
236
15
        let (
uuid12
,
bytes_received12
) = match self.active_uploads.lock().entry(uuid) {
237
3
            Entry::Occupied(mut entry) => {
238
3
                let maybe_idle_stream = entry.get_mut();
239
3
                let Some(idle_stream) = maybe_idle_stream.1.take() else {
  Branch (239:21): [True: 3, False: 0]
  Branch (239:21): [Folded - Ignored]
240
0
                    return Err(make_input_err!("Cannot upload same UUID simultaneously"));
241
                };
242
3
                let bytes_received = maybe_idle_stream.0.clone();
243
3
                info!(msg = "Joining existing stream", entry = ?entry.key());
244
3
                return Ok(idle_stream.into_active_stream(bytes_received, self));
245
            }
246
12
            Entry::Vacant(entry) => {
247
12
                let bytes_received = Arc::new(AtomicU64::new(0));
248
12
                let uuid = entry.key().clone();
249
                // Our stream is "in use" if the key is in the map, but the value is None.
250
12
                entry.insert((bytes_received.clone(), None));
251
12
                (uuid, bytes_received)
252
            }
253
        };
254
255
        // Important: Do not return an error from this point onwards without
256
        // removing the entry from the map, otherwise that UUID becomes
257
        // unusable.
258
259
12
        let (tx, rx) = make_buf_channel_pair();
260
12
        let store_update_fut = Box::pin(async move 
{8
261
            // We need to wrap `Store::update()` in a another future because we need to capture
262
            // `store` to ensure its lifetime follows the future and not the caller.
263
8
            store
264
8
                // Bytestream always uses digest size as the actual byte size.
265
8
                .update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes()))
266
8
                .await
267
8
        });
268
12
        Ok(ActiveStreamGuard {
269
12
            stream_state: Some(StreamState {
270
12
                uuid,
271
12
                tx,
272
12
                store_update_fut,
273
12
            }),
274
12
            bytes_received,
275
12
            bytestream_server: self,
276
12
        })
277
15
    }
278
279
3
    async fn inner_read(
280
3
        &self,
281
3
        store: Store,
282
3
        digest: DigestInfo,
283
3
        read_request: ReadRequest,
284
3
    ) -> Result<impl Stream<Item = Result<ReadResponse, Status>> + Send + use<>, Error> {
285
        struct ReaderState {
286
            max_bytes_per_stream: usize,
287
            rx: DropCloserReadHalf,
288
            maybe_get_part_result: Option<Result<(), Error>>,
289
            get_part_fut: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
290
        }
291
292
3
        let read_limit = u64::try_from(read_request.read_limit)
293
3
            .err_tip(|| "Could not convert read_limit to u64")
?0
;
294
295
3
        let (tx, rx) = make_buf_channel_pair();
296
297
3
        let read_limit = if read_limit != 0 {
  Branch (297:29): [True: 3, False: 0]
  Branch (297:29): [Folded - Ignored]
298
3
            Some(read_limit)
299
        } else {
300
0
            None
301
        };
302
303
        // This allows us to call a destructor when the the object is dropped.
304
3
        let state = Some(ReaderState {
305
3
            rx,
306
3
            max_bytes_per_stream: self.max_bytes_per_stream,
307
3
            maybe_get_part_result: None,
308
3
            get_part_fut: Box::pin(async move {
309
3
                store
310
3
                    .get_part(
311
3
                        digest,
312
3
                        tx,
313
3
                        u64::try_from(read_request.read_offset)
314
3
                            .err_tip(|| "Could not convert read_offset to u64")
?0
,
315
3
                        read_limit,
316
                    )
317
3
                    .await
318
3
            }),
319
        });
320
321
3
        let read_stream_span = error_span!("read_stream");
322
323
9.77k
        Ok(
Box::pin3
(
unfold3
(
state3
, move |state| {
324
9.77k
            async {
325
9.77k
            let mut state = state
?0
; // If None our stream is done.
326
9.77k
            let mut response = ReadResponse::default();
327
            {
328
9.77k
                let consume_fut = state.rx.consume(Some(state.max_bytes_per_stream));
329
9.77k
                tokio::pin!(consume_fut);
330
                loop {
331
9.77k
                    tokio::select! {
332
9.77k
                        
read_result9.77k
= &mut consume_fut => {
333
9.77k
                            match read_result {
334
9.76k
                                Ok(bytes) => {
335
9.76k
                                    if bytes.is_empty() {
  Branch (335:40): [True: 2, False: 9.76k]
  Branch (335:40): [Folded - Ignored]
336
                                        // EOF.
337
2
                                        return None;
338
9.76k
                                    }
339
9.76k
                                    if bytes.len() > state.max_bytes_per_stream {
  Branch (339:40): [True: 0, False: 9.76k]
  Branch (339:40): [Folded - Ignored]
340
0
                                        let err = make_err!(Code::Internal, "Returned store size was larger than read size");
341
0
                                        return Some((Err(err.into()), None));
342
9.76k
                                    }
343
9.76k
                                    response.data = bytes;
344
9.76k
                                    trace!(response = ?response);
345
9.76k
                                    debug!(response.data = 
format!0
(
"<redacted len({})>"0
,
response.data0
.
len0
()));
346
9.76k
                                    break;
347
                                }
348
1
                                Err(mut e) => {
349
                                    // We may need to propagate the error from reading the data through first.
350
                                    // For example, the NotFound error will come through `get_part_fut`, and
351
                                    // will not be present in `e`, but we need to ensure we pass NotFound error
352
                                    // code or the client won't know why it failed.
353
1
                                    let get_part_result = if let Some(result) = state.maybe_get_part_result {
  Branch (353:66): [True: 1, False: 0]
  Branch (353:66): [Folded - Ignored]
354
1
                                        result
355
                                    } else {
356
                                        // This should never be `future::pending()` if maybe_get_part_result is
357
                                        // not set.
358
0
                                        state.get_part_fut.await
359
                                    };
360
1
                                    if let Err(err) = get_part_result {
  Branch (360:44): [True: 1, False: 0]
  Branch (360:44): [Folded - Ignored]
361
1
                                        e = err.merge(e);
362
1
                                    
}0
363
1
                                    if e.code == Code::NotFound {
  Branch (363:40): [True: 1, False: 0]
  Branch (363:40): [Folded - Ignored]
364
1
                                        // Trim the error code. Not Found is quite common and we don't want to send a large
365
1
                                        // error (debug) message for something that is common. We resize to just the last
366
1
                                        // message as it will be the most relevant.
367
1
                                        e.messages.truncate(1);
368
1
                                    
}0
369
1
                                    error!(response = ?e);
370
1
                                    return Some((Err(e.into()), None))
371
                                }
372
                            }
373
                        },
374
9.77k
                        
result3
= &mut state.get_part_fut => {
375
3
                            state.maybe_get_part_result = Some(result);
376
3
                            // It is non-deterministic on which future will finish in what order.
377
3
                            // It is also possible that the `state.rx.consume()` call above may not be able to
378
3
                            // respond even though the publishing future is done.
379
3
                            // Because of this we set the writing future to pending so it never finishes.
380
3
                            // The `state.rx.consume()` future will eventually finish and return either the
381
3
                            // data or an error.
382
3
                            // An EOF will terminate the `state.rx.consume()` future, but we are also protected
383
3
                            // because we are dropping the writing future, it will drop the `tx` channel
384
3
                            // which will eventually propagate an error to the `state.rx.consume()` future if
385
3
                            // the EOF was not sent due to some other error.
386
3
                            state.get_part_fut = Box::pin(pending());
387
3
                        },
388
                    }
389
                }
390
            }
391
9.76k
            Some((Ok(response), Some(state)))
392
9.77k
        }.instrument(read_stream_span.clone())
393
9.77k
        })))
394
3
    }
395
396
    // We instrument tracing here as well as below because `stream` has a hash on it
397
    // that is extracted from the first stream message. If we only implemented it below
398
    // we would not have the hash available to us.
399
    #[instrument(
400
        ret(level = Level::DEBUG),
401
15
        level = Level::ERROR,
402
        skip(self, store),
403
        fields(stream.first_msg = "<redacted>")
404
    )]
405
    async fn inner_write(
406
        &self,
407
        store: Store,
408
        digest: DigestInfo,
409
        stream: WriteRequestStreamWrapper<impl Stream<Item = Result<WriteRequest, Status>> + Unpin>,
410
    ) -> Result<Response<WriteResponse>, Error> {
411
15
        async fn process_client_stream(
412
15
            mut stream: WriteRequestStreamWrapper<
413
15
                impl Stream<Item = Result<WriteRequest, Status>> + Unpin,
414
15
            >,
415
15
            tx: &mut DropCloserWriteHalf,
416
15
            outer_bytes_received: &Arc<AtomicU64>,
417
15
            expected_size: u64,
418
15
        ) -> Result<(), Error> {
419
            loop {
420
25
                let 
write_request20
= match stream.next().await {
421
                    // Code path for when client tries to gracefully close the stream.
422
                    // If this happens it means there's a problem with the data sent,
423
                    // because we always close the stream from our end before this point
424
                    // by counting the number of bytes sent from the client. If they send
425
                    // less than the amount they said they were going to send and then
426
                    // close the stream, we know there's a problem.
427
                    None => {
428
0
                        return Err(make_input_err!(
429
0
                            "Client closed stream before sending all data"
430
0
                        ));
431
                    }
432
                    // Code path for client stream error. Probably client disconnect.
433
5
                    Some(Err(err)) => return Err(err),
434
                    // Code path for received chunk of data.
435
20
                    Some(Ok(write_request)) => write_request,
436
                };
437
438
20
                if write_request.write_offset < 0 {
  Branch (438:20): [True: 1, False: 19]
  Branch (438:20): [Folded - Ignored]
439
1
                    return Err(make_input_err!(
440
1
                        "Invalid negative write offset in write request: {}",
441
1
                        write_request.write_offset
442
1
                    ));
443
19
                }
444
19
                let write_offset = write_request.write_offset as u64;
445
446
                // If we get duplicate data because a client didn't know where
447
                // it left off from, then we can simply skip it.
448
19
                let 
data18
= if write_offset < tx.get_bytes_written() {
  Branch (448:31): [True: 2, False: 17]
  Branch (448:31): [Folded - Ignored]
449
2
                    if (write_offset + write_request.data.len() as u64) < tx.get_bytes_written() {
  Branch (449:24): [True: 0, False: 2]
  Branch (449:24): [Folded - Ignored]
450
0
                        if write_request.finish_write {
  Branch (450:28): [True: 0, False: 0]
  Branch (450:28): [Folded - Ignored]
451
0
                            return Err(make_input_err!(
452
0
                                "Resumed stream finished at {} bytes when we already received {} bytes.",
453
0
                                write_offset + write_request.data.len() as u64,
454
0
                                tx.get_bytes_written()
455
0
                            ));
456
0
                        }
457
0
                        continue;
458
2
                    }
459
2
                    write_request
460
2
                        .data
461
2
                        .slice((tx.get_bytes_written() - write_offset) as usize..)
462
                } else {
463
17
                    if write_offset != tx.get_bytes_written() {
  Branch (463:24): [True: 1, False: 16]
  Branch (463:24): [Folded - Ignored]
464
1
                        return Err(make_input_err!(
465
1
                            "Received out of order data. Got {}, expected {}",
466
1
                            write_offset,
467
1
                            tx.get_bytes_written()
468
1
                        ));
469
16
                    }
470
16
                    write_request.data
471
                };
472
473
                // Do not process EOF or weird stuff will happen.
474
18
                if !data.is_empty() {
  Branch (474:20): [True: 13, False: 5]
  Branch (474:20): [Folded - Ignored]
475
                    // We also need to process the possible EOF branch, so we can't early return.
476
13
                    if let Err(
mut err0
) = tx.send(data).await {
  Branch (476:28): [True: 0, False: 13]
  Branch (476:28): [Folded - Ignored]
477
0
                        err.code = Code::Internal;
478
0
                        return Err(err);
479
13
                    }
480
13
                    outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release);
481
5
                }
482
483
18
                if expected_size < tx.get_bytes_written() {
  Branch (483:20): [True: 0, False: 18]
  Branch (483:20): [Folded - Ignored]
484
0
                    return Err(make_input_err!("Received more bytes than expected"));
485
18
                }
486
18
                if write_request.finish_write {
  Branch (486:20): [True: 8, False: 10]
  Branch (486:20): [Folded - Ignored]
487
                    // Gracefully close our stream.
488
8
                    tx.send_eof()
489
8
                        .err_tip(|| "Failed to send EOF in ByteStream::write")
?0
;
490
8
                    return Ok(());
491
10
                }
492
                // Continue.
493
            }
494
            // Unreachable.
495
15
        }
496
497
        let uuid = stream
498
            .resource_info
499
            .uuid
500
            .as_ref()
501
            .ok_or_else(|| make_input_err!("UUID must be set if writing data"))?
502
            .to_string();
503
        let mut active_stream_guard = self.create_or_join_upload_stream(uuid, store, digest)?;
504
        let expected_size = stream.resource_info.expected_size as u64;
505
506
        let active_stream = active_stream_guard.stream_state.as_mut().unwrap();
507
        try_join!(
508
            process_client_stream(
509
                stream,
510
                &mut active_stream.tx,
511
                &active_stream_guard.bytes_received,
512
                expected_size
513
            ),
514
            (&mut active_stream.store_update_fut)
515
0
                .map_err(|err| { err.append("Error updating inner store") })
516
        )?;
517
518
        // Close our guard and consider the stream no longer active.
519
        active_stream_guard.graceful_finish();
520
521
        Ok(Response::new(WriteResponse {
522
            committed_size: expected_size as i64,
523
        }))
524
    }
525
526
3
    async fn inner_query_write_status(
527
3
        &self,
528
3
        query_request: &QueryWriteStatusRequest,
529
3
    ) -> Result<Response<QueryWriteStatusResponse>, Error> {
530
3
        let mut resource_info = ResourceInfo::new(&query_request.resource_name, true)
?0
;
531
532
3
        let store_clone = self
533
3
            .stores
534
3
            .get(resource_info.instance_name.as_ref())
535
3
            .err_tip(|| 
{0
536
0
                format!(
537
0
                    "'instance_name' not configured for '{}'",
538
0
                    &resource_info.instance_name
539
                )
540
0
            })?
541
3
            .clone();
542
543
3
        let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)
?0
;
544
545
        // If we are a GrpcStore we shortcut here, as this is a special store.
546
3
        if let Some(
grpc_store0
) = store_clone.downcast_ref::<GrpcStore>(Some(digest.into())) {
  Branch (546:16): [True: 0, False: 3]
  Branch (546:16): [Folded - Ignored]
547
0
            return grpc_store
548
0
                .query_write_status(Request::new(query_request.clone()))
549
0
                .await;
550
3
        }
551
552
3
        let uuid = resource_info
553
3
            .uuid
554
3
            .take()
555
3
            .ok_or_else(|| make_input_err!("UUID must be set if querying write status"))
?0
;
556
557
        {
558
3
            let active_uploads = self.active_uploads.lock();
559
3
            if let Some((
received_bytes1
,
_maybe_idle_stream1
)) = active_uploads.get(uuid.as_ref()) {
  Branch (559:20): [True: 1, False: 2]
  Branch (559:20): [Folded - Ignored]
560
1
                return Ok(Response::new(QueryWriteStatusResponse {
561
1
                    committed_size: received_bytes.load(Ordering::Acquire) as i64,
562
1
                    // If we are in the active_uploads map, but the value is None,
563
1
                    // it means the stream is not complete.
564
1
                    complete: false,
565
1
                }));
566
2
            }
567
        }
568
569
2
        let has_fut = store_clone.has(digest);
570
2
        let Some(
item_size1
) = has_fut.await.err_tip(|| "Failed to call .has() on store")
?0
else {
  Branch (570:13): [True: 1, False: 1]
  Branch (570:13): [Folded - Ignored]
571
            // We lie here and say that the stream needs to start over, even though
572
            // it was never started. This can happen when the client disconnects
573
            // before sending the first payload, but the client thinks it did send
574
            // the payload.
575
1
            return Ok(Response::new(QueryWriteStatusResponse {
576
1
                committed_size: 0,
577
1
                complete: false,
578
1
            }));
579
        };
580
1
        Ok(Response::new(QueryWriteStatusResponse {
581
1
            committed_size: item_size as i64,
582
1
            complete: true,
583
1
        }))
584
3
    }
585
}
586
587
#[tonic::async_trait]
588
impl ByteStream for ByteStreamServer {
589
    type ReadStream = ReadStream;
590
591
    #[instrument(
592
        err,
593
        level = Level::ERROR,
594
        skip_all,
595
        fields(request = ?grpc_request.get_ref())
596
    )]
597
    async fn read(
598
        &self,
599
        grpc_request: Request<ReadRequest>,
600
6
    ) -> Result<Response<Self::ReadStream>, Status> {
601
3
        let read_request = grpc_request.into_inner();
602
3
        let resource_info = ResourceInfo::new(&read_request.resource_name, false)
?0
;
603
3
        let instance_name = resource_info.instance_name.as_ref();
604
3
        let store = self
605
3
            .stores
606
3
            .get(instance_name)
607
3
            .err_tip(|| format!(
"'instance_name' not configured for '{instance_name}'"0
))
?0
608
3
            .clone();
609
610
3
        let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)
?0
;
611
612
        // If we are a GrpcStore we shortcut here, as this is a special store.
613
3
        if let Some(
grpc_store0
) = store.downcast_ref::<GrpcStore>(Some(digest.into())) {
  Branch (613:16): [True: 0, False: 3]
  Branch (613:16): [Folded - Ignored]
614
0
            let stream = Box::pin(grpc_store.read(Request::new(read_request)).await?);
615
0
            return Ok(Response::new(stream));
616
3
        }
617
618
3
        let digest_function = resource_info.digest_function.as_deref().map_or_else(
619
3
            || Ok(default_digest_hasher_func()),
620
            DigestHasherFunc::try_from,
621
0
        )?;
622
623
3
        let resp = self
624
3
            .inner_read(store, digest, read_request)
625
3
            .instrument(error_span!("bytestream_read"))
626
3
            .with_context(
627
3
                make_ctx_for_hash_func(digest_function).err_tip(|| "In BytestreamServer::read")
?0
,
628
            )
629
3
            .await
630
3
            .err_tip(|| "In ByteStreamServer::read")
631
3
            .map(|stream| -> Response<Self::ReadStream> { Response::new(Box::pin(stream)) })
632
3
            .map_err(Into::into);
633
634
3
        if resp.is_ok() {
  Branch (634:12): [True: 3, False: 0]
  Branch (634:12): [Folded - Ignored]
635
3
            debug!(return = "Ok(<stream>)");
636
0
        }
637
638
3
        resp
639
6
    }
640
641
    #[instrument(
642
        err,
643
        level = Level::ERROR,
644
        skip_all,
645
        fields(request = ?grpc_request.get_ref())
646
    )]
647
    async fn write(
648
        &self,
649
        grpc_request: Request<Streaming<WriteRequest>>,
650
32
    ) -> Result<Response<WriteResponse>, Status> {
651
16
        let request = grpc_request.into_inner();
652
16
        let 
stream15
= WriteRequestStreamWrapper::from(request)
653
16
            .await
654
16
            .err_tip(|| "Could not unwrap first stream message")
655
16
            .map_err(Into::<Status>::into)
?1
;
656
657
15
        let instance_name = stream.resource_info.instance_name.as_ref();
658
15
        let store = self
659
15
            .stores
660
15
            .get(instance_name)
661
15
            .err_tip(|| format!(
"'instance_name' not configured for '{instance_name}'"0
))
?0
662
15
            .clone();
663
664
15
        let digest = DigestInfo::try_new(
665
15
            &stream.resource_info.hash,
666
15
            stream.resource_info.expected_size,
667
        )
668
15
        .err_tip(|| "Invalid digest input in ByteStream::write")
?0
;
669
670
        // If we are a GrpcStore we shortcut here, as this is a special store.
671
15
        if let Some(
grpc_store0
) = store.downcast_ref::<GrpcStore>(Some(digest.into())) {
  Branch (671:16): [True: 0, False: 15]
  Branch (671:16): [Folded - Ignored]
672
0
            let resp = grpc_store.write(stream).await.map_err(Into::into);
673
0
            return resp;
674
15
        }
675
676
15
        let digest_function = stream
677
15
            .resource_info
678
15
            .digest_function
679
15
            .as_deref()
680
15
            .map_or_else(
681
15
                || Ok(default_digest_hasher_func()),
682
                DigestHasherFunc::try_from,
683
0
            )?;
684
685
15
        self.inner_write(store, digest, stream)
686
15
            .instrument(error_span!("bytestream_write"))
687
15
            .with_context(
688
15
                make_ctx_for_hash_func(digest_function).err_tip(|| "In BytestreamServer::write")
?0
,
689
            )
690
15
            .await
691
15
            .err_tip(|| "In ByteStreamServer::write")
692
15
            .map_err(Into::into)
693
32
    }
694
695
    #[instrument(
696
        err,
697
        ret(level = Level::INFO),
698
        level = Level::ERROR,
699
        skip_all,
700
        fields(request = ?grpc_request.get_ref())
701
    )]
702
    async fn query_write_status(
703
        &self,
704
        grpc_request: Request<QueryWriteStatusRequest>,
705
6
    ) -> Result<Response<QueryWriteStatusResponse>, Status> {
706
3
        let request = grpc_request.into_inner();
707
3
        self.inner_query_write_status(&request)
708
3
            .await
709
3
            .err_tip(|| "Failed on query_write_status() command")
710
3
            .map_err(Into::into)
711
6
    }
712
}