Coverage Report

Created: 2025-04-19 16:54

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/src/bin/nativelink.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::{HashMap, HashSet};
16
use std::net::SocketAddr;
17
use std::sync::Arc;
18
use std::time::{Duration, SystemTime, UNIX_EPOCH};
19
20
use async_lock::Mutex as AsyncMutex;
21
use axum::Router;
22
use clap::Parser;
23
use futures::FutureExt;
24
use futures::future::{BoxFuture, Either, OptionFuture, TryFutureExt, try_join_all};
25
use hyper::StatusCode;
26
use hyper_util::rt::tokio::TokioIo;
27
use hyper_util::server::conn::auto;
28
use hyper_util::service::TowerToHyperService;
29
use mimalloc::MiMalloc;
30
use nativelink_config::cas_server::{
31
    CasConfig, GlobalConfig, HttpCompressionAlgorithm, ListenerConfig, SchedulerConfig,
32
    ServerConfig, StoreConfig, WorkerConfig,
33
};
34
use nativelink_config::stores::ConfigDigestHashFunction;
35
use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err};
36
use nativelink_metric::{
37
    MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent, RootMetricsComponent,
38
};
39
use nativelink_scheduler::default_scheduler_factory::scheduler_factory;
40
use nativelink_service::ac_server::AcServer;
41
use nativelink_service::bep_server::BepServer;
42
use nativelink_service::bytestream_server::ByteStreamServer;
43
use nativelink_service::capabilities_server::CapabilitiesServer;
44
use nativelink_service::cas_server::CasServer;
45
use nativelink_service::execution_server::ExecutionServer;
46
use nativelink_service::fetch_server::FetchServer;
47
use nativelink_service::health_server::HealthServer;
48
use nativelink_service::push_server::PushServer;
49
use nativelink_service::worker_api_server::WorkerApiServer;
50
use nativelink_store::default_store_factory::store_factory;
51
use nativelink_store::store_manager::StoreManager;
52
use nativelink_util::common::fs::set_open_file_limit;
53
use nativelink_util::digest_hasher::{DigestHasherFunc, set_default_digest_hasher_func};
54
use nativelink_util::health_utils::HealthRegistryBuilder;
55
use nativelink_util::metrics_utils::Counter;
56
use nativelink_util::operation_state_manager::ClientStateManager;
57
use nativelink_util::origin_context::OriginContext;
58
use nativelink_util::origin_event_middleware::OriginEventMiddlewareLayer;
59
use nativelink_util::origin_event_publisher::OriginEventPublisher;
60
use nativelink_util::shutdown_guard::{Priority, ShutdownGuard};
61
use nativelink_util::store_trait::{
62
    DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG, set_default_digest_size_health_check,
63
};
64
use nativelink_util::task::TaskExecutor;
65
use nativelink_util::{background_spawn, fs, init_tracing, spawn};
66
use nativelink_worker::local_worker::new_local_worker;
67
use parking_lot::{Mutex, RwLock};
68
use rustls_pemfile::{certs as extract_certs, crls as extract_crls};
69
use scopeguard::guard;
70
use tokio::net::TcpListener;
71
use tokio::select;
72
#[cfg(target_family = "unix")]
73
use tokio::signal::unix::{SignalKind, signal};
74
use tokio::sync::{broadcast, mpsc};
75
use tokio_rustls::TlsAcceptor;
76
use tokio_rustls::rustls::pki_types::CertificateDer;
77
use tokio_rustls::rustls::server::WebPkiClientVerifier;
78
use tokio_rustls::rustls::{RootCertStore, ServerConfig as TlsServerConfig};
79
use tonic::codec::CompressionEncoding;
80
use tonic::service::Routes;
81
use tracing::{Level, error_span, event, trace_span};
82
83
#[global_allocator]
84
static GLOBAL: MiMalloc = MiMalloc;
85
86
/// Note: This must be kept in sync with the documentation in `AdminConfig::path`.
87
const DEFAULT_ADMIN_API_PATH: &str = "/admin";
88
89
// Note: This must be kept in sync with the documentation in `HealthConfig::path`.
90
const DEFAULT_HEALTH_STATUS_CHECK_PATH: &str = "/status";
91
92
// Note: This must be kept in sync with the documentation in
93
// `OriginEventsConfig::max_event_queue_size`.
94
const DEFAULT_MAX_QUEUE_EVENTS: usize = 65536;
95
96
/// Broadcast Channel Capacity
97
/// Note: The actual capacity may be greater than the provided capacity.
98
const BROADCAST_CAPACITY: usize = 1;
99
100
/// Backend for bazel remote execution / cache API.
101
#[derive(Parser, Debug)]
102
#[clap(
103
    author = "Trace Machina, Inc. <nativelink@tracemachina.com>",
104
    version,
105
    about,
106
    long_about = None
107
)]
108
struct Args {
109
    /// Config file to use.
110
    #[clap(value_parser)]
111
0
    config_file: String,
112
}
113
114
/// The root metrics collector struct. All metrics will be
115
/// collected from this struct traversing down each child
116
/// component.
117
#[derive(MetricsComponent)]
118
struct RootMetrics {
119
    #[metric(group = "stores")]
120
    stores: Arc<dyn RootMetricsComponent>,
121
    #[metric(group = "servers")]
122
    servers: HashMap<String, Arc<dyn RootMetricsComponent>>,
123
    #[metric(group = "workers")]
124
    workers: HashMap<String, Arc<dyn RootMetricsComponent>>,
125
    // TODO(allada) We cannot upcast these to RootMetricsComponent because
126
    // of https://github.com/rust-lang/rust/issues/65991.
127
    // TODO(allada) To prevent output from being too verbose we only
128
    // print the action_schedulers.
129
    #[metric(group = "action_schedulers")]
130
    schedulers: HashMap<String, Arc<dyn ClientStateManager>>,
131
}
132
133
impl RootMetricsComponent for RootMetrics {}
134
135
/// Wrapper to allow us to hash `SocketAddr` for metrics.
136
#[derive(Hash, PartialEq, Eq)]
137
struct SocketAddrWrapper(SocketAddr);
138
139
impl MetricsComponent for SocketAddrWrapper {
140
0
    fn publish(
141
0
        &self,
142
0
        _kind: MetricKind,
143
0
        _field_metadata: MetricFieldData,
144
0
    ) -> Result<MetricPublishKnownKindData, nativelink_metric::Error> {
145
0
        Ok(MetricPublishKnownKindData::String(self.0.to_string()))
146
0
    }
147
}
148
149
impl RootMetricsComponent for SocketAddrWrapper {}
150
151
/// Simple wrapper to enable us to register the Hashmap so it can
152
/// report metrics about what clients are connected.
153
#[derive(MetricsComponent)]
154
struct ConnectedClientsMetrics {
155
    #[metric(group = "currently_connected_clients")]
156
    inner: Mutex<HashSet<SocketAddrWrapper>>,
157
    #[metric(help = "Total client connections since server started")]
158
    counter: Counter,
159
    #[metric(help = "Timestamp when the server started")]
160
    server_start_ts: u64,
161
}
162
163
impl RootMetricsComponent for ConnectedClientsMetrics {}
164
165
trait RoutesExt {
166
    fn add_optional_service<S>(self, svc: Option<S>) -> Self
167
    where
168
        S: tower::Service<axum::http::Request<tonic::body::Body>, Error = std::convert::Infallible>
169
            + tonic::server::NamedService
170
            + Clone
171
            + Send
172
            + Sync
173
            + 'static,
174
        S::Response: axum::response::IntoResponse,
175
        S::Future: Send + 'static;
176
}
177
178
impl RoutesExt for Routes {
179
0
    fn add_optional_service<S>(mut self, svc: Option<S>) -> Self
180
0
    where
181
0
        S: tower::Service<axum::http::Request<tonic::body::Body>, Error = std::convert::Infallible>
182
0
            + tonic::server::NamedService
183
0
            + Clone
184
0
            + Send
185
0
            + Sync
186
0
            + 'static,
187
0
        S::Response: axum::response::IntoResponse,
188
0
        S::Future: Send + 'static,
189
0
    {
190
0
        if let Some(svc) = svc {
  Branch (190:16): [Folded - Ignored]
191
0
            self = self.add_service(svc);
192
0
        }
193
0
        self
194
0
    }
195
}
196
197
0
async fn inner_main(
198
0
    cfg: CasConfig,
199
0
    server_start_timestamp: u64,
200
0
    shutdown_tx: broadcast::Sender<ShutdownGuard>,
201
0
) -> Result<(), Error> {
202
0
    const fn into_encoding(from: HttpCompressionAlgorithm) -> Option<CompressionEncoding> {
203
0
        match from {
204
0
            HttpCompressionAlgorithm::Gzip => Some(CompressionEncoding::Gzip),
205
0
            HttpCompressionAlgorithm::None => None,
206
        }
207
0
    }
208
209
0
    let health_registry_builder =
210
0
        Arc::new(AsyncMutex::new(HealthRegistryBuilder::new("nativelink")));
211
0
212
0
    let store_manager = Arc::new(StoreManager::new());
213
    {
214
0
        let mut health_registry_lock = health_registry_builder.lock().await;
215
216
0
        for StoreConfig { name, spec } in cfg.stores {
217
0
            let health_component_name = format!("stores/{name}");
218
0
            let mut health_register_store =
219
0
                health_registry_lock.sub_builder(&health_component_name);
220
0
            let store = store_factory(&spec, &store_manager, Some(&mut health_register_store))
221
0
                .await
222
0
                .err_tip(|| format!("Failed to create store '{name}'"))?;
223
0
            store_manager.add_store(&name, store);
224
        }
225
    }
226
227
0
    let mut root_futures: Vec<BoxFuture<Result<(), Error>>> = Vec::new();
228
229
0
    let maybe_origin_event_tx = cfg
230
0
        .experimental_origin_events
231
0
        .as_ref()
232
0
        .map(|origin_events_cfg| {
233
0
            let mut max_queued_events = origin_events_cfg.max_event_queue_size;
234
0
            if max_queued_events == 0 {
  Branch (234:16): [Folded - Ignored]
235
0
                max_queued_events = DEFAULT_MAX_QUEUE_EVENTS;
236
0
            }
237
0
            let (tx, rx) = mpsc::channel(max_queued_events);
238
0
            let store_name = origin_events_cfg.publisher.store.as_str();
239
0
            let store = store_manager.get_store(store_name).err_tip(|| {
240
0
                format!("Could not get store {store_name} for origin event publisher")
241
0
            })?;
242
243
0
            root_futures.push(Box::pin(
244
0
                OriginEventPublisher::new(store, rx, shutdown_tx.clone())
245
0
                    .run()
246
0
                    .map(Ok),
247
0
            ));
248
0
249
0
            Ok::<_, Error>(tx)
250
0
        })
251
0
        .transpose()?;
252
253
0
    let mut action_schedulers = HashMap::new();
254
0
    let mut worker_schedulers = HashMap::new();
255
0
    for SchedulerConfig { name, spec } in cfg.schedulers.iter().flatten() {
256
0
        let (maybe_action_scheduler, maybe_worker_scheduler) =
257
0
            scheduler_factory(spec, &store_manager, maybe_origin_event_tx.as_ref())
258
0
                .err_tip(|| format!("Failed to create scheduler '{name}'"))?;
259
0
        if let Some(action_scheduler) = maybe_action_scheduler {
  Branch (259:16): [Folded - Ignored]
260
0
            action_schedulers.insert(name.clone(), action_scheduler.clone());
261
0
        }
262
0
        if let Some(worker_scheduler) = maybe_worker_scheduler {
  Branch (262:16): [Folded - Ignored]
263
0
            worker_schedulers.insert(name.clone(), worker_scheduler.clone());
264
0
        }
265
    }
266
267
0
    let mut server_metrics: HashMap<String, Arc<dyn RootMetricsComponent>> = HashMap::new();
268
0
    // Registers all the ConnectedClientsMetrics to the registries
269
0
    // and zips them in. It is done this way to get around the need
270
0
    // for `root_metrics_registry` to become immutable in the loop.
271
0
    let servers_and_clients: Vec<(ServerConfig, _)> = cfg
272
0
        .servers
273
0
        .into_iter()
274
0
        .enumerate()
275
0
        .map(|(i, server_cfg)| {
276
0
            let name = if server_cfg.name.is_empty() {
  Branch (276:27): [Folded - Ignored]
277
0
                format!("{i}")
278
            } else {
279
0
                server_cfg.name.clone()
280
            };
281
0
            let connected_clients_mux = Arc::new(ConnectedClientsMetrics {
282
0
                inner: Mutex::new(HashSet::new()),
283
0
                counter: Counter::default(),
284
0
                server_start_ts: server_start_timestamp,
285
0
            });
286
0
            server_metrics.insert(name.clone(), connected_clients_mux.clone());
287
0
288
0
            (server_cfg, connected_clients_mux)
289
0
        })
290
0
        .collect();
291
0
292
0
    let root_metrics = Arc::new(RwLock::new(RootMetrics {
293
0
        stores: store_manager.clone(),
294
0
        servers: server_metrics,
295
0
        workers: HashMap::new(), // Will be filled in later.
296
0
        schedulers: action_schedulers.clone(),
297
0
    }));
298
299
0
    for (server_cfg, connected_clients_mux) in servers_and_clients {
300
0
        let services = server_cfg
301
0
            .services
302
0
            .err_tip(|| "'services' must be configured")?;
303
304
        // Currently we only support http as our socket type.
305
0
        let ListenerConfig::Http(http_config) = server_cfg.listener;
306
307
0
        let tonic_services = Routes::builder()
308
0
            .routes()
309
0
            .add_optional_service(
310
0
                services
311
0
                    .ac
312
0
                    .map_or(Ok(None), |cfg| {
313
0
                        AcServer::new(&cfg, &store_manager).map(|v| {
314
0
                            let mut service = v.into_service();
315
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
316
0
                            if let Some(encoding) =
  Branch (316:36): [Folded - Ignored]
317
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::None))
318
0
                            {
319
0
                                service = service.send_compressed(encoding);
320
0
                            }
321
0
                            for encoding in http_config
322
0
                                .compression
323
0
                                .accepted_compression_algorithms
324
0
                                .iter()
325
0
                                // Filter None values.
326
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
327
0
                            {
328
0
                                service = service.accept_compressed(encoding);
329
0
                            }
330
0
                            Some(service)
331
0
                        })
332
0
                    })
333
0
                    .err_tip(|| "Could not create AC service")?,
334
            )
335
0
            .add_optional_service(
336
0
                services
337
0
                    .cas
338
0
                    .map_or(Ok(None), |cfg| {
339
0
                        CasServer::new(&cfg, &store_manager).map(|v| {
340
0
                            let mut service = v.into_service();
341
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
342
0
                            if let Some(encoding) =
  Branch (342:36): [Folded - Ignored]
343
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::None))
344
0
                            {
345
0
                                service = service.send_compressed(encoding);
346
0
                            }
347
0
                            for encoding in http_config
348
0
                                .compression
349
0
                                .accepted_compression_algorithms
350
0
                                .iter()
351
0
                                // Filter None values.
352
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
353
0
                            {
354
0
                                service = service.accept_compressed(encoding);
355
0
                            }
356
0
                            Some(service)
357
0
                        })
358
0
                    })
359
0
                    .err_tip(|| "Could not create CAS service")?,
360
            )
361
0
            .add_optional_service(
362
0
                services
363
0
                    .execution
364
0
                    .map_or(Ok(None), |cfg| {
365
0
                        ExecutionServer::new(&cfg, &action_schedulers, &store_manager).map(|v| {
366
0
                            let mut service = v.into_service();
367
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
368
0
                            if let Some(encoding) =
  Branch (368:36): [Folded - Ignored]
369
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::None))
370
0
                            {
371
0
                                service = service.send_compressed(encoding);
372
0
                            }
373
0
                            for encoding in http_config
374
0
                                .compression
375
0
                                .accepted_compression_algorithms
376
0
                                .iter()
377
0
                                // Filter None values.
378
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
379
0
                            {
380
0
                                service = service.accept_compressed(encoding);
381
0
                            }
382
0
                            Some(service)
383
0
                        })
384
0
                    })
385
0
                    .err_tip(|| "Could not create Execution service")?,
386
            )
387
0
            .add_optional_service(
388
0
                services
389
0
                    .fetch
390
0
                    .map_or(Ok(None), |cfg| {
391
0
                        FetchServer::new(&cfg, &store_manager).map(|v| {
392
0
                            let mut service = v.into_service();
393
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
394
0
                            if let Some(encoding) =
  Branch (394:36): [Folded - Ignored]
395
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::None))
396
0
                            {
397
0
                                service = service.send_compressed(encoding);
398
0
                            }
399
0
                            for encoding in http_config
400
0
                                .compression
401
0
                                .accepted_compression_algorithms
402
0
                                .iter()
403
0
                                // Filter None values.
404
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
405
0
                            {
406
0
                                service = service.accept_compressed(encoding);
407
0
                            }
408
0
                            Some(service)
409
0
                        })
410
0
                    })
411
0
                    .err_tip(|| "Could not create Fetch service")?,
412
            )
413
0
            .add_optional_service(
414
0
                services
415
0
                    .push
416
0
                    .map_or(Ok(None), |cfg| {
417
0
                        PushServer::new(&cfg, &store_manager).map(|v| {
418
0
                            let mut service = v.into_service();
419
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
420
0
                            if let Some(encoding) =
  Branch (420:36): [Folded - Ignored]
421
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::None))
422
0
                            {
423
0
                                service = service.send_compressed(encoding);
424
0
                            }
425
0
                            for encoding in http_config
426
0
                                .compression
427
0
                                .accepted_compression_algorithms
428
0
                                .iter()
429
0
                                // Filter None values.
430
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
431
0
                            {
432
0
                                service = service.accept_compressed(encoding);
433
0
                            }
434
0
                            Some(service)
435
0
                        })
436
0
                    })
437
0
                    .err_tip(|| "Could not create Push service")?,
438
            )
439
0
            .add_optional_service(
440
0
                services
441
0
                    .bytestream
442
0
                    .map_or(Ok(None), |cfg| {
443
0
                        ByteStreamServer::new(&cfg, &store_manager).map(|v| {
444
0
                            let mut service = v.into_service();
445
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
446
0
                            if let Some(encoding) =
  Branch (446:36): [Folded - Ignored]
447
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::None))
448
0
                            {
449
0
                                service = service.send_compressed(encoding);
450
0
                            }
451
0
                            for encoding in http_config
452
0
                                .compression
453
0
                                .accepted_compression_algorithms
454
0
                                .iter()
455
0
                                // Filter None values.
456
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
457
0
                            {
458
0
                                service = service.accept_compressed(encoding);
459
0
                            }
460
0
                            Some(service)
461
0
                        })
462
0
                    })
463
0
                    .err_tip(|| "Could not create ByteStream service")?,
464
            )
465
0
            .add_optional_service(
466
0
                OptionFuture::from(
467
0
                    services
468
0
                        .capabilities
469
0
                        .as_ref()
470
0
                        // Borrow checker fighting here...
471
0
                        .map(|_| {
472
0
                            CapabilitiesServer::new(
473
0
                                services.capabilities.as_ref().unwrap(),
474
0
                                &action_schedulers,
475
0
                            )
476
0
                        }),
477
                )
478
0
                .await
479
0
                .map_or(Ok::<Option<CapabilitiesServer>, Error>(None), |server| {
480
0
                    Ok(Some(server?))
481
0
                })
482
0
                .err_tip(|| "Could not create Capabilities service")?
483
0
                .map(|v| {
484
0
                    let mut service = v.into_service();
485
0
                    let send_algo = &http_config.compression.send_compression_algorithm;
486
0
                    if let Some(encoding) =
  Branch (486:28): [Folded - Ignored]
487
0
                        into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::None))
488
0
                    {
489
0
                        service = service.send_compressed(encoding);
490
0
                    }
491
0
                    for encoding in http_config
492
0
                        .compression
493
0
                        .accepted_compression_algorithms
494
0
                        .iter()
495
0
                        // Filter None values.
496
0
                        .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
497
0
                    {
498
0
                        service = service.accept_compressed(encoding);
499
0
                    }
500
0
                    service
501
0
                }),
502
            )
503
0
            .add_optional_service(
504
0
                services
505
0
                    .worker_api
506
0
                    .map_or(Ok(None), |cfg| {
507
0
                        WorkerApiServer::new(&cfg, &worker_schedulers).map(|v| {
508
0
                            let mut service = v.into_service();
509
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
510
0
                            if let Some(encoding) =
  Branch (510:36): [Folded - Ignored]
511
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::None))
512
0
                            {
513
0
                                service = service.send_compressed(encoding);
514
0
                            }
515
0
                            for encoding in http_config
516
0
                                .compression
517
0
                                .accepted_compression_algorithms
518
0
                                .iter()
519
0
                                // Filter None values.
520
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
521
0
                            {
522
0
                                service = service.accept_compressed(encoding);
523
0
                            }
524
0
                            Some(service)
525
0
                        })
526
0
                    })
527
0
                    .err_tip(|| "Could not create WorkerApi service")?,
528
            )
529
0
            .add_optional_service(
530
0
                services
531
0
                    .experimental_bep
532
0
                    .map_or(Ok(None), |cfg| {
533
0
                        BepServer::new(&cfg, &store_manager).map(|v| {
534
0
                            let mut service = v.into_service();
535
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
536
0
                            if let Some(encoding) =
  Branch (536:36): [Folded - Ignored]
537
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::None))
538
0
                            {
539
0
                                service = service.send_compressed(encoding);
540
0
                            }
541
0
                            for encoding in http_config
542
0
                                .compression
543
0
                                .accepted_compression_algorithms
544
0
                                .iter()
545
0
                                // Filter None values.
546
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
547
0
                            {
548
0
                                service = service.accept_compressed(encoding);
549
0
                            }
550
0
                            Some(service)
551
0
                        })
552
0
                    })
553
0
                    .err_tip(|| "Could not create BEP service")?,
554
            );
555
556
0
        let health_registry = health_registry_builder.lock().await.build();
557
0
558
0
        let mut svc = tonic_services
559
0
            .into_axum_router()
560
0
            .layer(OriginEventMiddlewareLayer::new(
561
0
                maybe_origin_event_tx.clone(),
562
0
                server_cfg.experimental_identity_header.clone(),
563
            ));
564
565
0
        if let Some(health_cfg) = services.health {
  Branch (565:16): [Folded - Ignored]
566
0
            let path = if health_cfg.path.is_empty() {
  Branch (566:27): [Folded - Ignored]
567
0
                DEFAULT_HEALTH_STATUS_CHECK_PATH
568
            } else {
569
0
                &health_cfg.path
570
            };
571
0
            svc = svc.route_service(path, HealthServer::new(health_registry));
572
0
        }
573
574
0
        if let Some(admin_config) = services.admin {
  Branch (574:16): [Folded - Ignored]
575
0
            let path = if admin_config.path.is_empty() {
  Branch (575:27): [Folded - Ignored]
576
0
                DEFAULT_ADMIN_API_PATH
577
            } else {
578
0
                &admin_config.path
579
            };
580
0
            let worker_schedulers = Arc::new(worker_schedulers.clone());
581
0
            svc = svc.nest_service(
582
0
                path,
583
0
                Router::new().route(
584
0
                    "/scheduler/{instance_name}/set_drain_worker/{worker_id}/{is_draining}",
585
0
                    axum::routing::post(
586
0
                        move |params: axum::extract::Path<(String, String, String)>| async move {
587
0
                            let (instance_name, worker_id, is_draining) = params.0;
588
0
                            (async move {
589
0
                                let is_draining = match is_draining.as_str() {
590
0
                                    "0" => false,
591
0
                                    "1" => true,
592
                                    _ => {
593
0
                                        return Err(make_err!(
594
0
                                            Code::Internal,
595
0
                                            "{} is neither 0 nor 1",
596
0
                                            is_draining
597
0
                                        ));
598
                                    }
599
                                };
600
0
                                worker_schedulers
601
0
                                    .get(&instance_name)
602
0
                                    .err_tip(|| {
603
0
                                        format!(
604
0
                                            "Can not get an instance with the name of '{}'",
605
0
                                            &instance_name
606
0
                                        )
607
0
                                    })?
608
0
                                    .clone()
609
0
                                    .set_drain_worker(&worker_id.clone().into(), is_draining)
610
0
                                    .await?;
611
0
                                Ok::<_, Error>(format!("Draining worker {worker_id}"))
612
                            })
613
0
                            .await
614
0
                            .map_err(|e| {
615
0
                                Err::<String, _>((
616
0
                                    StatusCode::INTERNAL_SERVER_ERROR,
617
0
                                    format!("Error: {e:?}"),
618
0
                                ))
619
0
                            })
620
0
                        },
621
                    ),
622
                ),
623
            );
624
0
        }
625
626
0
        svc = svc
627
0
            // This is the default service that executes if no other endpoint matches.
628
0
            .fallback((StatusCode::NOT_FOUND, "Not Found"));
629
630
        // Configure our TLS acceptor if we have TLS configured.
631
0
        let maybe_tls_acceptor = http_config.tls.map_or(Ok(None), |tls_config| {
632
0
            fn read_cert(cert_file: &str) -> Result<Vec<CertificateDer<'static>>, Error> {
633
0
                let mut cert_reader = std::io::BufReader::new(
634
0
                    std::fs::File::open(cert_file)
635
0
                        .err_tip(|| format!("Could not open cert file {cert_file}"))?,
636
                );
637
0
                let certs = extract_certs(&mut cert_reader)
638
0
                    .collect::<Result<Vec<CertificateDer<'_>>, _>>()
639
0
                    .err_tip(|| format!("Could not extract certs from file {cert_file}"))?;
640
0
                Ok(certs)
641
0
            }
642
0
            let certs = read_cert(&tls_config.cert_file)?;
643
0
            let mut key_reader = std::io::BufReader::new(
644
0
                std::fs::File::open(&tls_config.key_file)
645
0
                    .err_tip(|| format!("Could not open key file {}", tls_config.key_file))?,
646
            );
647
0
            let key = match rustls_pemfile::read_one(&mut key_reader)
648
0
                .err_tip(|| format!("Could not extract key(s) from file {}", tls_config.key_file))?
649
            {
650
0
                Some(rustls_pemfile::Item::Pkcs8Key(key)) => key.into(),
651
0
                Some(rustls_pemfile::Item::Sec1Key(key)) => key.into(),
652
0
                Some(rustls_pemfile::Item::Pkcs1Key(key)) => key.into(),
653
                _ => {
654
0
                    return Err(make_err!(
655
0
                        Code::Internal,
656
0
                        "No keys found in file {}",
657
0
                        tls_config.key_file
658
0
                    ));
659
                }
660
            };
661
0
            if let Ok(Some(_)) = rustls_pemfile::read_one(&mut key_reader) {
  Branch (661:20): [Folded - Ignored]
662
0
                return Err(make_err!(
663
0
                    Code::InvalidArgument,
664
0
                    "Expected 1 key in file {}",
665
0
                    tls_config.key_file
666
0
                ));
667
0
            }
668
0
            let verifier = if let Some(client_ca_file) = &tls_config.client_ca_file {
  Branch (668:35): [Folded - Ignored]
669
0
                let mut client_auth_roots = RootCertStore::empty();
670
0
                for cert in read_cert(client_ca_file)? {
671
0
                    client_auth_roots.add(cert).map_err(|e| {
672
0
                        make_err!(Code::Internal, "Could not read client CA: {e:?}")
673
0
                    })?;
674
                }
675
0
                let crls = if let Some(client_crl_file) = &tls_config.client_crl_file {
  Branch (675:35): [Folded - Ignored]
676
0
                    let mut crl_reader = std::io::BufReader::new(
677
0
                        std::fs::File::open(client_crl_file)
678
0
                            .err_tip(|| format!("Could not open CRL file {client_crl_file}"))?,
679
                    );
680
0
                    extract_crls(&mut crl_reader)
681
0
                        .collect::<Result<_, _>>()
682
0
                        .err_tip(|| format!("Could not extract CRLs from file {client_crl_file}"))?
683
                } else {
684
0
                    Vec::new()
685
                };
686
0
                WebPkiClientVerifier::builder(Arc::new(client_auth_roots))
687
0
                    .with_crls(crls)
688
0
                    .build()
689
0
                    .map_err(|e| {
690
0
                        make_err!(
691
0
                            Code::Internal,
692
0
                            "Could not create WebPkiClientVerifier: {e:?}"
693
0
                        )
694
0
                    })?
695
            } else {
696
0
                WebPkiClientVerifier::no_client_auth()
697
            };
698
0
            let mut config = TlsServerConfig::builder()
699
0
                .with_client_cert_verifier(verifier)
700
0
                .with_single_cert(certs, key)
701
0
                .map_err(|e| {
702
0
                    make_err!(Code::Internal, "Could not create TlsServerConfig : {e:?}")
703
0
                })?;
704
705
0
            config.alpn_protocols.push("h2".into());
706
0
            Ok(Some(TlsAcceptor::from(Arc::new(config))))
707
0
        })?;
708
709
0
        let socket_addr = http_config
710
0
            .socket_address
711
0
            .parse::<SocketAddr>()
712
0
            .map_err(|e| {
713
0
                make_input_err!("Invalid address '{}' - {e:?}", http_config.socket_address)
714
0
            })?;
715
0
        let tcp_listener = TcpListener::bind(&socket_addr).await?;
716
0
        let mut http = auto::Builder::new(TaskExecutor::default());
717
0
718
0
        let http_config = &http_config.advanced_http;
719
0
        if let Some(value) = http_config.http2_keep_alive_interval {
  Branch (719:16): [Folded - Ignored]
720
0
            http.http2()
721
0
                .keep_alive_interval(Duration::from_secs(u64::from(value)));
722
0
        }
723
724
0
        if let Some(value) = http_config.experimental_http2_max_pending_accept_reset_streams {
  Branch (724:16): [Folded - Ignored]
725
0
            http.http2()
726
0
                .max_pending_accept_reset_streams(usize::try_from(value).err_tip(
727
0
                    || "Could not convert experimental_http2_max_pending_accept_reset_streams",
728
0
                )?);
729
0
        }
730
0
        if let Some(value) = http_config.experimental_http2_initial_stream_window_size {
  Branch (730:16): [Folded - Ignored]
731
0
            http.http2().initial_stream_window_size(value);
732
0
        }
733
0
        if let Some(value) = http_config.experimental_http2_initial_connection_window_size {
  Branch (733:16): [Folded - Ignored]
734
0
            http.http2().initial_connection_window_size(value);
735
0
        }
736
0
        if let Some(value) = http_config.experimental_http2_adaptive_window {
  Branch (736:16): [Folded - Ignored]
737
0
            http.http2().adaptive_window(value);
738
0
        }
739
0
        if let Some(value) = http_config.experimental_http2_max_frame_size {
  Branch (739:16): [Folded - Ignored]
740
0
            http.http2().max_frame_size(value);
741
0
        }
742
0
        if let Some(value) = http_config.experimental_http2_max_concurrent_streams {
  Branch (742:16): [Folded - Ignored]
743
0
            http.http2().max_concurrent_streams(value);
744
0
        }
745
0
        if let Some(value) = http_config.experimental_http2_keep_alive_timeout {
  Branch (745:16): [Folded - Ignored]
746
0
            http.http2()
747
0
                .keep_alive_timeout(Duration::from_secs(u64::from(value)));
748
0
        }
749
0
        if let Some(value) = http_config.experimental_http2_max_send_buf_size {
  Branch (749:16): [Folded - Ignored]
750
0
            http.http2().max_send_buf_size(
751
0
                usize::try_from(value).err_tip(|| "Could not convert http2_max_send_buf_size")?,
752
            );
753
0
        }
754
0
        if let Some(true) = http_config.experimental_http2_enable_connect_protocol {
  Branch (754:16): [Folded - Ignored]
755
0
            http.http2().enable_connect_protocol();
756
0
        }
757
0
        if let Some(value) = http_config.experimental_http2_max_header_list_size {
  Branch (757:16): [Folded - Ignored]
758
0
            http.http2().max_header_list_size(value);
759
0
        }
760
0
        event!(Level::WARN, "Ready, listening on {socket_addr}",);
761
0
        root_futures.push(Box::pin(async move {
762
            loop {
763
0
                select! {
764
0
                    accept_result = tcp_listener.accept() => {
765
0
                        match accept_result {
766
0
                            Ok((tcp_stream, remote_addr)) => {
767
0
                                event!(
768
                                    target: "nativelink::services",
769
0
                                    Level::INFO,
770
                                    ?remote_addr,
771
                                    ?socket_addr,
772
0
                                    "Client connected"
773
                                );
774
0
                                connected_clients_mux
775
0
                                    .inner
776
0
                                    .lock()
777
0
                                    .insert(SocketAddrWrapper(remote_addr));
778
0
                                connected_clients_mux.counter.inc();
779
0
780
0
                                // This is the safest way to guarantee that if our future
781
0
                                // is ever dropped we will cleanup our data.
782
0
                                let scope_guard = guard(
783
0
                                    Arc::downgrade(&connected_clients_mux),
784
0
                                    move |weak_connected_clients_mux| {
785
0
                                        event!(
786
                                            target: "nativelink::services",
787
0
                                            Level::INFO,
788
                                            ?remote_addr,
789
                                            ?socket_addr,
790
0
                                            "Client disconnected"
791
                                        );
792
0
                                        if let Some(connected_clients_mux) = weak_connected_clients_mux.upgrade() {
  Branch (792:48): [Folded - Ignored]
793
0
                                            connected_clients_mux
794
0
                                                .inner
795
0
                                                .lock()
796
0
                                                .remove(&SocketAddrWrapper(remote_addr));
797
0
                                        }
798
0
                                    },
799
                                );
800
801
0
                                let (http, svc, maybe_tls_acceptor) =
802
0
                                    (http.clone(), svc.clone(), maybe_tls_acceptor.clone());
803
0
                                Arc::new(OriginContext::new()).background_spawn(
804
0
                                    error_span!(
805
                                        target: "nativelink::services",
806
                                        "http_connection",
807
                                        ?remote_addr,
808
                                        ?socket_addr
809
                                    ),
810
0
                                    async move {},
811
                                );
812
0
                                background_spawn!(
813
                                    name: "http_connection",
814
0
                                    fut: async move {
815
0
                                        // Move it into our spawn, so if our spawn dies the cleanup happens.
816
0
                                        let _guard = scope_guard;
817
0
                                        let serve_connection = if let Some(tls_acceptor) = maybe_tls_acceptor {
  Branch (817:71): [Folded - Ignored]
818
0
                                            match tls_acceptor.accept(tcp_stream).await {
819
0
                                                Ok(tls_stream) => Either::Left(http.serve_connection(
820
0
                                                    TokioIo::new(tls_stream),
821
0
                                                    TowerToHyperService::new(svc),
822
0
                                                )),
823
0
                                                Err(err) => {
824
0
                                                    event!(Level::ERROR, ?err, "Failed to accept tls stream");
825
0
                                                    return;
826
                                                }
827
                                            }
828
                                        } else {
829
0
                                            Either::Right(http.serve_connection(
830
0
                                                TokioIo::new(tcp_stream),
831
0
                                                TowerToHyperService::new(svc),
832
0
                                            ))
833
                                        };
834
835
0
                                        if let Err(err) = serve_connection.await {
  Branch (835:48): [Folded - Ignored]
836
0
                                            event!(
837
                                                target: "nativelink::services",
838
0
                                                Level::ERROR,
839
                                                ?err,
840
0
                                                "Failed running service"
841
                                            );
842
0
                                        }
843
0
                                    },
844
                                    target: "nativelink::services",
845
                                    ?remote_addr,
846
                                    ?socket_addr,
847
                                );
848
                            },
849
0
                            Err(err) => {
850
0
                                event!(Level::ERROR, ?err, "Failed to accept tcp connection");
851
                            }
852
                        }
853
                    },
854
                }
855
            }
856
            // Unreachable
857
        }));
858
    }
859
860
    {
861
        // We start workers after our TcpListener is setup so if our worker connects to one
862
        // of these services it will be able to connect.
863
0
        let worker_cfgs = cfg.workers.unwrap_or_default();
864
0
        let mut worker_names = HashSet::with_capacity(worker_cfgs.len());
865
0
        let mut worker_metrics: HashMap<String, Arc<dyn RootMetricsComponent>> = HashMap::new();
866
0
        for (i, worker_cfg) in worker_cfgs.into_iter().enumerate() {
867
0
            let spawn_fut = match worker_cfg {
868
0
                WorkerConfig::Local(local_worker_cfg) => {
869
0
                    let fast_slow_store = store_manager
870
0
                        .get_store(&local_worker_cfg.cas_fast_slow_store)
871
0
                        .err_tip(|| {
872
0
                            format!(
873
0
                                "Failed to find store for cas_store_ref in worker config : {}",
874
0
                                local_worker_cfg.cas_fast_slow_store
875
0
                            )
876
0
                        })?;
877
878
0
                    let maybe_ac_store = if let Some(ac_store_ref) =
  Branch (878:49): [Folded - Ignored]
879
0
                        &local_worker_cfg.upload_action_result.ac_store
880
                    {
881
0
                        Some(store_manager.get_store(ac_store_ref).err_tip(|| {
882
0
                            format!("Failed to find store for ac_store in worker config : {ac_store_ref}")
883
0
                        })?)
884
                    } else {
885
0
                        None
886
                    };
887
                    // Note: Defaults to fast_slow_store if not specified. If this ever changes it must
888
                    // be updated in config documentation for the `historical_results_store` the field.
889
0
                    let historical_store = if let Some(cas_store_ref) = &local_worker_cfg
  Branch (889:51): [Folded - Ignored]
890
0
                        .upload_action_result
891
0
                        .historical_results_store
892
                    {
893
0
                        store_manager.get_store(cas_store_ref).err_tip(|| {
894
0
                                format!(
895
0
                                "Failed to find store for historical_results_store in worker config : {cas_store_ref}"
896
0
                            )
897
0
                            })?
898
                    } else {
899
0
                        fast_slow_store.clone()
900
                    };
901
0
                    let (local_worker, metrics) = new_local_worker(
902
0
                        Arc::new(local_worker_cfg),
903
0
                        fast_slow_store,
904
0
                        maybe_ac_store,
905
0
                        historical_store,
906
0
                    )
907
0
                    .await
908
0
                    .err_tip(|| "Could not make LocalWorker")?;
909
910
0
                    let name = if local_worker.name().is_empty() {
  Branch (910:35): [Folded - Ignored]
911
0
                        format!("worker_{i}")
912
                    } else {
913
0
                        local_worker.name().clone()
914
                    };
915
916
0
                    if worker_names.contains(&name) {
  Branch (916:24): [Folded - Ignored]
917
0
                        Err(make_input_err!(
918
0
                            "Duplicate worker name '{}' found in config",
919
0
                            name
920
0
                        ))?;
921
0
                    }
922
0
                    worker_names.insert(name.clone());
923
0
                    worker_metrics.insert(name.clone(), metrics);
924
0
                    let shutdown_rx = shutdown_tx.subscribe();
925
0
                    let fut = Arc::new(OriginContext::new())
926
0
                        .wrap_async(trace_span!("worker_ctx"), local_worker.run(shutdown_rx));
927
0
                    spawn!("worker", fut, ?name)
928
                }
929
            };
930
0
            root_futures.push(Box::pin(spawn_fut.map_ok_or_else(|e| Err(e.into()), |v| v)));
931
        }
932
0
        root_metrics.write().workers = worker_metrics;
933
    }
934
935
0
    if let Err(e) = try_join_all(root_futures).await {
  Branch (935:12): [Folded - Ignored]
936
0
        panic!("{e:?}");
937
0
    };
938
0
939
0
    Ok(())
940
0
}
941
942
0
fn get_config() -> Result<CasConfig, Box<dyn std::error::Error>> {
943
0
    let args = Args::parse();
944
0
    let json_contents = String::from_utf8(
945
0
        std::fs::read(&args.config_file)
946
0
            .err_tip(|| format!("Could not open config file {}", args.config_file))?,
947
0
    )?;
948
0
    Ok(serde_json5::from_str(&json_contents)?)
949
0
}
950
951
0
fn main() -> Result<(), Box<dyn std::error::Error>> {
952
0
    init_tracing()?;
953
954
0
    let mut cfg = get_config()?;
955
956
0
    let global_cfg = if let Some(global_cfg) = &mut cfg.global {
  Branch (956:29): [Folded - Ignored]
957
0
        if global_cfg.max_open_files == 0 {
  Branch (957:12): [Folded - Ignored]
958
0
            global_cfg.max_open_files = fs::DEFAULT_OPEN_FILE_LIMIT;
959
0
        }
960
0
        if global_cfg.default_digest_size_health_check == 0 {
  Branch (960:12): [Folded - Ignored]
961
0
            global_cfg.default_digest_size_health_check = DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG;
962
0
        }
963
964
0
        *global_cfg
965
    } else {
966
0
        GlobalConfig {
967
0
            max_open_files: fs::DEFAULT_OPEN_FILE_LIMIT,
968
0
            default_digest_hash_function: None,
969
0
            default_digest_size_health_check: DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG,
970
0
        }
971
    };
972
0
    set_open_file_limit(global_cfg.max_open_files);
973
0
    set_default_digest_hasher_func(DigestHasherFunc::from(
974
0
        global_cfg
975
0
            .default_digest_hash_function
976
0
            .unwrap_or(ConfigDigestHashFunction::Sha256),
977
0
    ))?;
978
0
    set_default_digest_size_health_check(global_cfg.default_digest_size_health_check)?;
979
980
0
    let server_start_time = SystemTime::now()
981
0
        .duration_since(UNIX_EPOCH)
982
0
        .unwrap()
983
0
        .as_secs();
984
985
    #[expect(clippy::disallowed_methods, reason = "starting main runtime")]
986
0
    let runtime = tokio::runtime::Builder::new_multi_thread()
987
0
        .enable_all()
988
0
        .build()?;
989
990
    // Initiates the shutdown process by broadcasting the shutdown signal via the `oneshot::Sender` to all listeners.
991
    // Each listener will perform its cleanup and then drop its `oneshot::Sender`, signaling completion.
992
    // Once all `oneshot::Sender` instances are dropped, the worker knows it can safely terminate.
993
0
    let (shutdown_tx, _) = broadcast::channel::<ShutdownGuard>(BROADCAST_CAPACITY);
994
0
    let shutdown_tx_clone = shutdown_tx.clone();
995
0
    let mut shutdown_guard = ShutdownGuard::default();
996
0
997
0
    #[expect(clippy::disallowed_methods, reason = "signal handler on main runtime")]
998
0
    runtime.spawn(async move {
999
0
        tokio::signal::ctrl_c()
1000
0
            .await
1001
0
            .expect("Failed to listen to SIGINT");
1002
0
        eprintln!("User terminated process via SIGINT");
1003
0
        std::process::exit(130);
1004
    });
1005
1006
    #[cfg(target_family = "unix")]
1007
    #[expect(clippy::disallowed_methods, reason = "signal handler on main runtime")]
1008
0
    runtime.spawn(async move {
1009
0
        signal(SignalKind::terminate())
1010
0
            .expect("Failed to listen to SIGTERM")
1011
0
            .recv()
1012
0
            .await;
1013
0
        event!(Level::WARN, "Process terminated via SIGTERM",);
1014
0
        drop(shutdown_tx_clone.send(shutdown_guard.clone()));
1015
0
        let () = shutdown_guard.wait_for(Priority::P0).await;
1016
0
        event!(Level::WARN, "Successfully shut down nativelink.",);
1017
0
        std::process::exit(143);
1018
    });
1019
1020
    #[expect(clippy::disallowed_methods, reason = "waiting on everything to finish")]
1021
0
    runtime
1022
0
        .block_on(Arc::new(OriginContext::new()).wrap_async(
1023
0
            trace_span!("main"),
1024
0
            inner_main(cfg, server_start_time, shutdown_tx),
1025
0
        ))
1026
0
        .err_tip(|| "main() function failed")?;
1027
0
    Ok(())
1028
0
}