Coverage Report

Created: 2024-10-22 12:33

/build/source/nativelink-service/src/cas_server.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::{HashMap, VecDeque};
16
use std::convert::Into;
17
use std::pin::Pin;
18
19
use bytes::Bytes;
20
use futures::stream::{FuturesUnordered, Stream};
21
use futures::TryStreamExt;
22
use nativelink_config::cas_server::{CasStoreConfig, InstanceName};
23
use nativelink_error::{error_if, make_input_err, Code, Error, ResultExt};
24
use nativelink_proto::build::bazel::remote::execution::v2::content_addressable_storage_server::{
25
    ContentAddressableStorage, ContentAddressableStorageServer as Server,
26
};
27
use nativelink_proto::build::bazel::remote::execution::v2::{
28
    batch_read_blobs_response, batch_update_blobs_response, compressor, BatchReadBlobsRequest,
29
    BatchReadBlobsResponse, BatchUpdateBlobsRequest, BatchUpdateBlobsResponse, Directory,
30
    FindMissingBlobsRequest, FindMissingBlobsResponse, GetTreeRequest, GetTreeResponse,
31
};
32
use nativelink_proto::google::rpc::Status as GrpcStatus;
33
use nativelink_store::ac_utils::get_and_decode_digest;
34
use nativelink_store::grpc_store::GrpcStore;
35
use nativelink_store::store_manager::StoreManager;
36
use nativelink_util::common::DigestInfo;
37
use nativelink_util::digest_hasher::make_ctx_for_hash_func;
38
use nativelink_util::store_trait::{Store, StoreLike};
39
use tonic::{Request, Response, Status};
40
use tracing::{error_span, event, instrument, Level};
41
42
pub struct CasServer {
43
    stores: HashMap<String, Store>,
44
}
45
46
type GetTreeStream = Pin<Box<dyn Stream<Item = Result<GetTreeResponse, Status>> + Send + 'static>>;
47
48
impl CasServer {
49
8
    pub fn new(
50
8
        config: &HashMap<InstanceName, CasStoreConfig>,
51
8
        store_manager: &StoreManager,
52
8
    ) -> Result<Self, Error> {
53
8
        let mut stores = HashMap::with_capacity(config.len());
54
16
        for (
instance_name, cas_cfg8
) in config {
55
8
            let store = store_manager.get_store(&cas_cfg.cas_store).ok_or_else(|| {
56
0
                make_input_err!("'cas_store': '{}' does not exist", cas_cfg.cas_store)
57
8
            })
?0
;
58
8
            stores.insert(instance_name.to_string(), store);
59
        }
60
8
        Ok(CasServer { stores })
61
8
    }
62
63
0
    pub fn into_service(self) -> Server<CasServer> {
64
0
        Server::new(self)
65
0
    }
66
67
4
    async fn inner_find_missing_blobs(
68
4
        &self,
69
4
        request: FindMissingBlobsRequest,
70
4
    ) -> Result<Response<FindMissingBlobsResponse>, Error> {
71
4
        let instance_name = &request.instance_name;
72
4
        let store = self
73
4
            .stores
74
4
            .get(instance_name)
75
4
            .err_tip(|| 
format!("'instance_name' not configured for '{instance_name}'")0
)
?0
76
4
            .clone();
77
4
78
4
        let mut requested_blobs = Vec::with_capacity(request.blob_digests.len());
79
7
        for digest in 
request.blob_digests.iter()4
{
80
7
            requested_blobs.push(DigestInfo::try_from(digest.clone())
?1
.
into()6
);
81
        }
82
3
        let sizes = store
83
3
            .has_many(&requested_blobs)
84
0
            .await
85
3
            .err_tip(|| 
"In find_missing_blobs"0
)
?0
;
86
3
        let missing_blob_digests = sizes
87
3
            .into_iter()
88
3
            .zip(request.blob_digests)
89
5
            .filter_map(|(maybe_size, digest)| maybe_size.map_or_else(|| 
Some(digest)2
, |_|
None3
))
90
3
            .collect();
91
3
92
3
        Ok(Response::new(FindMissingBlobsResponse {
93
3
            missing_blob_digests,
94
3
        }))
95
4
    }
96
97
2
    async fn inner_batch_update_blobs(
98
2
        &self,
99
2
        request: BatchUpdateBlobsRequest,
100
2
    ) -> Result<Response<BatchUpdateBlobsResponse>, Error> {
101
2
        let instance_name = &request.instance_name;
102
103
2
        let store = self
104
2
            .stores
105
2
            .get(instance_name)
106
2
            .err_tip(|| 
format!("'instance_name' not configured for '{instance_name}'")0
)
?0
107
2
            .clone();
108
109
        // If we are a GrpcStore we shortcut here, as this is a special store.
110
        // Note: We don't know the digests here, so we try perform a very shallow
111
        // check to see if it's a grpc store.
112
2
        if let Some(
grpc_store0
) = store.downcast_ref::<GrpcStore>(None) {
  Branch (112:16): [True: 0, False: 2]
  Branch (112:16): [Folded - Ignored]
113
0
            return grpc_store.batch_update_blobs(Request::new(request)).await;
114
2
        }
115
2
116
2
        let store_ref = &store;
117
2
        let update_futures: FuturesUnordered<_> = request
118
2
            .requests
119
2
            .into_iter()
120
3
            .map(|request| async move {
121
3
                let digest = request
122
3
                    .digest
123
3
                    .clone()
124
3
                    .err_tip(|| 
"Digest not found in request"0
)
?0
;
125
3
                let request_data = request.data;
126
3
                let digest_info = DigestInfo::try_from(digest.clone())
?0
;
127
3
                let size_bytes = usize::try_from(digest_info.size_bytes())
128
3
                    .err_tip(|| 
"Digest size_bytes was not convertible to usize"0
)
?0
;
129
0
                error_if!(
130
3
                    size_bytes != request_data.len(),
  Branch (130:21): [True: 0, False: 3]
  Branch (130:21): [Folded - Ignored]
131
                    "Digest for upload had mismatching sizes, digest said {} data  said {}",
132
                    size_bytes,
133
0
                    request_data.len()
134
                );
135
3
                let result = store_ref
136
3
                    .update_oneshot(digest_info, request_data)
137
0
                    .await
138
3
                    .err_tip(|| 
"Error writing to store"0
);
139
3
                Ok::<_, Error>(batch_update_blobs_response::Response {
140
3
                    digest: Some(digest),
141
3
                    status: Some(result.map_or_else(Into::into, |_| GrpcStatus::default())),
142
3
                })
143
6
            })
144
2
            .collect();
145
2
        let responses = update_futures
146
2
            .try_collect::<Vec<batch_update_blobs_response::Response>>()
147
0
            .await?;
148
149
2
        Ok(Response::new(BatchUpdateBlobsResponse { responses }))
150
2
    }
151
152
1
    async fn inner_batch_read_blobs(
153
1
        &self,
154
1
        request: BatchReadBlobsRequest,
155
1
    ) -> Result<Response<BatchReadBlobsResponse>, Error> {
156
1
        let instance_name = &request.instance_name;
157
158
1
        let store = self
159
1
            .stores
160
1
            .get(instance_name)
161
1
            .err_tip(|| 
format!("'instance_name' not configured for '{instance_name}'")0
)
?0
162
1
            .clone();
163
164
        // If we are a GrpcStore we shortcut here, as this is a special store.
165
        // Note: We don't know the digests here, so we try perform a very shallow
166
        // check to see if it's a grpc store.
167
1
        if let Some(
grpc_store0
) = store.downcast_ref::<GrpcStore>(None) {
  Branch (167:16): [True: 0, False: 1]
  Branch (167:16): [Folded - Ignored]
168
0
            return grpc_store.batch_read_blobs(Request::new(request)).await;
169
1
        }
170
1
171
1
        let store_ref = &store;
172
1
        let read_futures: FuturesUnordered<_> = request
173
1
            .digests
174
1
            .into_iter()
175
3
            .map(|digest| async move {
176
3
                let digest_copy = DigestInfo::try_from(digest.clone())
?0
;
177
                // TODO(allada) There is a security risk here of someone taking all the memory on the instance.
178
3
                let result = store_ref
179
3
                    .get_part_unchunked(digest_copy, 0, None)
180
3
                    .await
181
3
                    .err_tip(|| 
"Error reading from store"1
);
182
3
                let (status, data) = result.map_or_else(
183
3
                    |mut e| {
184
1
                        if e.code == Code::NotFound {
  Branch (184:28): [True: 1, False: 0]
  Branch (184:28): [Folded - Ignored]
185
1
                            // Trim the error code. Not Found is quite common and we don't want to send a large
186
1
                            // error (debug) message for something that is common. We resize to just the last
187
1
                            // message as it will be the most relevant.
188
1
                            e.messages.resize_with(1, String::new);
189
1
                        }
0
190
1
                        (e.into(), Bytes::new())
191
3
                    },
192
3
                    |v| 
(GrpcStatus::default(), v)2
,
193
3
                );
194
3
                Ok::<_, Error>(batch_read_blobs_response::Response {
195
3
                    status: Some(status),
196
3
                    digest: Some(digest),
197
3
                    compressor: compressor::Value::Identity.into(),
198
3
                    data,
199
3
                })
200
6
            })
201
1
            .collect();
202
1
        let responses = read_futures
203
1
            .try_collect::<Vec<batch_read_blobs_response::Response>>()
204
1
            .await
?0
;
205
206
1
        Ok(Response::new(BatchReadBlobsResponse { responses }))
207
1
    }
208
209
6
    async fn inner_get_tree(
210
6
        &self,
211
6
        request: GetTreeRequest,
212
6
    ) -> Result<Response<GetTreeStream>, Error> {
213
6
        let instance_name = &request.instance_name;
214
215
6
        let store = self
216
6
            .stores
217
6
            .get(instance_name)
218
6
            .err_tip(|| 
format!("'instance_name' not configured for '{instance_name}'")0
)
?0
219
6
            .clone();
220
221
        // If we are a GrpcStore we shortcut here, as this is a special store.
222
        // Note: We don't know the digests here, so we try perform a very shallow
223
        // check to see if it's a grpc store.
224
6
        if let Some(
grpc_store0
) = store.downcast_ref::<GrpcStore>(None) {
  Branch (224:16): [True: 0, False: 6]
  Branch (224:16): [Folded - Ignored]
225
0
            let stream = grpc_store
226
0
                .get_tree(Request::new(request))
227
0
                .await?
228
0
                .into_inner();
229
0
            return Ok(Response::new(Box::pin(stream)));
230
6
        }
231
6
        let root_digest: DigestInfo = request
232
6
            .root_digest
233
6
            .err_tip(|| 
"Expected root_digest to exist in GetTreeRequest"0
)
?0
234
6
            .try_into()
235
6
            .err_tip(|| 
"In GetTreeRequest::root_digest"0
)
?0
;
236
237
6
        let mut deque: VecDeque<DigestInfo> = VecDeque::new();
238
6
        let mut directories: Vec<Directory> = Vec::new();
239
        // `page_token` will return the `{hash_str}-{size_bytes}` of the current request's first directory digest.
240
6
        let page_token_digest = if request.page_token.is_empty() {
  Branch (240:36): [True: 2, False: 4]
  Branch (240:36): [Folded - Ignored]
241
2
            root_digest
242
        } else {
243
4
            let mut page_token_parts = request.page_token.split('-');
244
4
            DigestInfo::try_new(
245
4
                page_token_parts
246
4
                    .next()
247
4
                    .err_tip(|| 
"Failed to parse `hash_str` in `page_token`"0
)
?0
,
248
4
                page_token_parts
249
4
                    .next()
250
4
                    .err_tip(|| 
"Failed to parse `size_bytes` in `page_token`"0
)
?0
251
4
                    .parse::<i64>()
252
4
                    .err_tip(|| 
"Failed to parse `size_bytes` as i64"0
)
?0
,
253
            )
254
4
            .err_tip(|| 
"Failed to parse `page_token` as `Digest` in `GetTreeRequest`"0
)
?0
255
        };
256
6
        let page_size = request.page_size;
257
6
        // If `page_size` is 0, paging is not necessary.
258
6
        let mut page_token_matched = page_size == 0;
259
6
        deque.push_back(root_digest);
260
261
28
        while !deque.is_empty() {
  Branch (261:15): [True: 26, False: 2]
  Branch (261:15): [Folded - Ignored]
262
26
            let digest: DigestInfo = deque.pop_front().err_tip(|| 
"In VecDeque::pop_front"0
)
?0
;
263
26
            let directory = get_and_decode_digest::<Directory>(&store, digest.into())
264
26
                .await
265
26
                .err_tip(|| 
"Converting digest to Directory"0
)
?0
;
266
26
            if digest == page_token_digest {
  Branch (266:16): [True: 6, False: 20]
  Branch (266:16): [Folded - Ignored]
267
6
                page_token_matched = true;
268
20
            }
269
56
            for 
directory30
in &directory.directories {
270
30
                let digest: DigestInfo = directory
271
30
                    .digest
272
30
                    .clone()
273
30
                    .err_tip(|| 
"Expected Digest to exist in Directory::directories::digest"0
)
?0
274
30
                    .try_into()
275
30
                    .err_tip(|| 
"In Directory::file::digest"0
)
?0
;
276
30
                deque.push_back(digest);
277
            }
278
26
            if page_token_matched {
  Branch (278:16): [True: 20, False: 6]
  Branch (278:16): [Folded - Ignored]
279
20
                directories.push(directory);
280
20
                if directories.len() as i32 == page_size {
  Branch (280:20): [True: 4, False: 16]
  Branch (280:20): [Folded - Ignored]
281
4
                    break;
282
16
                }
283
6
            }
284
        }
285
        // `next_page_token` will return the `{hash_str}:{size_bytes}` of the next request's first directory digest.
286
        // It will be an empty string when it reached the end of the directory tree.
287
6
        let next_page_token: String = if let Some(
value3
) = deque.front() {
  Branch (287:46): [True: 3, False: 3]
  Branch (287:46): [Folded - Ignored]
288
3
            format!("{value}")
289
        } else {
290
3
            String::new()
291
        };
292
293
6
        Ok(Response::new(Box::pin(futures::stream::once(async {
294
6
            Ok(GetTreeResponse {
295
6
                directories,
296
6
                next_page_token,
297
6
            })
298
6
        }))))
299
6
    }
300
}
301
302
#[tonic::async_trait]
303
impl ContentAddressableStorage for CasServer {
304
    type GetTreeStream = GetTreeStream;
305
306
    #[allow(clippy::blocks_in_conditions)]
307
4
    #[instrument(
308
        err,
309
        ret(level = Level::INFO),
310
        level = Level::ERROR,
311
        skip_all,
312
        fields(request = ?grpc_request.get_ref())
313
8
    )]
314
    async fn find_missing_blobs(
315
        &self,
316
        grpc_request: Request<FindMissingBlobsRequest>,
317
4
    ) -> Result<Response<FindMissingBlobsResponse>, Status> {
318
4
        let request = grpc_request.into_inner();
319
4
        make_ctx_for_hash_func(request.digest_function)
320
4
            .err_tip(|| 
"In CasServer::find_missing_blobs"0
)
?0
321
            .wrap_async(
322
4
                error_span!("cas_server_find_missing_blobs"),
323
4
                self.inner_find_missing_blobs(request),
324
            )
325
0
            .await
326
4
            .err_tip(|| 
"Failed on find_missing_blobs() command"1
)
327
4
            .map_err(Into::into)
328
8
    }
329
330
    #[allow(clippy::blocks_in_conditions)]
331
2
    #[instrument(
332
        err,
333
        ret(level = Level::INFO),
334
        level = Level::ERROR,
335
        skip_all,
336
        fields(request = ?grpc_request.get_ref())
337
4
    )]
338
    async fn batch_update_blobs(
339
        &self,
340
        grpc_request: Request<BatchUpdateBlobsRequest>,
341
2
    ) -> Result<Response<BatchUpdateBlobsResponse>, Status> {
342
2
        let request = grpc_request.into_inner();
343
2
        make_ctx_for_hash_func(request.digest_function)
344
2
            .err_tip(|| 
"In CasServer::batch_update_blobs"0
)
?0
345
            .wrap_async(
346
2
                error_span!("cas_server_batch_update_blobs"),
347
2
                self.inner_batch_update_blobs(request),
348
            )
349
0
            .await
350
2
            .err_tip(|| 
"Failed on batch_update_blobs() command"0
)
351
2
            .map_err(Into::into)
352
4
    }
353
354
    #[allow(clippy::blocks_in_conditions)]
355
1
    #[instrument(
356
        err,
357
        ret(level = Level::INFO),
358
        level = Level::ERROR,
359
        skip_all,
360
        fields(request = ?grpc_request.get_ref())
361
2
    )]
362
    async fn batch_read_blobs(
363
        &self,
364
        grpc_request: Request<BatchReadBlobsRequest>,
365
1
    ) -> Result<Response<BatchReadBlobsResponse>, Status> {
366
1
        let request = grpc_request.into_inner();
367
1
        make_ctx_for_hash_func(request.digest_function)
368
1
            .err_tip(|| 
"In CasServer::batch_read_blobs"0
)
?0
369
            .wrap_async(
370
1
                error_span!("cas_server_batch_read_blobs"),
371
1
                self.inner_batch_read_blobs(request),
372
            )
373
1
            .await
374
1
            .err_tip(|| 
"Failed on batch_read_blobs() command"0
)
375
1
            .map_err(Into::into)
376
2
    }
377
378
    #[allow(clippy::blocks_in_conditions)]
379
6
    #[instrument(
380
        err,
381
        level = Level::ERROR,
382
        skip_all,
383
        fields(request = ?grpc_request.get_ref())
384
12
    )]
385
    async fn get_tree(
386
        &self,
387
        grpc_request: Request<GetTreeRequest>,
388
6
    ) -> Result<Response<Self::GetTreeStream>, Status> {
389
6
        let request = grpc_request.into_inner();
390
6
        let resp = make_ctx_for_hash_func(request.digest_function)
391
6
            .err_tip(|| 
"In CasServer::get_tree"0
)
?0
392
            .wrap_async(
393
6
                error_span!("cas_server_get_tree"),
394
6
                self.inner_get_tree(request),
395
            )
396
26
            .await
397
6
            .err_tip(|| 
"Failed on get_tree() command"0
)
398
6
            .map_err(Into::into);
399
6
        if resp.is_ok() {
  Branch (399:12): [True: 6, False: 0]
  Branch (399:12): [Folded - Ignored]
400
6
            event!(Level::DEBUG, return = "Ok(<stream>)");
401
0
        }
402
6
        resp
403
12
    }
404
}