Coverage Report

Created: 2026-06-04 10:48

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-util/src/connection_manager.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::pin::Pin;
16
use core::task::{Context, Poll};
17
use core::time::Duration;
18
use std::collections::VecDeque;
19
use std::sync::Arc;
20
21
use futures::Future;
22
use futures::stream::{FuturesUnordered, StreamExt, unfold};
23
use nativelink_config::stores::Retry;
24
use nativelink_error::{Code, Error, make_err};
25
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
26
use tonic::transport::{Channel, Endpoint, channel};
27
use tracing::{debug, error, info, warn};
28
29
use crate::background_spawn;
30
use crate::retry::{self, Retrier, RetryResult};
31
32
/// A helper utility that enables management of a suite of connections to an
33
/// upstream gRPC endpoint using Tonic.
34
#[derive(Debug)]
35
pub struct ConnectionManager {
36
    // The channel to request connections from the worker.
37
    worker_tx: mpsc::Sender<(String, oneshot::Sender<Connection>)>,
38
}
39
40
/// The index into `ConnectionManagerWorker::endpoints`.
41
type EndpointIndex = usize;
42
/// The identifier for a given connection to a given Endpoint, used to identify
43
/// when a particular connection has failed or becomes available.
44
type ConnectionIndex = usize;
45
46
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47
struct ChannelIdentifier {
48
    /// The index into `ConnectionManagerWorker::endpoints` that established this
49
    /// Channel.
50
    endpoint_index: EndpointIndex,
51
    /// A unique identifier for this particular connection to the Endpoint.
52
    connection_index: ConnectionIndex,
53
}
54
55
/// The requests that can be made from a Connection to the
56
/// `ConnectionManagerWorker` such as informing it that it's been dropped or that
57
/// an error occurred.
58
enum ConnectionRequest {
59
    /// Notify that a Connection was dropped, if it was dropped while the
60
    /// connection was still pending, then return the pending Channel to be
61
    /// added back to the available channels.
62
    Dropped(Option<EstablishedChannel>),
63
    /// Notify that a Connection was established, return the Channel to the
64
    /// available channels.
65
    Connected(EstablishedChannel),
66
    /// Notify that there was a transport error on the given Channel, the bool
67
    /// specifies whether the connection was in the process of being established
68
    /// or not (i.e. whether it's been returned to available channels yet).
69
    Error((ChannelIdentifier, bool)),
70
}
71
72
/// The result of a Future that connects to a given Endpoint.  This is a tuple
73
/// of the index into the `ConnectionManagerWorker::endpoints` that this
74
/// connection is for, the iteration of the connection and the result of the
75
/// connection itself.
76
type IndexedChannel = Result<EstablishedChannel, (ChannelIdentifier, Error)>;
77
78
/// A channel that has been established to an endpoint with some metadata around
79
/// it to allow identification of the Channel if it errors in order to correctly
80
/// remove it.
81
#[derive(Debug, Clone)]
82
struct EstablishedChannel {
83
    /// The Channel itself that the meta data relates to.
84
    channel: Channel,
85
    /// The identifier of the channel in the worker.
86
    identifier: ChannelIdentifier,
87
}
88
89
/// The context of the worker used to manage all of the connections.  This
90
/// handles reconnecting to endpoints on errors and multiple connections to a
91
/// given endpoint.
92
struct ConnectionManagerWorker {
93
    /// The endpoints to establish Channels and the identifier of the last
94
    /// connection attempt to that endpoint.
95
    endpoints: Vec<(ConnectionIndex, Endpoint)>,
96
    /// The channel used to communicate between a Connection and the worker.
97
    connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
98
    /// Gates the maximum number of in-flight `Connection` objects.
99
    /// Was an explicit `usize` counter; now an `Arc<Semaphore>` so the
100
    /// `OwnedSemaphorePermit` held by each `Connection` releases on
101
    /// drop (RAII), instead of relying on a `ConnectionRequest::Dropped`
102
    /// round-trip that could be lost on tonic transport errors or task
103
    /// aborts.
104
    available_connections: Arc<Semaphore>,
105
    /// Channels that are currently being connected.
106
    connecting_channels: FuturesUnordered<Pin<Box<dyn Future<Output = IndexedChannel> + Send>>>,
107
    /// Connected channels that are available for use.
108
    available_channels: VecDeque<EstablishedChannel>,
109
    /// Requests for a Channel when available - (reason, request)
110
    waiting_connections: VecDeque<(String, oneshot::Sender<Connection>)>,
111
    /// The retry configuration for connecting to an Endpoint, on failure will
112
    /// restart the retrier after a 1 second delay.
113
    retrier: Retrier,
114
}
115
116
/// The maximum number of queued requests to obtain a connection from the
117
/// worker before applying back pressure to the requestor.  It makes sense to
118
/// keep this small since it has to wait for a response anyway.
119
const WORKER_BACKLOG: usize = 8;
120
121
impl ConnectionManager {
122
    /// Create a connection manager that creates a balance list between a given
123
    /// set of Endpoints.  This will restrict the number of concurrent requests
124
    /// and automatically re-connect upon transport error.
125
9
    pub fn new(
126
9
        endpoints: impl IntoIterator<Item = Endpoint>,
127
9
        mut connections_per_endpoint: usize,
128
9
        mut max_concurrent_requests: usize,
129
9
        retry: Retry,
130
9
        jitter_fn: retry::JitterFn,
131
9
    ) -> Self {
132
9
        let (worker_tx, worker_rx) = mpsc::channel(WORKER_BACKLOG);
133
        // The connection messages always come from sync contexts (e.g. drop)
134
        // and therefore, we'd end up spawning for them if this was bounded
135
        // which defeats the object since there would be no backpressure
136
        // applied. Therefore it makes sense for this to be unbounded.
137
9
        let (connection_tx, connection_rx) = mpsc::unbounded_channel();
138
9
        let endpoints = endpoints
139
9
            .into_iter()
140
9
            .map(|endpoint| (0, endpoint))
141
9
            .collect();
142
143
9
        if max_concurrent_requests == 0 {
144
6
            max_concurrent_requests = Semaphore::MAX_PERMITS;
145
6
        } else {
146
3
            max_concurrent_requests = max_concurrent_requests.min(Semaphore::MAX_PERMITS);
147
3
        }
148
9
        if connections_per_endpoint == 0 {
149
6
            connections_per_endpoint = 1;
150
6
        
}3
151
9
        let worker = ConnectionManagerWorker {
152
9
            endpoints,
153
9
            available_connections: Arc::new(Semaphore::new(max_concurrent_requests)),
154
9
            connection_tx,
155
9
            connecting_channels: FuturesUnordered::new(),
156
9
            available_channels: VecDeque::new(),
157
9
            waiting_connections: VecDeque::new(),
158
9
            retrier: Retrier::new(
159
9
                Arc::new(|duration| 
Box::pin0
(
tokio::time::sleep0
(
duration0
))),
160
9
                jitter_fn,
161
9
                retry,
162
            ),
163
        };
164
9
        background_spawn!("connection_manager_worker_spawn", async move 
{8
165
8
            worker
166
8
                .service_requests(connections_per_endpoint, worker_rx, connection_rx)
167
8
                .await;
168
0
        });
169
9
        Self { worker_tx }
170
9
    }
171
172
    /// Get a Connection that can be used as a `tonic::Channel`, except it
173
    /// performs some additional counting to reconnect on error and restrict
174
    /// the number of concurrent connections.
175
220
    pub async fn connection(&self, reason: String) -> Result<Connection, Error> {
176
220
        let (tx, rx) = oneshot::channel();
177
220
        self.worker_tx
178
220
            .send((reason, tx))
179
220
            .await
180
220
            .map_err(|err| 
make_err!0
(
Code::Unavailable0
, "Requesting a new connection: {err:?}"))
?0
;
181
220
        rx.await
182
212
            .map_err(|err| 
make_err!0
(
Code::Unavailable0
, "Waiting for a new connection: {err:?}"))
183
212
    }
184
}
185
186
impl ConnectionManagerWorker {
187
8
    async fn service_requests(
188
8
        mut self,
189
8
        connections_per_endpoint: usize,
190
8
        mut worker_rx: mpsc::Receiver<(String, oneshot::Sender<Connection>)>,
191
8
        mut connection_rx: mpsc::UnboundedReceiver<ConnectionRequest>,
192
8
    ) {
193
        // Make the initial set of connections, connection failures will be
194
        // handled in the same way as future transport failures, so no need to
195
        // do anything special.
196
8
        for endpoint_index in 0..self.endpoints.len() {
197
12
            for _ in 
0..connections_per_endpoint8
{
198
12
                self.connect_endpoint(endpoint_index, None);
199
12
            }
200
        }
201
202
        // The main worker loop, when select resolves one of its arms the other
203
        // ones are cancelled, therefore it's important that they maintain no
204
        // state while `await`-ing.  This is enforced through the use of
205
        // non-async functions to do all of the work.
206
        loop {
207
454
            tokio::select! {
208
454
                
request220
= worker_rx.recv() => {
209
220
                    let Some((reason, request)) = request else {
210
                        // The ConnectionManager was dropped, shut down the
211
                        // worker.
212
0
                        break;
213
                    };
214
220
                    self.handle_worker(reason, request);
215
                }
216
454
                
maybe_request214
= connection_rx.recv() => {
217
214
                    if let Some(request) = maybe_request {
218
214
                        self.handle_connection(request);
219
214
                    
}0
220
                }
221
454
                
maybe_connection_result12
= self.connect_next() => {
222
12
                    if let Some(connection_result) = maybe_connection_result {
223
12
                        self.handle_connected(connection_result);
224
12
                    
}0
225
                }
226
            }
227
        }
228
0
    }
229
230
454
    async fn connect_next(&mut self) -> Option<IndexedChannel> 
{308
231
308
        if self.connecting_channels.is_empty() {
232
            // Make this Future never resolve, we will get cancelled by the
233
            // select if there's some change in state to `self` and can re-enter
234
            // and evaluate `connecting_channels` again.
235
287
            futures::future::pending::<()>().await;
236
21
        }
237
21
        self.connecting_channels.next().await
238
12
    }
239
240
    // This must never be made async otherwise the select may cancel it.
241
12
    fn handle_connected(&mut self, connection_result: IndexedChannel) {
242
12
        match connection_result {
243
12
            Ok(established_channel) => {
244
12
                self.available_channels.push_back(established_channel);
245
12
                self.maybe_available_connection();
246
12
            }
247
            // When the retrier runs out of attempts start again from the
248
            // beginning of the retry period.  Never want to be in a
249
            // situation where we give up on an Endpoint forever.
250
0
            Err((identifier, _)) => {
251
0
                self.connect_endpoint(identifier.endpoint_index, Some(identifier.connection_index));
252
0
            }
253
        }
254
12
    }
255
256
12
    fn connect_endpoint(&mut self, endpoint_index: usize, connection_index: Option<usize>) {
257
12
        let Some((current_connection_index, endpoint)) = self.endpoints.get_mut(endpoint_index)
258
        else {
259
            // Unknown endpoint, this should never happen.
260
0
            error!(?endpoint_index, "Connection to unknown endpoint requested");
261
0
            return;
262
        };
263
12
        let is_backoff = connection_index.is_some();
264
12
        let connection_index = connection_index.unwrap_or_else(|| {
265
12
            *current_connection_index += 1;
266
12
            *current_connection_index
267
12
        });
268
12
        if is_backoff {
269
0
            warn!(
270
                ?connection_index,
271
0
                endpoint = ?endpoint.uri(),
272
                "Connection failed, reconnecting"
273
            );
274
        } else {
275
12
            info!(
276
                ?connection_index,
277
12
                endpoint = ?endpoint.uri(),
278
                "Creating new connection"
279
            );
280
        }
281
12
        let identifier = ChannelIdentifier {
282
12
            endpoint_index,
283
12
            connection_index,
284
12
        };
285
12
        let connection_stream = unfold(endpoint.clone(), move |endpoint| async move {
286
12
            let result = endpoint.connect().await.map_err(|err| 
{0
287
0
                make_err!(
288
0
                    Code::Unavailable,
289
                    "Failed to connect to {:?}: {err:?}",
290
0
                    endpoint.uri()
291
                )
292
0
            });
293
12
            Some((
294
12
                result.map_or_else(RetryResult::Retry, RetryResult::Ok),
295
12
                endpoint,
296
12
            ))
297
24
        });
298
12
        let retrier = self.retrier.clone();
299
12
        self.connecting_channels.push(Box::pin(async move {
300
12
            if is_backoff {
301
                // Just in case the retry config is 0, then we need to
302
                // introduce some delay so we aren't in a hard loop.
303
0
                tokio::time::sleep(Duration::from_secs(1)).await;
304
12
            }
305
12
            retrier.retry(connection_stream).await.map_or_else(
306
0
                |err| Err((identifier, err)),
307
12
                |channel| {
308
12
                    Ok(EstablishedChannel {
309
12
                        channel,
310
12
                        identifier,
311
12
                    })
312
12
                },
313
            )
314
12
        }));
315
12
    }
316
317
    // This must never be made async otherwise the select may cancel it.
318
220
    fn handle_worker(&mut self, reason: String, tx: oneshot::Sender<Connection>) {
319
220
        let maybe_permit = self.available_connections.clone().try_acquire_owned().ok();
320
220
        if let Some(
permit219
) = maybe_permit
321
219
            && let Some(
channel140
) = self.available_channels.pop_front()
322
        {
323
140
            debug!(reason, "ConnectionManager: request running");
324
140
            self.provide_channel(channel, tx, permit);
325
        } else {
326
80
            debug!(
327
80
                available_permits = self.available_connections.available_permits(),
328
80
                available_channels = self.available_channels.len(),
329
80
                waiting_connections = self.waiting_connections.len(),
330
                reason,
331
                "ConnectionManager: no connection available, request queued",
332
            );
333
80
            self.waiting_connections.push_back((reason, tx));
334
        }
335
220
    }
336
337
220
    fn provide_channel(
338
220
        &self,
339
220
        channel: EstablishedChannel,
340
220
        tx: oneshot::Sender<Connection>,
341
220
        permit: OwnedSemaphorePermit,
342
220
    ) {
343
220
        drop(tx.send(Connection {
344
220
            tx: self.connection_tx.clone(),
345
220
            pending_channel: Some(channel.channel.clone()),
346
220
            channel,
347
220
            _permit: permit,
348
220
        }));
349
220
    }
350
351
226
    fn maybe_available_connection(&mut self) {
352
306
        while !self.waiting_connections.is_empty() && 
!self.available_channels.is_empty()89
{
353
80
            let Some(permit) = self.available_connections.clone().try_acquire_owned().ok() else {
354
0
                break;
355
            };
356
80
            let Some(channel) = self.available_channels.pop_front() else {
357
0
                drop(permit);
358
0
                break;
359
            };
360
80
            let Some((reason, tx)) = self.waiting_connections.pop_front() else {
361
0
                self.available_channels.push_front(channel);
362
0
                drop(permit);
363
0
                break;
364
            };
365
80
            debug!(reason, "ConnectionManager: channel available, running");
366
80
            self.provide_channel(channel, tx, permit);
367
        }
368
226
    }
369
370
    // This must never be made async otherwise the select may cancel it.
371
214
    fn handle_connection(&mut self, request: ConnectionRequest) {
372
214
        match request {
373
209
            ConnectionRequest::Dropped(maybe_channel) => {
374
209
                if let Some(channel) = maybe_channel {
375
209
                    self.available_channels.push_back(channel);
376
209
                
}0
377
209
                self.maybe_available_connection();
378
            }
379
5
            ConnectionRequest::Connected(channel) => {
380
5
                self.available_channels.push_back(channel);
381
5
                self.maybe_available_connection();
382
5
            }
383
            // Handle a transport error on a connection by making it unavailable
384
            // for use and establishing a new connection to the endpoint.
385
0
            ConnectionRequest::Error((identifier, was_pending)) => {
386
0
                let should_reconnect = if was_pending {
387
0
                    true
388
                } else {
389
0
                    let original_length = self.available_channels.len();
390
0
                    self.available_channels
391
0
                        .retain(|channel| channel.identifier != identifier);
392
                    // Only reconnect if it wasn't already disconnected.
393
0
                    original_length != self.available_channels.len()
394
                };
395
0
                if should_reconnect {
396
0
                    self.connect_endpoint(identifier.endpoint_index, None);
397
0
                }
398
            }
399
        }
400
214
    }
401
}
402
403
/// An instance of this is obtained for every communication with the gGRPC
404
/// service.  This handles the permit for limiting concurrency, and also
405
/// re-connecting the underlying channel on error.  It depends on users
406
/// reporting all errors.
407
/// NOTE: This should never be cloneable because its lifetime is linked to the
408
///       semaphore permit it carries — `_permit` is released exactly once,
409
///       when the `Connection` drops.
410
#[derive(Debug)]
411
pub struct Connection {
412
    /// Communication with `ConnectionManagerWorker` to inform about transport
413
    /// errors and when the Connection is dropped.
414
    tx: mpsc::UnboundedSender<ConnectionRequest>,
415
    /// If set, the Channel that will be returned to the worker when connection
416
    /// completes (success or failure) or when the Connection is dropped if that
417
    /// happens before connection completes.
418
    pending_channel: Option<Channel>,
419
    /// The identifier to send to `tx`.
420
    channel: EstablishedChannel,
421
    _permit: OwnedSemaphorePermit,
422
}
423
424
impl Drop for Connection {
425
220
    fn drop(&mut self) {
426
220
        let pending_channel = self
427
220
            .pending_channel
428
220
            .take()
429
220
            .map(|channel| EstablishedChannel {
430
215
                channel,
431
215
                identifier: self.channel.identifier,
432
215
            });
433
220
        drop(self.tx.send(ConnectionRequest::Dropped(pending_channel)));
434
220
    }
435
}
436
437
/// A wrapper around the `channel::ResponseFuture` that forwards errors to the `tx`.
438
pub struct ResponseFuture {
439
    /// The wrapped future that actually does the work.
440
    inner: channel::ResponseFuture,
441
    /// Communication with `ConnectionManagerWorker` to inform about transport
442
    /// errors.
443
    connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
444
    /// The identifier to send to `connection_tx` on a transport error.
445
    identifier: ChannelIdentifier,
446
}
447
448
impl core::fmt::Debug for ResponseFuture {
449
0
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
450
0
        f.debug_struct("ResponseFuture")
451
0
            .field("inner", &self.inner)
452
0
            .field("connection_tx", &self.connection_tx)
453
0
            .field("identifier", &self.identifier)
454
0
            .finish()
455
0
    }
456
}
457
458
/// This is mostly copied from `tonic::transport::channel` except it wraps it
459
/// to allow messaging about connection success and failure.
460
impl tonic::codegen::Service<tonic::codegen::http::Request<tonic::body::Body>> for Connection {
461
    type Response = tonic::codegen::http::Response<tonic::body::Body>;
462
    type Error = tonic::transport::Error;
463
    type Future = ResponseFuture;
464
465
5
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
466
5
        let result = self.channel.channel.poll_ready(cx);
467
5
        if let Poll::Ready(result) = &result {
468
5
            match result {
469
                Ok(()) => {
470
5
                    if let Some(pending_channel) = self.pending_channel.take() {
471
5
                        drop(
472
5
                            self.tx
473
5
                                .send(ConnectionRequest::Connected(EstablishedChannel {
474
5
                                    channel: pending_channel,
475
5
                                    identifier: self.channel.identifier,
476
5
                                })),
477
5
                        );
478
5
                    
}0
479
                }
480
0
                Err(err) => {
481
0
                    debug!(?err, "Error while creating connection on channel");
482
0
                    drop(self.tx.send(ConnectionRequest::Error((
483
0
                        self.channel.identifier,
484
0
                        self.pending_channel.take().is_some(),
485
0
                    ))));
486
                }
487
            }
488
0
        }
489
5
        result
490
5
    }
491
492
5
    fn call(&mut self, request: tonic::codegen::http::Request<tonic::body::Body>) -> Self::Future {
493
5
        ResponseFuture {
494
5
            inner: self.channel.channel.call(request),
495
5
            connection_tx: self.tx.clone(),
496
5
            identifier: self.channel.identifier,
497
5
        }
498
5
    }
499
}
500
501
/// This is mostly copied from `tonic::transport::channel` except it wraps it
502
/// to allow messaging about connection failure.
503
impl Future for ResponseFuture {
504
    type Output =
505
        Result<tonic::codegen::http::Response<tonic::body::Body>, tonic::transport::Error>;
506
507
15
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
508
15
        let result = Pin::new(&mut self.inner).poll(cx);
509
5
        if let Poll::Ready(Err(_)) = &result {
510
0
            drop(
511
0
                self.connection_tx
512
0
                    .send(ConnectionRequest::Error((self.identifier, false))),
513
0
            );
514
15
        }
515
15
        result
516
15
    }
517
}