Coverage Report

Created: 2024-11-22 20:17

/build/source/nativelink-util/src/connection_manager.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::VecDeque;
16
use std::pin::Pin;
17
use std::sync::Arc;
18
use std::task::{Context, Poll};
19
use std::time::Duration;
20
21
use futures::stream::{unfold, FuturesUnordered, StreamExt};
22
use futures::Future;
23
use nativelink_config::stores::Retry;
24
use nativelink_error::{make_err, Code, Error};
25
use tokio::sync::{mpsc, oneshot};
26
use tonic::transport::{channel, Channel, Endpoint};
27
use tracing::{event, Level};
28
29
use crate::background_spawn;
30
use crate::retry::{self, Retrier, RetryResult};
31
32
/// A helper utility that enables management of a suite of connections to an
33
/// upstream gRPC endpoint using Tonic.
34
pub struct ConnectionManager {
35
    // The channel to request connections from the worker.
36
    worker_tx: mpsc::Sender<oneshot::Sender<Connection>>,
37
}
38
39
/// The index into ConnectionManagerWorker::endpoints.
40
type EndpointIndex = usize;
41
/// The identifier for a given connection to a given Endpoint, used to identify
42
/// when a particular connection has failed or becomes available.
43
type ConnectionIndex = usize;
44
45
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
46
struct ChannelIdentifier {
47
    /// The index into ConnectionManagerWorker::endpoints that established this
48
    /// Channel.
49
    endpoint_index: EndpointIndex,
50
    /// A unique identifier for this particular connection to the Endpoint.
51
    connection_index: ConnectionIndex,
52
}
53
54
/// The requests that can be made from a Connection to the
55
/// ConnectionManagerWorker such as informing it that it's been dropped or that
56
/// an error occurred.
57
enum ConnectionRequest {
58
    /// Notify that a Connection was dropped, if it was dropped while the
59
    /// connection was still pending, then return the pending Channel to be
60
    /// added back to the available channels.
61
    Dropped(Option<EstablishedChannel>),
62
    /// Notify that a Connection was established, return the Channel to the
63
    /// available channels.
64
    Connected(EstablishedChannel),
65
    /// Notify that there was a transport error on the given Channel, the bool
66
    /// specifies whether the connection was in the process of being established
67
    /// or not (i.e. whether it's been returned to available channels yet).
68
    Error((ChannelIdentifier, bool)),
69
}
70
71
/// The result of a Future that connects to a given Endpoint.  This is a tuple
72
/// of the index into the ConnectionManagerWorker::endpoints that this
73
/// connection is for, the iteration of the connection and the result of the
74
/// connection itself.
75
type IndexedChannel = Result<EstablishedChannel, (ChannelIdentifier, Error)>;
76
77
/// A channel that has been established to an endpoint with some metadata around
78
/// it to allow identification of the Channel if it errors in order to correctly
79
/// remove it.
80
#[derive(Clone)]
81
struct EstablishedChannel {
82
    /// The Channel itself that the meta data relates to.
83
    channel: Channel,
84
    /// The identifier of the channel in the worker.
85
    identifier: ChannelIdentifier,
86
}
87
88
/// The context of the worker used to manage all of the connections.  This
89
/// handles reconnecting to endpoints on errors and multiple connections to a
90
/// given endpoint.
91
struct ConnectionManagerWorker {
92
    /// The endpoints to establish Channels and the identifier of the last
93
    /// connection attempt to that endpoint.
94
    endpoints: Vec<(ConnectionIndex, Endpoint)>,
95
    /// The channel used to communicate between a Connection and the worker.
96
    connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
97
    /// The number of connections that are currently allowed to be made.
98
    available_connections: usize,
99
    /// Channels that are currently being connected.
100
    connecting_channels: FuturesUnordered<Pin<Box<dyn Future<Output = IndexedChannel> + Send>>>,
101
    /// Connected channels that are available for use.
102
    available_channels: VecDeque<EstablishedChannel>,
103
    /// Requests for a Channel when available.
104
    waiting_connections: VecDeque<oneshot::Sender<Connection>>,
105
    /// The retry configuration for connecting to an Endpoint, on failure will
106
    /// restart the retrier after a 1 second delay.
107
    retrier: Retrier,
108
}
109
110
/// The maximum number of queued requests to obtain a connection from the
111
/// worker before applying back pressure to the requestor.  It makes sense to
112
/// keep this small since it has to wait for a response anyway.
113
const WORKER_BACKLOG: usize = 8;
114
115
impl ConnectionManager {
116
    /// Create a connection manager that creates a balance list between a given
117
    /// set of Endpoints.  This will restrict the number of concurrent requests
118
    /// and automatically re-connect upon transport error.
119
0
    pub fn new(
120
0
        endpoints: impl IntoIterator<Item = Endpoint>,
121
0
        mut connections_per_endpoint: usize,
122
0
        mut max_concurrent_requests: usize,
123
0
        retry: Retry,
124
0
        jitter_fn: retry::JitterFn,
125
0
    ) -> Self {
126
0
        let (worker_tx, worker_rx) = mpsc::channel(WORKER_BACKLOG);
127
0
        // The connection messages always come from sync contexts (e.g. drop)
128
0
        // and therefore, we'd end up spawning for them if this was bounded
129
0
        // which defeats the object since there would be no backpressure
130
0
        // applied. Therefore it makes sense for this to be unbounded.
131
0
        let (connection_tx, connection_rx) = mpsc::unbounded_channel();
132
0
        let endpoints = Vec::from_iter(endpoints.into_iter().map(|endpoint| (0, endpoint)));
133
0
        if max_concurrent_requests == 0 {
  Branch (133:12): [True: 0, False: 0]
  Branch (133:12): [True: 0, False: 0]
  Branch (133:12): [Folded - Ignored]
  Branch (133:12): [Folded - Ignored]
134
0
            max_concurrent_requests = usize::MAX;
135
0
        }
136
0
        if connections_per_endpoint == 0 {
  Branch (136:12): [True: 0, False: 0]
  Branch (136:12): [True: 0, False: 0]
  Branch (136:12): [Folded - Ignored]
  Branch (136:12): [Folded - Ignored]
137
0
            connections_per_endpoint = 1;
138
0
        }
139
0
        let worker = ConnectionManagerWorker {
140
0
            endpoints,
141
0
            available_connections: max_concurrent_requests,
142
0
            connection_tx,
143
0
            connecting_channels: FuturesUnordered::new(),
144
0
            available_channels: VecDeque::new(),
145
0
            waiting_connections: VecDeque::new(),
146
0
            retrier: Retrier::new(
147
0
                Arc::new(|duration| Box::pin(tokio::time::sleep(duration))),
148
0
                jitter_fn,
149
0
                retry,
150
0
            ),
151
0
        };
152
0
        background_spawn!("connection_manager_worker_spawn", async move {
153
0
            worker
154
0
                .service_requests(connections_per_endpoint, worker_rx, connection_rx)
155
0
                .await;
156
0
        });
157
0
        Self { worker_tx }
158
0
    }
159
160
    /// Get a Connection that can be used as a tonic::Channel, except it
161
    /// performs some additional counting to reconnect on error and restrict
162
    /// the number of concurrent connections.
163
0
    pub async fn connection(&self) -> Result<Connection, Error> {
164
0
        let (tx, rx) = oneshot::channel();
165
0
        self.worker_tx
166
0
            .send(tx)
167
0
            .await
168
0
            .map_err(|err| make_err!(Code::Unavailable, "Requesting a new connection: {err:?}"))?;
169
0
        rx.await
170
0
            .map_err(|err| make_err!(Code::Unavailable, "Waiting for a new connection: {err:?}"))
171
0
    }
172
}
173
174
impl ConnectionManagerWorker {
175
0
    async fn service_requests(
176
0
        mut self,
177
0
        connections_per_endpoint: usize,
178
0
        mut worker_rx: mpsc::Receiver<oneshot::Sender<Connection>>,
179
0
        mut connection_rx: mpsc::UnboundedReceiver<ConnectionRequest>,
180
0
    ) {
181
        // Make the initial set of connections, connection failures will be
182
        // handled in the same way as future transport failures, so no need to
183
        // do anything special.
184
0
        for endpoint_index in 0..self.endpoints.len() {
185
0
            for _ in 0..connections_per_endpoint {
186
0
                self.connect_endpoint(endpoint_index, None);
187
0
            }
188
        }
189
190
        // The main worker loop, when select resolves one of its arms the other
191
        // ones are cancelled, therefore it's important that they maintain no
192
        // state while `await`-ing.  This is enforced through the use of
193
        // non-async functions to do all of the work.
194
        loop {
195
0
            tokio::select! {
196
0
                request = worker_rx.recv() => {
197
0
                    let Some(request) = request else {
  Branch (197:25): [True: 0, False: 0]
  Branch (197:25): [True: 0, False: 0]
  Branch (197:25): [Folded - Ignored]
  Branch (197:25): [Folded - Ignored]
198
                        // The ConnectionManager was dropped, shut down the
199
                        // worker.
200
0
                        break;
201
                    };
202
0
                    self.handle_worker(request);
203
                }
204
0
                maybe_request = connection_rx.recv() => {
205
0
                    if let Some(request) = maybe_request {
  Branch (205:28): [True: 0, False: 0]
  Branch (205:28): [True: 0, False: 0]
  Branch (205:28): [Folded - Ignored]
  Branch (205:28): [Folded - Ignored]
206
0
                        self.handle_connection(request);
207
0
                    }
208
                }
209
0
                maybe_connection_result = self.connect_next() => {
210
0
                    if let Some(connection_result) = maybe_connection_result {
  Branch (210:28): [True: 0, False: 0]
  Branch (210:28): [True: 0, False: 0]
  Branch (210:28): [Folded - Ignored]
  Branch (210:28): [Folded - Ignored]
211
0
                        self.handle_connected(connection_result);
212
0
                    }
213
                }
214
            }
215
        }
216
0
    }
217
218
0
    async fn connect_next(&mut self) -> Option<IndexedChannel> {
219
0
        if self.connecting_channels.is_empty() {
  Branch (219:12): [True: 0, False: 0]
  Branch (219:12): [True: 0, False: 0]
  Branch (219:12): [Folded - Ignored]
  Branch (219:12): [Folded - Ignored]
220
            // Make this Future never resolve, we will get cancelled by the
221
            // select if there's some change in state to `self` and can re-enter
222
            // and evaluate `connecting_channels` again.
223
0
            futures::future::pending::<()>().await;
224
0
        }
225
0
        self.connecting_channels.next().await
226
0
    }
227
228
    // This must never be made async otherwise the select may cancel it.
229
0
    fn handle_connected(&mut self, connection_result: IndexedChannel) {
230
0
        match connection_result {
231
0
            Ok(established_channel) => {
232
0
                self.available_channels.push_back(established_channel);
233
0
                self.maybe_available_connection();
234
0
            }
235
            // When the retrier runs out of attempts start again from the
236
            // beginning of the retry period.  Never want to be in a
237
            // situation where we give up on an Endpoint forever.
238
0
            Err((identifier, _)) => {
239
0
                self.connect_endpoint(identifier.endpoint_index, Some(identifier.connection_index));
240
0
            }
241
        }
242
0
    }
243
244
0
    fn connect_endpoint(&mut self, endpoint_index: usize, connection_index: Option<usize>) {
245
0
        let Some((current_connection_index, endpoint)) = self.endpoints.get_mut(endpoint_index)
  Branch (245:13): [True: 0, False: 0]
  Branch (245:13): [Folded - Ignored]
246
        else {
247
            // Unknown endpoint, this should never happen.
248
0
            event!(
249
0
                Level::ERROR,
250
                ?endpoint_index,
251
0
                "Connection to unknown endpoint requested"
252
            );
253
0
            return;
254
        };
255
0
        let is_backoff = connection_index.is_some();
256
0
        let connection_index = connection_index.unwrap_or_else(|| {
257
0
            *current_connection_index += 1;
258
0
            *current_connection_index
259
0
        });
260
0
        if is_backoff {
  Branch (260:12): [True: 0, False: 0]
  Branch (260:12): [Folded - Ignored]
261
0
            event!(
262
0
                Level::WARN,
263
                ?connection_index,
264
0
                endpoint = ?endpoint.uri(),
265
0
                "Connection failed, reconnecting"
266
            );
267
        } else {
268
0
            event!(
269
0
                Level::INFO,
270
                ?connection_index,
271
0
                endpoint = ?endpoint.uri(),
272
0
                "Creating new connection"
273
            );
274
        }
275
0
        let identifier = ChannelIdentifier {
276
0
            endpoint_index,
277
0
            connection_index,
278
0
        };
279
0
        let connection_stream = unfold(endpoint.clone(), move |endpoint| async move {
280
0
            let result = endpoint.connect().await.map_err(|err| {
281
0
                make_err!(
282
0
                    Code::Unavailable,
283
0
                    "Failed to connect to {:?}: {err:?}",
284
0
                    endpoint.uri()
285
0
                )
286
0
            });
287
0
            Some((
288
0
                result.map_or_else(RetryResult::Retry, RetryResult::Ok),
289
0
                endpoint,
290
0
            ))
291
0
        });
292
0
        let retrier = self.retrier.clone();
293
0
        self.connecting_channels.push(Box::pin(async move {
294
0
            if is_backoff {
  Branch (294:16): [True: 0, False: 0]
  Branch (294:16): [Folded - Ignored]
295
                // Just in case the retry config is 0, then we need to
296
                // introduce some delay so we aren't in a hard loop.
297
0
                tokio::time::sleep(Duration::from_secs(1)).await;
298
0
            }
299
0
            retrier.retry(connection_stream).await.map_or_else(
300
0
                |err| Err((identifier, err)),
301
0
                |channel| {
302
0
                    Ok(EstablishedChannel {
303
0
                        channel,
304
0
                        identifier,
305
0
                    })
306
0
                },
307
0
            )
308
0
        }));
309
0
    }
310
311
    // This must never be made async otherwise the select may cancel it.
312
0
    fn handle_worker(&mut self, tx: oneshot::Sender<Connection>) {
313
0
        if let Some(channel) = (self.available_connections > 0)
  Branch (313:16): [True: 0, False: 0]
  Branch (313:16): [Folded - Ignored]
314
0
            .then_some(())
315
0
            .and_then(|_| self.available_channels.pop_front())
316
0
        {
317
0
            self.provide_channel(channel, tx);
318
0
        } else {
319
0
            self.waiting_connections.push_back(tx);
320
0
        }
321
0
    }
322
323
0
    fn provide_channel(&mut self, channel: EstablishedChannel, tx: oneshot::Sender<Connection>) {
324
0
        // We decrement here because we create Connection, this will signal when
325
0
        // it is Dropped and therefore increment this again.
326
0
        self.available_connections -= 1;
327
0
        let _ = tx.send(Connection {
328
0
            connection_tx: self.connection_tx.clone(),
329
0
            pending_channel: Some(channel.channel.clone()),
330
0
            channel,
331
0
        });
332
0
    }
333
334
0
    fn maybe_available_connection(&mut self) {
335
0
        while self.available_connections > 0
  Branch (335:15): [True: 0, False: 0]
  Branch (335:15): [Folded - Ignored]
336
0
            && !self.waiting_connections.is_empty()
  Branch (336:16): [True: 0, False: 0]
  Branch (336:16): [Folded - Ignored]
337
0
            && !self.available_channels.is_empty()
  Branch (337:16): [True: 0, False: 0]
  Branch (337:16): [Folded - Ignored]
338
        {
339
0
            if let Some(channel) = self.available_channels.pop_front() {
  Branch (339:20): [True: 0, False: 0]
  Branch (339:20): [Folded - Ignored]
340
0
                if let Some(tx) = self.waiting_connections.pop_front() {
  Branch (340:24): [True: 0, False: 0]
  Branch (340:24): [Folded - Ignored]
341
0
                    self.provide_channel(channel, tx);
342
0
                } else {
343
0
                    // This should never happen, but better than an unwrap.
344
0
                    self.available_channels.push_front(channel);
345
0
                }
346
0
            }
347
        }
348
0
    }
349
350
    // This must never be made async otherwise the select may cancel it.
351
0
    fn handle_connection(&mut self, request: ConnectionRequest) {
352
0
        match request {
353
0
            ConnectionRequest::Dropped(maybe_channel) => {
354
0
                if let Some(channel) = maybe_channel {
  Branch (354:24): [True: 0, False: 0]
  Branch (354:24): [Folded - Ignored]
355
0
                    self.available_channels.push_back(channel);
356
0
                }
357
0
                self.available_connections += 1;
358
0
                self.maybe_available_connection();
359
            }
360
0
            ConnectionRequest::Connected(channel) => {
361
0
                self.available_channels.push_back(channel);
362
0
                self.maybe_available_connection();
363
0
            }
364
            // Handle a transport error on a connection by making it unavailable
365
            // for use and establishing a new connection to the endpoint.
366
0
            ConnectionRequest::Error((identifier, was_pending)) => {
367
0
                let should_reconnect = if was_pending {
  Branch (367:43): [True: 0, False: 0]
  Branch (367:43): [Folded - Ignored]
368
0
                    true
369
                } else {
370
0
                    let original_length = self.available_channels.len();
371
0
                    self.available_channels
372
0
                        .retain(|channel| channel.identifier != identifier);
373
0
                    // Only reconnect if it wasn't already disconnected.
374
0
                    original_length != self.available_channels.len()
375
                };
376
0
                if should_reconnect {
  Branch (376:20): [True: 0, False: 0]
  Branch (376:20): [Folded - Ignored]
377
0
                    self.connect_endpoint(identifier.endpoint_index, None);
378
0
                }
379
            }
380
        }
381
0
    }
382
}
383
384
/// An instance of this is obtained for every communication with the gGRPC
385
/// service.  This handles the permit for limiting concurrency, and also
386
/// re-connecting the underlying channel on error.  It depends on users
387
/// reporting all errors.
388
/// NOTE: This should never be cloneable because its lifetime is linked to the
389
///       ConnectionManagerWorker::available_connections.
390
pub struct Connection {
391
    /// Communication with ConnectionManagerWorker to inform about transport
392
    /// errors and when the Connection is dropped.
393
    connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
394
    /// If set, the Channel that will be returned to the worker when connection
395
    /// completes (success or failure) or when the Connection is dropped if that
396
    /// happens before connection completes.
397
    pending_channel: Option<Channel>,
398
    /// The identifier to send to connection_tx.
399
    channel: EstablishedChannel,
400
}
401
402
impl Drop for Connection {
403
0
    fn drop(&mut self) {
404
0
        let pending_channel = self
405
0
            .pending_channel
406
0
            .take()
407
0
            .map(|channel| EstablishedChannel {
408
0
                channel,
409
0
                identifier: self.channel.identifier,
410
0
            });
411
0
        let _ = self
412
0
            .connection_tx
413
0
            .send(ConnectionRequest::Dropped(pending_channel));
414
0
    }
415
}
416
417
/// A wrapper around the channel::ResponseFuture that forwards errors to the
418
/// connection_tx.
419
pub struct ResponseFuture {
420
    /// The wrapped future that actually does the work.
421
    inner: channel::ResponseFuture,
422
    /// Communication with ConnectionManagerWorker to inform about transport
423
    /// errors.
424
    connection_tx: mpsc::UnboundedSender<ConnectionRequest>,
425
    /// The identifier to send to connection_tx on a transport error.
426
    identifier: ChannelIdentifier,
427
}
428
429
/// This is mostly copied from tonic::transport::channel except it wraps it
430
/// to allow messaging about connection success and failure.
431
impl tonic::codegen::Service<tonic::codegen::http::Request<tonic::body::BoxBody>> for Connection {
432
    type Response = tonic::codegen::http::Response<tonic::body::BoxBody>;
433
    type Error = tonic::transport::Error;
434
    type Future = ResponseFuture;
435
436
0
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
437
0
        let result = self.channel.channel.poll_ready(cx);
438
0
        if let Poll::Ready(result) = &result {
  Branch (438:16): [True: 0, False: 0]
  Branch (438:16): [Folded - Ignored]
439
0
            match result {
440
                Ok(_) => {
441
0
                    if let Some(pending_channel) = self.pending_channel.take() {
  Branch (441:28): [True: 0, False: 0]
  Branch (441:28): [Folded - Ignored]
442
0
                        let _ = self.connection_tx.send(ConnectionRequest::Connected(
443
0
                            EstablishedChannel {
444
0
                                channel: pending_channel,
445
0
                                identifier: self.channel.identifier,
446
0
                            },
447
0
                        ));
448
0
                    }
449
                }
450
0
                Err(err) => {
451
0
                    event!(
452
0
                        Level::DEBUG,
453
                        ?err,
454
0
                        "Error while creating connection on channel"
455
                    );
456
0
                    let _ = self.connection_tx.send(ConnectionRequest::Error((
457
0
                        self.channel.identifier,
458
0
                        self.pending_channel.take().is_some(),
459
0
                    )));
460
                }
461
            }
462
0
        }
463
0
        result
464
0
    }
465
466
0
    fn call(
467
0
        &mut self,
468
0
        request: tonic::codegen::http::Request<tonic::body::BoxBody>,
469
0
    ) -> Self::Future {
470
0
        ResponseFuture {
471
0
            inner: self.channel.channel.call(request),
472
0
            connection_tx: self.connection_tx.clone(),
473
0
            identifier: self.channel.identifier,
474
0
        }
475
0
    }
476
}
477
478
/// This is mostly copied from tonic::transport::channel except it wraps it
479
/// to allow messaging about connection failure.
480
impl Future for ResponseFuture {
481
    type Output =
482
        Result<tonic::codegen::http::Response<tonic::body::BoxBody>, tonic::transport::Error>;
483
484
0
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
485
0
        let result = Pin::new(&mut self.inner).poll(cx);
486
0
        if let Poll::Ready(Err(_)) = &result {
  Branch (486:16): [True: 0, False: 0]
  Branch (486:16): [Folded - Ignored]
487
0
            let _ = self
488
0
                .connection_tx
489
0
                .send(ConnectionRequest::Error((self.identifier, false)));
490
0
        }
491
0
        result
492
0
    }
493
}