Coverage Report

Created: 2024-11-20 10:13

/build/source/nativelink-service/src/bytestream_server.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::hash_map::Entry;
16
use std::collections::HashMap;
17
use std::convert::Into;
18
use std::fmt::{Debug, Formatter};
19
use std::pin::Pin;
20
use std::sync::atomic::{AtomicU64, Ordering};
21
use std::sync::Arc;
22
use std::time::Duration;
23
24
use futures::future::{pending, BoxFuture};
25
use futures::stream::unfold;
26
use futures::{try_join, Future, Stream, TryFutureExt};
27
use nativelink_config::cas_server::ByteStreamConfig;
28
use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt};
29
use nativelink_proto::google::bytestream::byte_stream_server::{
30
    ByteStream, ByteStreamServer as Server,
31
};
32
use nativelink_proto::google::bytestream::{
33
    QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest,
34
    WriteResponse,
35
};
36
use nativelink_store::grpc_store::GrpcStore;
37
use nativelink_store::store_manager::StoreManager;
38
use nativelink_util::buf_channel::{
39
    make_buf_channel_pair, DropCloserReadHalf, DropCloserWriteHalf,
40
};
41
use nativelink_util::common::DigestInfo;
42
use nativelink_util::digest_hasher::{
43
    default_digest_hasher_func, make_ctx_for_hash_func, DigestHasherFunc,
44
};
45
use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper;
46
use nativelink_util::resource_info::ResourceInfo;
47
use nativelink_util::spawn;
48
use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo};
49
use nativelink_util::task::JoinHandleDropGuard;
50
use parking_lot::Mutex;
51
use tokio::time::sleep;
52
use tonic::{Request, Response, Status, Streaming};
53
use tracing::{enabled, error_span, event, instrument, Instrument, Level};
54
55
/// If this value changes update the documentation in the config definition.
56
const DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT: Duration = Duration::from_secs(60);
57
58
/// If this value changes update the documentation in the config definition.
59
const DEFAULT_MAX_BYTES_PER_STREAM: usize = 64 * 1024;
60
61
/// If this value changes update the documentation in the config definition.
62
const DEFAULT_MAX_DECODING_MESSAGE_SIZE: usize = 4 * 1024 * 1024;
63
64
type ReadStream = Pin<Box<dyn Stream<Item = Result<ReadResponse, Status>> + Send + 'static>>;
65
type StoreUpdateFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>;
66
67
struct StreamState {
68
    uuid: String,
69
    tx: DropCloserWriteHalf,
70
    store_update_fut: StoreUpdateFuture,
71
}
72
73
impl Debug for StreamState {
74
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
75
0
        f.debug_struct("StreamState")
76
0
            .field("uuid", &self.uuid)
77
0
            .finish()
78
0
    }
79
}
80
81
/// If a stream is in this state, it will automatically be put back into an `IdleStream` and
82
/// placed back into the `active_uploads` map as an `IdleStream` after it is dropped.
83
/// To prevent it from being put back into an `IdleStream` you must call `.graceful_finish()`.
84
struct ActiveStreamGuard<'a> {
85
    stream_state: Option<StreamState>,
86
    bytes_received: Arc<AtomicU64>,
87
    bytestream_server: &'a ByteStreamServer,
88
}
89
90
impl<'a> ActiveStreamGuard<'a> {
91
    /// Consumes the guard. The stream will be considered "finished", will
92
    /// remove it from the active_uploads.
93
8
    fn graceful_finish(mut self) {
94
8
        let stream_state = self.stream_state.take().unwrap();
95
8
        self.bytestream_server
96
8
            .active_uploads
97
8
            .lock()
98
8
            .remove(&stream_state.uuid);
99
8
    }
100
}
101
102
impl<'a> Drop for ActiveStreamGuard<'a> {
103
14
    fn drop(&mut self) {
104
14
        let Some(
stream_state6
) = self.stream_state.take() else {
  Branch (104:13): [True: 6, False: 8]
  Branch (104:13): [Folded - Ignored]
105
8
            return; // If None it means we don't want it put back into an IdleStream.
106
        };
107
6
        let weak_active_uploads = Arc::downgrade(&self.bytestream_server.active_uploads);
108
6
        let mut active_uploads = self.bytestream_server.active_uploads.lock();
109
6
        let uuid = stream_state.uuid.clone();
110
6
        let Some(active_uploads_slot) = active_uploads.get_mut(&uuid) else {
  Branch (110:13): [True: 6, False: 0]
  Branch (110:13): [Folded - Ignored]
111
0
            event!(
112
0
                Level::ERROR,
113
                err = "Failed to find active upload. This should never happen.",
114
                uuid = ?uuid,
115
            );
116
0
            return;
117
        };
118
6
        let sleep_fn = self.bytestream_server.sleep_fn.clone();
119
6
        active_uploads_slot.1 = Some(IdleStream {
120
6
            stream_state,
121
6
            _timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move {
122
3
                (*sleep_fn)().
await0
;
123
0
                if let Some(active_uploads) = weak_active_uploads.upgrade() {
  Branch (123:24): [True: 0, False: 0]
  Branch (123:24): [Folded - Ignored]
124
0
                    let mut active_uploads = active_uploads.lock();
125
0
                    event!(Level::INFO, msg = "Removing idle stream", uuid = ?uuid);
126
0
                    active_uploads.remove(&uuid);
127
0
                }
128
6
            
}0
),
129
        });
130
14
    }
131
}
132
133
/// Represents a stream that is in the "idle" state. this means it is not currently being used
134
/// by a client. If it is not used within a certain amount of time it will be removed from the
135
/// `active_uploads` map automatically.
136
#[derive(Debug)]
137
struct IdleStream {
138
    stream_state: StreamState,
139
    _timeout_streaam_drop_guard: JoinHandleDropGuard<()>,
140
}
141
142
impl IdleStream {
143
3
    fn into_active_stream(
144
3
        self,
145
3
        bytes_received: Arc<AtomicU64>,
146
3
        bytestream_server: &ByteStreamServer,
147
3
    ) -> ActiveStreamGuard<'_> {
148
3
        ActiveStreamGuard {
149
3
            stream_state: Some(self.stream_state),
150
3
            bytes_received,
151
3
            bytestream_server,
152
3
        }
153
3
    }
154
}
155
156
type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>);
157
type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
158
159
pub struct ByteStreamServer {
160
    stores: HashMap<String, Store>,
161
    // Max number of bytes to send on each grpc stream chunk.
162
    max_bytes_per_stream: usize,
163
    max_decoding_message_size: usize,
164
    active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>,
165
    sleep_fn: SleepFn,
166
}
167
168
impl ByteStreamServer {
169
14
    pub fn new(config: &ByteStreamConfig, store_manager: &StoreManager) -> Result<Self, Error> {
170
14
        let mut persist_stream_on_disconnect_timeout =
171
14
            Duration::from_secs(config.persist_stream_on_disconnect_timeout as u64);
172
14
        if config.persist_stream_on_disconnect_timeout == 0 {
  Branch (172:12): [True: 14, False: 0]
  Branch (172:12): [Folded - Ignored]
173
14
            persist_stream_on_disconnect_timeout = DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT;
174
14
        }
0
175
14
        Self::new_with_sleep_fn(
176
14
            config,
177
14
            store_manager,
178
14
            Arc::new(move || 
Box::pin(sleep(persist_stream_on_disconnect_timeout))3
),
179
14
        )
180
14
    }
181
182
14
    pub fn new_with_sleep_fn(
183
14
        config: &ByteStreamConfig,
184
14
        store_manager: &StoreManager,
185
14
        sleep_fn: SleepFn,
186
14
    ) -> Result<Self, Error> {
187
14
        let mut stores = HashMap::with_capacity(config.cas_stores.len());
188
28
        for (
instance_name, store_name14
) in &config.cas_stores {
189
14
            let store = store_manager
190
14
                .get_store(store_name)
191
14
                .ok_or_else(|| 
make_input_err!("'cas_store': '{}' does not exist", store_name)0
)
?0
;
192
14
            stores.insert(instance_name.to_string(), store);
193
        }
194
14
        let max_bytes_per_stream = if config.max_bytes_per_stream == 0 {
  Branch (194:39): [True: 1, False: 13]
  Branch (194:39): [Folded - Ignored]
195
1
            DEFAULT_MAX_BYTES_PER_STREAM
196
        } else {
197
13
            config.max_bytes_per_stream
198
        };
199
14
        let max_decoding_message_size = if config.max_decoding_message_size == 0 {
  Branch (199:44): [True: 13, False: 1]
  Branch (199:44): [Folded - Ignored]
200
13
            DEFAULT_MAX_DECODING_MESSAGE_SIZE
201
        } else {
202
1
            config.max_decoding_message_size
203
        };
204
14
        Ok(ByteStreamServer {
205
14
            stores,
206
14
            max_bytes_per_stream,
207
14
            max_decoding_message_size,
208
14
            active_uploads: Arc::new(Mutex::new(HashMap::new())),
209
14
            sleep_fn,
210
14
        })
211
14
    }
212
213
1
    pub fn into_service(self) -> Server<Self> {
214
1
        let max_decoding_message_size = self.max_decoding_message_size;
215
1
        Server::new(self).max_decoding_message_size(max_decoding_message_size)
216
1
    }
217
218
14
    fn create_or_join_upload_stream(
219
14
        &self,
220
14
        uuid: String,
221
14
        store: Store,
222
14
        digest: DigestInfo,
223
14
    ) -> Result<ActiveStreamGuard<'_>, Error> {
224
14
        let (
uuid, bytes_received11
) = match self.active_uploads.lock().entry(uuid) {
225
3
            Entry::Occupied(mut entry) => {
226
3
                let maybe_idle_stream = entry.get_mut();
227
3
                let Some(idle_stream) = maybe_idle_stream.1.take() else {
  Branch (227:21): [True: 3, False: 0]
  Branch (227:21): [Folded - Ignored]
228
0
                    return Err(make_input_err!("Cannot upload same UUID simultaneously"));
229
                };
230
3
                let bytes_received = maybe_idle_stream.0.clone();
231
3
                event!(Level::INFO, msg = "Joining existing stream", entry = ?
entry.key()0
);
232
3
                return Ok(idle_stream.into_active_stream(bytes_received, self));
233
            }
234
11
            Entry::Vacant(entry) => {
235
11
                let bytes_received = Arc::new(AtomicU64::new(0));
236
11
                let uuid = entry.key().clone();
237
11
                // Our stream is "in use" if the key is in the map, but the value is None.
238
11
                entry.insert((bytes_received.clone(), None));
239
11
                (uuid, bytes_received)
240
11
            }
241
11
        };
242
11
243
11
        // Important: Do not return an error from this point onwards without
244
11
        // removing the entry from the map, otherwise that UUID becomes
245
11
        // unusable.
246
11
247
11
        let (tx, rx) = make_buf_channel_pair();
248
11
        let store_update_fut = Box::pin(async move {
249
8
            // We need to wrap `Store::update()` in a another future because we need to capture
250
8
            // `store` to ensure its lifetime follows the future and not the caller.
251
8
            store
252
8
                // Bytestream always uses digest size as the actual byte size.
253
8
                .update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes()))
254
101
                .await
255
11
        
}8
);
256
11
        Ok(ActiveStreamGuard {
257
11
            stream_state: Some(StreamState {
258
11
                uuid,
259
11
                tx,
260
11
                store_update_fut,
261
11
            }),
262
11
            bytes_received,
263
11
            bytestream_server: self,
264
11
        })
265
14
    }
266
267
3
    async fn inner_read(
268
3
        &self,
269
3
        store: Store,
270
3
        digest: DigestInfo,
271
3
        read_request: ReadRequest,
272
3
    ) -> Result<Response<ReadStream>, Error> {
273
        struct ReaderState {
274
            max_bytes_per_stream: usize,
275
            rx: DropCloserReadHalf,
276
            maybe_get_part_result: Option<Result<(), Error>>,
277
            get_part_fut: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
278
        }
279
280
3
        let read_limit = u64::try_from(read_request.read_limit)
281
3
            .err_tip(|| 
"Could not convert read_limit to u64"0
)
?0
;
282
283
3
        let (tx, rx) = make_buf_channel_pair();
284
285
3
        let read_limit = if read_limit != 0 {
  Branch (285:29): [True: 3, False: 0]
  Branch (285:29): [Folded - Ignored]
286
3
            Some(read_limit)
287
        } else {
288
0
            None
289
        };
290
291
        // This allows us to call a destructor when the the object is dropped.
292
3
        let state = Some(ReaderState {
293
3
            rx,
294
3
            max_bytes_per_stream: self.max_bytes_per_stream,
295
3
            maybe_get_part_result: None,
296
3
            get_part_fut: Box::pin(async move {
297
3
                store
298
3
                    .get_part(
299
3
                        digest,
300
3
                        tx,
301
3
                        u64::try_from(read_request.read_offset)
302
3
                            .err_tip(|| 
"Could not convert read_offset to u64"0
)
?0
,
303
3
                        read_limit,
304
                    )
305
0
                    .await
306
3
            }),
307
3
        });
308
309
3
        let read_stream_span = error_span!("read_stream");
310
311
9.77k
        
Ok(Response::new(Box::pin(unfold(state, 3
move |state| {
312
9.77k
            async {
313
9.77k
            let 
mut state9.77k
= state
?2
; // If None our stream is done.
314
9.77k
            let mut response = ReadResponse::default();
315
9.77k
            {
316
9.77k
                let consume_fut = state.rx.consume(Some(state.max_bytes_per_stream));
317
9.77k
                tokio::pin!(consume_fut);
318
                loop {
319
9.77k
                    tokio::select! {
320
9.77k
                        
read_result9.77k
= &mut consume_fut => {
321
9.77k
                            match read_result {
322
9.76k
                                Ok(bytes) => {
323
9.76k
                                    if bytes.is_empty() {
  Branch (323:40): [True: 2, False: 9.76k]
  Branch (323:40): [Folded - Ignored]
324
                                        // EOF.
325
2
                                        return Some((Ok(response), None));
326
9.76k
                                    }
327
9.76k
                                    if bytes.len() > state.max_bytes_per_stream {
  Branch (327:40): [True: 0, False: 9.76k]
  Branch (327:40): [Folded - Ignored]
328
0
                                        let err = make_err!(Code::Internal, "Returned store size was larger than read size");
329
0
                                        return Some((Err(err.into()), None));
330
9.76k
                                    }
331
9.76k
                                    response.data = bytes;
332
9.76k
                                    if enabled!(Level::DEBUG) {
333
0
                                        event!(Level::INFO, response = ?response);
334
                                    } else {
335
9.76k
                                        event!(Level::INFO, response.data = 
format!("<redacted len({})>", response.data.len())0
);
336
                                    }
337
9.76k
                                    break;
338
                                }
339
1
                                Err(mut e) => {
340
                                    // We may need to propagate the error from reading the data through first.
341
                                    // For example, the NotFound error will come through `get_part_fut`, and
342
                                    // will not be present in `e`, but we need to ensure we pass NotFound error
343
                                    // code or the client won't know why it failed.
344
1
                                    let get_part_result = if let Some(result) = state.maybe_get_part_result {
  Branch (344:66): [True: 1, False: 0]
  Branch (344:66): [Folded - Ignored]
345
1
                                        result
346
                                    } else {
347
                                        // This should never be `future::pending()` if maybe_get_part_result is
348
                                        // not set.
349
0
                                        state.get_part_fut.await
350
                                    };
351
1
                                    if let Err(err) = get_part_result {
  Branch (351:44): [True: 1, False: 0]
  Branch (351:44): [Folded - Ignored]
352
1
                                        e = err.merge(e);
353
1
                                    }
0
354
1
                                    if e.code == Code::NotFound {
  Branch (354:40): [True: 1, False: 0]
  Branch (354:40): [Folded - Ignored]
355
1
                                        // Trim the error code. Not Found is quite common and we don't want to send a large
356
1
                                        // error (debug) message for something that is common. We resize to just the last
357
1
                                        // message as it will be the most relevant.
358
1
                                        e.messages.truncate(1);
359
1
                                    }
0
360
1
                                    event!(Level::ERROR, response = ?e);
361
1
                                    return Some((Err(e.into()), None))
362
                                }
363
                            }
364
                        },
365
9.77k
                        
result3
= &mut state.get_part_fut => {
366
3
                            state.maybe_get_part_result = Some(result);
367
3
                            // It is non-deterministic on which future will finish in what order.
368
3
                            // It is also possible that the `state.rx.consume()` call above may not be able to
369
3
                            // respond even though the publishing future is done.
370
3
                            // Because of this we set the writing future to pending so it never finishes.
371
3
                            // The `state.rx.consume()` future will eventually finish and return either the
372
3
                            // data or an error.
373
3
                            // An EOF will terminate the `state.rx.consume()` future, but we are also protected
374
3
                            // because we are dropping the writing future, it will drop the `tx` channel
375
3
                            // which will eventually propagate an error to the `state.rx.consume()` future if
376
3
                            // the EOF was not sent due to some other error.
377
3
                            state.get_part_fut = Box::pin(pending());
378
3
                        },
379
                    }
380
                }
381
            }
382
9.76k
            Some((Ok(response), Some(state)))
383
9.77k
        }.instrument(read_stream_span.clone())
384
9.77k
        }
))))3
385
3
    }
386
387
    // We instrument tracing here as well as below because `stream` has a hash on it
388
    // that is extracted from the first stream message. If we only implemented it below
389
    // we would not have the hash available to us.
390
28
    #[instrument(
391
        ret(level = Level::INFO),
392
14
        level = Level::ERROR,
393
        skip(self, store),
394
        fields(stream.first_msg = "<redacted>")
395
28
    )]
396
    async fn inner_write(
397
        &self,
398
        store: Store,
399
        digest: DigestInfo,
400
        stream: WriteRequestStreamWrapper<Streaming<WriteRequest>, Status>,
401
    ) -> Result<Response<WriteResponse>, Error> {
402
14
        async fn process_client_stream(
403
14
            mut stream: WriteRequestStreamWrapper<Streaming<WriteRequest>, Status>,
404
14
            tx: &mut DropCloserWriteHalf,
405
14
            outer_bytes_received: &Arc<AtomicU64>,
406
14
            expected_size: u64,
407
14
        ) -> Result<(), Error> {
408
            loop {
409
101
                let 
write_request20
= match
stream.next()24
.await {
410
                    // Code path for when client tries to gracefully close the stream.
411
                    // If this happens it means there's a problem with the data sent,
412
                    // because we always close the stream from our end before this point
413
                    // by counting the number of bytes sent from the client. If they send
414
                    // less than the amount they said they were going to send and then
415
                    // close the stream, we know there's a problem.
416
                    None => {
417
0
                        return Err(make_input_err!(
418
0
                            "Client closed stream before sending all data"
419
0
                        ))
420
                    }
421
                    // Code path for client stream error. Probably client disconnect.
422
4
                    Some(Err(err)) => return Err(err),
423
                    // Code path for received chunk of data.
424
20
                    Some(Ok(write_request)) => write_request,
425
20
                };
426
20
427
20
                if write_request.write_offset < 0 {
  Branch (427:20): [True: 1, False: 19]
  Branch (427:20): [Folded - Ignored]
428
1
                    return Err(make_input_err!(
429
1
                        "Invalid negative write offset in write request: {}",
430
1
                        write_request.write_offset
431
1
                    ));
432
19
                }
433
19
                let write_offset = write_request.write_offset as u64;
434
435
                // If we get duplicate data because a client didn't know where
436
                // it left off from, then we can simply skip it.
437
19
                let 
data18
= if write_offset < tx.get_bytes_written() {
  Branch (437:31): [True: 2, False: 17]
  Branch (437:31): [Folded - Ignored]
438
2
                    if (write_offset + write_request.data.len() as u64) < tx.get_bytes_written() {
  Branch (438:24): [True: 0, False: 2]
  Branch (438:24): [Folded - Ignored]
439
0
                        if write_request.finish_write {
  Branch (439:28): [True: 0, False: 0]
  Branch (439:28): [Folded - Ignored]
440
0
                            return Err(make_input_err!(
441
0
                                "Resumed stream finished at {} bytes when we already received {} bytes.",
442
0
                                write_offset + write_request.data.len() as u64,
443
0
                                tx.get_bytes_written()
444
0
                            ));
445
0
                        }
446
0
                        continue;
447
2
                    }
448
2
                    write_request
449
2
                        .data
450
2
                        .slice((tx.get_bytes_written() - write_offset) as usize..)
451
                } else {
452
17
                    if write_offset != tx.get_bytes_written() {
  Branch (452:24): [True: 1, False: 16]
  Branch (452:24): [Folded - Ignored]
453
1
                        return Err(make_input_err!(
454
1
                            "Received out of order data. Got {}, expected {}",
455
1
                            write_offset,
456
1
                            tx.get_bytes_written()
457
1
                        ));
458
16
                    }
459
16
                    write_request.data
460
                };
461
462
                // Do not process EOF or weird stuff will happen.
463
18
                if !data.is_empty() {
  Branch (463:20): [True: 13, False: 5]
  Branch (463:20): [Folded - Ignored]
464
                    // We also need to process the possible EOF branch, so we can't early return.
465
13
                    if let Err(
mut err0
) = tx.send(data).
await0
{
  Branch (465:28): [True: 0, False: 13]
  Branch (465:28): [Folded - Ignored]
466
0
                        err.code = Code::Internal;
467
0
                        return Err(err);
468
13
                    }
469
13
                    outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release);
470
5
                }
471
472
18
                if expected_size < tx.get_bytes_written() {
  Branch (472:20): [True: 0, False: 18]
  Branch (472:20): [Folded - Ignored]
473
0
                    return Err(make_input_err!("Received more bytes than expected"));
474
18
                }
475
18
                if write_request.finish_write {
  Branch (475:20): [True: 8, False: 10]
  Branch (475:20): [Folded - Ignored]
476
                    // Gracefully close our stream.
477
8
                    tx.send_eof()
478
8
                        .err_tip(|| 
"Failed to send EOF in ByteStream::write"0
)
?0
;
479
8
                    return Ok(());
480
10
                }
481
                // Continue.
482
            }
483
            // Unreachable.
484
14
        }
485
486
        let uuid = stream
487
            .resource_info
488
            .uuid
489
            .as_ref()
490
0
            .ok_or_else(|| make_input_err!("UUID must be set if writing data"))?
491
            .to_string();
492
        let mut active_stream_guard = self.create_or_join_upload_stream(uuid, store, digest)?;
493
        let expected_size = stream.resource_info.expected_size as u64;
494
495
        let active_stream = active_stream_guard.stream_state.as_mut().unwrap();
496
        try_join!(
497
            process_client_stream(
498
                stream,
499
                &mut active_stream.tx,
500
                &active_stream_guard.bytes_received,
501
                expected_size
502
            ),
503
            (&mut active_stream.store_update_fut)
504
0
                .map_err(|err| { err.append("Error updating inner store") })
505
        )?;
506
507
        // Close our guard and consider the stream no longer active.
508
        active_stream_guard.graceful_finish();
509
510
        Ok(Response::new(WriteResponse {
511
            committed_size: expected_size as i64,
512
        }))
513
    }
514
515
3
    async fn inner_query_write_status(
516
3
        &self,
517
3
        query_request: &QueryWriteStatusRequest,
518
3
    ) -> Result<Response<QueryWriteStatusResponse>, Error> {
519
3
        let mut resource_info = ResourceInfo::new(&query_request.resource_name, true)
?0
;
520
521
3
        let store_clone = self
522
3
            .stores
523
3
            .get(resource_info.instance_name.as_ref())
524
3
            .err_tip(|| {
525
0
                format!(
526
0
                    "'instance_name' not configured for '{}'",
527
0
                    &resource_info.instance_name
528
0
                )
529
3
            })
?0
530
3
            .clone();
531
532
3
        let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)
?0
;
533
534
        // If we are a GrpcStore we shortcut here, as this is a special store.
535
3
        if let Some(
grpc_store0
) = store_clone.downcast_ref::<GrpcStore>(Some(digest.into())) {
  Branch (535:16): [True: 0, False: 3]
  Branch (535:16): [Folded - Ignored]
536
0
            return grpc_store
537
0
                .query_write_status(Request::new(query_request.clone()))
538
0
                .await;
539
3
        }
540
541
3
        let uuid = resource_info
542
3
            .uuid
543
3
            .take()
544
3
            .ok_or_else(|| 
make_input_err!("UUID must be set if querying write status")0
)
?0
;
545
546
        {
547
3
            let active_uploads = self.active_uploads.lock();
548
3
            if let Some((
received_bytes, _maybe_idle_stream1
)) = active_uploads.get(uuid.as_ref()) {
  Branch (548:20): [True: 1, False: 2]
  Branch (548:20): [Folded - Ignored]
549
1
                return Ok(Response::new(QueryWriteStatusResponse {
550
1
                    committed_size: received_bytes.load(Ordering::Acquire) as i64,
551
1
                    // If we are in the active_uploads map, but the value is None,
552
1
                    // it means the stream is not complete.
553
1
                    complete: false,
554
1
                }));
555
2
            }
556
2
        }
557
2
558
2
        let has_fut = store_clone.has(digest);
559
2
        let Some(
item_size1
) = has_fut.
await0
.err_tip(||
"Failed to call .has() on store"0
)
?0
else {
  Branch (559:13): [True: 1, False: 1]
  Branch (559:13): [Folded - Ignored]
560
            // We lie here and say that the stream needs to start over, even though
561
            // it was never started. This can happen when the client disconnects
562
            // before sending the first payload, but the client thinks it did send
563
            // the payload.
564
1
            return Ok(Response::new(QueryWriteStatusResponse {
565
1
                committed_size: 0,
566
1
                complete: false,
567
1
            }));
568
        };
569
1
        Ok(Response::new(QueryWriteStatusResponse {
570
1
            committed_size: item_size as i64,
571
1
            complete: true,
572
1
        }))
573
3
    }
574
}
575
576
#[tonic::async_trait]
577
impl ByteStream for ByteStreamServer {
578
    type ReadStream = ReadStream;
579
580
    #[allow(clippy::blocks_in_conditions)]
581
3
    #[instrument(
582
        err,
583
        level = Level::ERROR,
584
        skip_all,
585
        fields(request = ?grpc_request.get_ref())
586
6
    )]
587
    async fn read(
588
        &self,
589
        grpc_request: Request<ReadRequest>,
590
3
    ) -> Result<Response<Self::ReadStream>, Status> {
591
3
        let read_request = grpc_request.into_inner();
592
593
3
        let resource_info = ResourceInfo::new(&read_request.resource_name, false)
?0
;
594
3
        let instance_name = resource_info.instance_name.as_ref();
595
3
        let store = self
596
3
            .stores
597
3
            .get(instance_name)
598
3
            .err_tip(|| 
format!("'instance_name' not configured for '{instance_name}'")0
)
?0
599
3
            .clone();
600
601
3
        let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)
?0
;
602
603
        // If we are a GrpcStore we shortcut here, as this is a special store.
604
3
        if let Some(
grpc_store0
) = store.downcast_ref::<GrpcStore>(Some(digest.into())) {
  Branch (604:16): [True: 0, False: 3]
  Branch (604:16): [Folded - Ignored]
605
0
            let stream = grpc_store.read(Request::new(read_request)).await?;
606
0
            return Ok(Response::new(Box::pin(stream)));
607
3
        }
608
609
3
        let digest_function = resource_info.digest_function.as_deref().map_or_else(
610
3
            || Ok(default_digest_hasher_func()),
611
3
            DigestHasherFunc::try_from,
612
3
        )
?0
;
613
614
3
        let resp = make_ctx_for_hash_func(digest_function)
615
3
            .err_tip(|| 
"In BytestreamServer::read"0
)
?0
616
            .wrap_async(
617
3
                error_span!("bytestream_read"),
618
3
                self.inner_read(store, digest, read_request),
619
            )
620
0
            .await
621
3
            .err_tip(|| 
"In ByteStreamServer::read"0
)
622
3
            .map_err(Into::into);
623
3
624
3
        if resp.is_ok() {
  Branch (624:12): [True: 3, False: 0]
  Branch (624:12): [Folded - Ignored]
625
3
            event!(Level::DEBUG, return = "Ok(<stream>)");
626
0
        }
627
3
        resp
628
6
    }
629
630
    #[allow(clippy::blocks_in_conditions)]
631
15
    #[instrument(
632
        err,
633
        level = Level::ERROR,
634
        skip_all,
635
        fields(request = ?grpc_request.get_ref())
636
30
    )]
637
    async fn write(
638
        &self,
639
        grpc_request: Request<Streaming<WriteRequest>>,
640
15
    ) -> Result<Response<WriteResponse>, Status> {
641
15
        let 
stream14
= WriteRequestStreamWrapper::from(grpc_request.into_inner())
642
65
            .await
643
15
            .err_tip(|| 
"Could not unwrap first stream message"1
)
644
15
            .map_err(Into::<Status>::into)
?1
;
645
646
14
        let instance_name = stream.resource_info.instance_name.as_ref();
647
14
        let store = self
648
14
            .stores
649
14
            .get(instance_name)
650
14
            .err_tip(|| 
format!("'instance_name' not configured for '{instance_name}'")0
)
?0
651
14
            .clone();
652
653
14
        let digest = DigestInfo::try_new(
654
14
            &stream.resource_info.hash,
655
14
            stream.resource_info.expected_size,
656
14
        )
657
14
        .err_tip(|| 
"Invalid digest input in ByteStream::write"0
)
?0
;
658
659
        // If we are a GrpcStore we shortcut here, as this is a special store.
660
14
        if let Some(
grpc_store0
) = store.downcast_ref::<GrpcStore>(Some(digest.into())) {
  Branch (660:16): [True: 0, False: 14]
  Branch (660:16): [Folded - Ignored]
661
0
            return grpc_store.write(stream).await.map_err(Into::into);
662
14
        }
663
664
14
        let digest_function = stream
665
14
            .resource_info
666
14
            .digest_function
667
14
            .as_deref()
668
14
            .map_or_else(
669
14
                || Ok(default_digest_hasher_func()),
670
14
                DigestHasherFunc::try_from,
671
14
            )
?0
;
672
673
14
        make_ctx_for_hash_func(digest_function)
674
14
            .err_tip(|| 
"In BytestreamServer::write"0
)
?0
675
            .wrap_async(
676
14
                error_span!("bytestream_write"),
677
14
                self.inner_write(store, digest, stream),
678
            )
679
101
            .await
680
14
            .err_tip(|| 
"In ByteStreamServer::write"6
)
681
14
            .map_err(Into::into)
682
30
    }
683
684
    #[allow(clippy::blocks_in_conditions)]
685
3
    #[instrument(
686
        err,
687
        ret(level = Level::INFO),
688
        level = Level::ERROR,
689
        skip_all,
690
        fields(request = ?grpc_request.get_ref())
691
6
    )]
692
    async fn query_write_status(
693
        &self,
694
        grpc_request: Request<QueryWriteStatusRequest>,
695
3
    ) -> Result<Response<QueryWriteStatusResponse>, Status> {
696
3
        self.inner_query_write_status(&grpc_request.into_inner())
697
0
            .await
698
3
            .err_tip(|| 
"Failed on query_write_status() command"0
)
699
3
            .map_err(Into::into)
700
6
    }
701
}