Coverage Report

Created: 2025-01-30 02:10

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/memory_awaited_action_db.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::{BTreeMap, BTreeSet, HashMap};
16
use std::ops::{Bound, RangeBounds};
17
use std::sync::Arc;
18
use std::time::Duration;
19
20
use async_lock::Mutex;
21
use futures::{FutureExt, Stream};
22
use nativelink_config::stores::EvictionPolicy;
23
use nativelink_error::{error_if, make_err, Code, Error, ResultExt};
24
use nativelink_metric::MetricsComponent;
25
use nativelink_util::action_messages::{
26
    ActionInfo, ActionStage, ActionUniqueKey, ActionUniqueQualifier, OperationId,
27
};
28
use nativelink_util::chunked_stream::ChunkedStream;
29
use nativelink_util::evicting_map::{EvictingMap, LenEntry};
30
use nativelink_util::instant_wrapper::InstantWrapper;
31
use nativelink_util::spawn;
32
use nativelink_util::task::JoinHandleDropGuard;
33
use tokio::sync::{mpsc, watch, Notify};
34
use tracing::{event, Level};
35
36
use crate::awaited_action_db::{
37
    AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedAction,
38
    SortedAwaitedActionState,
39
};
40
41
/// Number of events to process per cycle.
42
const MAX_ACTION_EVENTS_RX_PER_CYCLE: usize = 1024;
43
44
/// Duration to wait before sending client keep alive messages.
45
const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10);
46
47
/// Represents a client that is currently listening to an action.
48
/// When the client is dropped, it will send the `AwaitedAction` to the
49
/// `event_tx` if there are other cleanups needed.
50
#[derive(Debug)]
51
struct ClientAwaitedAction {
52
    /// The `OperationId` that the client is listening to.
53
    operation_id: OperationId,
54
55
    /// The sender to notify of this struct being dropped.
56
    event_tx: mpsc::UnboundedSender<ActionEvent>,
57
}
58
59
impl ClientAwaitedAction {
60
27
    pub fn new(operation_id: OperationId, event_tx: mpsc::UnboundedSender<ActionEvent>) -> Self {
61
27
        Self {
62
27
            operation_id,
63
27
            event_tx,
64
27
        }
65
27
    }
66
67
3
    pub fn operation_id(&self) -> &OperationId {
68
3
        &self.operation_id
69
3
    }
70
}
71
72
impl Drop for ClientAwaitedAction {
73
27
    fn drop(&mut self) {
74
27
        // If we failed to send it means noone is listening.
75
27
        let _ = self.event_tx.send(ActionEvent::ClientDroppedOperation(
76
27
            self.operation_id.clone(),
77
27
        ));
78
27
    }
79
}
80
81
/// Trait to be able to use the `EvictingMap` with `ClientAwaitedAction`.
82
/// Note: We only use `EvictingMap` for a time based eviction, which is
83
/// why the implementation has fixed default values in it.
84
impl LenEntry for ClientAwaitedAction {
85
    #[inline]
86
83
    fn len(&self) -> u64 {
87
83
        0
88
83
    }
89
90
    #[inline]
91
0
    fn is_empty(&self) -> bool {
92
0
        true
93
0
    }
94
}
95
96
/// Actions the `AwaitedActionsDb` needs to process.
97
#[derive(Debug)]
98
pub(crate) enum ActionEvent {
99
    /// A client has sent a keep alive message.
100
    ClientKeepAlive(OperationId),
101
    /// A client has dropped and pointed to `OperationId`.
102
    ClientDroppedOperation(OperationId),
103
}
104
105
/// Information required to track an individual client
106
/// keep alive config and state.
107
struct ClientInfo<I: InstantWrapper, NowFn: Fn() -> I> {
108
    /// The client operation id.
109
    client_operation_id: OperationId,
110
    /// The last time a keep alive was sent.
111
    last_keep_alive: I,
112
    /// The function to get the current time.
113
    now_fn: NowFn,
114
    /// The sender to notify of this struct had an event.
115
    event_tx: mpsc::UnboundedSender<ActionEvent>,
116
}
117
118
/// Subscriber that clients can be used to monitor when `AwaitedActions` change.
119
pub struct MemoryAwaitedActionSubscriber<I: InstantWrapper, NowFn: Fn() -> I> {
120
    /// The receiver to listen for changes.
121
    awaited_action_rx: watch::Receiver<AwaitedAction>,
122
    /// If a client id is known this is the info needed to keep the client
123
    /// action alive.
124
    client_info: Option<ClientInfo<I, NowFn>>,
125
}
126
127
impl<I: InstantWrapper, NowFn: Fn() -> I> MemoryAwaitedActionSubscriber<I, NowFn> {
128
83
    fn new(mut awaited_action_rx: watch::Receiver<AwaitedAction>) -> Self {
129
83
        awaited_action_rx.mark_changed();
130
83
        Self {
131
83
            awaited_action_rx,
132
83
            client_info: None,
133
83
        }
134
83
    }
135
136
30
    fn new_with_client(
137
30
        mut awaited_action_rx: watch::Receiver<AwaitedAction>,
138
30
        client_operation_id: OperationId,
139
30
        event_tx: mpsc::UnboundedSender<ActionEvent>,
140
30
        now_fn: NowFn,
141
30
    ) -> Self
142
30
    where
143
30
        NowFn: Fn() -> I,
144
30
    {
145
30
        awaited_action_rx.mark_changed();
146
30
        Self {
147
30
            awaited_action_rx,
148
30
            client_info: Some(ClientInfo {
149
30
                client_operation_id,
150
30
                last_keep_alive: I::from_secs(0),
151
30
                now_fn,
152
30
                event_tx,
153
30
            }),
154
30
        }
155
30
    }
156
}
157
158
impl<I, NowFn> AwaitedActionSubscriber for MemoryAwaitedActionSubscriber<I, NowFn>
159
where
160
    I: InstantWrapper,
161
    NowFn: Fn() -> I + Send + Sync + 'static,
162
{
163
58
    async fn changed(&mut self) -> Result<AwaitedAction, Error> {
164
44
        let client_operation_id = {
165
58
            let changed_fut = self.awaited_action_rx.changed().map(|r| {
166
44
                r.map_err(|e| {
167
0
                    make_err!(
168
0
                        Code::Internal,
169
0
                        "Failed to wait for awaited action to change {e:?}"
170
0
                    )
171
44
                })
172
58
            });
173
58
            let Some(client_info) = self.client_info.as_mut() else {
  Branch (173:17): [True: 0, False: 0]
  Branch (173:17): [Folded - Ignored]
  Branch (173:17): [True: 58, False: 0]
174
0
                changed_fut.await?;
175
0
                return Ok(self.awaited_action_rx.borrow().clone());
176
            };
177
58
            tokio::pin!(changed_fut);
178
            loop {
179
164
                if client_info.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION {
  Branch (179:20): [True: 0, False: 0]
  Branch (179:20): [Folded - Ignored]
  Branch (179:20): [True: 55, False: 109]
180
55
                    client_info.last_keep_alive = (client_info.now_fn)();
181
55
                    // Failing to send just means our receiver dropped.
182
55
                    let _ = client_info.event_tx.send(ActionEvent::ClientKeepAlive(
183
55
                        client_info.client_operation_id.clone(),
184
55
                    ));
185
109
                }
186
164
                let sleep_fut = (client_info.now_fn)().sleep(CLIENT_KEEPALIVE_DURATION);
187
164
                tokio::select! {
188
164
                    
result44
= &mut changed_fut => {
189
44
                        result
?0
;
190
44
                        break;
191
                    }
192
164
                    () = sleep_fut => {
193
106
                        // If we haven't received any updates for a while, we should
194
106
                        // let the database know that we are still listening to prevent
195
106
                        // the action from being dropped.
196
106
                    }
197
                }
198
            }
199
44
            client_info.client_operation_id.clone()
200
44
        };
201
44
        // At this stage we know that this event is a client request, so we need
202
44
        // to populate the client_operation_id.
203
44
        let mut awaited_action = self.awaited_action_rx.borrow().clone();
204
44
        awaited_action.set_client_operation_id(client_operation_id);
205
44
        Ok(awaited_action)
206
44
    }
207
208
168
    async fn borrow(&self) -> Result<AwaitedAction, Error> {
209
168
        let mut awaited_action = self.awaited_action_rx.borrow().clone();
210
168
        if let Some(
client_info16
) = self.client_info.as_ref() {
  Branch (210:16): [True: 0, False: 0]
  Branch (210:16): [Folded - Ignored]
  Branch (210:16): [True: 16, False: 152]
211
16
            awaited_action.set_client_operation_id(client_info.client_operation_id.clone());
212
152
        }
213
168
        Ok(awaited_action)
214
168
    }
215
}
216
217
/// A struct that is used to keep the devloper from trying to
218
/// return early from a function.
219
struct NoEarlyReturn;
220
221
#[derive(Default, MetricsComponent)]
222
struct SortedAwaitedActions {
223
    #[metric(group = "unknown")]
224
    unknown: BTreeSet<SortedAwaitedAction>,
225
    #[metric(group = "cache_check")]
226
    cache_check: BTreeSet<SortedAwaitedAction>,
227
    #[metric(group = "queued")]
228
    queued: BTreeSet<SortedAwaitedAction>,
229
    #[metric(group = "executing")]
230
    executing: BTreeSet<SortedAwaitedAction>,
231
    #[metric(group = "completed")]
232
    completed: BTreeSet<SortedAwaitedAction>,
233
}
234
235
impl SortedAwaitedActions {
236
39
    fn btree_for_state(&mut self, state: &ActionStage) -> &mut BTreeSet<SortedAwaitedAction> {
237
39
        match state {
238
0
            ActionStage::Unknown => &mut self.unknown,
239
0
            ActionStage::CacheCheck => &mut self.cache_check,
240
26
            ActionStage::Queued => &mut self.queued,
241
13
            ActionStage::Executing => &mut self.executing,
242
0
            ActionStage::Completed(_) | ActionStage::CompletedFromCache(_) => &mut self.completed,
243
        }
244
39
    }
245
246
63
    fn insert_sort_map_for_stage(
247
63
        &mut self,
248
63
        stage: &ActionStage,
249
63
        sorted_awaited_action: &SortedAwaitedAction,
250
63
    ) -> Result<(), Error> {
251
63
        let newly_inserted = match stage {
252
0
            ActionStage::Unknown => self.unknown.insert(sorted_awaited_action.clone()),
253
0
            ActionStage::CacheCheck => self.cache_check.insert(sorted_awaited_action.clone()),
254
31
            ActionStage::Queued => self.queued.insert(sorted_awaited_action.clone()),
255
26
            ActionStage::Executing => self.executing.insert(sorted_awaited_action.clone()),
256
6
            ActionStage::Completed(_) => self.completed.insert(sorted_awaited_action.clone()),
257
            ActionStage::CompletedFromCache(_) => {
258
0
                self.completed.insert(sorted_awaited_action.clone())
259
            }
260
        };
261
63
        if !newly_inserted {
  Branch (261:12): [True: 0, False: 63]
  Branch (261:12): [Folded - Ignored]
262
0
            return Err(make_err!(
263
0
                Code::Internal,
264
0
                "Tried to insert an action that was already in the sorted map. This should never happen. {:?} - {:?}",
265
0
                stage,
266
0
                sorted_awaited_action
267
0
            ));
268
63
        }
269
63
        Ok(())
270
63
    }
271
272
39
    fn process_state_changes(
273
39
        &mut self,
274
39
        old_awaited_action: &AwaitedAction,
275
39
        new_awaited_action: &AwaitedAction,
276
39
    ) -> Result<(), Error> {
277
39
        let btree = self.btree_for_state(&old_awaited_action.state().stage);
278
39
        let maybe_sorted_awaited_action = btree.take(&SortedAwaitedAction {
279
39
            sort_key: old_awaited_action.sort_key(),
280
39
            operation_id: new_awaited_action.operation_id().clone(),
281
39
        });
282
283
39
        let Some(sorted_awaited_action) = maybe_sorted_awaited_action else {
  Branch (283:13): [True: 39, False: 0]
  Branch (283:13): [Folded - Ignored]
284
0
            return Err(make_err!(
285
0
                Code::Internal,
286
0
                "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync - {} - {:?}",
287
0
                new_awaited_action.operation_id(),
288
0
                new_awaited_action,
289
0
            ));
290
        };
291
292
39
        self.insert_sort_map_for_stage(&new_awaited_action.state().stage, &sorted_awaited_action)
293
39
            .err_tip(|| 
"In AwaitedActionDb::update_awaited_action"0
)
?0
;
294
39
        Ok(())
295
39
    }
296
}
297
298
/// The database for storing the state of all actions.
299
#[derive(MetricsComponent)]
300
pub struct AwaitedActionDbImpl<I: InstantWrapper, NowFn: Fn() -> I> {
301
    /// A lookup table to lookup the state of an action by its client operation id.
302
    #[metric(group = "client_operation_ids")]
303
    client_operation_to_awaited_action: EvictingMap<OperationId, Arc<ClientAwaitedAction>, I>,
304
305
    /// A lookup table to lookup the state of an action by its worker operation id.
306
    #[metric(group = "operation_ids")]
307
    operation_id_to_awaited_action: BTreeMap<OperationId, watch::Sender<AwaitedAction>>,
308
309
    /// A lookup table to lookup the state of an action by its unique qualifier.
310
    #[metric(group = "action_info_hash_key_to_awaited_action")]
311
    action_info_hash_key_to_awaited_action: HashMap<ActionUniqueKey, OperationId>,
312
313
    /// A sorted set of [`AwaitedAction`]s. A wrapper is used to perform sorting
314
    /// based on the [`AwaitedActionSortKey`] of the [`AwaitedAction`].
315
    ///
316
    /// See [`AwaitedActionSortKey`] for more information on the ordering.
317
    #[metric(group = "sorted_action_infos")]
318
    sorted_action_info_hash_keys: SortedAwaitedActions,
319
320
    /// The number of connected clients for each operation id.
321
    #[metric(group = "connected_clients_for_operation_id")]
322
    connected_clients_for_operation_id: HashMap<OperationId, usize>,
323
324
    /// Where to send notifications about important events related to actions.
325
    action_event_tx: mpsc::UnboundedSender<ActionEvent>,
326
327
    /// The function to get the current time.
328
    now_fn: NowFn,
329
}
330
331
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync> AwaitedActionDbImpl<I, NowFn> {
332
503
    async fn get_awaited_action_by_id(
333
503
        &self,
334
503
        client_operation_id: &OperationId,
335
503
    ) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
336
503
        let maybe_client_awaited_action = self
337
503
            .client_operation_to_awaited_action
338
503
            .get(client_operation_id)
339
503
            .await;
340
503
        let Some(
client_awaited_action3
) = maybe_client_awaited_action else {
  Branch (340:13): [True: 0, False: 0]
  Branch (340:13): [Folded - Ignored]
  Branch (340:13): [True: 3, False: 500]
341
500
            return Ok(None);
342
        };
343
344
3
        self.operation_id_to_awaited_action
345
3
            .get(client_awaited_action.operation_id())
346
3
            .map(|tx| {
347
3
                Some(MemoryAwaitedActionSubscriber::new_with_client(
348
3
                    tx.subscribe(),
349
3
                    client_operation_id.clone(),
350
3
                    self.action_event_tx.clone(),
351
3
                    self.now_fn.clone(),
352
3
                ))
353
3
            })
354
3
            .ok_or_else(|| {
355
0
                make_err!(
356
0
                    Code::Internal,
357
0
                    "Failed to get client operation id {client_operation_id:?}"
358
0
                )
359
3
            })
360
503
    }
361
362
    /// Processes action events that need to be handled by the database.
363
55
    async fn handle_action_events(
364
55
        &mut self,
365
55
        action_events: impl IntoIterator<Item = ActionEvent>,
366
55
    ) -> NoEarlyReturn {
367
110
        for 
action55
in action_events {
368
55
            event!(Level::DEBUG, ?action, 
"Handling action"0
);
369
55
            match action {
370
1
                ActionEvent::ClientDroppedOperation(operation_id) => {
371
                    // Cleanup operation_id_to_awaited_action.
372
1
                    let Some(tx) = self.operation_id_to_awaited_action.remove(&operation_id) else {
  Branch (372:25): [True: 0, False: 0]
  Branch (372:25): [Folded - Ignored]
  Branch (372:25): [True: 1, False: 0]
373
0
                        event!(
374
0
                            Level::ERROR,
375
                            ?operation_id,
376
0
                            "operation_id_to_awaited_action does not have operation_id"
377
                        );
378
0
                        continue;
379
                    };
380
1
                    let connected_clients = if let Some(connected_clients) = self
  Branch (380:52): [True: 0, False: 0]
  Branch (380:52): [Folded - Ignored]
  Branch (380:52): [True: 1, False: 0]
381
1
                        .connected_clients_for_operation_id
382
1
                        .remove(&operation_id)
383
                    {
384
1
                        connected_clients - 1
385
                    } else {
386
0
                        event!(
387
0
                            Level::ERROR,
388
                            ?operation_id,
389
0
                            "connected_clients_for_operation_id does not have operation_id"
390
                        );
391
0
                        0
392
                    };
393
                    // Note: It is rare to have more than one client listening
394
                    // to the same action, so we assume that we are the last
395
                    // client and insert it back into the map if we detect that
396
                    // there are still clients listening (ie: the happy path
397
                    // is operation.connected_clients == 0).
398
1
                    if connected_clients != 0 {
  Branch (398:24): [True: 0, False: 0]
  Branch (398:24): [Folded - Ignored]
  Branch (398:24): [True: 1, False: 0]
399
1
                        self.operation_id_to_awaited_action
400
1
                            .insert(operation_id.clone(), tx);
401
1
                        self.connected_clients_for_operation_id
402
1
                            .insert(operation_id, connected_clients);
403
1
                        continue;
404
0
                    }
405
0
                    event!(
406
0
                        Level::DEBUG,
407
                        ?operation_id,
408
0
                        "Clearing operation from state manager"
409
                    );
410
0
                    let awaited_action = tx.borrow().clone();
411
0
                    // Cleanup action_info_hash_key_to_awaited_action if it was marked cached.
412
0
                    match &awaited_action.action_info().unique_qualifier {
413
0
                        ActionUniqueQualifier::Cachable(action_key) => {
414
0
                            let maybe_awaited_action = self
415
0
                                .action_info_hash_key_to_awaited_action
416
0
                                .remove(action_key);
417
0
                            if !awaited_action.state().stage.is_finished()
  Branch (417:32): [True: 0, False: 0]
  Branch (417:32): [Folded - Ignored]
  Branch (417:32): [True: 0, False: 0]
418
0
                                && maybe_awaited_action.is_none()
  Branch (418:36): [True: 0, False: 0]
  Branch (418:36): [Folded - Ignored]
  Branch (418:36): [True: 0, False: 0]
419
                            {
420
0
                                event!(
421
0
                                    Level::ERROR,
422
                                    ?operation_id,
423
                                    ?awaited_action,
424
                                    ?action_key,
425
0
                                    "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync",
426
                                );
427
0
                            }
428
                        }
429
0
                        ActionUniqueQualifier::Uncachable(_action_key) => {
430
0
                            // This Operation should not be in the hash_key map.
431
0
                        }
432
                    }
433
434
                    // Cleanup sorted_awaited_action.
435
0
                    let sort_key = awaited_action.sort_key();
436
0
                    let sort_btree_for_state = self
437
0
                        .sorted_action_info_hash_keys
438
0
                        .btree_for_state(&awaited_action.state().stage);
439
0
440
0
                    let maybe_sorted_awaited_action =
441
0
                        sort_btree_for_state.take(&SortedAwaitedAction {
442
0
                            sort_key,
443
0
                            operation_id: operation_id.clone(),
444
0
                        });
445
0
                    if maybe_sorted_awaited_action.is_none() {
  Branch (445:24): [True: 0, False: 0]
  Branch (445:24): [Folded - Ignored]
  Branch (445:24): [True: 0, False: 0]
446
0
                        event!(
447
0
                            Level::ERROR,
448
                            ?operation_id,
449
                            ?sort_key,
450
0
                            "Expected maybe_sorted_awaited_action to have {sort_key:?}",
451
                        );
452
0
                    }
453
                }
454
54
                ActionEvent::ClientKeepAlive(client_id) => {
455
54
                    let maybe_size = self
456
54
                        .client_operation_to_awaited_action
457
54
                        .size_for_key(&client_id)
458
54
                        .await;
459
54
                    if maybe_size.is_none() {
  Branch (459:24): [True: 0, False: 0]
  Branch (459:24): [Folded - Ignored]
  Branch (459:24): [True: 0, False: 54]
460
0
                        event!(
461
0
                            Level::ERROR,
462
                            ?client_id,
463
0
                            "client_operation_to_awaited_action does not have client_id",
464
                        );
465
54
                    }
466
                }
467
            }
468
        }
469
55
        NoEarlyReturn
470
55
    }
471
472
0
    fn get_awaited_actions_range(
473
0
        &self,
474
0
        start: Bound<&OperationId>,
475
0
        end: Bound<&OperationId>,
476
0
    ) -> impl Iterator<Item = (&'_ OperationId, MemoryAwaitedActionSubscriber<I, NowFn>)> {
477
0
        self.operation_id_to_awaited_action
478
0
            .range((start, end))
479
0
            .map(|(operation_id, tx)| {
480
0
                (
481
0
                    operation_id,
482
0
                    MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()),
483
0
                )
484
0
            })
485
0
    }
486
487
83
    fn get_by_operation_id(
488
83
        &self,
489
83
        operation_id: &OperationId,
490
83
    ) -> Option<MemoryAwaitedActionSubscriber<I, NowFn>> {
491
83
        self.operation_id_to_awaited_action
492
83
            .get(operation_id)
493
83
            .map(|tx| MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()))
494
83
    }
495
496
130
    fn get_range_of_actions<'a, 'b>(
497
130
        &'a self,
498
130
        state: SortedAwaitedActionState,
499
130
        range: impl RangeBounds<SortedAwaitedAction> + 'b,
500
130
    ) -> impl DoubleEndedIterator<
501
130
        Item = Result<
502
130
            (
503
130
                &'a SortedAwaitedAction,
504
130
                MemoryAwaitedActionSubscriber<I, NowFn>,
505
130
            ),
506
130
            Error,
507
130
        >,
508
130
    > + 'a {
509
130
        let btree = match state {
510
0
            SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check,
511
130
            SortedAwaitedActionState::Queued => &self.sorted_action_info_hash_keys.queued,
512
0
            SortedAwaitedActionState::Executing => &self.sorted_action_info_hash_keys.executing,
513
0
            SortedAwaitedActionState::Completed => &self.sorted_action_info_hash_keys.completed,
514
        };
515
130
        btree.range(range).map(|sorted_awaited_action| {
516
43
            let operation_id = &sorted_awaited_action.operation_id;
517
43
            self.get_by_operation_id(operation_id)
518
43
                .ok_or_else(|| {
519
0
                    make_err!(
520
0
                        Code::Internal,
521
0
                        "Failed to get operation id {}",
522
0
                        operation_id
523
0
                    )
524
43
                })
525
43
                .map(|subscriber| (sorted_awaited_action, subscriber))
526
130
        })
527
130
    }
528
529
39
    fn process_state_changes_for_hash_key_map(
530
39
        action_info_hash_key_to_awaited_action: &mut HashMap<ActionUniqueKey, OperationId>,
531
39
        new_awaited_action: &AwaitedAction,
532
39
    ) {
533
39
        // Only process changes if the stage is not finished.
534
39
        if !new_awaited_action.state().stage.is_finished() {
  Branch (534:12): [True: 0, False: 0]
  Branch (534:12): [Folded - Ignored]
  Branch (534:12): [True: 33, False: 6]
535
33
            return;
536
6
        }
537
6
        match &new_awaited_action.action_info().unique_qualifier {
538
6
            ActionUniqueQualifier::Cachable(action_key) => {
539
6
                let maybe_awaited_action =
540
6
                    action_info_hash_key_to_awaited_action.remove(action_key);
541
6
                match maybe_awaited_action {
542
6
                    Some(removed_operation_id) => {
543
6
                        if &removed_operation_id != new_awaited_action.operation_id() {
  Branch (543:28): [True: 0, False: 0]
  Branch (543:28): [Folded - Ignored]
  Branch (543:28): [True: 0, False: 6]
544
0
                            event!(
545
0
                                Level::ERROR,
546
                                ?removed_operation_id,
547
                                ?new_awaited_action,
548
                                ?action_key,
549
0
                                "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync",
550
                            );
551
6
                        }
552
                    }
553
                    None => {
554
0
                        event!(
555
0
                            Level::ERROR,
556
                            ?new_awaited_action,
557
                            ?action_key,
558
0
                            "action_info_hash_key_to_awaited_action out of sync, it should have had the unique_key",
559
                        );
560
                    }
561
                }
562
            }
563
0
            ActionUniqueQualifier::Uncachable(_action_key) => {
564
0
                // If we are not cachable, the action should not be in the
565
0
                // hash_key map, so we don't need to process anything in
566
0
                // action_info_hash_key_to_awaited_action.
567
0
            }
568
        }
569
39
    }
570
571
39
    fn update_awaited_action(
572
39
        &mut self,
573
39
        mut new_awaited_action: AwaitedAction,
574
39
    ) -> Result<(), Error> {
575
39
        let tx = self
576
39
            .operation_id_to_awaited_action
577
39
            .get(new_awaited_action.operation_id())
578
39
            .ok_or_else(|| {
579
0
                make_err!(
580
0
                    Code::Internal,
581
0
                    "OperationId does not exist in map in AwaitedActionDb::update_awaited_action"
582
0
                )
583
39
            })
?0
;
584
        {
585
            // Note: It's important to drop old_awaited_action before we call
586
            // send_replace or we will have a deadlock.
587
39
            let old_awaited_action = tx.borrow();
588
39
589
39
            // Do not process changes if the action version is not in sync with
590
39
            // what the sender based the update on.
591
39
            if old_awaited_action.version() != new_awaited_action.version() {
  Branch (591:16): [True: 0, False: 0]
  Branch (591:16): [Folded - Ignored]
  Branch (591:16): [True: 0, False: 39]
592
0
                return Err(make_err!(
593
0
                    // From: https://grpc.github.io/grpc/core/md_doc_statuscodes.html
594
0
                    // Use ABORTED if the client should retry at a higher level
595
0
                    // (e.g., when a client-specified test-and-set fails,
596
0
                    // indicating the client should restart a read-modify-write
597
0
                    // sequence)
598
0
                    Code::Aborted,
599
0
                    "{} Expected {} but got {} for operation_id {:?} - {:?}",
600
0
                    "Tried to update an awaited action with an incorrect version.",
601
0
                    old_awaited_action.version(),
602
0
                    new_awaited_action.version(),
603
0
                    old_awaited_action,
604
0
                    new_awaited_action,
605
0
                ));
606
39
            }
607
39
            new_awaited_action.increment_version();
608
609
0
            error_if!(
610
39
                old_awaited_action.action_info().unique_qualifier
  Branch (610:17): [True: 0, False: 0]
  Branch (610:17): [Folded - Ignored]
  Branch (610:17): [True: 0, False: 39]
611
39
                    != new_awaited_action.action_info().unique_qualifier,
612
                "Unique key changed for operation_id {:?} - {:?} - {:?}",
613
0
                new_awaited_action.operation_id(),
614
0
                old_awaited_action.action_info(),
615
0
                new_awaited_action.action_info(),
616
            );
617
39
            let is_same_stage = old_awaited_action
618
39
                .state()
619
39
                .stage
620
39
                .is_same_stage(&new_awaited_action.state().stage);
621
39
622
39
            if !is_same_stage {
  Branch (622:16): [True: 0, False: 0]
  Branch (622:16): [Folded - Ignored]
  Branch (622:16): [True: 39, False: 0]
623
39
                self.sorted_action_info_hash_keys
624
39
                    .process_state_changes(&old_awaited_action, &new_awaited_action)
?0
;
625
39
                Self::process_state_changes_for_hash_key_map(
626
39
                    &mut self.action_info_hash_key_to_awaited_action,
627
39
                    &new_awaited_action,
628
39
                );
629
0
            }
630
        }
631
632
        // Notify all listeners of the new state and ignore if no one is listening.
633
        // Note: Do not use `.send()` as it will not update the state if all listeners
634
        // are dropped.
635
39
        let _ = tx.send_replace(new_awaited_action);
636
39
637
39
        Ok(())
638
39
    }
639
640
    /// Creates a new [`ClientAwaitedAction`] and a [`watch::Receiver`] to
641
    /// listen for changes. We don't do this in-line because it is important
642
    /// to ALWAYS construct a [`ClientAwaitedAction`] before inserting it into
643
    /// the map. Failing to do so may result in memory leaks. This is because
644
    /// [`ClientAwaitedAction`] implements a drop function that will trigger
645
    /// cleanup of the other maps on drop.
646
24
    fn make_client_awaited_action(
647
24
        &mut self,
648
24
        operation_id: &OperationId,
649
24
        awaited_action: AwaitedAction,
650
24
    ) -> (Arc<ClientAwaitedAction>, watch::Receiver<AwaitedAction>) {
651
24
        let (tx, rx) = watch::channel(awaited_action);
652
24
        let client_awaited_action = Arc::new(ClientAwaitedAction::new(
653
24
            operation_id.clone(),
654
24
            self.action_event_tx.clone(),
655
24
        ));
656
24
        self.operation_id_to_awaited_action
657
24
            .insert(operation_id.clone(), tx);
658
24
        self.connected_clients_for_operation_id
659
24
            .insert(operation_id.clone(), 1);
660
24
        (client_awaited_action, rx)
661
24
    }
662
663
27
    async fn add_action(
664
27
        &mut self,
665
27
        client_operation_id: OperationId,
666
27
        action_info: Arc<ActionInfo>,
667
27
    ) -> Result<MemoryAwaitedActionSubscriber<I, NowFn>, Error> {
668
        // Check to see if the action is already known and subscribe if it is.
669
27
        let subscription_result = self
670
27
            .try_subscribe(
671
27
                &client_operation_id,
672
27
                &action_info.unique_qualifier,
673
27
                action_info.priority,
674
27
            )
675
27
            .await
676
27
            .err_tip(|| 
"In AwaitedActionDb::subscribe_or_add_action"0
);
677
27
        match subscription_result {
678
0
            Err(err) => return Err(err),
679
3
            Ok(Some(subscription)) => return Ok(subscription),
680
24
            Ok(None) => { /* Add item to queue. */ }
681
        }
682
683
24
        let maybe_unique_key = match &action_info.unique_qualifier {
684
24
            ActionUniqueQualifier::Cachable(unique_key) => Some(unique_key.clone()),
685
0
            ActionUniqueQualifier::Uncachable(_unique_key) => None,
686
        };
687
24
        let operation_id = OperationId::default();
688
24
        let awaited_action =
689
24
            AwaitedAction::new(operation_id.clone(), action_info, (self.now_fn)().now());
690
24
        debug_assert!(
691
0
            ActionStage::Queued == awaited_action.state().stage,
692
0
            "Expected action to be queued"
693
        );
694
24
        let sort_key = awaited_action.sort_key();
695
24
696
24
        let (client_awaited_action, rx) =
697
24
            self.make_client_awaited_action(&operation_id.clone(), awaited_action);
698
24
699
24
        event!(
700
24
            Level::DEBUG,
701
            ?client_operation_id,
702
            ?operation_id,
703
            ?client_awaited_action,
704
0
            "Adding action"
705
        );
706
707
24
        self.client_operation_to_awaited_action
708
24
            .insert(client_operation_id.clone(), client_awaited_action)
709
24
            .await;
710
711
        // Note: We only put items in the map that are cachable.
712
24
        if let Some(unique_key) = maybe_unique_key {
  Branch (712:16): [True: 0, False: 0]
  Branch (712:16): [Folded - Ignored]
  Branch (712:16): [True: 24, False: 0]
713
24
            let old_value = self
714
24
                .action_info_hash_key_to_awaited_action
715
24
                .insert(unique_key, operation_id.clone());
716
24
            if let Some(
old_value0
) = old_value {
  Branch (716:20): [True: 0, False: 0]
  Branch (716:20): [Folded - Ignored]
  Branch (716:20): [True: 0, False: 24]
717
0
                event!(
718
0
                    Level::ERROR,
719
                    ?operation_id,
720
                    ?old_value,
721
0
                    "action_info_hash_key_to_awaited_action already has unique_key"
722
                );
723
24
            }
724
0
        }
725
726
24
        self.sorted_action_info_hash_keys
727
24
            .insert_sort_map_for_stage(
728
24
                &ActionStage::Queued,
729
24
                &SortedAwaitedAction {
730
24
                    sort_key,
731
24
                    operation_id,
732
24
                },
733
24
            )
734
24
            .err_tip(|| 
"In AwaitedActionDb::subscribe_or_add_action"0
)
?0
;
735
736
24
        Ok(MemoryAwaitedActionSubscriber::new_with_client(
737
24
            rx,
738
24
            client_operation_id,
739
24
            self.action_event_tx.clone(),
740
24
            self.now_fn.clone(),
741
24
        ))
742
27
    }
743
744
27
    async fn try_subscribe(
745
27
        &mut self,
746
27
        client_operation_id: &OperationId,
747
27
        unique_qualifier: &ActionUniqueQualifier,
748
27
        // TODO(allada) To simplify the scheduler 2024 refactor, we
749
27
        // removed the ability to upgrade priorities of actions.
750
27
        // we should add priority upgrades back in.
751
27
        _priority: i32,
752
27
    ) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
753
27
        let unique_key = match unique_qualifier {
754
27
            ActionUniqueQualifier::Cachable(unique_key) => unique_key,
755
0
            ActionUniqueQualifier::Uncachable(_unique_key) => return Ok(None),
756
        };
757
758
27
        let Some(
operation_id3
) = self.action_info_hash_key_to_awaited_action.get(unique_key) else {
  Branch (758:13): [True: 0, False: 0]
  Branch (758:13): [Folded - Ignored]
  Branch (758:13): [True: 3, False: 24]
759
24
            return Ok(None); // Not currently running.
760
        };
761
762
3
        let Some(tx) = self.operation_id_to_awaited_action.get(operation_id) else {
  Branch (762:13): [True: 0, False: 0]
  Branch (762:13): [Folded - Ignored]
  Branch (762:13): [True: 3, False: 0]
763
0
            return Err(make_err!(
764
0
                Code::Internal,
765
0
                "operation_id_to_awaited_action and action_info_hash_key_to_awaited_action are out of sync for {unique_key:?} - {operation_id}"
766
0
            ));
767
        };
768
769
0
        error_if!(
770
3
            tx.borrow().state().stage.is_finished(),
  Branch (770:13): [True: 0, False: 0]
  Branch (770:13): [Folded - Ignored]
  Branch (770:13): [True: 0, False: 3]
771
            "Tried to subscribe to a completed action but it already finished. This should never happen. {:?}",
772
0
            tx.borrow()
773
        );
774
775
3
        let maybe_connected_clients = self
776
3
            .connected_clients_for_operation_id
777
3
            .get_mut(operation_id);
778
3
        let Some(connected_clients) = maybe_connected_clients else {
  Branch (778:13): [True: 0, False: 0]
  Branch (778:13): [Folded - Ignored]
  Branch (778:13): [True: 3, False: 0]
779
0
            return Err(make_err!(
780
0
                Code::Internal,
781
0
                "connected_clients_for_operation_id and operation_id_to_awaited_action are out of sync for {unique_key:?} - {operation_id}"
782
0
            ));
783
        };
784
3
        *connected_clients += 1;
785
3
786
3
        // Immediately mark the keep alive, we don't need to wake anyone
787
3
        // so we always fake that it was not actually changed.
788
3
        // Failing update the client could lead to the client connecting
789
3
        // then not updating the keep alive in time, resulting in the
790
3
        // operation timing out due to async behavior.
791
3
        tx.send_if_modified(|awaited_action| {
792
3
            awaited_action.update_client_keep_alive((self.now_fn)().now());
793
3
            false
794
3
        });
795
3
        let subscription = tx.subscribe();
796
3
797
3
        self.client_operation_to_awaited_action
798
3
            .insert(
799
3
                client_operation_id.clone(),
800
3
                Arc::new(ClientAwaitedAction::new(
801
3
                    operation_id.clone(),
802
3
                    self.action_event_tx.clone(),
803
3
                )),
804
3
            )
805
3
            .await;
806
807
3
        Ok(Some(MemoryAwaitedActionSubscriber::new_with_client(
808
3
            subscription,
809
3
            client_operation_id.clone(),
810
3
            self.action_event_tx.clone(),
811
3
            self.now_fn.clone(),
812
3
        )))
813
27
    }
814
}
815
816
#[derive(MetricsComponent)]
817
pub struct MemoryAwaitedActionDb<I: InstantWrapper, NowFn: Fn() -> I> {
818
    #[metric]
819
    inner: Arc<Mutex<AwaitedActionDbImpl<I, NowFn>>>,
820
    tasks_change_notify: Arc<Notify>,
821
    _handle_awaited_action_events: JoinHandleDropGuard<()>,
822
}
823
824
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static>
825
    MemoryAwaitedActionDb<I, NowFn>
826
{
827
20
    pub fn new(
828
20
        eviction_config: &EvictionPolicy,
829
20
        tasks_change_notify: Arc<Notify>,
830
20
        now_fn: NowFn,
831
20
    ) -> Self {
832
20
        let (action_event_tx, mut action_event_rx) = mpsc::unbounded_channel();
833
20
        let inner = Arc::new(Mutex::new(AwaitedActionDbImpl {
834
20
            client_operation_to_awaited_action: EvictingMap::new(eviction_config, (now_fn)()),
835
20
            operation_id_to_awaited_action: BTreeMap::new(),
836
20
            action_info_hash_key_to_awaited_action: HashMap::new(),
837
20
            sorted_action_info_hash_keys: SortedAwaitedActions::default(),
838
20
            connected_clients_for_operation_id: HashMap::new(),
839
20
            action_event_tx,
840
20
            now_fn,
841
20
        }));
842
20
        let weak_inner = Arc::downgrade(&inner);
843
20
        Self {
844
20
            inner,
845
20
            tasks_change_notify,
846
20
            _handle_awaited_action_events: spawn!("handle_awaited_action_events", async move {
847
20
                let mut dropped_operation_ids = Vec::with_capacity(MAX_ACTION_EVENTS_RX_PER_CYCLE);
848
                loop {
849
75
                    dropped_operation_ids.clear();
850
75
                    action_event_rx
851
75
                        .recv_many(&mut dropped_operation_ids, MAX_ACTION_EVENTS_RX_PER_CYCLE)
852
75
                        .await;
853
55
                    let Some(inner) = weak_inner.upgrade() else {
  Branch (853:25): [True: 0, False: 0]
  Branch (853:25): [Folded - Ignored]
  Branch (853:25): [True: 55, False: 0]
854
0
                        return; // Nothing to cleanup, our struct is dropped.
855
                    };
856
55
                    let mut inner = inner.lock().await;
857
55
                    inner
858
55
                        .handle_action_events(dropped_operation_ids.drain(..))
859
55
                        .await;
860
                }
861
20
            
}0
),
862
        }
863
20
    }
864
}
865
866
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static> AwaitedActionDb
867
    for MemoryAwaitedActionDb<I, NowFn>
868
{
869
    type Subscriber = MemoryAwaitedActionSubscriber<I, NowFn>;
870
871
503
    async fn get_awaited_action_by_id(
872
503
        &self,
873
503
        client_operation_id: &OperationId,
874
503
    ) -> Result<Option<Self::Subscriber>, Error> {
875
503
        self.inner
876
503
            .lock()
877
503
            .await
878
503
            .get_awaited_action_by_id(client_operation_id)
879
503
            .await
880
503
    }
881
882
0
    async fn get_all_awaited_actions(
883
0
        &self,
884
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
885
0
        Ok(ChunkedStream::new(
886
0
            Bound::Unbounded,
887
0
            Bound::Unbounded,
888
0
            move |start, end, mut output| async move {
889
0
                let inner = self.inner.lock().await;
890
0
                let mut maybe_new_start = None;
891
892
0
                for (operation_id, item) in
893
0
                    inner.get_awaited_actions_range(start.as_ref(), end.as_ref())
894
0
                {
895
0
                    output.push_back(item);
896
0
                    maybe_new_start = Some(operation_id);
897
0
                }
898
899
0
                Ok(maybe_new_start
900
0
                    .map(|new_start| ((Bound::Excluded(new_start.clone()), end), output)))
901
0
            },
902
0
        ))
903
0
    }
904
905
40
    async fn get_by_operation_id(
906
40
        &self,
907
40
        operation_id: &OperationId,
908
40
    ) -> Result<Option<Self::Subscriber>, Error> {
909
40
        Ok(self.inner.lock().await.get_by_operation_id(operation_id))
910
40
    }
911
912
90
    async fn get_range_of_actions(
913
90
        &self,
914
90
        state: SortedAwaitedActionState,
915
90
        start: Bound<SortedAwaitedAction>,
916
90
        end: Bound<SortedAwaitedAction>,
917
90
        desc: bool,
918
90
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
919
90
        Ok(ChunkedStream::new(
920
90
            start,
921
90
            end,
922
130
            move |start, end, mut output| async move {
923
130
                let inner = self.inner.lock().await;
924
130
                let mut done = true;
925
130
                let mut new_start = start.as_ref();
926
130
                let mut new_end = end.as_ref();
927
130
928
130
                let iterator = inner
929
130
                    .get_range_of_actions(state, (start.as_ref(), end.as_ref()))
930
130
                    .map(|res| 
res.err_tip(43
||
"In AwaitedActionDb::get_range_of_actions"0
)43
);
931
130
932
130
                // TODO(allada) This should probably use the `.left()/right()` pattern,
933
130
                // but that doesn't exist in the std or any libraries we use.
934
130
                if desc {
  Branch (934:20): [True: 0, False: 0]
  Branch (934:20): [Folded - Ignored]
  Branch (934:20): [True: 130, False: 0]
935
130
                    for 
result43
in iterator.rev() {
936
43
                        let (sorted_awaited_action, item) =
937
43
                            result.err_tip(|| 
"In AwaitedActionDb::get_range_of_actions"0
)
?0
;
938
43
                        output.push_back(item);
939
43
                        new_end = Bound::Excluded(sorted_awaited_action);
940
43
                        done = false;
941
                    }
942
                } else {
943
0
                    for result in iterator {
944
0
                        let (sorted_awaited_action, item) =
945
0
                            result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?;
946
0
                        output.push_back(item);
947
0
                        new_start = Bound::Excluded(sorted_awaited_action);
948
0
                        done = false;
949
                    }
950
                }
951
130
                if done {
  Branch (951:20): [True: 0, False: 0]
  Branch (951:20): [Folded - Ignored]
  Branch (951:20): [True: 90, False: 40]
952
90
                    return Ok(None);
953
40
                }
954
40
                Ok(Some(((new_start.cloned(), new_end.cloned()), output)))
955
260
            },
956
90
        ))
957
90
    }
958
959
39
    async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
960
39
        self.inner
961
39
            .lock()
962
39
            .await
963
39
            .update_awaited_action(new_awaited_action)
?0
;
964
39
        self.tasks_change_notify.notify_one();
965
39
        Ok(())
966
39
    }
967
968
27
    async fn add_action(
969
27
        &self,
970
27
        client_operation_id: OperationId,
971
27
        action_info: Arc<ActionInfo>,
972
27
    ) -> Result<Self::Subscriber, Error> {
973
27
        let subscriber = self
974
27
            .inner
975
27
            .lock()
976
27
            .await
977
27
            .add_action(client_operation_id, action_info)
978
27
            .await
?0
;
979
27
        self.tasks_change_notify.notify_one();
980
27
        Ok(subscriber)
981
27
    }
982
}