/build/source/nativelink-util/src/tls_utils.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::time::Duration; |
16 | | |
17 | | use nativelink_config::stores::{ClientTlsConfig, GrpcEndpoint}; |
18 | | use nativelink_error::{Code, Error, make_err, make_input_err}; |
19 | | use tonic::transport::Uri; |
20 | | use tracing::{info, warn}; |
21 | | |
22 | 10 | pub fn load_client_config( |
23 | 10 | config: &Option<ClientTlsConfig>, |
24 | 10 | ) -> Result<Option<tonic::transport::ClientTlsConfig>, Error> { |
25 | 10 | let Some(config9 ) = config else { Branch (25:9): [True: 9, False: 1]
Branch (25:9): [Folded - Ignored]
|
26 | 1 | return Ok(None); |
27 | | }; |
28 | | |
29 | 9 | if config.use_native_roots == Some(true) { Branch (29:8): [True: 5, False: 4]
Branch (29:8): [Folded - Ignored]
|
30 | 5 | if config.ca_file.is_some() { Branch (30:12): [True: 0, False: 5]
Branch (30:12): [Folded - Ignored]
|
31 | 0 | warn!("Native root certificates are being used, all certificate files will be ignored"); |
32 | 5 | } |
33 | 5 | return Ok(Some( |
34 | 5 | tonic::transport::ClientTlsConfig::new().with_native_roots(), |
35 | 5 | )); |
36 | 4 | } |
37 | | |
38 | 4 | let Some(ca_file3 ) = &config.ca_file else { Branch (38:9): [True: 3, False: 1]
Branch (38:9): [Folded - Ignored]
|
39 | 1 | return Err(make_err!( |
40 | 1 | Code::Internal, |
41 | 1 | "CA certificate must be provided if not using native root certificates" |
42 | 1 | )); |
43 | | }; |
44 | | |
45 | 3 | let read_config = tonic::transport::ClientTlsConfig::new().ca_certificate( |
46 | 3 | tonic::transport::Certificate::from_pem(std::fs::read_to_string(ca_file)?0 ), |
47 | | ); |
48 | 3 | let config1 = if let Some(client_certificate2 ) = &config.cert_file { Branch (48:25): [True: 2, False: 1]
Branch (48:25): [Folded - Ignored]
|
49 | 2 | let Some(client_key1 ) = &config.key_file else { Branch (49:13): [True: 1, False: 1]
Branch (49:13): [Folded - Ignored]
|
50 | 1 | return Err(make_err!( |
51 | 1 | Code::Internal, |
52 | 1 | "Client certificate specified, but no key" |
53 | 1 | )); |
54 | | }; |
55 | 1 | read_config.identity(tonic::transport::Identity::from_pem( |
56 | 1 | std::fs::read_to_string(client_certificate)?0 , |
57 | 1 | std::fs::read_to_string(client_key)?0 , |
58 | | )) |
59 | | } else { |
60 | 1 | if config.key_file.is_some() { Branch (60:12): [True: 1, False: 0]
Branch (60:12): [Folded - Ignored]
|
61 | 1 | return Err(make_err!( |
62 | 1 | Code::Internal, |
63 | 1 | "Client key specified, but no certificate" |
64 | 1 | )); |
65 | 0 | } |
66 | 0 | read_config |
67 | | }; |
68 | | |
69 | 1 | Ok(Some(config)) |
70 | 10 | } |
71 | | |
72 | 7 | pub fn endpoint_from( |
73 | 7 | endpoint: &str, |
74 | 7 | tls_config: Option<tonic::transport::ClientTlsConfig>, |
75 | 7 | ) -> Result<tonic::transport::Endpoint, Error> { |
76 | 7 | let endpoint6 = Uri::try_from(endpoint) |
77 | 7 | .map_err(|e| make_err!(Code::Internal1 , "Unable to parse endpoint {endpoint}: {e:?}"))?1 ; |
78 | | |
79 | | // Tonic uses the TLS configuration if the scheme is "https", so replace |
80 | | // grpcs with https. |
81 | 6 | let endpoint = if endpoint.scheme_str() == Some("grpcs") { Branch (81:23): [True: 1, False: 5]
Branch (81:23): [Folded - Ignored]
|
82 | 1 | let mut parts = endpoint.into_parts(); |
83 | 1 | parts.scheme = Some("https".parse().map_err(|e| {0 |
84 | 0 | make_err!( |
85 | 0 | Code::Internal, |
86 | | "https is an invalid scheme apparently? {e:?}" |
87 | | ) |
88 | 0 | })?); |
89 | 1 | parts.try_into().map_err(|e| {0 |
90 | 0 | make_err!( |
91 | 0 | Code::Internal, |
92 | | "Error changing Uri from grpcs to https: {e:?}" |
93 | | ) |
94 | 0 | })? |
95 | | } else { |
96 | 5 | endpoint |
97 | | }; |
98 | | |
99 | 6 | let endpoint_transport3 = if let Some(tls_config4 ) = tls_config { Branch (99:37): [True: 4, False: 2]
Branch (99:37): [Folded - Ignored]
|
100 | 4 | let Some(authority3 ) = endpoint.authority() else { Branch (100:13): [True: 3, False: 1]
Branch (100:13): [Folded - Ignored]
|
101 | 1 | return Err(make_input_err!( |
102 | 1 | "Unable to determine authority of endpoint: {endpoint}" |
103 | 1 | )); |
104 | | }; |
105 | 3 | if endpoint.scheme_str() != Some("https") { Branch (105:12): [True: 1, False: 2]
Branch (105:12): [Folded - Ignored]
|
106 | 1 | return Err(make_input_err!( |
107 | 1 | "You have set TLS configuration on {endpoint}, but the scheme is not https or grpcs" |
108 | 1 | )); |
109 | 2 | } |
110 | 2 | let tls_config = tls_config.domain_name(authority.host()); |
111 | 2 | tonic::transport::Endpoint::from(endpoint) |
112 | 2 | .tls_config(tls_config) |
113 | 2 | .map_err(|e| make_input_err!("Setting mTLS configuration: {e:?}"))?0 |
114 | | } else { |
115 | 2 | if endpoint.scheme_str() == Some("https") { Branch (115:12): [True: 1, False: 1]
Branch (115:12): [Folded - Ignored]
|
116 | 1 | return Err(make_input_err!( |
117 | 1 | "The scheme of {endpoint} is https or grpcs, but no TLS configuration was provided" |
118 | 1 | )); |
119 | 1 | } |
120 | 1 | tonic::transport::Endpoint::from(endpoint) |
121 | | }; |
122 | | |
123 | 3 | Ok(endpoint_transport) |
124 | 7 | } |
125 | | |
126 | 0 | pub fn endpoint(endpoint_config: &GrpcEndpoint) -> Result<tonic::transport::Endpoint, Error> { |
127 | 0 | let endpoint = endpoint_from( |
128 | 0 | &endpoint_config.address, |
129 | 0 | load_client_config(&endpoint_config.tls_config)?, |
130 | 0 | )?; |
131 | | |
132 | 0 | let connect_timeout = if endpoint_config.connect_timeout_s > 0 { Branch (132:30): [True: 0, False: 0]
Branch (132:30): [Folded - Ignored]
|
133 | 0 | Duration::from_secs(endpoint_config.connect_timeout_s) |
134 | | } else { |
135 | 0 | Duration::from_secs(30) |
136 | | }; |
137 | 0 | let tcp_keepalive = if endpoint_config.tcp_keepalive_s > 0 { Branch (137:28): [True: 0, False: 0]
Branch (137:28): [Folded - Ignored]
|
138 | 0 | Duration::from_secs(endpoint_config.tcp_keepalive_s) |
139 | | } else { |
140 | 0 | Duration::from_secs(30) |
141 | | }; |
142 | 0 | let http2_keepalive_interval = if endpoint_config.http2_keepalive_interval_s > 0 { Branch (142:39): [True: 0, False: 0]
Branch (142:39): [Folded - Ignored]
|
143 | 0 | Duration::from_secs(endpoint_config.http2_keepalive_interval_s) |
144 | | } else { |
145 | 0 | Duration::from_secs(30) |
146 | | }; |
147 | 0 | let http2_keepalive_timeout = if endpoint_config.http2_keepalive_timeout_s > 0 { Branch (147:38): [True: 0, False: 0]
Branch (147:38): [Folded - Ignored]
|
148 | 0 | Duration::from_secs(endpoint_config.http2_keepalive_timeout_s) |
149 | | } else { |
150 | 0 | Duration::from_secs(20) |
151 | | }; |
152 | | |
153 | 0 | info!( |
154 | | address = %endpoint_config.address, |
155 | | concurrency_limit = ?endpoint_config.concurrency_limit, |
156 | 0 | connect_timeout_s = connect_timeout.as_secs(), |
157 | 0 | tcp_keepalive_s = tcp_keepalive.as_secs(), |
158 | 0 | http2_keepalive_interval_s = http2_keepalive_interval.as_secs(), |
159 | 0 | http2_keepalive_timeout_s = http2_keepalive_timeout.as_secs(), |
160 | 0 | "tls_utils::endpoint: creating gRPC endpoint with keepalive", |
161 | | ); |
162 | | |
163 | 0 | let mut endpoint = endpoint |
164 | 0 | .connect_timeout(connect_timeout) |
165 | 0 | .tcp_keepalive(Some(tcp_keepalive)) |
166 | 0 | .http2_keep_alive_interval(http2_keepalive_interval) |
167 | 0 | .keep_alive_timeout(http2_keepalive_timeout) |
168 | 0 | .keep_alive_while_idle(true); |
169 | | |
170 | 0 | if let Some(concurrency_limit) = endpoint_config.concurrency_limit { Branch (170:12): [True: 0, False: 0]
Branch (170:12): [Folded - Ignored]
|
171 | 0 | endpoint = endpoint.concurrency_limit(concurrency_limit); |
172 | 0 | } |
173 | | |
174 | 0 | Ok(endpoint) |
175 | 0 | } |