Coverage Report

Created: 2025-05-08 18:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-util/src/connection_manager.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::pin::Pin;
16
use core::task::{Context, Poll};
17
use core::time::Duration;
18
use std::collections::VecDeque;
19
use std::sync::Arc;
20
21
use futures::Future;
22
use futures::stream::{FuturesUnordered, StreamExt, unfold};
23
use nativelink_config::stores::Retry;
24
use nativelink_error::{Code, Error, make_err};
25
use tokio::sync::{mpsc, oneshot};
26
use tonic::transport::{Channel, Endpoint, channel};
27
use tracing::{debug, error, info, warn};
28
29
use crate::background_spawn;
30
use crate::retry::{self, Retrier, RetryResult};
31
32
/// A helper utility that enables management of a suite of connections to an
33
/// upstream gRPC endpoint using Tonic.
34
#[derive(Debug)]
35
pub struct ConnectionManager {
36
    // The channel to request connections from the worker.
37
    worker_tx: mpsc::Sender<oneshot::Sender<Connection>>,
38
}
39
40
/// The index into `ConnectionManagerWorker::endpoints`.
41
type EndpointIndex = usize;
42
/// The identifier for a given connection to a given Endpoint, used to identify
43
/// when a particular connection has failed or becomes available.
44
type ConnectionIndex = usize;
45
46
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47
struct ChannelIdentifier {
48
    /// The index into `ConnectionManagerWorker::endpoints` that established this
49
    /// Channel.
50
    endpoint_index: EndpointIndex,
51
    /// A unique identifier for this particular connection to the Endpoint.
52
    connection_index: ConnectionIndex,
53
}
54
55
/// The requests that can be made from a Connection to the
56
/// `ConnectionManagerWorker` such as informing it that it's been dropped or that
57
/// an error occurred.
58
enum ConnectionRequest {
59
    /// Notify that a Connection was dropped, if it was dropped while the
60
    /// connection was still pending, then return the pending Channel to be
61
    /// added back to the available channels.
62
    Dropped(Option<EstablishedChannel>),
63
    /// Notify that a Connection was established, return the Channel to the
64
    /// available channels.
65
    Connected(EstablishedChannel),
66
    /// Notify that there was a transport error on the given Channel, the bool
67
    /// specifies whether the connection was in the process of being established
68
    /// or not (i.e. whether it's been returned to available channels yet).
69
    Error((ChannelIdentifier, bool)),
70
}
71
72
/// The result of a Future that connects to a given Endpoint.  This is a tuple
73
/// of the index into the `ConnectionManagerWorker::endpoints` that this
74
/// connection is for, the iteration of the connection and the result of the
75
/// connection itself.
76
type IndexedChannel = Result<EstablishedChannel, (ChannelIdentifier, Error)>;
77
78
/// A channel that has been established to an endpoint with some metadata around
79
/// it to allow identification of the Channel if it errors in order to correctly
80
/// remove it.
81
#[derive(Debug, Clone)]
82
struct EstablishedChannel {
83
    /// The Channel itself that the meta data relates to.
84
    channel: Channel,
85
    /// The identifier of the channel in the worker.
86
    identifier: ChannelIdentifier,
87
}
88
89
/// The context of the worker used to manage all of the connections.  This
90
/// handles reconnecting to endpoints on errors and multiple connections to a
91
/// given endpoint.
92
struct ConnectionManagerWorker {
93
    /// The endpoints to establish Channels and the identifier of the last
94
    /// connection attempt to that endpoint.
95
    endpoints: Vec<(ConnectionIndex, Endpoint)>,
96
    /// The channel used to communicate between a Connection and the worker.
97
    connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
98
    /// The number of connections that are currently allowed to be made.
99
    available_connections: usize,
100
    /// Channels that are currently being connected.
101
    connecting_channels: FuturesUnordered<Pin<Box<dyn Future<Output = IndexedChannel> + Send>>>,
102
    /// Connected channels that are available for use.
103
    available_channels: VecDeque<EstablishedChannel>,
104
    /// Requests for a Channel when available.
105
    waiting_connections: VecDeque<oneshot::Sender<Connection>>,
106
    /// The retry configuration for connecting to an Endpoint, on failure will
107
    /// restart the retrier after a 1 second delay.
108
    retrier: Retrier,
109
}
110
111
/// The maximum number of queued requests to obtain a connection from the
112
/// worker before applying back pressure to the requestor.  It makes sense to
113
/// keep this small since it has to wait for a response anyway.
114
const WORKER_BACKLOG: usize = 8;
115
116
impl ConnectionManager {
117
    /// Create a connection manager that creates a balance list between a given
118
    /// set of Endpoints.  This will restrict the number of concurrent requests
119
    /// and automatically re-connect upon transport error.
120
0
    pub fn new(
121
0
        endpoints: impl IntoIterator<Item = Endpoint>,
122
0
        mut connections_per_endpoint: usize,
123
0
        mut max_concurrent_requests: usize,
124
0
        retry: Retry,
125
0
        jitter_fn: retry::JitterFn,
126
0
    ) -> Self {
127
0
        let (worker_tx, worker_rx) = mpsc::channel(WORKER_BACKLOG);
128
0
        // The connection messages always come from sync contexts (e.g. drop)
129
0
        // and therefore, we'd end up spawning for them if this was bounded
130
0
        // which defeats the object since there would be no backpressure
131
0
        // applied. Therefore it makes sense for this to be unbounded.
132
0
        let (connection_tx, connection_rx) = mpsc::unbounded_channel();
133
0
        let endpoints = endpoints
134
0
            .into_iter()
135
0
            .map(|endpoint| (0, endpoint))
136
0
            .collect();
137
0
138
0
        if max_concurrent_requests == 0 {
  Branch (138:12): [True: 0, False: 0]
  Branch (138:12): [True: 0, False: 0]
  Branch (138:12): [Folded - Ignored]
  Branch (138:12): [Folded - Ignored]
139
0
            max_concurrent_requests = usize::MAX;
140
0
        }
141
0
        if connections_per_endpoint == 0 {
  Branch (141:12): [True: 0, False: 0]
  Branch (141:12): [True: 0, False: 0]
  Branch (141:12): [Folded - Ignored]
  Branch (141:12): [Folded - Ignored]
142
0
            connections_per_endpoint = 1;
143
0
        }
144
0
        let worker = ConnectionManagerWorker {
145
0
            endpoints,
146
0
            available_connections: max_concurrent_requests,
147
0
            connection_tx,
148
0
            connecting_channels: FuturesUnordered::new(),
149
0
            available_channels: VecDeque::new(),
150
0
            waiting_connections: VecDeque::new(),
151
0
            retrier: Retrier::new(
152
0
                Arc::new(|duration| Box::pin(tokio::time::sleep(duration))),
153
0
                jitter_fn,
154
0
                retry,
155
0
            ),
156
0
        };
157
0
        background_spawn!("connection_manager_worker_spawn", async move {
158
0
            worker
159
0
                .service_requests(connections_per_endpoint, worker_rx, connection_rx)
160
0
                .await;
161
0
        });
162
0
        Self { worker_tx }
163
0
    }
164
165
    /// Get a Connection that can be used as a `tonic::Channel`, except it
166
    /// performs some additional counting to reconnect on error and restrict
167
    /// the number of concurrent connections.
168
0
    pub async fn connection(&self) -> Result<Connection, Error> {
169
0
        let (tx, rx) = oneshot::channel();
170
0
        self.worker_tx
171
0
            .send(tx)
172
0
            .await
173
0
            .map_err(|err| make_err!(Code::Unavailable, "Requesting a new connection: {err:?}"))?;
174
0
        rx.await
175
0
            .map_err(|err| make_err!(Code::Unavailable, "Waiting for a new connection: {err:?}"))
176
0
    }
177
}
178
179
impl ConnectionManagerWorker {
180
0
    async fn service_requests(
181
0
        mut self,
182
0
        connections_per_endpoint: usize,
183
0
        mut worker_rx: mpsc::Receiver<oneshot::Sender<Connection>>,
184
0
        mut connection_rx: mpsc::UnboundedReceiver<ConnectionRequest>,
185
0
    ) {
186
        // Make the initial set of connections, connection failures will be
187
        // handled in the same way as future transport failures, so no need to
188
        // do anything special.
189
0
        for endpoint_index in 0..self.endpoints.len() {
190
0
            for _ in 0..connections_per_endpoint {
191
0
                self.connect_endpoint(endpoint_index, None);
192
0
            }
193
        }
194
195
        // The main worker loop, when select resolves one of its arms the other
196
        // ones are cancelled, therefore it's important that they maintain no
197
        // state while `await`-ing.  This is enforced through the use of
198
        // non-async functions to do all of the work.
199
        loop {
200
0
            tokio::select! {
201
0
                request = worker_rx.recv() => {
202
0
                    let Some(request) = request else {
  Branch (202:25): [True: 0, False: 0]
  Branch (202:25): [True: 0, False: 0]
  Branch (202:25): [Folded - Ignored]
  Branch (202:25): [Folded - Ignored]
203
                        // The ConnectionManager was dropped, shut down the
204
                        // worker.
205
0
                        break;
206
                    };
207
0
                    self.handle_worker(request);
208
                }
209
0
                maybe_request = connection_rx.recv() => {
210
0
                    if let Some(request) = maybe_request {
  Branch (210:28): [True: 0, False: 0]
  Branch (210:28): [True: 0, False: 0]
  Branch (210:28): [Folded - Ignored]
  Branch (210:28): [Folded - Ignored]
211
0
                        self.handle_connection(request);
212
0
                    }
213
                }
214
0
                maybe_connection_result = self.connect_next() => {
215
0
                    if let Some(connection_result) = maybe_connection_result {
  Branch (215:28): [True: 0, False: 0]
  Branch (215:28): [True: 0, False: 0]
  Branch (215:28): [Folded - Ignored]
  Branch (215:28): [Folded - Ignored]
216
0
                        self.handle_connected(connection_result);
217
0
                    }
218
                }
219
            }
220
        }
221
0
    }
222
223
0
    async fn connect_next(&mut self) -> Option<IndexedChannel> {
224
0
        if self.connecting_channels.is_empty() {
  Branch (224:12): [True: 0, False: 0]
  Branch (224:12): [True: 0, False: 0]
  Branch (224:12): [Folded - Ignored]
  Branch (224:12): [Folded - Ignored]
225
            // Make this Future never resolve, we will get cancelled by the
226
            // select if there's some change in state to `self` and can re-enter
227
            // and evaluate `connecting_channels` again.
228
0
            futures::future::pending::<()>().await;
229
0
        }
230
0
        self.connecting_channels.next().await
231
0
    }
232
233
    // This must never be made async otherwise the select may cancel it.
234
0
    fn handle_connected(&mut self, connection_result: IndexedChannel) {
235
0
        match connection_result {
236
0
            Ok(established_channel) => {
237
0
                self.available_channels.push_back(established_channel);
238
0
                self.maybe_available_connection();
239
0
            }
240
            // When the retrier runs out of attempts start again from the
241
            // beginning of the retry period.  Never want to be in a
242
            // situation where we give up on an Endpoint forever.
243
0
            Err((identifier, _)) => {
244
0
                self.connect_endpoint(identifier.endpoint_index, Some(identifier.connection_index));
245
0
            }
246
        }
247
0
    }
248
249
0
    fn connect_endpoint(&mut self, endpoint_index: usize, connection_index: Option<usize>) {
250
0
        let Some((current_connection_index, endpoint)) = self.endpoints.get_mut(endpoint_index)
  Branch (250:13): [True: 0, False: 0]
  Branch (250:13): [Folded - Ignored]
251
        else {
252
            // Unknown endpoint, this should never happen.
253
0
            error!(?endpoint_index, "Connection to unknown endpoint requested");
254
0
            return;
255
        };
256
0
        let is_backoff = connection_index.is_some();
257
0
        let connection_index = connection_index.unwrap_or_else(|| {
258
0
            *current_connection_index += 1;
259
0
            *current_connection_index
260
0
        });
261
0
        if is_backoff {
  Branch (261:12): [True: 0, False: 0]
  Branch (261:12): [Folded - Ignored]
262
0
            warn!(
263
                ?connection_index,
264
0
                endpoint = ?endpoint.uri(),
265
0
                "Connection failed, reconnecting"
266
            );
267
        } else {
268
0
            info!(
269
                ?connection_index,
270
0
                endpoint = ?endpoint.uri(),
271
0
                "Creating new connection"
272
            );
273
        }
274
0
        let identifier = ChannelIdentifier {
275
0
            endpoint_index,
276
0
            connection_index,
277
0
        };
278
0
        let connection_stream = unfold(endpoint.clone(), move |endpoint| async move {
279
0
            let result = endpoint.connect().await.map_err(|err| {
280
0
                make_err!(
281
0
                    Code::Unavailable,
282
0
                    "Failed to connect to {:?}: {err:?}",
283
0
                    endpoint.uri()
284
0
                )
285
0
            });
286
0
            Some((
287
0
                result.map_or_else(RetryResult::Retry, RetryResult::Ok),
288
0
                endpoint,
289
0
            ))
290
0
        });
291
0
        let retrier = self.retrier.clone();
292
0
        self.connecting_channels.push(Box::pin(async move {
293
0
            if is_backoff {
  Branch (293:16): [True: 0, False: 0]
  Branch (293:16): [Folded - Ignored]
294
                // Just in case the retry config is 0, then we need to
295
                // introduce some delay so we aren't in a hard loop.
296
0
                tokio::time::sleep(Duration::from_secs(1)).await;
297
0
            }
298
0
            retrier.retry(connection_stream).await.map_or_else(
299
0
                |err| Err((identifier, err)),
300
0
                |channel| {
301
0
                    Ok(EstablishedChannel {
302
0
                        channel,
303
0
                        identifier,
304
0
                    })
305
0
                },
306
            )
307
0
        }));
308
0
    }
309
310
    // This must never be made async otherwise the select may cancel it.
311
0
    fn handle_worker(&mut self, tx: oneshot::Sender<Connection>) {
312
0
        if let Some(channel) = (self.available_connections > 0)
  Branch (312:16): [True: 0, False: 0]
  Branch (312:16): [Folded - Ignored]
313
0
            .then_some(())
314
0
            .and_then(|()| self.available_channels.pop_front())
315
0
        {
316
0
            self.provide_channel(channel, tx);
317
0
        } else {
318
0
            self.waiting_connections.push_back(tx);
319
0
        }
320
0
    }
321
322
0
    fn provide_channel(&mut self, channel: EstablishedChannel, tx: oneshot::Sender<Connection>) {
323
0
        // We decrement here because we create Connection, this will signal when
324
0
        // it is Dropped and therefore increment this again.
325
0
        self.available_connections -= 1;
326
0
        drop(tx.send(Connection {
327
0
            tx: self.connection_tx.clone(),
328
0
            pending_channel: Some(channel.channel.clone()),
329
0
            channel,
330
0
        }));
331
0
    }
332
333
0
    fn maybe_available_connection(&mut self) {
334
0
        while self.available_connections > 0
  Branch (334:15): [True: 0, False: 0]
  Branch (334:15): [Folded - Ignored]
335
0
            && !self.waiting_connections.is_empty()
  Branch (335:16): [True: 0, False: 0]
  Branch (335:16): [Folded - Ignored]
336
0
            && !self.available_channels.is_empty()
  Branch (336:16): [True: 0, False: 0]
  Branch (336:16): [Folded - Ignored]
337
        {
338
0
            if let Some(channel) = self.available_channels.pop_front() {
  Branch (338:20): [True: 0, False: 0]
  Branch (338:20): [Folded - Ignored]
339
0
                if let Some(tx) = self.waiting_connections.pop_front() {
  Branch (339:24): [True: 0, False: 0]
  Branch (339:24): [Folded - Ignored]
340
0
                    self.provide_channel(channel, tx);
341
0
                } else {
342
0
                    // This should never happen, but better than an unwrap.
343
0
                    self.available_channels.push_front(channel);
344
0
                }
345
0
            }
346
        }
347
0
    }
348
349
    // This must never be made async otherwise the select may cancel it.
350
0
    fn handle_connection(&mut self, request: ConnectionRequest) {
351
0
        match request {
352
0
            ConnectionRequest::Dropped(maybe_channel) => {
353
0
                if let Some(channel) = maybe_channel {
  Branch (353:24): [True: 0, False: 0]
  Branch (353:24): [Folded - Ignored]
354
0
                    self.available_channels.push_back(channel);
355
0
                }
356
0
                self.available_connections += 1;
357
0
                self.maybe_available_connection();
358
            }
359
0
            ConnectionRequest::Connected(channel) => {
360
0
                self.available_channels.push_back(channel);
361
0
                self.maybe_available_connection();
362
0
            }
363
            // Handle a transport error on a connection by making it unavailable
364
            // for use and establishing a new connection to the endpoint.
365
0
            ConnectionRequest::Error((identifier, was_pending)) => {
366
0
                let should_reconnect = if was_pending {
  Branch (366:43): [True: 0, False: 0]
  Branch (366:43): [Folded - Ignored]
367
0
                    true
368
                } else {
369
0
                    let original_length = self.available_channels.len();
370
0
                    self.available_channels
371
0
                        .retain(|channel| channel.identifier != identifier);
372
                    // Only reconnect if it wasn't already disconnected.
373
0
                    original_length != self.available_channels.len()
374
                };
375
0
                if should_reconnect {
  Branch (375:20): [True: 0, False: 0]
  Branch (375:20): [Folded - Ignored]
376
0
                    self.connect_endpoint(identifier.endpoint_index, None);
377
0
                }
378
            }
379
        }
380
0
    }
381
}
382
383
/// An instance of this is obtained for every communication with the gGRPC
384
/// service.  This handles the permit for limiting concurrency, and also
385
/// re-connecting the underlying channel on error.  It depends on users
386
/// reporting all errors.
387
/// NOTE: This should never be cloneable because its lifetime is linked to the
388
///       `ConnectionManagerWorker::available_connections`.
389
#[derive(Debug)]
390
pub struct Connection {
391
    /// Communication with `ConnectionManagerWorker` to inform about transport
392
    /// errors and when the Connection is dropped.
393
    tx: mpsc::UnboundedSender<ConnectionRequest>,
394
    /// If set, the Channel that will be returned to the worker when connection
395
    /// completes (success or failure) or when the Connection is dropped if that
396
    /// happens before connection completes.
397
    pending_channel: Option<Channel>,
398
    /// The identifier to send to `tx`.
399
    channel: EstablishedChannel,
400
}
401
402
impl Drop for Connection {
403
0
    fn drop(&mut self) {
404
0
        let pending_channel = self
405
0
            .pending_channel
406
0
            .take()
407
0
            .map(|channel| EstablishedChannel {
408
0
                channel,
409
0
                identifier: self.channel.identifier,
410
0
            });
411
0
        drop(self.tx.send(ConnectionRequest::Dropped(pending_channel)));
412
0
    }
413
}
414
415
/// A wrapper around the `channel::ResponseFuture` that forwards errors to the `tx`.
416
pub struct ResponseFuture {
417
    /// The wrapped future that actually does the work.
418
    inner: channel::ResponseFuture,
419
    /// Communication with `ConnectionManagerWorker` to inform about transport
420
    /// errors.
421
    connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
422
    /// The identifier to send to `connection_tx` on a transport error.
423
    identifier: ChannelIdentifier,
424
}
425
426
impl core::fmt::Debug for ResponseFuture {
427
0
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
428
0
        f.debug_struct("ResponseFuture")
429
0
            .field("inner", &self.inner)
430
0
            .field("connection_tx", &self.connection_tx)
431
0
            .field("identifier", &self.identifier)
432
0
            .finish()
433
0
    }
434
}
435
436
/// This is mostly copied from `tonic::transport::channel` except it wraps it
437
/// to allow messaging about connection success and failure.
438
impl tonic::codegen::Service<tonic::codegen::http::Request<tonic::body::Body>> for Connection {
439
    type Response = tonic::codegen::http::Response<tonic::body::Body>;
440
    type Error = tonic::transport::Error;
441
    type Future = ResponseFuture;
442
443
0
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
444
0
        let result = self.channel.channel.poll_ready(cx);
445
0
        if let Poll::Ready(result) = &result {
  Branch (445:16): [True: 0, False: 0]
  Branch (445:16): [Folded - Ignored]
446
0
            match result {
447
                Ok(()) => {
448
0
                    if let Some(pending_channel) = self.pending_channel.take() {
  Branch (448:28): [True: 0, False: 0]
  Branch (448:28): [Folded - Ignored]
449
0
                        drop(
450
0
                            self.tx
451
0
                                .send(ConnectionRequest::Connected(EstablishedChannel {
452
0
                                    channel: pending_channel,
453
0
                                    identifier: self.channel.identifier,
454
0
                                })),
455
0
                        );
456
0
                    }
457
                }
458
0
                Err(err) => {
459
0
                    debug!(?err, "Error while creating connection on channel");
460
0
                    drop(self.tx.send(ConnectionRequest::Error((
461
0
                        self.channel.identifier,
462
0
                        self.pending_channel.take().is_some(),
463
0
                    ))));
464
                }
465
            }
466
0
        }
467
0
        result
468
0
    }
469
470
0
    fn call(&mut self, request: tonic::codegen::http::Request<tonic::body::Body>) -> Self::Future {
471
0
        ResponseFuture {
472
0
            inner: self.channel.channel.call(request),
473
0
            connection_tx: self.tx.clone(),
474
0
            identifier: self.channel.identifier,
475
0
        }
476
0
    }
477
}
478
479
/// This is mostly copied from `tonic::transport::channel` except it wraps it
480
/// to allow messaging about connection failure.
481
impl Future for ResponseFuture {
482
    type Output =
483
        Result<tonic::codegen::http::Response<tonic::body::Body>, tonic::transport::Error>;
484
485
0
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
486
0
        let result = Pin::new(&mut self.inner).poll(cx);
487
0
        if let Poll::Ready(Err(_)) = &result {
  Branch (487:16): [True: 0, False: 0]
  Branch (487:16): [Folded - Ignored]
488
0
            drop(
489
0
                self.connection_tx
490
0
                    .send(ConnectionRequest::Error((self.identifier, false))),
491
0
            );
492
0
        }
493
0
        result
494
0
    }
495
}