Coverage Report

Created: 2024-10-22 12:33

/build/source/nativelink-scheduler/src/memory_awaited_action_db.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::{BTreeMap, BTreeSet, HashMap};
16
use std::ops::{Bound, RangeBounds};
17
use std::sync::Arc;
18
use std::time::Duration;
19
20
use async_lock::Mutex;
21
use futures::{FutureExt, Stream};
22
use nativelink_config::stores::EvictionPolicy;
23
use nativelink_error::{error_if, make_err, Code, Error, ResultExt};
24
use nativelink_metric::MetricsComponent;
25
use nativelink_util::action_messages::{
26
    ActionInfo, ActionStage, ActionUniqueKey, ActionUniqueQualifier, OperationId,
27
};
28
use nativelink_util::chunked_stream::ChunkedStream;
29
use nativelink_util::evicting_map::{EvictingMap, LenEntry};
30
use nativelink_util::instant_wrapper::InstantWrapper;
31
use nativelink_util::spawn;
32
use nativelink_util::task::JoinHandleDropGuard;
33
use tokio::sync::{mpsc, watch, Notify};
34
use tracing::{event, Level};
35
36
use crate::awaited_action_db::{
37
    AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedAction,
38
    SortedAwaitedActionState,
39
};
40
41
/// Number of events to process per cycle.
42
const MAX_ACTION_EVENTS_RX_PER_CYCLE: usize = 1024;
43
44
/// Duration to wait before sending client keep alive messages.
45
const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10);
46
47
/// Represents a client that is currently listening to an action.
48
/// When the client is dropped, it will send the [`AwaitedAction`] to the
49
/// `event_tx` if there are other cleanups needed.
50
#[derive(Debug)]
51
struct ClientAwaitedAction {
52
    /// The OperationId that the client is listening to.
53
    operation_id: OperationId,
54
55
    /// The sender to notify of this struct being dropped.
56
    event_tx: mpsc::UnboundedSender<ActionEvent>,
57
}
58
59
impl ClientAwaitedAction {
60
25
    pub fn new(operation_id: OperationId, event_tx: mpsc::UnboundedSender<ActionEvent>) -> Self {
61
25
        Self {
62
25
            operation_id,
63
25
            event_tx,
64
25
        }
65
25
    }
66
67
3
    pub fn operation_id(&self) -> &OperationId {
68
3
        &self.operation_id
69
3
    }
70
}
71
72
impl Drop for ClientAwaitedAction {
73
25
    fn drop(&mut self) {
74
25
        // If we failed to send it means noone is listening.
75
25
        let _ = self.event_tx.send(ActionEvent::ClientDroppedOperation(
76
25
            self.operation_id.clone(),
77
25
        ));
78
25
    }
79
}
80
81
/// Trait to be able to use the EvictingMap with [`ClientAwaitedAction`].
82
/// Note: We only use EvictingMap for a time based eviction, which is
83
/// why the implementation has fixed default values in it.
84
impl LenEntry for ClientAwaitedAction {
85
    #[inline]
86
79
    fn len(&self) -> u64 {
87
79
        0
88
79
    }
89
90
    #[inline]
91
0
    fn is_empty(&self) -> bool {
92
0
        true
93
0
    }
94
}
95
96
/// Actions the AwaitedActionsDb needs to process.
97
#[derive(Debug)]
98
pub(crate) enum ActionEvent {
99
    /// A client has sent a keep alive message.
100
    ClientKeepAlive(OperationId),
101
    /// A client has dropped and pointed to OperationId.
102
    ClientDroppedOperation(OperationId),
103
}
104
105
/// Information required to track an individual client
106
/// keep alive config and state.
107
struct ClientInfo<I: InstantWrapper, NowFn: Fn() -> I> {
108
    /// The client operation id.
109
    client_operation_id: OperationId,
110
    /// The last time a keep alive was sent.
111
    last_keep_alive: I,
112
    /// The function to get the current time.
113
    now_fn: NowFn,
114
    /// The sender to notify of this struct had an event.
115
    event_tx: mpsc::UnboundedSender<ActionEvent>,
116
}
117
118
/// Subscriber that clients can be used to monitor when AwaitedActions change.
119
pub struct MemoryAwaitedActionSubscriber<I: InstantWrapper, NowFn: Fn() -> I> {
120
    /// The receiver to listen for changes.
121
    awaited_action_rx: watch::Receiver<AwaitedAction>,
122
    /// If a client id is known this is the info needed to keep the client
123
    /// action alive.
124
    client_info: Option<ClientInfo<I, NowFn>>,
125
}
126
127
impl<I: InstantWrapper, NowFn: Fn() -> I> MemoryAwaitedActionSubscriber<I, NowFn> {
128
80
    fn new(mut awaited_action_rx: watch::Receiver<AwaitedAction>) -> Self {
129
80
        awaited_action_rx.mark_changed();
130
80
        Self {
131
80
            awaited_action_rx,
132
80
            client_info: None,
133
80
        }
134
80
    }
135
136
28
    fn new_with_client(
137
28
        mut awaited_action_rx: watch::Receiver<AwaitedAction>,
138
28
        client_operation_id: OperationId,
139
28
        event_tx: mpsc::UnboundedSender<ActionEvent>,
140
28
        now_fn: NowFn,
141
28
    ) -> Self
142
28
    where
143
28
        NowFn: Fn() -> I,
144
28
    {
145
28
        awaited_action_rx.mark_changed();
146
28
        Self {
147
28
            awaited_action_rx,
148
28
            client_info: Some(ClientInfo {
149
28
                client_operation_id,
150
28
                last_keep_alive: I::from_secs(0),
151
28
                now_fn,
152
28
                event_tx,
153
28
            }),
154
28
        }
155
28
    }
156
}
157
158
impl<I, NowFn> AwaitedActionSubscriber for MemoryAwaitedActionSubscriber<I, NowFn>
159
where
160
    I: InstantWrapper,
161
    NowFn: Fn() -> I + Send + Sync + 'static,
162
{
163
54
    async fn changed(&mut self) -> Result<AwaitedAction, Error> {
164
42
        let client_operation_id = {
165
54
            let changed_fut = self.awaited_action_rx.changed().map(|r| {
166
42
                r.map_err(|e| {
167
0
                    make_err!(
168
0
                        Code::Internal,
169
0
                        "Failed to wait for awaited action to change {e:?}"
170
0
                    )
171
42
                })
172
54
            });
173
54
            let Some(client_info) = self.client_info.as_mut() else {
  Branch (173:17): [True: 0, False: 0]
  Branch (173:17): [Folded - Ignored]
  Branch (173:17): [True: 54, False: 0]
174
0
                changed_fut.await?;
175
0
                return Ok(self.awaited_action_rx.borrow().clone());
176
            };
177
54
            tokio::pin!(changed_fut);
178
            loop {
179
158
                if client_info.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION {
  Branch (179:20): [True: 0, False: 0]
  Branch (179:20): [Folded - Ignored]
  Branch (179:20): [True: 55, False: 103]
180
55
                    client_info.last_keep_alive = (client_info.now_fn)();
181
55
                    // Failing to send just means our receiver dropped.
182
55
                    let _ = client_info.event_tx.send(ActionEvent::ClientKeepAlive(
183
55
                        client_info.client_operation_id.clone(),
184
55
                    ));
185
103
                }
186
158
                let sleep_fut = (client_info.now_fn)().sleep(CLIENT_KEEPALIVE_DURATION);
187
158
                tokio::select! {
188
158
                    
result42
= &mut changed_fut => {
189
42
                        result
?0
;
190
42
                        break;
191
                    }
192
158
                    _ = sleep_fut => {
193
104
                        // If we haven't received any updates for a while, we should
194
104
                        // let the database know that we are still listening to prevent
195
104
                        // the action from being dropped.
196
104
                    }
197
                }
198
            }
199
42
            client_info.client_operation_id.clone()
200
42
        };
201
42
        // At this stage we know that this event is a client request, so we need
202
42
        // to populate the client_operation_id.
203
42
        let mut awaited_action = self.awaited_action_rx.borrow().clone();
204
42
        let mut state = awaited_action.state().as_ref().clone();
205
42
        state.client_operation_id = client_operation_id;
206
42
        awaited_action.set_state(Arc::new(state), None);
207
42
        Ok(awaited_action)
208
42
    }
209
210
162
    async fn borrow(&self) -> Result<AwaitedAction, Error> {
211
162
        let mut awaited_action = self.awaited_action_rx.borrow().clone();
212
162
        if let Some(
client_info16
) = self.client_info.as_ref() {
  Branch (212:16): [True: 0, False: 0]
  Branch (212:16): [Folded - Ignored]
  Branch (212:16): [True: 16, False: 146]
213
16
            let mut state = awaited_action.state().as_ref().clone();
214
16
            state.client_operation_id = client_info.client_operation_id.clone();
215
16
            awaited_action.set_state(Arc::new(state), None);
216
146
        }
217
162
        Ok(awaited_action)
218
162
    }
219
}
220
221
/// A struct that is used to keep the devloper from trying to
222
/// return early from a function.
223
struct NoEarlyReturn;
224
225
0
#[derive(Default, MetricsComponent)]
226
struct SortedAwaitedActions {
227
    #[metric(group = "unknown")]
228
    unknown: BTreeSet<SortedAwaitedAction>,
229
    #[metric(group = "cache_check")]
230
    cache_check: BTreeSet<SortedAwaitedAction>,
231
    #[metric(group = "queued")]
232
    queued: BTreeSet<SortedAwaitedAction>,
233
    #[metric(group = "executing")]
234
    executing: BTreeSet<SortedAwaitedAction>,
235
    #[metric(group = "completed")]
236
    completed: BTreeSet<SortedAwaitedAction>,
237
}
238
239
impl SortedAwaitedActions {
240
39
    fn btree_for_state(&mut self, state: &ActionStage) -> &mut BTreeSet<SortedAwaitedAction> {
241
39
        match state {
242
0
            ActionStage::Unknown => &mut self.unknown,
243
0
            ActionStage::CacheCheck => &mut self.cache_check,
244
26
            ActionStage::Queued => &mut self.queued,
245
13
            ActionStage::Executing => &mut self.executing,
246
0
            ActionStage::Completed(_) => &mut self.completed,
247
0
            ActionStage::CompletedFromCache(_) => &mut self.completed,
248
        }
249
39
    }
250
251
62
    fn insert_sort_map_for_stage(
252
62
        &mut self,
253
62
        stage: &ActionStage,
254
62
        sorted_awaited_action: &SortedAwaitedAction,
255
62
    ) -> Result<(), Error> {
256
62
        let newly_inserted = match stage {
257
0
            ActionStage::Unknown => self.unknown.insert(sorted_awaited_action.clone()),
258
0
            ActionStage::CacheCheck => self.cache_check.insert(sorted_awaited_action.clone()),
259
30
            ActionStage::Queued => self.queued.insert(sorted_awaited_action.clone()),
260
26
            ActionStage::Executing => self.executing.insert(sorted_awaited_action.clone()),
261
6
            ActionStage::Completed(_) => self.completed.insert(sorted_awaited_action.clone()),
262
            ActionStage::CompletedFromCache(_) => {
263
0
                self.completed.insert(sorted_awaited_action.clone())
264
            }
265
        };
266
62
        if !newly_inserted {
  Branch (266:12): [True: 0, False: 62]
  Branch (266:12): [Folded - Ignored]
267
0
            return Err(make_err!(
268
0
                Code::Internal,
269
0
                "Tried to insert an action that was already in the sorted map. This should never happen. {:?} - {:?}",
270
0
                stage,
271
0
                sorted_awaited_action
272
0
            ));
273
62
        }
274
62
        Ok(())
275
62
    }
276
277
39
    fn process_state_changes(
278
39
        &mut self,
279
39
        old_awaited_action: &AwaitedAction,
280
39
        new_awaited_action: &AwaitedAction,
281
39
    ) -> Result<(), Error> {
282
39
        let btree = self.btree_for_state(&old_awaited_action.state().stage);
283
39
        let maybe_sorted_awaited_action = btree.take(&SortedAwaitedAction {
284
39
            sort_key: old_awaited_action.sort_key(),
285
39
            operation_id: new_awaited_action.operation_id().clone(),
286
39
        });
287
288
39
        let Some(sorted_awaited_action) = maybe_sorted_awaited_action else {
  Branch (288:13): [True: 39, False: 0]
  Branch (288:13): [Folded - Ignored]
289
0
            return Err(make_err!(
290
0
                Code::Internal,
291
0
                "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync - {} - {:?}",
292
0
                new_awaited_action.operation_id(),
293
0
                new_awaited_action,
294
0
            ));
295
        };
296
297
39
        self.insert_sort_map_for_stage(&new_awaited_action.state().stage, &sorted_awaited_action)
298
39
            .err_tip(|| 
"In AwaitedActionDb::update_awaited_action"0
)
?0
;
299
39
        Ok(())
300
39
    }
301
}
302
303
/// The database for storing the state of all actions.
304
0
#[derive(MetricsComponent)]
305
pub struct AwaitedActionDbImpl<I: InstantWrapper, NowFn: Fn() -> I> {
306
    /// A lookup table to lookup the state of an action by its client operation id.
307
    #[metric(group = "client_operation_ids")]
308
    client_operation_to_awaited_action: EvictingMap<OperationId, Arc<ClientAwaitedAction>, I>,
309
310
    /// A lookup table to lookup the state of an action by its worker operation id.
311
    #[metric(group = "operation_ids")]
312
    operation_id_to_awaited_action: BTreeMap<OperationId, watch::Sender<AwaitedAction>>,
313
314
    /// A lookup table to lookup the state of an action by its unique qualifier.
315
    #[metric(group = "action_info_hash_key_to_awaited_action")]
316
    action_info_hash_key_to_awaited_action: HashMap<ActionUniqueKey, OperationId>,
317
318
    /// A sorted set of [`AwaitedAction`]s. A wrapper is used to perform sorting
319
    /// based on the [`AwaitedActionSortKey`] of the [`AwaitedAction`].
320
    ///
321
    /// See [`AwaitedActionSortKey`] for more information on the ordering.
322
    #[metric(group = "sorted_action_infos")]
323
    sorted_action_info_hash_keys: SortedAwaitedActions,
324
325
    /// The number of connected clients for each operation id.
326
    #[metric(group = "connected_clients_for_operation_id")]
327
    connected_clients_for_operation_id: HashMap<OperationId, usize>,
328
329
    /// Where to send notifications about important events related to actions.
330
    action_event_tx: mpsc::UnboundedSender<ActionEvent>,
331
332
    /// The function to get the current time.
333
    now_fn: NowFn,
334
}
335
336
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync> AwaitedActionDbImpl<I, NowFn> {
337
503
    async fn get_awaited_action_by_id(
338
503
        &self,
339
503
        client_operation_id: &OperationId,
340
503
    ) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
341
503
        let maybe_client_awaited_action = self
342
503
            .client_operation_to_awaited_action
343
503
            .get(client_operation_id)
344
0
            .await;
345
503
        let Some(
client_awaited_action3
) = maybe_client_awaited_action else {
  Branch (345:13): [True: 0, False: 0]
  Branch (345:13): [Folded - Ignored]
  Branch (345:13): [True: 3, False: 500]
346
500
            return Ok(None);
347
        };
348
349
3
        self.operation_id_to_awaited_action
350
3
            .get(client_awaited_action.operation_id())
351
3
            .map(|tx| {
352
3
                Some(MemoryAwaitedActionSubscriber::new_with_client(
353
3
                    tx.subscribe(),
354
3
                    client_operation_id.clone(),
355
3
                    self.action_event_tx.clone(),
356
3
                    self.now_fn.clone(),
357
3
                ))
358
3
            })
359
3
            .ok_or_else(|| {
360
0
                make_err!(
361
0
                    Code::Internal,
362
0
                    "Failed to get client operation id {client_operation_id:?}"
363
0
                )
364
3
            })
365
503
    }
366
367
    /// Processes action events that need to be handled by the database.
368
54
    async fn handle_action_events(
369
54
        &mut self,
370
54
        action_events: impl IntoIterator<Item = ActionEvent>,
371
54
    ) -> NoEarlyReturn {
372
54
        for action in action_events.into_iter() {
373
54
            event!(Level::DEBUG, ?action, 
"Handling action"0
);
374
54
            match action {
375
0
                ActionEvent::ClientDroppedOperation(operation_id) => {
376
                    // Cleanup operation_id_to_awaited_action.
377
0
                    let Some(tx) = self.operation_id_to_awaited_action.remove(&operation_id) else {
  Branch (377:25): [True: 0, False: 0]
  Branch (377:25): [Folded - Ignored]
  Branch (377:25): [True: 0, False: 0]
378
0
                        event!(
379
0
                            Level::ERROR,
380
                            ?operation_id,
381
0
                            "operation_id_to_awaited_action does not have operation_id"
382
                        );
383
0
                        continue;
384
                    };
385
0
                    let connected_clients = if let Some(connected_clients) = self
  Branch (385:52): [True: 0, False: 0]
  Branch (385:52): [Folded - Ignored]
  Branch (385:52): [True: 0, False: 0]
386
0
                        .connected_clients_for_operation_id
387
0
                        .remove(&operation_id)
388
                    {
389
0
                        connected_clients - 1
390
                    } else {
391
0
                        event!(
392
0
                            Level::ERROR,
393
                            ?operation_id,
394
0
                            "connected_clients_for_operation_id does not have operation_id"
395
                        );
396
0
                        0
397
                    };
398
                    // Note: It is rare to have more than one client listening
399
                    // to the same action, so we assume that we are the last
400
                    // client and insert it back into the map if we detect that
401
                    // there are still clients listening (ie: the happy path
402
                    // is operation.connected_clients == 0).
403
0
                    if connected_clients != 0 {
  Branch (403:24): [True: 0, False: 0]
  Branch (403:24): [Folded - Ignored]
  Branch (403:24): [True: 0, False: 0]
404
0
                        self.operation_id_to_awaited_action
405
0
                            .insert(operation_id.clone(), tx);
406
0
                        self.connected_clients_for_operation_id
407
0
                            .insert(operation_id, connected_clients);
408
0
                        continue;
409
0
                    }
410
0
                    event!(
411
0
                        Level::DEBUG,
412
                        ?operation_id,
413
0
                        "Clearing operation from state manager"
414
                    );
415
0
                    let awaited_action = tx.borrow().clone();
416
0
                    // Cleanup action_info_hash_key_to_awaited_action if it was marked cached.
417
0
                    match &awaited_action.action_info().unique_qualifier {
418
0
                        ActionUniqueQualifier::Cachable(action_key) => {
419
0
                            let maybe_awaited_action = self
420
0
                                .action_info_hash_key_to_awaited_action
421
0
                                .remove(action_key);
422
0
                            if !awaited_action.state().stage.is_finished()
  Branch (422:32): [True: 0, False: 0]
  Branch (422:32): [Folded - Ignored]
  Branch (422:32): [True: 0, False: 0]
423
0
                                && maybe_awaited_action.is_none()
  Branch (423:36): [True: 0, False: 0]
  Branch (423:36): [Folded - Ignored]
  Branch (423:36): [True: 0, False: 0]
424
                            {
425
0
                                event!(
426
0
                                    Level::ERROR,
427
                                    ?operation_id,
428
                                    ?awaited_action,
429
                                    ?action_key,
430
0
                                    "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync",
431
                                );
432
0
                            }
433
                        }
434
0
                        ActionUniqueQualifier::Uncachable(_action_key) => {
435
0
                            // This Operation should not be in the hash_key map.
436
0
                        }
437
                    }
438
439
                    // Cleanup sorted_awaited_action.
440
0
                    let sort_key = awaited_action.sort_key();
441
0
                    let sort_btree_for_state = self
442
0
                        .sorted_action_info_hash_keys
443
0
                        .btree_for_state(&awaited_action.state().stage);
444
0
445
0
                    let maybe_sorted_awaited_action =
446
0
                        sort_btree_for_state.take(&SortedAwaitedAction {
447
0
                            sort_key,
448
0
                            operation_id: operation_id.clone(),
449
0
                        });
450
0
                    if maybe_sorted_awaited_action.is_none() {
  Branch (450:24): [True: 0, False: 0]
  Branch (450:24): [Folded - Ignored]
  Branch (450:24): [True: 0, False: 0]
451
0
                        event!(
452
0
                            Level::ERROR,
453
                            ?operation_id,
454
                            ?sort_key,
455
0
                            "Expected maybe_sorted_awaited_action to have {sort_key:?}",
456
                        );
457
0
                    }
458
                }
459
54
                ActionEvent::ClientKeepAlive(client_id) => {
460
54
                    let maybe_size = self
461
54
                        .client_operation_to_awaited_action
462
54
                        .size_for_key(&client_id)
463
0
                        .await;
464
54
                    if maybe_size.is_none() {
  Branch (464:24): [True: 0, False: 0]
  Branch (464:24): [Folded - Ignored]
  Branch (464:24): [True: 0, False: 54]
465
0
                        event!(
466
0
                            Level::ERROR,
467
                            ?client_id,
468
0
                            "client_operation_to_awaited_action does not have client_id",
469
                        );
470
54
                    }
471
                }
472
            }
473
        }
474
54
        NoEarlyReturn
475
54
    }
476
477
0
    fn get_awaited_actions_range(
478
0
        &self,
479
0
        start: Bound<&OperationId>,
480
0
        end: Bound<&OperationId>,
481
0
    ) -> impl Iterator<Item = (&'_ OperationId, MemoryAwaitedActionSubscriber<I, NowFn>)> {
482
0
        self.operation_id_to_awaited_action
483
0
            .range((start, end))
484
0
            .map(|(operation_id, tx)| {
485
0
                (
486
0
                    operation_id,
487
0
                    MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()),
488
0
                )
489
0
            })
490
0
    }
491
492
80
    fn get_by_operation_id(
493
80
        &self,
494
80
        operation_id: &OperationId,
495
80
    ) -> Option<MemoryAwaitedActionSubscriber<I, NowFn>> {
496
80
        self.operation_id_to_awaited_action
497
80
            .get(operation_id)
498
80
            .map(|tx| MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()))
499
80
    }
500
501
124
    fn get_range_of_actions<'a, 'b>(
502
124
        &'a self,
503
124
        state: SortedAwaitedActionState,
504
124
        range: impl RangeBounds<SortedAwaitedAction> + 'b,
505
124
    ) -> impl DoubleEndedIterator<
506
124
        Item = Result<
507
124
            (
508
124
                &'a SortedAwaitedAction,
509
124
                MemoryAwaitedActionSubscriber<I, NowFn>,
510
124
            ),
511
124
            Error,
512
124
        >,
513
124
    > + 'a {
514
124
        let btree = match state {
515
0
            SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check,
516
124
            SortedAwaitedActionState::Queued => &self.sorted_action_info_hash_keys.queued,
517
0
            SortedAwaitedActionState::Executing => &self.sorted_action_info_hash_keys.executing,
518
0
            SortedAwaitedActionState::Completed => &self.sorted_action_info_hash_keys.completed,
519
        };
520
124
        btree.range(range).map(|sorted_awaited_action| {
521
40
            let operation_id = &sorted_awaited_action.operation_id;
522
40
            self.get_by_operation_id(operation_id)
523
40
                .ok_or_else(|| {
524
0
                    make_err!(
525
0
                        Code::Internal,
526
0
                        "Failed to get operation id {}",
527
0
                        operation_id
528
0
                    )
529
40
                })
530
40
                .map(|subscriber| (sorted_awaited_action, subscriber))
531
124
        })
532
124
    }
533
534
39
    fn process_state_changes_for_hash_key_map(
535
39
        action_info_hash_key_to_awaited_action: &mut HashMap<ActionUniqueKey, OperationId>,
536
39
        new_awaited_action: &AwaitedAction,
537
39
    ) {
538
39
        // Only process changes if the stage is not finished.
539
39
        if !new_awaited_action.state().stage.is_finished() {
  Branch (539:12): [True: 0, False: 0]
  Branch (539:12): [Folded - Ignored]
  Branch (539:12): [True: 33, False: 6]
540
33
            return;
541
6
        }
542
6
        match &new_awaited_action.action_info().unique_qualifier {
543
6
            ActionUniqueQualifier::Cachable(action_key) => {
544
6
                let maybe_awaited_action =
545
6
                    action_info_hash_key_to_awaited_action.remove(action_key);
546
6
                match maybe_awaited_action {
547
6
                    Some(removed_operation_id) => {
548
6
                        if &removed_operation_id != new_awaited_action.operation_id() {
  Branch (548:28): [True: 0, False: 0]
  Branch (548:28): [Folded - Ignored]
  Branch (548:28): [True: 0, False: 6]
549
0
                            event!(
550
0
                                Level::ERROR,
551
                                ?removed_operation_id,
552
                                ?new_awaited_action,
553
                                ?action_key,
554
0
                                "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync",
555
                            );
556
6
                        }
557
                    }
558
                    None => {
559
0
                        event!(
560
0
                            Level::ERROR,
561
                            ?new_awaited_action,
562
                            ?action_key,
563
0
                            "action_info_hash_key_to_awaited_action out of sync, it should have had the unique_key",
564
                        );
565
                    }
566
                }
567
            }
568
0
            ActionUniqueQualifier::Uncachable(_action_key) => {
569
0
                // If we are not cachable, the action should not be in the
570
0
                // hash_key map, so we don't need to process anything in
571
0
                // action_info_hash_key_to_awaited_action.
572
0
            }
573
        }
574
39
    }
575
576
39
    fn update_awaited_action(
577
39
        &mut self,
578
39
        mut new_awaited_action: AwaitedAction,
579
39
    ) -> Result<(), Error> {
580
39
        let tx = self
581
39
            .operation_id_to_awaited_action
582
39
            .get(new_awaited_action.operation_id())
583
39
            .ok_or_else(|| {
584
0
                make_err!(
585
0
                    Code::Internal,
586
0
                    "OperationId does not exist in map in AwaitedActionDb::update_awaited_action"
587
0
                )
588
39
            })
?0
;
589
        {
590
            // Note: It's important to drop old_awaited_action before we call
591
            // send_replace or we will have a deadlock.
592
39
            let old_awaited_action = tx.borrow();
593
39
594
39
            // Do not process changes if the action version is not in sync with
595
39
            // what the sender based the update on.
596
39
            if old_awaited_action.version() != new_awaited_action.version() {
  Branch (596:16): [True: 0, False: 0]
  Branch (596:16): [Folded - Ignored]
  Branch (596:16): [True: 0, False: 39]
597
0
                return Err(make_err!(
598
0
                    // From: https://grpc.github.io/grpc/core/md_doc_statuscodes.html
599
0
                    // Use ABORTED if the client should retry at a higher level
600
0
                    // (e.g., when a client-specified test-and-set fails,
601
0
                    // indicating the client should restart a read-modify-write
602
0
                    // sequence)
603
0
                    Code::Aborted,
604
0
                    "{} Expected {} but got {} for operation_id {:?} - {:?}",
605
0
                    "Tried to update an awaited action with an incorrect version.",
606
0
                    old_awaited_action.version(),
607
0
                    new_awaited_action.version(),
608
0
                    old_awaited_action,
609
0
                    new_awaited_action,
610
0
                ));
611
39
            }
612
39
            new_awaited_action.increment_version();
613
614
0
            error_if!(
615
39
                old_awaited_action.action_info().unique_qualifier
  Branch (615:17): [True: 0, False: 0]
  Branch (615:17): [Folded - Ignored]
  Branch (615:17): [True: 0, False: 39]
616
39
                    != new_awaited_action.action_info().unique_qualifier,
617
                "Unique key changed for operation_id {:?} - {:?} - {:?}",
618
0
                new_awaited_action.operation_id(),
619
0
                old_awaited_action.action_info(),
620
0
                new_awaited_action.action_info(),
621
            );
622
39
            let is_same_stage = old_awaited_action
623
39
                .state()
624
39
                .stage
625
39
                .is_same_stage(&new_awaited_action.state().stage);
626
39
627
39
            if !is_same_stage {
  Branch (627:16): [True: 0, False: 0]
  Branch (627:16): [Folded - Ignored]
  Branch (627:16): [True: 39, False: 0]
628
39
                self.sorted_action_info_hash_keys
629
39
                    .process_state_changes(&old_awaited_action, &new_awaited_action)
?0
;
630
39
                Self::process_state_changes_for_hash_key_map(
631
39
                    &mut self.action_info_hash_key_to_awaited_action,
632
39
                    &new_awaited_action,
633
39
                );
634
0
            }
635
        }
636
637
        // Notify all listeners of the new state and ignore if no one is listening.
638
        // Note: Do not use `.send()` as it will not update the state if all listeners
639
        // are dropped.
640
39
        let _ = tx.send_replace(new_awaited_action);
641
39
642
39
        Ok(())
643
39
    }
644
645
    /// Creates a new [`ClientAwaitedAction`] and a [`watch::Receiver`] to
646
    /// listen for changes. We don't do this in-line because it is important
647
    /// to ALWAYS construct a [`ClientAwaitedAction`] before inserting it into
648
    /// the map. Failing to do so may result in memory leaks. This is because
649
    /// [`ClientAwaitedAction`] implements a drop function that will trigger
650
    /// cleanup of the other maps on drop.
651
23
    fn make_client_awaited_action(
652
23
        &mut self,
653
23
        operation_id: &OperationId,
654
23
        awaited_action: AwaitedAction,
655
23
    ) -> (Arc<ClientAwaitedAction>, watch::Receiver<AwaitedAction>) {
656
23
        let (tx, rx) = watch::channel(awaited_action);
657
23
        let client_awaited_action = Arc::new(ClientAwaitedAction::new(
658
23
            operation_id.clone(),
659
23
            self.action_event_tx.clone(),
660
23
        ));
661
23
        self.operation_id_to_awaited_action
662
23
            .insert(operation_id.clone(), tx);
663
23
        self.connected_clients_for_operation_id
664
23
            .insert(operation_id.clone(), 1);
665
23
        (client_awaited_action, rx)
666
23
    }
667
668
25
    async fn add_action(
669
25
        &mut self,
670
25
        client_operation_id: OperationId,
671
25
        action_info: Arc<ActionInfo>,
672
25
    ) -> Result<MemoryAwaitedActionSubscriber<I, NowFn>, Error> {
673
        // Check to see if the action is already known and subscribe if it is.
674
25
        let subscription_result = self
675
25
            .try_subscribe(
676
25
                &client_operation_id,
677
25
                &action_info.unique_qualifier,
678
25
                action_info.priority,
679
25
            )
680
0
            .await
681
25
            .err_tip(|| 
"In AwaitedActionDb::subscribe_or_add_action"0
);
682
25
        match subscription_result {
683
0
            Err(err) => return Err(err),
684
2
            Ok(Some(subscription)) => return Ok(subscription),
685
23
            Ok(None) => { /* Add item to queue. */ }
686
        }
687
688
23
        let maybe_unique_key = match &action_info.unique_qualifier {
689
23
            ActionUniqueQualifier::Cachable(unique_key) => Some(unique_key.clone()),
690
0
            ActionUniqueQualifier::Uncachable(_unique_key) => None,
691
        };
692
23
        let operation_id = OperationId::default();
693
23
        let awaited_action =
694
23
            AwaitedAction::new(operation_id.clone(), action_info, (self.now_fn)().now());
695
23
        debug_assert!(
696
0
            ActionStage::Queued == awaited_action.state().stage,
697
0
            "Expected action to be queued"
698
        );
699
23
        let sort_key = awaited_action.sort_key();
700
23
701
23
        let (client_awaited_action, rx) =
702
23
            self.make_client_awaited_action(&operation_id.clone(), awaited_action);
703
23
704
23
        event!(
705
23
            Level::DEBUG,
706
            ?client_operation_id,
707
            ?operation_id,
708
            ?client_awaited_action,
709
0
            "Adding action"
710
        );
711
712
23
        self.client_operation_to_awaited_action
713
23
            .insert(client_operation_id.clone(), client_awaited_action)
714
0
            .await;
715
716
        // Note: We only put items in the map that are cachable.
717
23
        if let Some(unique_key) = maybe_unique_key {
  Branch (717:16): [True: 0, False: 0]
  Branch (717:16): [Folded - Ignored]
  Branch (717:16): [True: 23, False: 0]
718
23
            let old_value = self
719
23
                .action_info_hash_key_to_awaited_action
720
23
                .insert(unique_key, operation_id.clone());
721
23
            if let Some(
old_value0
) = old_value {
  Branch (721:20): [True: 0, False: 0]
  Branch (721:20): [Folded - Ignored]
  Branch (721:20): [True: 0, False: 23]
722
0
                event!(
723
0
                    Level::ERROR,
724
                    ?operation_id,
725
                    ?old_value,
726
0
                    "action_info_hash_key_to_awaited_action already has unique_key"
727
                );
728
23
            }
729
0
        }
730
731
23
        self.sorted_action_info_hash_keys
732
23
            .insert_sort_map_for_stage(
733
23
                &ActionStage::Queued,
734
23
                &SortedAwaitedAction {
735
23
                    sort_key,
736
23
                    operation_id,
737
23
                },
738
23
            )
739
23
            .err_tip(|| 
"In AwaitedActionDb::subscribe_or_add_action"0
)
?0
;
740
741
23
        Ok(MemoryAwaitedActionSubscriber::new_with_client(
742
23
            rx,
743
23
            client_operation_id,
744
23
            self.action_event_tx.clone(),
745
23
            self.now_fn.clone(),
746
23
        ))
747
25
    }
748
749
25
    async fn try_subscribe(
750
25
        &mut self,
751
25
        client_operation_id: &OperationId,
752
25
        unique_qualifier: &ActionUniqueQualifier,
753
25
        // TODO(allada) To simplify the scheduler 2024 refactor, we
754
25
        // removed the ability to upgrade priorities of actions.
755
25
        // we should add priority upgrades back in.
756
25
        _priority: i32,
757
25
    ) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
758
25
        let unique_key = match unique_qualifier {
759
25
            ActionUniqueQualifier::Cachable(unique_key) => unique_key,
760
0
            ActionUniqueQualifier::Uncachable(_unique_key) => return Ok(None),
761
        };
762
763
25
        let Some(
operation_id2
) = self.action_info_hash_key_to_awaited_action.get(unique_key) else {
  Branch (763:13): [True: 0, False: 0]
  Branch (763:13): [Folded - Ignored]
  Branch (763:13): [True: 2, False: 23]
764
23
            return Ok(None); // Not currently running.
765
        };
766
767
2
        let Some(tx) = self.operation_id_to_awaited_action.get(operation_id) else {
  Branch (767:13): [True: 0, False: 0]
  Branch (767:13): [Folded - Ignored]
  Branch (767:13): [True: 2, False: 0]
768
0
            return Err(make_err!(
769
0
                Code::Internal,
770
0
                "operation_id_to_awaited_action and action_info_hash_key_to_awaited_action are out of sync for {unique_key:?} - {operation_id}"
771
0
            ));
772
        };
773
774
0
        error_if!(
775
2
            tx.borrow().state().stage.is_finished(),
  Branch (775:13): [True: 0, False: 0]
  Branch (775:13): [Folded - Ignored]
  Branch (775:13): [True: 0, False: 2]
776
            "Tried to subscribe to a completed action but it already finished. This should never happen. {:?}",
777
0
            tx.borrow()
778
        );
779
780
2
        let maybe_connected_clients = self
781
2
            .connected_clients_for_operation_id
782
2
            .get_mut(operation_id);
783
2
        let Some(connected_clients) = maybe_connected_clients else {
  Branch (783:13): [True: 0, False: 0]
  Branch (783:13): [Folded - Ignored]
  Branch (783:13): [True: 2, False: 0]
784
0
            return Err(make_err!(
785
0
                Code::Internal,
786
0
                "connected_clients_for_operation_id and operation_id_to_awaited_action are out of sync for {unique_key:?} - {operation_id}"
787
0
            ));
788
        };
789
2
        *connected_clients += 1;
790
2
791
2
        let subscription = tx.subscribe();
792
2
793
2
        self.client_operation_to_awaited_action
794
2
            .insert(
795
2
                client_operation_id.clone(),
796
2
                Arc::new(ClientAwaitedAction::new(
797
2
                    operation_id.clone(),
798
2
                    self.action_event_tx.clone(),
799
2
                )),
800
2
            )
801
0
            .await;
802
803
2
        Ok(Some(MemoryAwaitedActionSubscriber::new_with_client(
804
2
            subscription,
805
2
            client_operation_id.clone(),
806
2
            self.action_event_tx.clone(),
807
2
            self.now_fn.clone(),
808
2
        )))
809
25
    }
810
}
811
812
0
#[derive(MetricsComponent)]
813
pub struct MemoryAwaitedActionDb<I: InstantWrapper, NowFn: Fn() -> I> {
814
    #[metric]
815
    inner: Arc<Mutex<AwaitedActionDbImpl<I, NowFn>>>,
816
    tasks_change_notify: Arc<Notify>,
817
    _handle_awaited_action_events: JoinHandleDropGuard<()>,
818
}
819
820
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static>
821
    MemoryAwaitedActionDb<I, NowFn>
822
{
823
19
    pub fn new(
824
19
        eviction_config: &EvictionPolicy,
825
19
        tasks_change_notify: Arc<Notify>,
826
19
        now_fn: NowFn,
827
19
    ) -> Self {
828
19
        let (action_event_tx, mut action_event_rx) = mpsc::unbounded_channel();
829
19
        let inner = Arc::new(Mutex::new(AwaitedActionDbImpl {
830
19
            client_operation_to_awaited_action: EvictingMap::new(eviction_config, (now_fn)()),
831
19
            operation_id_to_awaited_action: BTreeMap::new(),
832
19
            action_info_hash_key_to_awaited_action: HashMap::new(),
833
19
            sorted_action_info_hash_keys: SortedAwaitedActions::default(),
834
19
            connected_clients_for_operation_id: HashMap::new(),
835
19
            action_event_tx,
836
19
            now_fn,
837
19
        }));
838
19
        let weak_inner = Arc::downgrade(&inner);
839
19
        Self {
840
19
            inner,
841
19
            tasks_change_notify,
842
19
            _handle_awaited_action_events: spawn!("handle_awaited_action_events", async move {
843
19
                let mut dropped_operation_ids = Vec::with_capacity(MAX_ACTION_EVENTS_RX_PER_CYCLE);
844
                loop {
845
73
                    dropped_operation_ids.clear();
846
73
                    action_event_rx
847
73
                        .recv_many(&mut dropped_operation_ids, MAX_ACTION_EVENTS_RX_PER_CYCLE)
848
54
                        .await;
849
54
                    let Some(inner) = weak_inner.upgrade() else {
  Branch (849:25): [True: 0, False: 0]
  Branch (849:25): [Folded - Ignored]
  Branch (849:25): [True: 54, False: 0]
850
0
                        return; // Nothing to cleanup, our struct is dropped.
851
                    };
852
54
                    let mut inner = inner.lock().
await0
;
853
54
                    inner
854
54
                        .handle_action_events(dropped_operation_ids.drain(..))
855
0
                        .await;
856
                }
857
19
            
}0
),
858
        }
859
19
    }
860
}
861
862
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static> AwaitedActionDb
863
    for MemoryAwaitedActionDb<I, NowFn>
864
{
865
    type Subscriber = MemoryAwaitedActionSubscriber<I, NowFn>;
866
867
503
    async fn get_awaited_action_by_id(
868
503
        &self,
869
503
        client_operation_id: &OperationId,
870
503
    ) -> Result<Option<Self::Subscriber>, Error> {
871
503
        self.inner
872
503
            .lock()
873
0
            .await
874
503
            .get_awaited_action_by_id(client_operation_id)
875
0
            .await
876
503
    }
877
878
0
    async fn get_all_awaited_actions(
879
0
        &self,
880
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
881
0
        Ok(ChunkedStream::new(
882
0
            Bound::Unbounded,
883
0
            Bound::Unbounded,
884
0
            move |start, end, mut output| async move {
885
0
                let inner = self.inner.lock().await;
886
0
                let mut maybe_new_start = None;
887
888
0
                for (operation_id, item) in
889
0
                    inner.get_awaited_actions_range(start.as_ref(), end.as_ref())
890
0
                {
891
0
                    output.push_back(item);
892
0
                    maybe_new_start = Some(operation_id);
893
0
                }
894
895
0
                Ok(maybe_new_start
896
0
                    .map(|new_start| ((Bound::Excluded(new_start.clone()), end), output)))
897
0
            },
898
0
        ))
899
0
    }
900
901
40
    async fn get_by_operation_id(
902
40
        &self,
903
40
        operation_id: &OperationId,
904
40
    ) -> Result<Option<Self::Subscriber>, Error> {
905
40
        Ok(self.inner.lock().
await0
.get_by_operation_id(operation_id))
906
40
    }
907
908
87
    async fn get_range_of_actions(
909
87
        &self,
910
87
        state: SortedAwaitedActionState,
911
87
        start: Bound<SortedAwaitedAction>,
912
87
        end: Bound<SortedAwaitedAction>,
913
87
        desc: bool,
914
87
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
915
87
        Ok(ChunkedStream::new(
916
87
            start,
917
87
            end,
918
124
            move |start, end, mut output| async move {
919
124
                let inner = self.inner.lock().
await0
;
920
124
                let mut done = true;
921
124
                let mut new_start = start.as_ref();
922
124
                let mut new_end = end.as_ref();
923
124
924
124
                let iterator = inner
925
124
                    .get_range_of_actions(state, (start.as_ref(), end.as_ref()))
926
124
                    .map(|res| 
res.err_tip(40
||
"In AwaitedActionDb::get_range_of_actions"0
)40
);
927
124
928
124
                // TODO(allada) This should probably use the `.left()/right()` pattern,
929
124
                // but that doesn't exist in the std or any libraries we use.
930
124
                if desc {
  Branch (930:20): [True: 0, False: 0]
  Branch (930:20): [Folded - Ignored]
  Branch (930:20): [True: 124, False: 0]
931
124
                    for 
result40
in iterator.rev() {
932
40
                        let (sorted_awaited_action, item) =
933
40
                            result.err_tip(|| 
"In AwaitedActionDb::get_range_of_actions"0
)
?0
;
934
40
                        output.push_back(item);
935
40
                        new_end = Bound::Excluded(sorted_awaited_action);
936
40
                        done = false;
937
                    }
938
                } else {
939
0
                    for result in iterator {
940
0
                        let (sorted_awaited_action, item) =
941
0
                            result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?;
942
0
                        output.push_back(item);
943
0
                        new_start = Bound::Excluded(sorted_awaited_action);
944
0
                        done = false;
945
                    }
946
                }
947
124
                if done {
  Branch (947:20): [True: 0, False: 0]
  Branch (947:20): [Folded - Ignored]
  Branch (947:20): [True: 87, False: 37]
948
87
                    return Ok(None);
949
37
                }
950
37
                Ok(Some(((new_start.cloned(), new_end.cloned()), output)))
951
248
            },
952
87
        ))
953
87
    }
954
955
39
    async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
956
39
        self.inner
957
39
            .lock()
958
0
            .await
959
39
            .update_awaited_action(new_awaited_action)
?0
;
960
39
        self.tasks_change_notify.notify_one();
961
39
        Ok(())
962
39
    }
963
964
25
    async fn add_action(
965
25
        &self,
966
25
        client_operation_id: OperationId,
967
25
        action_info: Arc<ActionInfo>,
968
25
    ) -> Result<Self::Subscriber, Error> {
969
25
        let subscriber = self
970
25
            .inner
971
25
            .lock()
972
0
            .await
973
25
            .add_action(client_operation_id, action_info)
974
0
            .await?;
975
25
        self.tasks_change_notify.notify_one();
976
25
        Ok(subscriber)
977
25
    }
978
}