Coverage Report

Created: 2026-04-14 11:55

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-util/src/connection_manager.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::pin::Pin;
16
use core::task::{Context, Poll};
17
use core::time::Duration;
18
use std::collections::VecDeque;
19
use std::sync::Arc;
20
21
use futures::Future;
22
use futures::stream::{FuturesUnordered, StreamExt, unfold};
23
use nativelink_config::stores::Retry;
24
use nativelink_error::{Code, Error, make_err};
25
use tokio::sync::{mpsc, oneshot};
26
use tonic::transport::{Channel, Endpoint, channel};
27
use tracing::{debug, error, info, warn};
28
29
use crate::background_spawn;
30
use crate::retry::{self, Retrier, RetryResult};
31
32
/// A helper utility that enables management of a suite of connections to an
33
/// upstream gRPC endpoint using Tonic.
34
#[derive(Debug)]
35
pub struct ConnectionManager {
36
    // The channel to request connections from the worker.
37
    worker_tx: mpsc::Sender<(String, oneshot::Sender<Connection>)>,
38
}
39
40
/// The index into `ConnectionManagerWorker::endpoints`.
41
type EndpointIndex = usize;
42
/// The identifier for a given connection to a given Endpoint, used to identify
43
/// when a particular connection has failed or becomes available.
44
type ConnectionIndex = usize;
45
46
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47
struct ChannelIdentifier {
48
    /// The index into `ConnectionManagerWorker::endpoints` that established this
49
    /// Channel.
50
    endpoint_index: EndpointIndex,
51
    /// A unique identifier for this particular connection to the Endpoint.
52
    connection_index: ConnectionIndex,
53
}
54
55
/// The requests that can be made from a Connection to the
56
/// `ConnectionManagerWorker` such as informing it that it's been dropped or that
57
/// an error occurred.
58
enum ConnectionRequest {
59
    /// Notify that a Connection was dropped, if it was dropped while the
60
    /// connection was still pending, then return the pending Channel to be
61
    /// added back to the available channels.
62
    Dropped(Option<EstablishedChannel>),
63
    /// Notify that a Connection was established, return the Channel to the
64
    /// available channels.
65
    Connected(EstablishedChannel),
66
    /// Notify that there was a transport error on the given Channel, the bool
67
    /// specifies whether the connection was in the process of being established
68
    /// or not (i.e. whether it's been returned to available channels yet).
69
    Error((ChannelIdentifier, bool)),
70
}
71
72
/// The result of a Future that connects to a given Endpoint.  This is a tuple
73
/// of the index into the `ConnectionManagerWorker::endpoints` that this
74
/// connection is for, the iteration of the connection and the result of the
75
/// connection itself.
76
type IndexedChannel = Result<EstablishedChannel, (ChannelIdentifier, Error)>;
77
78
/// A channel that has been established to an endpoint with some metadata around
79
/// it to allow identification of the Channel if it errors in order to correctly
80
/// remove it.
81
#[derive(Debug, Clone)]
82
struct EstablishedChannel {
83
    /// The Channel itself that the meta data relates to.
84
    channel: Channel,
85
    /// The identifier of the channel in the worker.
86
    identifier: ChannelIdentifier,
87
}
88
89
/// The context of the worker used to manage all of the connections.  This
90
/// handles reconnecting to endpoints on errors and multiple connections to a
91
/// given endpoint.
92
struct ConnectionManagerWorker {
93
    /// The endpoints to establish Channels and the identifier of the last
94
    /// connection attempt to that endpoint.
95
    endpoints: Vec<(ConnectionIndex, Endpoint)>,
96
    /// The channel used to communicate between a Connection and the worker.
97
    connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
98
    /// The number of connections that are currently allowed to be made.
99
    available_connections: usize,
100
    /// Channels that are currently being connected.
101
    connecting_channels: FuturesUnordered<Pin<Box<dyn Future<Output = IndexedChannel> + Send>>>,
102
    /// Connected channels that are available for use.
103
    available_channels: VecDeque<EstablishedChannel>,
104
    /// Requests for a Channel when available - (reason, request)
105
    waiting_connections: VecDeque<(String, oneshot::Sender<Connection>)>,
106
    /// The retry configuration for connecting to an Endpoint, on failure will
107
    /// restart the retrier after a 1 second delay.
108
    retrier: Retrier,
109
}
110
111
/// The maximum number of queued requests to obtain a connection from the
112
/// worker before applying back pressure to the requestor.  It makes sense to
113
/// keep this small since it has to wait for a response anyway.
114
const WORKER_BACKLOG: usize = 8;
115
116
impl ConnectionManager {
117
    /// Create a connection manager that creates a balance list between a given
118
    /// set of Endpoints.  This will restrict the number of concurrent requests
119
    /// and automatically re-connect upon transport error.
120
1
    pub fn new(
121
1
        endpoints: impl IntoIterator<Item = Endpoint>,
122
1
        mut connections_per_endpoint: usize,
123
1
        mut max_concurrent_requests: usize,
124
1
        retry: Retry,
125
1
        jitter_fn: retry::JitterFn,
126
1
    ) -> Self {
127
1
        let (worker_tx, worker_rx) = mpsc::channel(WORKER_BACKLOG);
128
        // The connection messages always come from sync contexts (e.g. drop)
129
        // and therefore, we'd end up spawning for them if this was bounded
130
        // which defeats the object since there would be no backpressure
131
        // applied. Therefore it makes sense for this to be unbounded.
132
1
        let (connection_tx, connection_rx) = mpsc::unbounded_channel();
133
1
        let endpoints = endpoints
134
1
            .into_iter()
135
1
            .map(|endpoint| (0, endpoint))
136
1
            .collect();
137
138
1
        if max_concurrent_requests == 0 {
139
1
            max_concurrent_requests = usize::MAX;
140
1
        
}0
141
1
        if connections_per_endpoint == 0 {
142
1
            connections_per_endpoint = 1;
143
1
        
}0
144
1
        let worker = ConnectionManagerWorker {
145
1
            endpoints,
146
1
            available_connections: max_concurrent_requests,
147
1
            connection_tx,
148
1
            connecting_channels: FuturesUnordered::new(),
149
1
            available_channels: VecDeque::new(),
150
1
            waiting_connections: VecDeque::new(),
151
1
            retrier: Retrier::new(
152
1
                Arc::new(|duration| 
Box::pin0
(
tokio::time::sleep0
(
duration0
))),
153
1
                jitter_fn,
154
1
                retry,
155
            ),
156
        };
157
1
        background_spawn!("connection_manager_worker_spawn", async move 
{0
158
0
            worker
159
0
                .service_requests(connections_per_endpoint, worker_rx, connection_rx)
160
0
                .await;
161
0
        });
162
1
        Self { worker_tx }
163
1
    }
164
165
    /// Get a Connection that can be used as a `tonic::Channel`, except it
166
    /// performs some additional counting to reconnect on error and restrict
167
    /// the number of concurrent connections.
168
0
    pub async fn connection(&self, reason: String) -> Result<Connection, Error> {
169
0
        let (tx, rx) = oneshot::channel();
170
0
        self.worker_tx
171
0
            .send((reason, tx))
172
0
            .await
173
0
            .map_err(|err| make_err!(Code::Unavailable, "Requesting a new connection: {err:?}"))?;
174
0
        rx.await
175
0
            .map_err(|err| make_err!(Code::Unavailable, "Waiting for a new connection: {err:?}"))
176
0
    }
177
}
178
179
impl ConnectionManagerWorker {
180
0
    async fn service_requests(
181
0
        mut self,
182
0
        connections_per_endpoint: usize,
183
0
        mut worker_rx: mpsc::Receiver<(String, oneshot::Sender<Connection>)>,
184
0
        mut connection_rx: mpsc::UnboundedReceiver<ConnectionRequest>,
185
0
    ) {
186
        // Make the initial set of connections, connection failures will be
187
        // handled in the same way as future transport failures, so no need to
188
        // do anything special.
189
0
        for endpoint_index in 0..self.endpoints.len() {
190
0
            for _ in 0..connections_per_endpoint {
191
0
                self.connect_endpoint(endpoint_index, None);
192
0
            }
193
        }
194
195
        // The main worker loop, when select resolves one of its arms the other
196
        // ones are cancelled, therefore it's important that they maintain no
197
        // state while `await`-ing.  This is enforced through the use of
198
        // non-async functions to do all of the work.
199
        loop {
200
0
            tokio::select! {
201
0
                request = worker_rx.recv() => {
202
0
                    let Some((reason, request)) = request else {
203
                        // The ConnectionManager was dropped, shut down the
204
                        // worker.
205
0
                        break;
206
                    };
207
0
                    self.handle_worker(reason, request);
208
                }
209
0
                maybe_request = connection_rx.recv() => {
210
0
                    if let Some(request) = maybe_request {
211
0
                        self.handle_connection(request);
212
0
                    }
213
                }
214
0
                maybe_connection_result = self.connect_next() => {
215
0
                    if let Some(connection_result) = maybe_connection_result {
216
0
                        self.handle_connected(connection_result);
217
0
                    }
218
                }
219
            }
220
        }
221
0
    }
222
223
0
    async fn connect_next(&mut self) -> Option<IndexedChannel> {
224
0
        if self.connecting_channels.is_empty() {
225
            // Make this Future never resolve, we will get cancelled by the
226
            // select if there's some change in state to `self` and can re-enter
227
            // and evaluate `connecting_channels` again.
228
0
            futures::future::pending::<()>().await;
229
0
        }
230
0
        self.connecting_channels.next().await
231
0
    }
232
233
    // This must never be made async otherwise the select may cancel it.
234
0
    fn handle_connected(&mut self, connection_result: IndexedChannel) {
235
0
        match connection_result {
236
0
            Ok(established_channel) => {
237
0
                self.available_channels.push_back(established_channel);
238
0
                self.maybe_available_connection();
239
0
            }
240
            // When the retrier runs out of attempts start again from the
241
            // beginning of the retry period.  Never want to be in a
242
            // situation where we give up on an Endpoint forever.
243
0
            Err((identifier, _)) => {
244
0
                self.connect_endpoint(identifier.endpoint_index, Some(identifier.connection_index));
245
0
            }
246
        }
247
0
    }
248
249
0
    fn connect_endpoint(&mut self, endpoint_index: usize, connection_index: Option<usize>) {
250
0
        let Some((current_connection_index, endpoint)) = self.endpoints.get_mut(endpoint_index)
251
        else {
252
            // Unknown endpoint, this should never happen.
253
0
            error!(?endpoint_index, "Connection to unknown endpoint requested");
254
0
            return;
255
        };
256
0
        let is_backoff = connection_index.is_some();
257
0
        let connection_index = connection_index.unwrap_or_else(|| {
258
0
            *current_connection_index += 1;
259
0
            *current_connection_index
260
0
        });
261
0
        if is_backoff {
262
0
            warn!(
263
                ?connection_index,
264
0
                endpoint = ?endpoint.uri(),
265
                "Connection failed, reconnecting"
266
            );
267
        } else {
268
0
            info!(
269
                ?connection_index,
270
0
                endpoint = ?endpoint.uri(),
271
                "Creating new connection"
272
            );
273
        }
274
0
        let identifier = ChannelIdentifier {
275
0
            endpoint_index,
276
0
            connection_index,
277
0
        };
278
0
        let connection_stream = unfold(endpoint.clone(), move |endpoint| async move {
279
0
            let result = endpoint.connect().await.map_err(|err| {
280
0
                make_err!(
281
0
                    Code::Unavailable,
282
                    "Failed to connect to {:?}: {err:?}",
283
0
                    endpoint.uri()
284
                )
285
0
            });
286
0
            Some((
287
0
                result.map_or_else(RetryResult::Retry, RetryResult::Ok),
288
0
                endpoint,
289
0
            ))
290
0
        });
291
0
        let retrier = self.retrier.clone();
292
0
        self.connecting_channels.push(Box::pin(async move {
293
0
            if is_backoff {
294
                // Just in case the retry config is 0, then we need to
295
                // introduce some delay so we aren't in a hard loop.
296
0
                tokio::time::sleep(Duration::from_secs(1)).await;
297
0
            }
298
0
            retrier.retry(connection_stream).await.map_or_else(
299
0
                |err| Err((identifier, err)),
300
0
                |channel| {
301
0
                    Ok(EstablishedChannel {
302
0
                        channel,
303
0
                        identifier,
304
0
                    })
305
0
                },
306
            )
307
0
        }));
308
0
    }
309
310
    // This must never be made async otherwise the select may cancel it.
311
0
    fn handle_worker(&mut self, reason: String, tx: oneshot::Sender<Connection>) {
312
0
        if let Some(channel) = (self.available_connections > 0)
313
0
            .then_some(())
314
0
            .and_then(|()| self.available_channels.pop_front())
315
        {
316
0
            debug!(reason, "ConnectionManager: request running");
317
0
            self.provide_channel(channel, tx);
318
        } else {
319
0
            debug!(
320
                available_connections = self.available_connections,
321
0
                available_channels = self.available_channels.len(),
322
0
                waiting_connections = self.waiting_connections.len(),
323
                reason,
324
                "ConnectionManager: no connection available, request queued",
325
            );
326
0
            self.waiting_connections.push_back((reason, tx));
327
        }
328
0
    }
329
330
0
    fn provide_channel(&mut self, channel: EstablishedChannel, tx: oneshot::Sender<Connection>) {
331
        // We decrement here because we create Connection, this will signal when
332
        // it is Dropped and therefore increment this again.
333
0
        self.available_connections -= 1;
334
0
        drop(tx.send(Connection {
335
0
            tx: self.connection_tx.clone(),
336
0
            pending_channel: Some(channel.channel.clone()),
337
0
            channel,
338
0
        }));
339
0
    }
340
341
0
    fn maybe_available_connection(&mut self) {
342
0
        while self.available_connections > 0
343
0
            && !self.waiting_connections.is_empty()
344
0
            && !self.available_channels.is_empty()
345
        {
346
0
            if let Some(channel) = self.available_channels.pop_front() {
347
0
                if let Some((reason, tx)) = self.waiting_connections.pop_front() {
348
0
                    debug!(reason, "ConnectionManager: channel available, running");
349
0
                    self.provide_channel(channel, tx);
350
0
                } else {
351
0
                    // This should never happen, but better than an unwrap.
352
0
                    self.available_channels.push_front(channel);
353
0
                }
354
0
            }
355
        }
356
0
    }
357
358
    // This must never be made async otherwise the select may cancel it.
359
0
    fn handle_connection(&mut self, request: ConnectionRequest) {
360
0
        match request {
361
0
            ConnectionRequest::Dropped(maybe_channel) => {
362
0
                if let Some(channel) = maybe_channel {
363
0
                    self.available_channels.push_back(channel);
364
0
                }
365
0
                self.available_connections += 1;
366
0
                self.maybe_available_connection();
367
            }
368
0
            ConnectionRequest::Connected(channel) => {
369
0
                self.available_channels.push_back(channel);
370
0
                self.maybe_available_connection();
371
0
            }
372
            // Handle a transport error on a connection by making it unavailable
373
            // for use and establishing a new connection to the endpoint.
374
0
            ConnectionRequest::Error((identifier, was_pending)) => {
375
0
                let should_reconnect = if was_pending {
376
0
                    true
377
                } else {
378
0
                    let original_length = self.available_channels.len();
379
0
                    self.available_channels
380
0
                        .retain(|channel| channel.identifier != identifier);
381
                    // Only reconnect if it wasn't already disconnected.
382
0
                    original_length != self.available_channels.len()
383
                };
384
0
                if should_reconnect {
385
0
                    self.connect_endpoint(identifier.endpoint_index, None);
386
0
                }
387
            }
388
        }
389
0
    }
390
}
391
392
/// An instance of this is obtained for every communication with the gGRPC
393
/// service.  This handles the permit for limiting concurrency, and also
394
/// re-connecting the underlying channel on error.  It depends on users
395
/// reporting all errors.
396
/// NOTE: This should never be cloneable because its lifetime is linked to the
397
///       `ConnectionManagerWorker::available_connections`.
398
#[derive(Debug)]
399
pub struct Connection {
400
    /// Communication with `ConnectionManagerWorker` to inform about transport
401
    /// errors and when the Connection is dropped.
402
    tx: mpsc::UnboundedSender<ConnectionRequest>,
403
    /// If set, the Channel that will be returned to the worker when connection
404
    /// completes (success or failure) or when the Connection is dropped if that
405
    /// happens before connection completes.
406
    pending_channel: Option<Channel>,
407
    /// The identifier to send to `tx`.
408
    channel: EstablishedChannel,
409
}
410
411
impl Drop for Connection {
412
0
    fn drop(&mut self) {
413
0
        let pending_channel = self
414
0
            .pending_channel
415
0
            .take()
416
0
            .map(|channel| EstablishedChannel {
417
0
                channel,
418
0
                identifier: self.channel.identifier,
419
0
            });
420
0
        drop(self.tx.send(ConnectionRequest::Dropped(pending_channel)));
421
0
    }
422
}
423
424
/// A wrapper around the `channel::ResponseFuture` that forwards errors to the `tx`.
425
pub struct ResponseFuture {
426
    /// The wrapped future that actually does the work.
427
    inner: channel::ResponseFuture,
428
    /// Communication with `ConnectionManagerWorker` to inform about transport
429
    /// errors.
430
    connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
431
    /// The identifier to send to `connection_tx` on a transport error.
432
    identifier: ChannelIdentifier,
433
}
434
435
impl core::fmt::Debug for ResponseFuture {
436
0
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
437
0
        f.debug_struct("ResponseFuture")
438
0
            .field("inner", &self.inner)
439
0
            .field("connection_tx", &self.connection_tx)
440
0
            .field("identifier", &self.identifier)
441
0
            .finish()
442
0
    }
443
}
444
445
/// This is mostly copied from `tonic::transport::channel` except it wraps it
446
/// to allow messaging about connection success and failure.
447
impl tonic::codegen::Service<tonic::codegen::http::Request<tonic::body::Body>> for Connection {
448
    type Response = tonic::codegen::http::Response<tonic::body::Body>;
449
    type Error = tonic::transport::Error;
450
    type Future = ResponseFuture;
451
452
0
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
453
0
        let result = self.channel.channel.poll_ready(cx);
454
0
        if let Poll::Ready(result) = &result {
455
0
            match result {
456
                Ok(()) => {
457
0
                    if let Some(pending_channel) = self.pending_channel.take() {
458
0
                        drop(
459
0
                            self.tx
460
0
                                .send(ConnectionRequest::Connected(EstablishedChannel {
461
0
                                    channel: pending_channel,
462
0
                                    identifier: self.channel.identifier,
463
0
                                })),
464
0
                        );
465
0
                    }
466
                }
467
0
                Err(err) => {
468
0
                    debug!(?err, "Error while creating connection on channel");
469
0
                    drop(self.tx.send(ConnectionRequest::Error((
470
0
                        self.channel.identifier,
471
0
                        self.pending_channel.take().is_some(),
472
0
                    ))));
473
                }
474
            }
475
0
        }
476
0
        result
477
0
    }
478
479
0
    fn call(&mut self, request: tonic::codegen::http::Request<tonic::body::Body>) -> Self::Future {
480
0
        ResponseFuture {
481
0
            inner: self.channel.channel.call(request),
482
0
            connection_tx: self.tx.clone(),
483
0
            identifier: self.channel.identifier,
484
0
        }
485
0
    }
486
}
487
488
/// This is mostly copied from `tonic::transport::channel` except it wraps it
489
/// to allow messaging about connection failure.
490
impl Future for ResponseFuture {
491
    type Output =
492
        Result<tonic::codegen::http::Response<tonic::body::Body>, tonic::transport::Error>;
493
494
0
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
495
0
        let result = Pin::new(&mut self.inner).poll(cx);
496
0
        if let Poll::Ready(Err(_)) = &result {
497
0
            drop(
498
0
                self.connection_tx
499
0
                    .send(ConnectionRequest::Error((self.identifier, false))),
500
0
            );
501
0
        }
502
0
        result
503
0
    }
504
}