Coverage Report

Created: 2025-11-07 13:29

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-service/src/bytestream_server.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::convert::Into;
16
use core::fmt::{Debug, Formatter};
17
use core::pin::Pin;
18
use core::sync::atomic::{AtomicU64, Ordering};
19
use core::time::Duration;
20
use std::collections::HashMap;
21
use std::collections::hash_map::Entry;
22
use std::sync::Arc;
23
use std::time::{SystemTime, UNIX_EPOCH};
24
25
use futures::future::{BoxFuture, pending};
26
use futures::stream::unfold;
27
use futures::{Future, Stream, TryFutureExt, try_join};
28
use nativelink_config::cas_server::{ByteStreamConfig, InstanceName, WithInstanceName};
29
use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err};
30
use nativelink_proto::google::bytestream::byte_stream_server::{
31
    ByteStream, ByteStreamServer as Server,
32
};
33
use nativelink_proto::google::bytestream::{
34
    QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest,
35
    WriteResponse,
36
};
37
use nativelink_store::grpc_store::GrpcStore;
38
use nativelink_store::store_manager::StoreManager;
39
use nativelink_util::buf_channel::{
40
    DropCloserReadHalf, DropCloserWriteHalf, make_buf_channel_pair,
41
};
42
use nativelink_util::common::DigestInfo;
43
use nativelink_util::digest_hasher::{
44
    DigestHasherFunc, default_digest_hasher_func, make_ctx_for_hash_func,
45
};
46
use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper;
47
use nativelink_util::resource_info::ResourceInfo;
48
use nativelink_util::spawn;
49
use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo};
50
use nativelink_util::task::JoinHandleDropGuard;
51
use opentelemetry::context::FutureExt;
52
use parking_lot::Mutex;
53
use tokio::time::sleep;
54
use tonic::{Request, Response, Status, Streaming};
55
use tracing::{Instrument, Level, debug, error, error_span, info, instrument, trace, warn};
56
57
/// If this value changes update the documentation in the config definition.
58
const DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT: Duration = Duration::from_secs(60);
59
60
/// If this value changes update the documentation in the config definition.
61
const DEFAULT_MAX_BYTES_PER_STREAM: usize = 64 * 1024;
62
63
type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>);
64
type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
65
66
pub struct InstanceInfo {
67
    store: Store,
68
    // Max number of bytes to send on each grpc stream chunk.
69
    max_bytes_per_stream: usize,
70
    active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>,
71
    sleep_fn: SleepFn,
72
}
73
74
impl Debug for InstanceInfo {
75
0
    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
76
0
        f.debug_struct("InstanceInfo")
77
0
            .field("store", &self.store)
78
0
            .field("max_bytes_per_stream", &self.max_bytes_per_stream)
79
0
            .field("active_uploads", &self.active_uploads)
80
0
            .finish()
81
0
    }
82
}
83
84
type ReadStream = Pin<Box<dyn Stream<Item = Result<ReadResponse, Status>> + Send + 'static>>;
85
type StoreUpdateFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>;
86
87
struct StreamState {
88
    uuid: String,
89
    tx: DropCloserWriteHalf,
90
    store_update_fut: StoreUpdateFuture,
91
}
92
93
impl Debug for StreamState {
94
0
    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
95
0
        f.debug_struct("StreamState")
96
0
            .field("uuid", &self.uuid)
97
0
            .finish()
98
0
    }
99
}
100
101
/// If a stream is in this state, it will automatically be put back into an `IdleStream` and
102
/// placed back into the `active_uploads` map as an `IdleStream` after it is dropped.
103
/// To prevent it from being put back into an `IdleStream` you must call `.graceful_finish()`.
104
struct ActiveStreamGuard {
105
    stream_state: Option<StreamState>,
106
    bytes_received: Arc<AtomicU64>,
107
    active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>,
108
    sleep_fn: SleepFn,
109
}
110
111
impl ActiveStreamGuard {
112
    /// Consumes the guard. The stream will be considered "finished", will
113
    /// remove it from the `active_uploads`.
114
8
    fn graceful_finish(mut self) {
115
8
        let stream_state = self.stream_state.take().unwrap();
116
8
        self.active_uploads.lock().remove(&stream_state.uuid);
117
8
    }
118
}
119
120
impl Drop for ActiveStreamGuard {
121
15
    fn drop(&mut self) {
122
15
        let Some(
stream_state7
) = self.stream_state.take() else {
  Branch (122:13): [True: 7, False: 8]
  Branch (122:13): [Folded - Ignored]
123
8
            return; // If None it means we don't want it put back into an IdleStream.
124
        };
125
7
        let weak_active_uploads = Arc::downgrade(&self.active_uploads);
126
7
        let mut active_uploads = self.active_uploads.lock();
127
7
        let uuid = stream_state.uuid.clone();
128
7
        let Some(active_uploads_slot) = active_uploads.get_mut(&uuid) else {
  Branch (128:13): [True: 7, False: 0]
  Branch (128:13): [Folded - Ignored]
129
0
            error!(
130
                err = "Failed to find active upload. This should never happen.",
131
                uuid = ?uuid,
132
            );
133
0
            return;
134
        };
135
7
        let sleep_fn = self.sleep_fn.clone();
136
7
        active_uploads_slot.1 = Some(IdleStream {
137
7
            stream_state,
138
7
            _timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move 
{3
139
3
                (*sleep_fn)().await;
140
0
                if let Some(active_uploads) = weak_active_uploads.upgrade() {
  Branch (140:24): [True: 0, False: 0]
  Branch (140:24): [Folded - Ignored]
141
0
                    let mut active_uploads = active_uploads.lock();
142
0
                    info!(msg = "Removing idle stream", uuid = ?uuid);
143
0
                    active_uploads.remove(&uuid);
144
0
                }
145
0
            }),
146
        });
147
15
    }
148
}
149
150
/// Represents a stream that is in the "idle" state. this means it is not currently being used
151
/// by a client. If it is not used within a certain amount of time it will be removed from the
152
/// `active_uploads` map automatically.
153
#[derive(Debug)]
154
struct IdleStream {
155
    stream_state: StreamState,
156
    _timeout_streaam_drop_guard: JoinHandleDropGuard<()>,
157
}
158
159
impl IdleStream {
160
3
    fn into_active_stream(
161
3
        self,
162
3
        bytes_received: Arc<AtomicU64>,
163
3
        instance_info: &InstanceInfo,
164
3
    ) -> ActiveStreamGuard {
165
3
        ActiveStreamGuard {
166
3
            stream_state: Some(self.stream_state),
167
3
            bytes_received,
168
3
            active_uploads: instance_info.active_uploads.clone(),
169
3
            sleep_fn: instance_info.sleep_fn.clone(),
170
3
        }
171
3
    }
172
}
173
174
#[derive(Debug)]
175
pub struct ByteStreamServer {
176
    instance_infos: HashMap<InstanceName, InstanceInfo>,
177
}
178
179
impl ByteStreamServer {
180
    /// Generate a unique UUID by appending a nanosecond timestamp to avoid collisions.
181
0
    fn generate_unique_uuid(base_uuid: &str) -> String {
182
0
        let timestamp = SystemTime::now()
183
0
            .duration_since(UNIX_EPOCH)
184
0
            .unwrap_or_default()
185
0
            .as_nanos();
186
0
        format!("{base_uuid}-{timestamp:x}")
187
0
    }
188
189
15
    pub fn new(
190
15
        configs: &[WithInstanceName<ByteStreamConfig>],
191
15
        store_manager: &StoreManager,
192
15
    ) -> Result<Self, Error> {
193
15
        let mut instance_infos: HashMap<String, InstanceInfo> = HashMap::new();
194
30
        for 
config15
in configs {
195
15
            let persist_stream_on_disconnect_timeout =
196
15
                if config.persist_stream_on_disconnect_timeout == 0 {
  Branch (196:20): [True: 15, False: 0]
  Branch (196:20): [Folded - Ignored]
197
15
                    DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT
198
                } else {
199
0
                    Duration::from_secs(config.persist_stream_on_disconnect_timeout as u64)
200
                };
201
15
            let _old_value = instance_infos.insert(
202
15
                config.instance_name.clone(),
203
15
                Self::new_with_sleep_fn(
204
15
                    config,
205
15
                    store_manager,
206
15
                    Arc::new(move || 
Box::pin3
(
sleep3
(
persist_stream_on_disconnect_timeout3
))),
207
0
                )?,
208
            );
209
        }
210
15
        Ok(Self { instance_infos })
211
15
    }
212
213
15
    pub fn new_with_sleep_fn(
214
15
        config: &WithInstanceName<ByteStreamConfig>,
215
15
        store_manager: &StoreManager,
216
15
        sleep_fn: SleepFn,
217
15
    ) -> Result<InstanceInfo, Error> {
218
15
        let store = store_manager
219
15
            .get_store(&config.cas_store)
220
15
            .ok_or_else(|| make_input_err!("'cas_store': '{}' does not exist", 
config.cas_store0
))
?0
;
221
15
        let max_bytes_per_stream = if config.max_bytes_per_stream == 0 {
  Branch (221:39): [True: 1, False: 14]
  Branch (221:39): [Folded - Ignored]
222
1
            DEFAULT_MAX_BYTES_PER_STREAM
223
        } else {
224
14
            config.max_bytes_per_stream
225
        };
226
15
        Ok(InstanceInfo {
227
15
            store,
228
15
            max_bytes_per_stream,
229
15
            active_uploads: Arc::new(Mutex::new(HashMap::new())),
230
15
            sleep_fn,
231
15
        })
232
15
    }
233
234
1
    pub fn into_service(self) -> Server<Self> {
235
1
        Server::new(self)
236
1
    }
237
238
    /// Creates or joins an upload stream for the given UUID.
239
    ///
240
    /// This function handles three scenarios:
241
    /// 1. UUID doesn't exist - creates a new upload stream
242
    /// 2. UUID exists but is idle - resumes the existing stream
243
    /// 3. UUID exists and is active - generates a unique UUID by appending a nanosecond
244
    ///    timestamp to avoid collision, then creates a new stream with that UUID
245
    ///
246
    /// The nanosecond timestamp ensures virtually zero probability of collision since
247
    /// two concurrent uploads would need to both collide on the original UUID AND
248
    /// generate the unique UUID in the exact same nanosecond.
249
15
    fn create_or_join_upload_stream(
250
15
        &self,
251
15
        uuid: &str,
252
15
        instance: &InstanceInfo,
253
15
        digest: DigestInfo,
254
15
    ) -> ActiveStreamGuard {
255
15
        let (
uuid12
,
bytes_received12
) = match instance.active_uploads.lock().entry(uuid.to_string()) {
256
3
            Entry::Occupied(mut entry) => {
257
3
                let maybe_idle_stream = entry.get_mut();
258
3
                if let Some(idle_stream) = maybe_idle_stream.1.take() {
  Branch (258:24): [True: 3, False: 0]
  Branch (258:24): [Folded - Ignored]
259
                    // Case 2: Stream exists but is idle, we can resume it
260
3
                    let bytes_received = maybe_idle_stream.0.clone();
261
3
                    info!(msg = "Joining existing stream", entry = ?entry.key());
262
3
                    return idle_stream.into_active_stream(bytes_received, instance);
263
0
                }
264
                // Case 3: Stream is active - generate a unique UUID to avoid collision
265
                // Using nanosecond timestamp makes collision probability essentially zero
266
0
                let original_uuid = entry.key().clone();
267
0
                let unique_uuid = Self::generate_unique_uuid(&original_uuid);
268
0
                warn!(
269
                    msg = "UUID collision detected, generating unique UUID to prevent conflict",
270
                    original_uuid = ?original_uuid,
271
                    unique_uuid = ?unique_uuid
272
                );
273
                // Entry goes out of scope here, releasing the lock
274
275
0
                let bytes_received = Arc::new(AtomicU64::new(0));
276
0
                let mut active_uploads = instance.active_uploads.lock();
277
                // Insert with the unique UUID - this should never collide due to nanosecond precision
278
0
                active_uploads.insert(unique_uuid.clone(), (bytes_received.clone(), None));
279
0
                (unique_uuid, bytes_received)
280
            }
281
12
            Entry::Vacant(entry) => {
282
                // Case 1: UUID doesn't exist, create new stream
283
12
                let bytes_received = Arc::new(AtomicU64::new(0));
284
12
                let uuid = entry.key().clone();
285
                // Our stream is "in use" if the key is in the map, but the value is None.
286
12
                entry.insert((bytes_received.clone(), None));
287
12
                (uuid, bytes_received)
288
            }
289
        };
290
291
        // Important: Do not return an error from this point onwards without
292
        // removing the entry from the map, otherwise that UUID becomes
293
        // unusable.
294
295
12
        let (tx, rx) = make_buf_channel_pair();
296
12
        let store = instance.store.clone();
297
12
        let store_update_fut = Box::pin(async move 
{8
298
            // We need to wrap `Store::update()` in a another future because we need to capture
299
            // `store` to ensure its lifetime follows the future and not the caller.
300
8
            store
301
8
                // Bytestream always uses digest size as the actual byte size.
302
8
                .update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes()))
303
8
                .await
304
8
        });
305
12
        ActiveStreamGuard {
306
12
            stream_state: Some(StreamState {
307
12
                uuid,
308
12
                tx,
309
12
                store_update_fut,
310
12
            }),
311
12
            bytes_received,
312
12
            active_uploads: instance.active_uploads.clone(),
313
12
            sleep_fn: instance.sleep_fn.clone(),
314
12
        }
315
15
    }
316
317
3
    async fn inner_read(
318
3
        &self,
319
3
        instance: &InstanceInfo,
320
3
        digest: DigestInfo,
321
3
        read_request: ReadRequest,
322
3
    ) -> Result<impl Stream<Item = Result<ReadResponse, Status>> + Send + use<>, Error> {
323
        struct ReaderState {
324
            max_bytes_per_stream: usize,
325
            rx: DropCloserReadHalf,
326
            maybe_get_part_result: Option<Result<(), Error>>,
327
            get_part_fut: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
328
        }
329
330
3
        let read_limit = u64::try_from(read_request.read_limit)
331
3
            .err_tip(|| "Could not convert read_limit to u64")
?0
;
332
333
3
        let (tx, rx) = make_buf_channel_pair();
334
335
3
        let read_limit = if read_limit != 0 {
  Branch (335:29): [True: 3, False: 0]
  Branch (335:29): [Folded - Ignored]
336
3
            Some(read_limit)
337
        } else {
338
0
            None
339
        };
340
341
        // This allows us to call a destructor when the the object is dropped.
342
3
        let store = instance.store.clone();
343
3
        let state = Some(ReaderState {
344
3
            rx,
345
3
            max_bytes_per_stream: instance.max_bytes_per_stream,
346
3
            maybe_get_part_result: None,
347
3
            get_part_fut: Box::pin(async move {
348
3
                store
349
3
                    .get_part(
350
3
                        digest,
351
3
                        tx,
352
3
                        u64::try_from(read_request.read_offset)
353
3
                            .err_tip(|| "Could not convert read_offset to u64")
?0
,
354
3
                        read_limit,
355
                    )
356
3
                    .await
357
3
            }),
358
        });
359
360
3
        let read_stream_span = error_span!("read_stream");
361
362
9.77k
        Ok(
Box::pin3
(
unfold3
(
state3
, move |state| {
363
9.77k
            async {
364
9.77k
            let mut state = state
?0
; // If None our stream is done.
365
9.77k
            let mut response = ReadResponse::default();
366
            {
367
9.77k
                let consume_fut = state.rx.consume(Some(state.max_bytes_per_stream));
368
9.77k
                tokio::pin!(consume_fut);
369
                loop {
370
9.77k
                    tokio::select! {
371
9.77k
                        
read_result9.77k
= &mut consume_fut => {
372
9.77k
                            match read_result {
373
9.76k
                                Ok(bytes) => {
374
9.76k
                                    if bytes.is_empty() {
  Branch (374:40): [True: 2, False: 9.76k]
  Branch (374:40): [Folded - Ignored]
375
                                        // EOF.
376
2
                                        return None;
377
9.76k
                                    }
378
9.76k
                                    if bytes.len() > state.max_bytes_per_stream {
  Branch (378:40): [True: 0, False: 9.76k]
  Branch (378:40): [Folded - Ignored]
379
0
                                        let err = make_err!(Code::Internal, "Returned store size was larger than read size");
380
0
                                        return Some((Err(err.into()), None));
381
9.76k
                                    }
382
9.76k
                                    response.data = bytes;
383
9.76k
                                    trace!(response = ?response);
384
9.76k
                                    debug!(response.data = format!("<redacted len({})>", response.data.len()));
385
9.76k
                                    break;
386
                                }
387
1
                                Err(mut e) => {
388
                                    // We may need to propagate the error from reading the data through first.
389
                                    // For example, the NotFound error will come through `get_part_fut`, and
390
                                    // will not be present in `e`, but we need to ensure we pass NotFound error
391
                                    // code or the client won't know why it failed.
392
1
                                    let get_part_result = if let Some(result) = state.maybe_get_part_result {
  Branch (392:66): [True: 1, False: 0]
  Branch (392:66): [Folded - Ignored]
393
1
                                        result
394
                                    } else {
395
                                        // This should never be `future::pending()` if maybe_get_part_result is
396
                                        // not set.
397
0
                                        state.get_part_fut.await
398
                                    };
399
1
                                    if let Err(err) = get_part_result {
  Branch (399:44): [True: 1, False: 0]
  Branch (399:44): [Folded - Ignored]
400
1
                                        e = err.merge(e);
401
1
                                    
}0
402
1
                                    if e.code == Code::NotFound {
  Branch (402:40): [True: 1, False: 0]
  Branch (402:40): [Folded - Ignored]
403
1
                                        // Trim the error code. Not Found is quite common and we don't want to send a large
404
1
                                        // error (debug) message for something that is common. We resize to just the last
405
1
                                        // message as it will be the most relevant.
406
1
                                        e.messages.truncate(1);
407
1
                                    
}0
408
1
                                    error!(response = ?e);
409
1
                                    return Some((Err(e.into()), None))
410
                                }
411
                            }
412
                        },
413
9.77k
                        
result3
= &mut state.get_part_fut => {
414
3
                            state.maybe_get_part_result = Some(result);
415
3
                            // It is non-deterministic on which future will finish in what order.
416
3
                            // It is also possible that the `state.rx.consume()` call above may not be able to
417
3
                            // respond even though the publishing future is done.
418
3
                            // Because of this we set the writing future to pending so it never finishes.
419
3
                            // The `state.rx.consume()` future will eventually finish and return either the
420
3
                            // data or an error.
421
3
                            // An EOF will terminate the `state.rx.consume()` future, but we are also protected
422
3
                            // because we are dropping the writing future, it will drop the `tx` channel
423
3
                            // which will eventually propagate an error to the `state.rx.consume()` future if
424
3
                            // the EOF was not sent due to some other error.
425
3
                            state.get_part_fut = Box::pin(pending());
426
3
                        },
427
                    }
428
                }
429
            }
430
9.76k
            Some((Ok(response), Some(state)))
431
9.77k
        }.instrument(read_stream_span.clone())
432
9.77k
        })))
433
3
    }
434
435
    // We instrument tracing here as well as below because `stream` has a hash on it
436
    // that is extracted from the first stream message. If we only implemented it below
437
    // we would not have the hash available to us.
438
    #[instrument(
439
        ret(level = Level::DEBUG),
440
        level = Level::ERROR,
441
        skip(self, instance_info),
442
    )]
443
15
    async fn inner_write(
444
15
        &self,
445
15
        instance_info: &InstanceInfo,
446
15
        digest: DigestInfo,
447
15
        stream: WriteRequestStreamWrapper<impl Stream<Item = Result<WriteRequest, Status>> + Unpin>,
448
15
    ) -> Result<Response<WriteResponse>, Error> {
449
15
        async fn process_client_stream(
450
15
            mut stream: WriteRequestStreamWrapper<
451
15
                impl Stream<Item = Result<WriteRequest, Status>> + Unpin,
452
15
            >,
453
15
            tx: &mut DropCloserWriteHalf,
454
15
            outer_bytes_received: &Arc<AtomicU64>,
455
15
            expected_size: u64,
456
15
        ) -> Result<(), Error> {
457
            loop {
458
25
                let 
write_request20
= match stream.next().await {
459
                    // Code path for when client tries to gracefully close the stream.
460
                    // If this happens it means there's a problem with the data sent,
461
                    // because we always close the stream from our end before this point
462
                    // by counting the number of bytes sent from the client. If they send
463
                    // less than the amount they said they were going to send and then
464
                    // close the stream, we know there's a problem.
465
                    None => {
466
0
                        return Err(make_input_err!(
467
0
                            "Client closed stream before sending all data"
468
0
                        ));
469
                    }
470
                    // Code path for client stream error. Probably client disconnect.
471
5
                    Some(Err(err)) => return Err(err),
472
                    // Code path for received chunk of data.
473
20
                    Some(Ok(write_request)) => write_request,
474
                };
475
476
20
                if write_request.write_offset < 0 {
  Branch (476:20): [True: 1, False: 19]
  Branch (476:20): [Folded - Ignored]
477
1
                    return Err(make_input_err!(
478
1
                        "Invalid negative write offset in write request: {}",
479
1
                        write_request.write_offset
480
1
                    ));
481
19
                }
482
19
                let write_offset = write_request.write_offset as u64;
483
484
                // If we get duplicate data because a client didn't know where
485
                // it left off from, then we can simply skip it.
486
19
                let 
data18
= if write_offset < tx.get_bytes_written() {
  Branch (486:31): [True: 2, False: 17]
  Branch (486:31): [Folded - Ignored]
487
2
                    if (write_offset + write_request.data.len() as u64) < tx.get_bytes_written() {
  Branch (487:24): [True: 0, False: 2]
  Branch (487:24): [Folded - Ignored]
488
0
                        if write_request.finish_write {
  Branch (488:28): [True: 0, False: 0]
  Branch (488:28): [Folded - Ignored]
489
0
                            return Err(make_input_err!(
490
0
                                "Resumed stream finished at {} bytes when we already received {} bytes.",
491
0
                                write_offset + write_request.data.len() as u64,
492
0
                                tx.get_bytes_written()
493
0
                            ));
494
0
                        }
495
0
                        continue;
496
2
                    }
497
2
                    write_request.data.slice(
498
2
                        usize::try_from(tx.get_bytes_written() - write_offset)
499
2
                            .unwrap_or(usize::MAX)..,
500
                    )
501
                } else {
502
17
                    if write_offset != tx.get_bytes_written() {
  Branch (502:24): [True: 1, False: 16]
  Branch (502:24): [Folded - Ignored]
503
1
                        return Err(make_input_err!(
504
1
                            "Received out of order data. Got {}, expected {}",
505
1
                            write_offset,
506
1
                            tx.get_bytes_written()
507
1
                        ));
508
16
                    }
509
16
                    write_request.data
510
                };
511
512
                // Do not process EOF or weird stuff will happen.
513
18
                if !data.is_empty() {
  Branch (513:20): [True: 13, False: 5]
  Branch (513:20): [Folded - Ignored]
514
                    // We also need to process the possible EOF branch, so we can't early return.
515
13
                    if let Err(
mut err0
) = tx.send(data).await {
  Branch (515:28): [True: 0, False: 13]
  Branch (515:28): [Folded - Ignored]
516
0
                        err.code = Code::Internal;
517
0
                        return Err(err);
518
13
                    }
519
13
                    outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release);
520
5
                }
521
522
18
                if expected_size < tx.get_bytes_written() {
  Branch (522:20): [True: 0, False: 18]
  Branch (522:20): [Folded - Ignored]
523
0
                    return Err(make_input_err!("Received more bytes than expected"));
524
18
                }
525
18
                if write_request.finish_write {
  Branch (525:20): [True: 8, False: 10]
  Branch (525:20): [Folded - Ignored]
526
                    // Gracefully close our stream.
527
8
                    tx.send_eof()
528
8
                        .err_tip(|| "Failed to send EOF in ByteStream::write")
?0
;
529
8
                    return Ok(());
530
10
                }
531
                // Continue.
532
            }
533
            // Unreachable.
534
15
        }
535
536
        let uuid = stream
537
            .resource_info
538
            .uuid
539
            .as_ref()
540
            .ok_or_else(|| make_input_err!("UUID must be set if writing data"))?;
541
        let mut active_stream_guard =
542
            self.create_or_join_upload_stream(uuid, instance_info, digest);
543
        let expected_size = stream.resource_info.expected_size as u64;
544
545
        let active_stream = active_stream_guard.stream_state.as_mut().unwrap();
546
        try_join!(
547
            process_client_stream(
548
                stream,
549
                &mut active_stream.tx,
550
                &active_stream_guard.bytes_received,
551
                expected_size
552
            ),
553
            (&mut active_stream.store_update_fut)
554
0
                .map_err(|err| { err.append("Error updating inner store") })
555
        )?;
556
557
        // Close our guard and consider the stream no longer active.
558
        active_stream_guard.graceful_finish();
559
560
        Ok(Response::new(WriteResponse {
561
            committed_size: expected_size as i64,
562
        }))
563
15
    }
564
565
3
    async fn inner_query_write_status(
566
3
        &self,
567
3
        query_request: &QueryWriteStatusRequest,
568
3
    ) -> Result<Response<QueryWriteStatusResponse>, Error> {
569
3
        let mut resource_info = ResourceInfo::new(&query_request.resource_name, true)
?0
;
570
571
3
        let instance = self
572
3
            .instance_infos
573
3
            .get(resource_info.instance_name.as_ref())
574
3
            .err_tip(|| 
{0
575
0
                format!(
576
0
                    "'instance_name' not configured for '{}'",
577
0
                    &resource_info.instance_name
578
                )
579
0
            })?;
580
3
        let store_clone = instance.store.clone();
581
582
3
        let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)
?0
;
583
584
        // If we are a GrpcStore we shortcut here, as this is a special store.
585
3
        if let Some(
grpc_store0
) = store_clone.downcast_ref::<GrpcStore>(Some(digest.into())) {
  Branch (585:16): [True: 0, False: 3]
  Branch (585:16): [Folded - Ignored]
586
0
            return grpc_store
587
0
                .query_write_status(Request::new(query_request.clone()))
588
0
                .await;
589
3
        }
590
591
3
        let uuid = resource_info
592
3
            .uuid
593
3
            .take()
594
3
            .ok_or_else(|| make_input_err!("UUID must be set if querying write status"))
?0
;
595
596
        {
597
3
            let active_uploads = instance.active_uploads.lock();
598
3
            if let Some((
received_bytes1
,
_maybe_idle_stream1
)) = active_uploads.get(uuid.as_ref()) {
  Branch (598:20): [True: 1, False: 2]
  Branch (598:20): [Folded - Ignored]
599
1
                return Ok(Response::new(QueryWriteStatusResponse {
600
1
                    committed_size: received_bytes.load(Ordering::Acquire) as i64,
601
1
                    // If we are in the active_uploads map, but the value is None,
602
1
                    // it means the stream is not complete.
603
1
                    complete: false,
604
1
                }));
605
2
            }
606
        }
607
608
2
        let has_fut = store_clone.has(digest);
609
2
        let Some(
item_size1
) = has_fut.await.err_tip(|| "Failed to call .has() on store")
?0
else {
  Branch (609:13): [True: 1, False: 1]
  Branch (609:13): [Folded - Ignored]
610
            // We lie here and say that the stream needs to start over, even though
611
            // it was never started. This can happen when the client disconnects
612
            // before sending the first payload, but the client thinks it did send
613
            // the payload.
614
1
            return Ok(Response::new(QueryWriteStatusResponse {
615
1
                committed_size: 0,
616
1
                complete: false,
617
1
            }));
618
        };
619
1
        Ok(Response::new(QueryWriteStatusResponse {
620
1
            committed_size: item_size as i64,
621
1
            complete: true,
622
1
        }))
623
3
    }
624
}
625
626
#[tonic::async_trait]
627
impl ByteStream for ByteStreamServer {
628
    type ReadStream = ReadStream;
629
630
    #[instrument(
631
        err,
632
        level = Level::ERROR,
633
        skip_all,
634
        fields(request = ?grpc_request.get_ref())
635
    )]
636
    async fn read(
637
        &self,
638
        grpc_request: Request<ReadRequest>,
639
    ) -> Result<Response<Self::ReadStream>, Status> {
640
        let read_request = grpc_request.into_inner();
641
        let resource_info = ResourceInfo::new(&read_request.resource_name, false)?;
642
        let instance_name = resource_info.instance_name.as_ref();
643
        let instance = self
644
            .instance_infos
645
            .get(instance_name)
646
0
            .err_tip(|| format!("'instance_name' not configured for '{instance_name}'"))?;
647
        let store = instance.store.clone();
648
649
        let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?;
650
651
        // If we are a GrpcStore we shortcut here, as this is a special store.
652
        if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(Some(digest.into())) {
653
            let stream = Box::pin(grpc_store.read(Request::new(read_request)).await?);
654
            return Ok(Response::new(stream));
655
        }
656
657
        let digest_function = resource_info.digest_function.as_deref().map_or_else(
658
3
            || Ok(default_digest_hasher_func()),
659
            DigestHasherFunc::try_from,
660
        )?;
661
662
        let resp = self
663
            .inner_read(instance, digest, read_request)
664
            .instrument(error_span!("bytestream_read"))
665
            .with_context(
666
                make_ctx_for_hash_func(digest_function).err_tip(|| "In BytestreamServer::read")?,
667
            )
668
            .await
669
            .err_tip(|| "In ByteStreamServer::read")
670
3
            .map(|stream| -> Response<Self::ReadStream> { Response::new(Box::pin(stream)) })
671
            .map_err(Into::into);
672
673
        if resp.is_ok() {
674
            debug!(return = "Ok(<stream>)");
675
        }
676
677
        resp
678
    }
679
680
    #[instrument(
681
        err,
682
        level = Level::ERROR,
683
        skip_all,
684
        fields(request = ?grpc_request.get_ref())
685
    )]
686
    async fn write(
687
        &self,
688
        grpc_request: Request<Streaming<WriteRequest>>,
689
    ) -> Result<Response<WriteResponse>, Status> {
690
        let request = grpc_request.into_inner();
691
        let stream = WriteRequestStreamWrapper::from(request)
692
            .await
693
            .err_tip(|| "Could not unwrap first stream message")
694
            .map_err(Into::<Status>::into)?;
695
696
        let instance_name = stream.resource_info.instance_name.as_ref();
697
        let instance = self
698
            .instance_infos
699
            .get(instance_name)
700
0
            .err_tip(|| format!("'instance_name' not configured for '{instance_name}'"))?;
701
        let store = instance.store.clone();
702
703
        let digest = DigestInfo::try_new(
704
            &stream.resource_info.hash,
705
            stream.resource_info.expected_size,
706
        )
707
        .err_tip(|| "Invalid digest input in ByteStream::write")?;
708
709
        // If we are a GrpcStore we shortcut here, as this is a special store.
710
        if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(Some(digest.into())) {
711
            let resp = grpc_store.write(stream).await.map_err(Into::into);
712
            return resp;
713
        }
714
715
        let digest_function = stream
716
            .resource_info
717
            .digest_function
718
            .as_deref()
719
            .map_or_else(
720
15
                || Ok(default_digest_hasher_func()),
721
                DigestHasherFunc::try_from,
722
            )?;
723
724
        self.inner_write(instance, digest, stream)
725
            .instrument(error_span!("bytestream_write"))
726
            .with_context(
727
                make_ctx_for_hash_func(digest_function).err_tip(|| "In BytestreamServer::write")?,
728
            )
729
            .await
730
            .err_tip(|| "In ByteStreamServer::write")
731
            .map_err(Into::into)
732
    }
733
734
    #[instrument(
735
        err,
736
        ret(level = Level::INFO),
737
        level = Level::ERROR,
738
        skip_all,
739
        fields(request = ?grpc_request.get_ref())
740
    )]
741
    async fn query_write_status(
742
        &self,
743
        grpc_request: Request<QueryWriteStatusRequest>,
744
    ) -> Result<Response<QueryWriteStatusResponse>, Status> {
745
        let request = grpc_request.into_inner();
746
        self.inner_query_write_status(&request)
747
            .await
748
            .err_tip(|| "Failed on query_write_status() command")
749
            .map_err(Into::into)
750
    }
751
}