Coverage Report

Created: 2024-10-22 12:33

/build/source/src/bin/nativelink.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::{HashMap, HashSet};
16
use std::net::SocketAddr;
17
use std::sync::Arc;
18
use std::time::{Duration, SystemTime, UNIX_EPOCH};
19
20
use async_lock::Mutex as AsyncMutex;
21
use axum::Router;
22
use clap::Parser;
23
use futures::future::{select_all, BoxFuture, Either, OptionFuture, TryFutureExt};
24
use hyper::{Response, StatusCode};
25
use hyper_util::rt::tokio::TokioIo;
26
use hyper_util::server::conn::auto;
27
use hyper_util::service::TowerToHyperService;
28
use mimalloc::MiMalloc;
29
use nativelink_config::cas_server::{
30
    CasConfig, GlobalConfig, HttpCompressionAlgorithm, ListenerConfig, ServerConfig, WorkerConfig,
31
};
32
use nativelink_config::stores::ConfigDigestHashFunction;
33
use nativelink_error::{make_err, Code, Error, ResultExt};
34
use nativelink_metric::{
35
    MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent, RootMetricsComponent,
36
};
37
use nativelink_metric_collector::{otel_export, MetricsCollectorLayer};
38
use nativelink_scheduler::default_scheduler_factory::scheduler_factory;
39
use nativelink_service::ac_server::AcServer;
40
use nativelink_service::bep_server::BepServer;
41
use nativelink_service::bytestream_server::ByteStreamServer;
42
use nativelink_service::capabilities_server::CapabilitiesServer;
43
use nativelink_service::cas_server::CasServer;
44
use nativelink_service::execution_server::ExecutionServer;
45
use nativelink_service::health_server::HealthServer;
46
use nativelink_service::worker_api_server::WorkerApiServer;
47
use nativelink_store::default_store_factory::store_factory;
48
use nativelink_store::store_manager::StoreManager;
49
use nativelink_util::action_messages::WorkerId;
50
use nativelink_util::common::fs::{set_idle_file_descriptor_timeout, set_open_file_limit};
51
use nativelink_util::digest_hasher::{set_default_digest_hasher_func, DigestHasherFunc};
52
use nativelink_util::health_utils::HealthRegistryBuilder;
53
use nativelink_util::metrics_utils::{set_metrics_enabled_for_this_thread, Counter};
54
use nativelink_util::operation_state_manager::ClientStateManager;
55
use nativelink_util::origin_context::OriginContext;
56
use nativelink_util::store_trait::{
57
    set_default_digest_size_health_check, DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG,
58
};
59
use nativelink_util::task::TaskExecutor;
60
use nativelink_util::{background_spawn, init_tracing, spawn, spawn_blocking};
61
use nativelink_worker::local_worker::new_local_worker;
62
use opentelemetry::metrics::MeterProvider;
63
use opentelemetry_sdk::metrics::SdkMeterProvider;
64
use parking_lot::{Mutex, RwLock};
65
use prometheus::{Encoder, TextEncoder};
66
use rustls_pemfile::{certs as extract_certs, crls as extract_crls};
67
use scopeguard::guard;
68
use tokio::net::TcpListener;
69
#[cfg(target_family = "unix")]
70
use tokio::signal::unix::{signal, SignalKind};
71
use tokio_rustls::rustls::pki_types::{CertificateDer, CertificateRevocationListDer};
72
use tokio_rustls::rustls::server::WebPkiClientVerifier;
73
use tokio_rustls::rustls::{RootCertStore, ServerConfig as TlsServerConfig};
74
use tokio_rustls::TlsAcceptor;
75
use tonic::codec::CompressionEncoding;
76
use tonic::transport::Server as TonicServer;
77
use tracing::{error_span, event, trace_span, Level};
78
use tracing_subscriber::layer::SubscriberExt;
79
80
#[global_allocator]
81
static GLOBAL: MiMalloc = MiMalloc;
82
83
/// Note: This must be kept in sync with the documentation in `PrometheusConfig::path`.
84
const DEFAULT_PROMETHEUS_METRICS_PATH: &str = "/metrics";
85
86
/// Note: This must be kept in sync with the documentation in `AdminConfig::path`.
87
const DEFAULT_ADMIN_API_PATH: &str = "/admin";
88
89
// Note: This must be kept in sync with the documentation in `HealthConfig::path`.
90
const DEFAULT_HEALTH_STATUS_CHECK_PATH: &str = "/status";
91
92
/// Name of environment variable to disable metrics.
93
const METRICS_DISABLE_ENV: &str = "NATIVELINK_DISABLE_METRICS";
94
95
/// Backend for bazel remote execution / cache API.
96
#[derive(Parser, Debug)]
97
#[clap(
98
    author = "Trace Machina, Inc. <nativelink@tracemachina.com>",
99
    version,
100
    about,
101
    long_about = None
102
)]
103
struct Args {
104
    /// Config file to use.
105
    #[clap(value_parser)]
106
0
    config_file: String,
107
}
108
109
/// The root metrics collector struct. All metrics will be
110
/// collected from this struct traversing down each child
111
/// component.
112
0
#[derive(MetricsComponent)]
113
struct RootMetrics {
114
    #[metric(group = "stores")]
115
    stores: Arc<dyn RootMetricsComponent>,
116
    #[metric(group = "servers")]
117
    servers: HashMap<String, Arc<dyn RootMetricsComponent>>,
118
    #[metric(group = "workers")]
119
    workers: HashMap<String, Arc<dyn RootMetricsComponent>>,
120
    // TODO(allada) We cannot upcast these to RootMetricsComponent because
121
    // of https://github.com/rust-lang/rust/issues/65991.
122
    // TODO(allada) To prevent output from being too verbose we only
123
    // print the action_schedulers.
124
    #[metric(group = "action_schedulers")]
125
    schedulers: HashMap<String, Arc<dyn ClientStateManager>>,
126
}
127
128
impl RootMetricsComponent for RootMetrics {}
129
130
/// Wrapper to allow us to hash `SocketAddr` for metrics.
131
#[derive(Hash, PartialEq, Eq)]
132
struct SocketAddrWrapper(SocketAddr);
133
134
impl MetricsComponent for SocketAddrWrapper {
135
0
    fn publish(
136
0
        &self,
137
0
        _kind: MetricKind,
138
0
        _field_metadata: MetricFieldData,
139
0
    ) -> Result<MetricPublishKnownKindData, nativelink_metric::Error> {
140
0
        Ok(MetricPublishKnownKindData::String(self.0.to_string()))
141
0
    }
142
}
143
144
impl RootMetricsComponent for SocketAddrWrapper {}
145
146
/// Simple wrapper to enable us to register the Hashmap so it can
147
/// report metrics about what clients are connected.
148
0
#[derive(MetricsComponent)]
149
struct ConnectedClientsMetrics {
150
    #[metric(group = "currently_connected_clients")]
151
    inner: Mutex<HashSet<SocketAddrWrapper>>,
152
    #[metric(help = "Total client connections since server started")]
153
    counter: Counter,
154
    #[metric(help = "Timestamp when the server started")]
155
    server_start_ts: u64,
156
}
157
158
impl RootMetricsComponent for ConnectedClientsMetrics {}
159
160
0
async fn inner_main(
161
0
    cfg: CasConfig,
162
0
    server_start_timestamp: u64,
163
0
) -> Result<(), Box<dyn std::error::Error>> {
164
0
    let health_registry_builder =
165
0
        Arc::new(AsyncMutex::new(HealthRegistryBuilder::new("nativelink")));
166
0
167
0
    let store_manager = Arc::new(StoreManager::new());
168
    {
169
0
        let mut health_registry_lock = health_registry_builder.lock().await;
170
171
0
        for (name, store_cfg) in cfg.stores {
172
0
            let health_component_name = format!("stores/{name}");
173
0
            let mut health_register_store =
174
0
                health_registry_lock.sub_builder(&health_component_name);
175
0
            let store = store_factory(&store_cfg, &store_manager, Some(&mut health_register_store))
176
0
                .await
177
0
                .err_tip(|| format!("Failed to create store '{name}'"))?;
178
0
            store_manager.add_store(&name, store);
179
        }
180
    }
181
182
0
    let mut action_schedulers = HashMap::new();
183
0
    let mut worker_schedulers = HashMap::new();
184
0
    if let Some(schedulers_cfg) = cfg.schedulers {
  Branch (184:12): [Folded - Ignored]
185
0
        for (name, scheduler_cfg) in schedulers_cfg {
186
0
            let (maybe_action_scheduler, maybe_worker_scheduler) =
187
0
                scheduler_factory(&scheduler_cfg, &store_manager)
188
0
                    .err_tip(|| format!("Failed to create scheduler '{name}'"))?;
189
0
            if let Some(action_scheduler) = maybe_action_scheduler {
  Branch (189:20): [Folded - Ignored]
190
0
                action_schedulers.insert(name.clone(), action_scheduler.clone());
191
0
            }
192
0
            if let Some(worker_scheduler) = maybe_worker_scheduler {
  Branch (192:20): [Folded - Ignored]
193
0
                worker_schedulers.insert(name.clone(), worker_scheduler.clone());
194
0
            }
195
        }
196
0
    }
197
198
0
    fn into_encoding(from: HttpCompressionAlgorithm) -> Option<CompressionEncoding> {
199
0
        match from {
200
0
            HttpCompressionAlgorithm::gzip => Some(CompressionEncoding::Gzip),
201
0
            HttpCompressionAlgorithm::none => None,
202
        }
203
0
    }
204
205
0
    let mut server_metrics: HashMap<String, Arc<dyn RootMetricsComponent>> = HashMap::new();
206
0
    // Registers all the ConnectedClientsMetrics to the registries
207
0
    // and zips them in. It is done this way to get around the need
208
0
    // for `root_metrics_registry` to become immutable in the loop.
209
0
    let servers_and_clients: Vec<(ServerConfig, _)> = cfg
210
0
        .servers
211
0
        .into_iter()
212
0
        .enumerate()
213
0
        .map(|(i, server_cfg)| {
214
0
            let name = if server_cfg.name.is_empty() {
  Branch (214:27): [Folded - Ignored]
215
0
                format!("{i}")
216
            } else {
217
0
                server_cfg.name.clone()
218
            };
219
0
            let connected_clients_mux = Arc::new(ConnectedClientsMetrics {
220
0
                inner: Mutex::new(HashSet::new()),
221
0
                counter: Counter::default(),
222
0
                server_start_ts: server_start_timestamp,
223
0
            });
224
0
            server_metrics.insert(name.clone(), connected_clients_mux.clone());
225
0
226
0
            (server_cfg, connected_clients_mux)
227
0
        })
228
0
        .collect();
229
0
230
0
    let mut root_futures: Vec<BoxFuture<Result<(), Error>>> = Vec::new();
231
0
232
0
    let root_metrics = Arc::new(RwLock::new(RootMetrics {
233
0
        stores: store_manager.clone(),
234
0
        servers: server_metrics,
235
0
        workers: HashMap::new(), // Will be filled in later.
236
0
        schedulers: action_schedulers.clone(),
237
0
    }));
238
239
0
    for (server_cfg, connected_clients_mux) in servers_and_clients {
240
0
        let services = server_cfg.services.ok_or("'services' must be configured")?;
241
242
        // Currently we only support http as our socket type.
243
0
        let ListenerConfig::http(http_config) = server_cfg.listener;
244
245
0
        let tonic_services = TonicServer::builder()
246
0
            .add_optional_service(
247
0
                services
248
0
                    .ac
249
0
                    .map_or(Ok(None), |cfg| {
250
0
                        AcServer::new(&cfg, &store_manager).map(|v| {
251
0
                            let mut service = v.into_service();
252
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
253
0
                            if let Some(encoding) =
  Branch (253:36): [Folded - Ignored]
254
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::none))
255
0
                            {
256
0
                                service = service.send_compressed(encoding);
257
0
                            }
258
0
                            for encoding in http_config
259
0
                                .compression
260
0
                                .accepted_compression_algorithms
261
0
                                .iter()
262
0
                                // Filter None values.
263
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
264
0
                            {
265
0
                                service = service.accept_compressed(encoding);
266
0
                            }
267
0
                            Some(service)
268
0
                        })
269
0
                    })
270
0
                    .err_tip(|| "Could not create AC service")?,
271
            )
272
            .add_optional_service(
273
0
                services
274
0
                    .cas
275
0
                    .map_or(Ok(None), |cfg| {
276
0
                        CasServer::new(&cfg, &store_manager).map(|v| {
277
0
                            let mut service = v.into_service();
278
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
279
0
                            if let Some(encoding) =
  Branch (279:36): [Folded - Ignored]
280
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::none))
281
0
                            {
282
0
                                service = service.send_compressed(encoding);
283
0
                            }
284
0
                            for encoding in http_config
285
0
                                .compression
286
0
                                .accepted_compression_algorithms
287
0
                                .iter()
288
0
                                // Filter None values.
289
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
290
0
                            {
291
0
                                service = service.accept_compressed(encoding);
292
0
                            }
293
0
                            Some(service)
294
0
                        })
295
0
                    })
296
0
                    .err_tip(|| "Could not create CAS service")?,
297
            )
298
            .add_optional_service(
299
0
                services
300
0
                    .execution
301
0
                    .map_or(Ok(None), |cfg| {
302
0
                        ExecutionServer::new(&cfg, &action_schedulers, &store_manager).map(|v| {
303
0
                            let mut service = v.into_service();
304
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
305
0
                            if let Some(encoding) =
  Branch (305:36): [Folded - Ignored]
306
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::none))
307
0
                            {
308
0
                                service = service.send_compressed(encoding);
309
0
                            }
310
0
                            for encoding in http_config
311
0
                                .compression
312
0
                                .accepted_compression_algorithms
313
0
                                .iter()
314
0
                                // Filter None values.
315
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
316
0
                            {
317
0
                                service = service.accept_compressed(encoding);
318
0
                            }
319
0
                            Some(service)
320
0
                        })
321
0
                    })
322
0
                    .err_tip(|| "Could not create Execution service")?,
323
            )
324
            .add_optional_service(
325
0
                services
326
0
                    .bytestream
327
0
                    .map_or(Ok(None), |cfg| {
328
0
                        ByteStreamServer::new(&cfg, &store_manager).map(|v| {
329
0
                            let mut service = v.into_service();
330
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
331
0
                            if let Some(encoding) =
  Branch (331:36): [Folded - Ignored]
332
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::none))
333
0
                            {
334
0
                                service = service.send_compressed(encoding);
335
0
                            }
336
0
                            for encoding in http_config
337
0
                                .compression
338
0
                                .accepted_compression_algorithms
339
0
                                .iter()
340
0
                                // Filter None values.
341
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
342
0
                            {
343
0
                                service = service.accept_compressed(encoding);
344
0
                            }
345
0
                            Some(service)
346
0
                        })
347
0
                    })
348
0
                    .err_tip(|| "Could not create ByteStream service")?,
349
            )
350
            .add_optional_service(
351
0
                OptionFuture::from(
352
0
                    services
353
0
                        .capabilities
354
0
                        .as_ref()
355
0
                        // Borrow checker fighting here...
356
0
                        .map(|_| {
357
0
                            CapabilitiesServer::new(
358
0
                                services.capabilities.as_ref().unwrap(),
359
0
                                &action_schedulers,
360
0
                            )
361
0
                        }),
362
0
                )
363
0
                .await
364
0
                .map_or(Ok::<Option<CapabilitiesServer>, Error>(None), |server| {
365
0
                    Ok(Some(server?))
366
0
                })
367
0
                .err_tip(|| "Could not create Capabilities service")?
368
0
                .map(|v| {
369
0
                    let mut service = v.into_service();
370
0
                    let send_algo = &http_config.compression.send_compression_algorithm;
371
0
                    if let Some(encoding) =
  Branch (371:28): [Folded - Ignored]
372
0
                        into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::none))
373
0
                    {
374
0
                        service = service.send_compressed(encoding);
375
0
                    }
376
0
                    for encoding in http_config
377
0
                        .compression
378
0
                        .accepted_compression_algorithms
379
0
                        .iter()
380
0
                        // Filter None values.
381
0
                        .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
382
0
                    {
383
0
                        service = service.accept_compressed(encoding);
384
0
                    }
385
0
                    service
386
0
                }),
387
0
            )
388
0
            .add_optional_service(
389
0
                services
390
0
                    .worker_api
391
0
                    .map_or(Ok(None), |cfg| {
392
0
                        WorkerApiServer::new(&cfg, &worker_schedulers).map(|v| {
393
0
                            let mut service = v.into_service();
394
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
395
0
                            if let Some(encoding) =
  Branch (395:36): [Folded - Ignored]
396
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::none))
397
0
                            {
398
0
                                service = service.send_compressed(encoding);
399
0
                            }
400
0
                            for encoding in http_config
401
0
                                .compression
402
0
                                .accepted_compression_algorithms
403
0
                                .iter()
404
0
                                // Filter None values.
405
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
406
0
                            {
407
0
                                service = service.accept_compressed(encoding);
408
0
                            }
409
0
                            Some(service)
410
0
                        })
411
0
                    })
412
0
                    .err_tip(|| "Could not create WorkerApi service")?,
413
            )
414
            .add_optional_service(
415
0
                services
416
0
                    .experimental_bep
417
0
                    .map_or(Ok(None), |cfg| {
418
0
                        BepServer::new(&cfg, &store_manager).map(|v| {
419
0
                            let mut service = v.into_service();
420
0
                            let send_algo = &http_config.compression.send_compression_algorithm;
421
0
                            if let Some(encoding) =
  Branch (421:36): [Folded - Ignored]
422
0
                                into_encoding(send_algo.unwrap_or(HttpCompressionAlgorithm::none))
423
0
                            {
424
0
                                service = service.send_compressed(encoding);
425
0
                            }
426
0
                            for encoding in http_config
427
0
                                .compression
428
0
                                .accepted_compression_algorithms
429
0
                                .iter()
430
0
                                // Filter None values.
431
0
                                .filter_map(|from: &HttpCompressionAlgorithm| into_encoding(*from))
432
0
                            {
433
0
                                service = service.accept_compressed(encoding);
434
0
                            }
435
0
                            Some(service)
436
0
                        })
437
0
                    })
438
0
                    .err_tip(|| "Could not create BEP service")?,
439
            );
440
441
0
        let health_registry = health_registry_builder.lock().await.build();
442
0
443
0
        let mut svc = Router::new()
444
0
            .merge(tonic_services.into_service().into_axum_router())
445
0
            // This is the default service that executes if no other endpoint matches.
446
0
            .fallback((StatusCode::NOT_FOUND, "Not Found"));
447
448
0
        if let Some(health_cfg) = services.health {
  Branch (448:16): [Folded - Ignored]
449
0
            let path = if health_cfg.path.is_empty() {
  Branch (449:27): [Folded - Ignored]
450
0
                DEFAULT_HEALTH_STATUS_CHECK_PATH
451
            } else {
452
0
                &health_cfg.path
453
            };
454
0
            svc = svc.route_service(path, HealthServer::new(health_registry));
455
0
        }
456
457
0
        if let Some(prometheus_cfg) = services.experimental_prometheus {
  Branch (457:16): [Folded - Ignored]
458
0
            fn error_to_response<E: std::error::Error>(e: E) -> Response<axum::body::Body> {
459
0
                let mut response = Response::new(format!("Error: {e:?}").into());
460
0
                *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
461
0
                response
462
0
            }
463
0
            let path = if prometheus_cfg.path.is_empty() {
  Branch (463:27): [Folded - Ignored]
464
0
                DEFAULT_PROMETHEUS_METRICS_PATH
465
            } else {
466
0
                &prometheus_cfg.path
467
            };
468
469
0
            let root_metrics_clone = root_metrics.clone();
470
0
471
0
            svc = svc.route_service(
472
0
                path,
473
0
                axum::routing::get(move |request: hyper::Request<axum::body::Body>| {
474
0
                    Arc::new(OriginContext::new()).wrap_async(
475
0
                        trace_span!("prometheus_ctx"),
476
0
                        async move {
477
0
                            // We spawn on a thread that can block to give more freedom to our metrics
478
0
                            // collection. This allows it to call functions like `tokio::block_in_place`
479
0
                            // if it needs to wait on a future.
480
0
                            spawn_blocking!("prometheus_metrics", move || {
481
0
                                let (layer, output_metrics) = MetricsCollectorLayer::new();
482
0
483
0
                                // Traverse all the MetricsComponent's. The `MetricsCollectorLayer` will
484
0
                                // collect all the metrics and store them in `output_metrics`.
485
0
                                tracing::subscriber::with_default(
486
0
                                    tracing_subscriber::registry().with(layer),
487
0
                                    || {
488
0
                                        let metrics_component = root_metrics_clone.read();
489
0
                                        MetricsComponent::publish(
490
0
                                            &*metrics_component,
491
0
                                            MetricKind::Component,
492
0
                                            MetricFieldData::default(),
493
0
                                        )
494
0
                                    },
495
0
                                )
496
0
                                .map_err(|e| make_err!(Code::Internal, "{e}"))
497
0
                                .err_tip(|| "While processing prometheus metrics")?;
498
499
                                // Convert the collected metrics into OpenTelemetry metrics then
500
                                // encode them into Prometheus format and populate them into a
501
                                // hyper::Response.
502
0
                                let response = {
503
0
                                    let registry = prometheus::Registry::new();
504
0
                                    let exporter = opentelemetry_prometheus::exporter()
505
0
                                        .with_registry(registry.clone())
506
0
                                        .without_counter_suffixes()
507
0
                                        .without_scope_info()
508
0
                                        .build()
509
0
                                        .map_err(|e| make_err!(Code::Internal, "{e}"))
510
0
                                        .err_tip(|| {
511
0
                                            "While creating OpenTelemetry Prometheus exporter"
512
0
                                        })?;
513
514
                                    // Prepare our OpenTelemetry collector/exporter.
515
0
                                    let provider =
516
0
                                        SdkMeterProvider::builder().with_reader(exporter).build();
517
0
                                    let meter = provider.meter("nativelink");
518
519
                                    // TODO(allada) We should put this as part of the config instead of a magic
520
                                    // request header.
521
0
                                    if let Some(json_type) =
  Branch (521:44): [Folded - Ignored]
522
0
                                        request.headers().get("x-nativelink-json")
523
                                    {
524
0
                                        let json_data = if json_type == "pretty" {
  Branch (524:60): [Folded - Ignored]
525
0
                                            serde_json::to_string_pretty(&*output_metrics.lock())
526
0
                                                .map_err(|e| {
527
0
                                                    make_err!(
528
0
                                                        Code::Internal,
529
0
                                                        "Could not convert to json {e:?}"
530
0
                                                    )
531
0
                                                })?
532
                                        } else {
533
0
                                            serde_json::to_string(&*output_metrics.lock()).map_err(
534
0
                                                |e| {
535
0
                                                    make_err!(
536
0
                                                        Code::Internal,
537
0
                                                        "Could not convert to json {e:?}"
538
0
                                                    )
539
0
                                                },
540
0
                                            )?
541
                                        };
542
0
                                        let mut response =
543
0
                                            Response::new(axum::body::Body::from(json_data));
544
0
                                        response.headers_mut().insert(
545
0
                                            hyper::header::CONTENT_TYPE,
546
0
                                            hyper::header::HeaderValue::from_static(
547
0
                                                "application/json",
548
0
                                            ),
549
0
                                        );
550
0
                                        return Ok(response);
551
0
                                    }
552
0
553
0
                                    // Export the metrics to OpenTelemetry.
554
0
                                    otel_export(
555
0
                                        "nativelink".to_string(),
556
0
                                        &meter,
557
0
                                        &output_metrics.lock(),
558
0
                                    );
559
0
560
0
                                    // Translate the OpenTelemetry metrics to Prometheus format and encode
561
0
                                    // them into a hyper::Response.
562
0
                                    let mut result = vec![];
563
0
                                    TextEncoder::new()
564
0
                                        .encode(&registry.gather(), &mut result)
565
0
                                        .unwrap();
566
0
                                    let mut response =
567
0
                                        Response::new(axum::body::Body::from(result));
568
0
                                    // Per spec we should probably use `application/openmetrics-text; version=1.0.0; charset=utf-8`
569
0
                                    // https://github.com/OpenObservability/OpenMetrics/blob/1386544931307dff279688f332890c31b6c5de36/specification/OpenMetrics.md#overall-structure
570
0
                                    // However, this makes debugging more difficult, so we use the old text/plain instead.
571
0
                                    response.headers_mut().insert(
572
0
                                        hyper::header::CONTENT_TYPE,
573
0
                                        hyper::header::HeaderValue::from_static(
574
0
                                            "text/plain; version=0.0.4; charset=utf-8",
575
0
                                        ),
576
0
                                    );
577
0
                                    Result::<_, Error>::Ok(response)
578
0
                                };
579
0
                                response
580
0
                            })
581
0
                            .await
582
0
                            .unwrap_or_else(|e| Ok(error_to_response(e)))
583
0
                            .unwrap_or_else(error_to_response)
584
0
                        },
585
0
                    )
586
0
                }),
587
0
            );
588
0
        }
589
590
0
        if let Some(admin_config) = services.admin {
  Branch (590:16): [Folded - Ignored]
591
0
            let path = if admin_config.path.is_empty() {
  Branch (591:27): [Folded - Ignored]
592
0
                DEFAULT_ADMIN_API_PATH
593
            } else {
594
0
                &admin_config.path
595
            };
596
0
            let worker_schedulers = Arc::new(worker_schedulers.clone());
597
0
            svc = svc.nest_service(
598
0
                path,
599
0
                Router::new().route(
600
0
                    "/scheduler/:instance_name/set_drain_worker/:worker_id/:is_draining",
601
0
                    axum::routing::post(
602
0
                        move |params: axum::extract::Path<(String, String, String)>| async move {
603
0
                            let (instance_name, worker_id, is_draining) = params.0;
604
0
                            (async move {
605
0
                                let is_draining = match is_draining.as_str() {
606
0
                                    "0" => false,
607
0
                                    "1" => true,
608
                                    _ => {
609
0
                                        return Err(make_err!(
610
0
                                            Code::Internal,
611
0
                                            "{} is neither 0 nor 1",
612
0
                                            is_draining
613
0
                                        ))
614
                                    }
615
                                };
616
0
                                worker_schedulers
617
0
                                    .get(&instance_name)
618
0
                                    .err_tip(|| {
619
0
                                        format!(
620
0
                                            "Can not get an instance with the name of '{}'",
621
0
                                            &instance_name
622
0
                                        )
623
0
                                    })?
624
0
                                    .clone()
625
0
                                    .set_drain_worker(
626
0
                                        &WorkerId::try_from(worker_id.clone())?,
627
0
                                        is_draining,
628
                                    )
629
0
                                    .await?;
630
0
                                Ok::<_, Error>(format!("Draining worker {worker_id}"))
631
                            })
632
0
                            .await
633
0
                            .map_err(|e| {
634
0
                                Err::<String, _>((
635
0
                                    axum::http::StatusCode::INTERNAL_SERVER_ERROR,
636
0
                                    format!("Error: {e:?}"),
637
0
                                ))
638
0
                            })
639
0
                        },
640
0
                    ),
641
0
                ),
642
0
            );
643
0
        }
644
645
        // Configure our TLS acceptor if we have TLS configured.
646
0
        let maybe_tls_acceptor = http_config.tls.map_or(Ok(None), |tls_config| {
647
0
            fn read_cert(cert_file: &str) -> Result<Vec<CertificateDer<'static>>, Error> {
648
0
                let mut cert_reader = std::io::BufReader::new(
649
0
                    std::fs::File::open(cert_file)
650
0
                        .err_tip(|| format!("Could not open cert file {cert_file}"))?,
651
                );
652
0
                let certs = extract_certs(&mut cert_reader)
653
0
                    .map(|certificate| certificate.map(CertificateDer::from))
654
0
                    .collect::<Result<Vec<CertificateDer<'_>>, _>>()
655
0
                    .err_tip(|| format!("Could not extract certs from file {cert_file}"))?;
656
0
                Ok(certs)
657
0
            }
658
0
            let certs = read_cert(&tls_config.cert_file)?;
659
0
            let mut key_reader = std::io::BufReader::new(
660
0
                std::fs::File::open(&tls_config.key_file)
661
0
                    .err_tip(|| format!("Could not open key file {}", tls_config.key_file))?,
662
            );
663
0
            let key = match rustls_pemfile::read_one(&mut key_reader)
664
0
                .err_tip(|| format!("Could not extract key(s) from file {}", tls_config.key_file))?
665
            {
666
0
                Some(rustls_pemfile::Item::Pkcs8Key(key)) => key.into(),
667
0
                Some(rustls_pemfile::Item::Sec1Key(key)) => key.into(),
668
0
                Some(rustls_pemfile::Item::Pkcs1Key(key)) => key.into(),
669
                _ => {
670
0
                    return Err(make_err!(
671
0
                        Code::Internal,
672
0
                        "No keys found in file {}",
673
0
                        tls_config.key_file
674
0
                    ))
675
                }
676
            };
677
0
            if let Ok(Some(_)) = rustls_pemfile::read_one(&mut key_reader) {
  Branch (677:20): [Folded - Ignored]
678
0
                return Err(make_err!(
679
0
                    Code::InvalidArgument,
680
0
                    "Expected 1 key in file {}",
681
0
                    tls_config.key_file
682
0
                ));
683
0
            }
684
0
            let verifier = if let Some(client_ca_file) = &tls_config.client_ca_file {
  Branch (684:35): [Folded - Ignored]
685
0
                let mut client_auth_roots = RootCertStore::empty();
686
0
                for cert in read_cert(client_ca_file)?.into_iter() {
687
0
                    client_auth_roots.add(cert).map_err(|e| {
688
0
                        make_err!(Code::Internal, "Could not read client CA: {e:?}")
689
0
                    })?;
690
                }
691
0
                let crls = if let Some(client_crl_file) = &tls_config.client_crl_file {
  Branch (691:35): [Folded - Ignored]
692
0
                    let mut crl_reader = std::io::BufReader::new(
693
0
                        std::fs::File::open(client_crl_file)
694
0
                            .err_tip(|| format!("Could not open CRL file {client_crl_file}"))?,
695
                    );
696
0
                    extract_crls(&mut crl_reader)
697
0
                        .map(|crl| crl.map(CertificateRevocationListDer::from))
698
0
                        .collect::<Result<_, _>>()
699
0
                        .err_tip(|| format!("Could not extract CRLs from file {client_crl_file}"))?
700
                } else {
701
0
                    Vec::new()
702
                };
703
0
                WebPkiClientVerifier::builder(Arc::new(client_auth_roots))
704
0
                    .with_crls(crls)
705
0
                    .build()
706
0
                    .map_err(|e| {
707
0
                        make_err!(
708
0
                            Code::Internal,
709
0
                            "Could not create WebPkiClientVerifier: {e:?}"
710
0
                        )
711
0
                    })?
712
            } else {
713
0
                WebPkiClientVerifier::no_client_auth()
714
            };
715
0
            let mut config = TlsServerConfig::builder()
716
0
                .with_client_cert_verifier(verifier)
717
0
                .with_single_cert(certs, key)
718
0
                .map_err(|e| {
719
0
                    make_err!(Code::Internal, "Could not create TlsServerConfig : {e:?}")
720
0
                })?;
721
722
0
            config.alpn_protocols.push("h2".into());
723
0
            Ok(Some(TlsAcceptor::from(Arc::new(config))))
724
0
        })?;
725
726
0
        let socket_addr = http_config.socket_address.parse::<SocketAddr>()?;
727
0
        let tcp_listener = TcpListener::bind(&socket_addr).await?;
728
0
        let mut http = auto::Builder::new(TaskExecutor::default());
729
0
730
0
        let http_config = &http_config.advanced_http;
731
0
        if let Some(value) = http_config.http2_keep_alive_interval {
  Branch (731:16): [Folded - Ignored]
732
0
            http.http2()
733
0
                .keep_alive_interval(Duration::from_secs(u64::from(value)));
734
0
        }
735
736
0
        if let Some(value) = http_config.experimental_http2_max_pending_accept_reset_streams {
  Branch (736:16): [Folded - Ignored]
737
0
            http.http2()
738
0
                .max_pending_accept_reset_streams(usize::try_from(value).err_tip(|| {
739
0
                    "Could not convert experimental_http2_max_pending_accept_reset_streams"
740
0
                })?);
741
0
        }
742
0
        if let Some(value) = http_config.experimental_http2_initial_stream_window_size {
  Branch (742:16): [Folded - Ignored]
743
0
            http.http2().initial_stream_window_size(value);
744
0
        }
745
0
        if let Some(value) = http_config.experimental_http2_initial_connection_window_size {
  Branch (745:16): [Folded - Ignored]
746
0
            http.http2().initial_connection_window_size(value);
747
0
        }
748
0
        if let Some(value) = http_config.experimental_http2_adaptive_window {
  Branch (748:16): [Folded - Ignored]
749
0
            http.http2().adaptive_window(value);
750
0
        }
751
0
        if let Some(value) = http_config.experimental_http2_max_frame_size {
  Branch (751:16): [Folded - Ignored]
752
0
            http.http2().max_frame_size(value);
753
0
        }
754
0
        if let Some(value) = http_config.experimental_http2_max_concurrent_streams {
  Branch (754:16): [Folded - Ignored]
755
0
            http.http2().max_concurrent_streams(value);
756
0
        }
757
0
        if let Some(value) = http_config.experimental_http2_keep_alive_timeout {
  Branch (757:16): [Folded - Ignored]
758
0
            http.http2()
759
0
                .keep_alive_timeout(Duration::from_secs(u64::from(value)));
760
0
        }
761
0
        if let Some(value) = http_config.experimental_http2_max_send_buf_size {
  Branch (761:16): [Folded - Ignored]
762
0
            http.http2().max_send_buf_size(
763
0
                usize::try_from(value).err_tip(|| "Could not convert http2_max_send_buf_size")?,
764
            );
765
0
        }
766
0
        if let Some(true) = http_config.experimental_http2_enable_connect_protocol {
  Branch (766:16): [Folded - Ignored]
767
0
            http.http2().enable_connect_protocol();
768
0
        }
769
0
        if let Some(value) = http_config.experimental_http2_max_header_list_size {
  Branch (769:16): [Folded - Ignored]
770
0
            http.http2().max_header_list_size(value);
771
0
        }
772
773
0
        event!(Level::WARN, "Ready, listening on {socket_addr}",);
774
0
        root_futures.push(Box::pin(async move {
775
            loop {
776
                // Wait for client to connect.
777
0
                let (tcp_stream, remote_addr) = match tcp_listener.accept().await {
778
0
                    Ok(result) => result,
779
0
                    Err(err) => {
780
0
                        event!(Level::ERROR, ?err, "Failed to accept tcp connection");
781
0
                        continue;
782
                    }
783
                };
784
0
                event!(
785
                    target: "nativelink::services",
786
0
                    Level::INFO,
787
                    ?remote_addr,
788
                    ?socket_addr,
789
0
                    "Client connected"
790
                );
791
0
                connected_clients_mux
792
0
                    .inner
793
0
                    .lock()
794
0
                    .insert(SocketAddrWrapper(remote_addr));
795
0
                connected_clients_mux.counter.inc();
796
0
797
0
                // This is the safest way to guarantee that if our future
798
0
                // is ever dropped we will cleanup our data.
799
0
                let scope_guard = guard(
800
0
                    Arc::downgrade(&connected_clients_mux),
801
0
                    move |weak_connected_clients_mux| {
802
0
                        event!(
803
                            target: "nativelink::services",
804
0
                            Level::INFO,
805
                            ?remote_addr,
806
                            ?socket_addr,
807
0
                            "Client disconnected"
808
                        );
809
0
                        if let Some(connected_clients_mux) = weak_connected_clients_mux.upgrade() {
  Branch (809:32): [Folded - Ignored]
810
0
                            connected_clients_mux
811
0
                                .inner
812
0
                                .lock()
813
0
                                .remove(&SocketAddrWrapper(remote_addr));
814
0
                        }
815
0
                    },
816
0
                );
817
0
818
0
                let (http, svc, maybe_tls_acceptor) =
819
0
                    (http.clone(), svc.clone(), maybe_tls_acceptor.clone());
820
0
                Arc::new(OriginContext::new()).background_spawn(
821
0
                    error_span!(
822
0
                        target: "nativelink::services",
823
0
                        "http_connection",
824
0
                        ?remote_addr,
825
0
                        ?socket_addr
826
0
                    ),
827
0
                    async move {},
828
0
                );
829
0
                background_spawn!(
830
0
                    name: "http_connection",
831
0
                    fut: async move {
832
0
                        // Move it into our spawn, so if our spawn dies the cleanup happens.
833
0
                        let _guard = scope_guard;
834
0
                        let serve_connection = if let Some(tls_acceptor) = maybe_tls_acceptor {
  Branch (834:55): [Folded - Ignored]
835
0
                            match tls_acceptor.accept(tcp_stream).await {
836
0
                                Ok(tls_stream) => Either::Left(http.serve_connection(
837
0
                                    TokioIo::new(tls_stream),
838
0
                                    TowerToHyperService::new(svc),
839
0
                                )),
840
0
                                Err(err) => {
841
0
                                    event!(Level::ERROR, ?err, "Failed to accept tls stream");
842
0
                                    return;
843
                                }
844
                            }
845
                        } else {
846
0
                            Either::Right(http.serve_connection(
847
0
                                TokioIo::new(tcp_stream),
848
0
                                TowerToHyperService::new(svc),
849
0
                            ))
850
                        };
851
852
0
                        if let Err(err) = serve_connection.await {
  Branch (852:32): [Folded - Ignored]
853
0
                            event!(
854
                                target: "nativelink::services",
855
0
                                Level::ERROR,
856
                                ?err,
857
0
                                "Failed running service"
858
                            );
859
0
                        }
860
0
                    },
861
0
                    target: "nativelink::services",
862
0
                    ?remote_addr,
863
0
                    ?socket_addr,
864
0
                );
865
            }
866
0
        }));
867
0
    }
868
869
    {
870
        // We start workers after our TcpListener is setup so if our worker connects to one
871
        // of these services it will be able to connect.
872
0
        let worker_cfgs = cfg.workers.unwrap_or_default();
873
0
        let mut worker_names = HashSet::with_capacity(worker_cfgs.len());
874
0
        let mut worker_metrics: HashMap<String, Arc<dyn RootMetricsComponent>> = HashMap::new();
875
0
        for (i, worker_cfg) in worker_cfgs.into_iter().enumerate() {
876
0
            let spawn_fut = match worker_cfg {
877
0
                WorkerConfig::local(local_worker_cfg) => {
878
0
                    let fast_slow_store = store_manager
879
0
                        .get_store(&local_worker_cfg.cas_fast_slow_store)
880
0
                        .err_tip(|| {
881
0
                            format!(
882
0
                                "Failed to find store for cas_store_ref in worker config : {}",
883
0
                                local_worker_cfg.cas_fast_slow_store
884
0
                            )
885
0
                        })?;
886
887
0
                    let maybe_ac_store = if let Some(ac_store_ref) =
  Branch (887:49): [Folded - Ignored]
888
0
                        &local_worker_cfg.upload_action_result.ac_store
889
                    {
890
0
                        Some(store_manager.get_store(ac_store_ref).err_tip(|| {
891
0
                            format!("Failed to find store for ac_store in worker config : {ac_store_ref}")
892
0
                        })?)
893
                    } else {
894
0
                        None
895
                    };
896
                    // Note: Defaults to fast_slow_store if not specified. If this ever changes it must
897
                    // be updated in config documentation for the `historical_results_store` the field.
898
0
                    let historical_store = if let Some(cas_store_ref) = &local_worker_cfg
  Branch (898:51): [Folded - Ignored]
899
0
                        .upload_action_result
900
0
                        .historical_results_store
901
                    {
902
0
                        store_manager.get_store(cas_store_ref).err_tip(|| {
903
0
                                format!(
904
0
                                "Failed to find store for historical_results_store in worker config : {cas_store_ref}"
905
0
                            )
906
0
                            })?
907
                    } else {
908
0
                        fast_slow_store.clone()
909
                    };
910
0
                    let (local_worker, metrics) = new_local_worker(
911
0
                        Arc::new(local_worker_cfg),
912
0
                        fast_slow_store,
913
0
                        maybe_ac_store,
914
0
                        historical_store,
915
0
                    )
916
0
                    .await
917
0
                    .err_tip(|| "Could not make LocalWorker")?;
918
0
                    let name = if local_worker.name().is_empty() {
  Branch (918:35): [Folded - Ignored]
919
0
                        format!("worker_{i}")
920
                    } else {
921
0
                        local_worker.name().clone()
922
                    };
923
0
                    if worker_names.contains(&name) {
  Branch (923:24): [Folded - Ignored]
924
0
                        Err(Box::new(make_err!(
925
0
                            Code::InvalidArgument,
926
0
                            "Duplicate worker name '{}' found in config",
927
0
                            name
928
0
                        )))?;
929
0
                    }
930
0
                    worker_names.insert(name.clone());
931
0
                    worker_metrics.insert(name.clone(), metrics);
932
0
                    let fut = Arc::new(OriginContext::new())
933
0
                        .wrap_async(trace_span!("worker_ctx"), local_worker.run());
934
0
                    spawn!("worker", fut, ?name)
935
                }
936
            };
937
0
            root_futures.push(Box::pin(spawn_fut.map_ok_or_else(|e| Err(e.into()), |v| v)));
938
0
        }
939
0
        root_metrics.write().workers = worker_metrics;
940
    }
941
942
0
    if let Err(e) = select_all(root_futures).await.0 {
  Branch (942:12): [Folded - Ignored]
943
0
        panic!("{e:?}");
944
0
    }
945
0
    unreachable!("None of the futures should resolve in main()");
946
0
}
947
948
0
async fn get_config() -> Result<CasConfig, Box<dyn std::error::Error>> {
949
0
    let args = Args::parse();
950
0
    let json_contents = String::from_utf8(
951
0
        std::fs::read(&args.config_file)
952
0
            .err_tip(|| format!("Could not open config file {}", args.config_file))?,
953
0
    )?;
954
0
    Ok(serde_json5::from_str(&json_contents)?)
955
0
}
956
957
0
fn main() -> Result<(), Box<dyn std::error::Error>> {
958
0
    init_tracing()?;
959
960
0
    let mut cfg = futures::executor::block_on(get_config())?;
961
962
0
    let (mut metrics_enabled, max_blocking_threads) = {
963
0
        // Note: If the default changes make sure you update the documentation in
964
0
        // `config/cas_server.rs`.
965
0
        const DEFAULT_MAX_OPEN_FILES: usize = 512;
966
0
        // Note: If the default changes make sure you update the documentation in
967
0
        // `config/cas_server.rs`.
968
0
        const DEFAULT_IDLE_FILE_DESCRIPTOR_TIMEOUT_MILLIS: u64 = 1000;
969
0
        let global_cfg = if let Some(global_cfg) = &mut cfg.global {
  Branch (969:33): [Folded - Ignored]
970
0
            if global_cfg.max_open_files == 0 {
  Branch (970:16): [Folded - Ignored]
971
0
                global_cfg.max_open_files = DEFAULT_MAX_OPEN_FILES;
972
0
            }
973
0
            if global_cfg.idle_file_descriptor_timeout_millis == 0 {
  Branch (973:16): [Folded - Ignored]
974
0
                global_cfg.idle_file_descriptor_timeout_millis =
975
0
                    DEFAULT_IDLE_FILE_DESCRIPTOR_TIMEOUT_MILLIS;
976
0
            }
977
0
            if global_cfg.default_digest_size_health_check == 0 {
  Branch (977:16): [Folded - Ignored]
978
0
                global_cfg.default_digest_size_health_check = DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG;
979
0
            }
980
981
0
            *global_cfg
982
        } else {
983
0
            GlobalConfig {
984
0
                max_open_files: DEFAULT_MAX_OPEN_FILES,
985
0
                idle_file_descriptor_timeout_millis: DEFAULT_IDLE_FILE_DESCRIPTOR_TIMEOUT_MILLIS,
986
0
                disable_metrics: cfg.servers.iter().all(|v| {
987
0
                    let Some(service) = &v.services else {
  Branch (987:25): [Folded - Ignored]
988
0
                        return true;
989
                    };
990
0
                    service.experimental_prometheus.is_none()
991
0
                }),
992
0
                default_digest_hash_function: None,
993
0
                default_digest_size_health_check: DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG,
994
0
            }
995
        };
996
0
        set_open_file_limit(global_cfg.max_open_files);
997
0
        set_idle_file_descriptor_timeout(Duration::from_millis(
998
0
            global_cfg.idle_file_descriptor_timeout_millis,
999
0
        ))?;
1000
0
        set_default_digest_hasher_func(DigestHasherFunc::from(
1001
0
            global_cfg
1002
0
                .default_digest_hash_function
1003
0
                .unwrap_or(ConfigDigestHashFunction::sha256),
1004
0
        ))?;
1005
0
        set_default_digest_size_health_check(global_cfg.default_digest_size_health_check)?;
1006
        // TODO (#513): prevent deadlocks by assigning max blocking threads number of open files * ten
1007
0
        (!global_cfg.disable_metrics, global_cfg.max_open_files * 10)
1008
0
    };
1009
0
    // Override metrics enabled if the environment variable is set.
1010
0
    if std::env::var(METRICS_DISABLE_ENV).is_ok() {
  Branch (1010:8): [Folded - Ignored]
1011
0
        metrics_enabled = false;
1012
0
    }
1013
0
    let server_start_time = SystemTime::now()
1014
0
        .duration_since(UNIX_EPOCH)
1015
0
        .unwrap()
1016
0
        .as_secs();
1017
    #[allow(clippy::disallowed_methods)]
1018
    {
1019
0
        let runtime = tokio::runtime::Builder::new_multi_thread()
1020
0
            .max_blocking_threads(max_blocking_threads)
1021
0
            .enable_all()
1022
0
            .on_thread_start(move || set_metrics_enabled_for_this_thread(metrics_enabled))
1023
0
            .build()?;
1024
0
        runtime.spawn(async move {
1025
0
            tokio::signal::ctrl_c()
1026
0
                .await
1027
0
                .expect("Failed to listen to SIGINT");
1028
0
            eprintln!("User terminated process via SIGINT");
1029
0
            std::process::exit(130);
1030
0
        });
1031
0
1032
0
        #[cfg(target_family = "unix")]
1033
0
        runtime.spawn(async move {
1034
0
            signal(SignalKind::terminate())
1035
0
                .expect("Failed to listen to SIGTERM")
1036
0
                .recv()
1037
0
                .await;
1038
0
            eprintln!("Process terminated via SIGTERM");
1039
0
            std::process::exit(143);
1040
0
        });
1041
0
1042
0
        runtime.block_on(
1043
0
            Arc::new(OriginContext::new())
1044
0
                .wrap_async(trace_span!("main"), inner_main(cfg, server_start_time)),
1045
        )
1046
    }
1047
0
}