Coverage Report

Created: 2025-03-08 07:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/memory_awaited_action_db.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::{BTreeMap, BTreeSet, HashMap};
16
use std::ops::{Bound, RangeBounds};
17
use std::sync::Arc;
18
19
use async_lock::Mutex;
20
use futures::{FutureExt, Stream};
21
use nativelink_config::stores::EvictionPolicy;
22
use nativelink_error::{error_if, make_err, Code, Error, ResultExt};
23
use nativelink_metric::MetricsComponent;
24
use nativelink_util::action_messages::{
25
    ActionInfo, ActionStage, ActionUniqueKey, ActionUniqueQualifier, OperationId,
26
};
27
use nativelink_util::chunked_stream::ChunkedStream;
28
use nativelink_util::evicting_map::{EvictingMap, LenEntry};
29
use nativelink_util::instant_wrapper::InstantWrapper;
30
use nativelink_util::spawn;
31
use nativelink_util::task::JoinHandleDropGuard;
32
use tokio::sync::{mpsc, watch, Notify};
33
use tracing::{event, Level};
34
35
use crate::awaited_action_db::{
36
    AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedAction,
37
    SortedAwaitedActionState, CLIENT_KEEPALIVE_DURATION,
38
};
39
40
/// Number of events to process per cycle.
41
const MAX_ACTION_EVENTS_RX_PER_CYCLE: usize = 1024;
42
43
/// Represents a client that is currently listening to an action.
44
/// When the client is dropped, it will send the `AwaitedAction` to the
45
/// `event_tx` if there are other cleanups needed.
46
#[derive(Debug)]
47
struct ClientAwaitedAction {
48
    /// The `OperationId` that the client is listening to.
49
    operation_id: OperationId,
50
51
    /// The sender to notify of this struct being dropped.
52
    event_tx: mpsc::UnboundedSender<ActionEvent>,
53
}
54
55
impl ClientAwaitedAction {
56
27
    pub const fn new(
57
27
        operation_id: OperationId,
58
27
        event_tx: mpsc::UnboundedSender<ActionEvent>,
59
27
    ) -> Self {
60
27
        Self {
61
27
            operation_id,
62
27
            event_tx,
63
27
        }
64
27
    }
65
66
3
    pub const fn operation_id(&self) -> &OperationId {
67
3
        &self.operation_id
68
3
    }
69
}
70
71
impl Drop for ClientAwaitedAction {
72
27
    fn drop(&mut self) {
73
27
        // If we failed to send it means noone is listening.
74
27
        let _ = self.event_tx.send(ActionEvent::ClientDroppedOperation(
75
27
            self.operation_id.clone(),
76
27
        ));
77
27
    }
78
}
79
80
/// Trait to be able to use the `EvictingMap` with `ClientAwaitedAction`.
81
/// Note: We only use `EvictingMap` for a time based eviction, which is
82
/// why the implementation has fixed default values in it.
83
impl LenEntry for ClientAwaitedAction {
84
    #[inline]
85
29
    fn len(&self) -> u64 {
86
29
        0
87
29
    }
88
89
    #[inline]
90
0
    fn is_empty(&self) -> bool {
91
0
        true
92
0
    }
93
}
94
95
/// Actions the `AwaitedActionsDb` needs to process.
96
#[derive(Debug)]
97
pub(crate) enum ActionEvent {
98
    /// A client has sent a keep alive message.
99
    ClientKeepAlive(OperationId),
100
    /// A client has dropped and pointed to `OperationId`.
101
    ClientDroppedOperation(OperationId),
102
}
103
104
/// Information required to track an individual client
105
/// keep alive config and state.
106
struct ClientInfo<I: InstantWrapper, NowFn: Fn() -> I> {
107
    /// The client operation id.
108
    client_operation_id: OperationId,
109
    /// The last time a keep alive was sent.
110
    last_keep_alive: I,
111
    /// The function to get the current time.
112
    now_fn: NowFn,
113
    /// The sender to notify of this struct had an event.
114
    event_tx: mpsc::UnboundedSender<ActionEvent>,
115
}
116
117
/// Subscriber that clients can be used to monitor when `AwaitedActions` change.
118
pub struct MemoryAwaitedActionSubscriber<I: InstantWrapper, NowFn: Fn() -> I> {
119
    /// The receiver to listen for changes.
120
    awaited_action_rx: watch::Receiver<AwaitedAction>,
121
    /// If a client id is known this is the info needed to keep the client
122
    /// action alive.
123
    client_info: Option<ClientInfo<I, NowFn>>,
124
}
125
126
impl<I: InstantWrapper, NowFn: Fn() -> I> MemoryAwaitedActionSubscriber<I, NowFn> {
127
583
    fn new(mut awaited_action_rx: watch::Receiver<AwaitedAction>) -> Self {
128
583
        awaited_action_rx.mark_changed();
129
583
        Self {
130
583
            awaited_action_rx,
131
583
            client_info: None,
132
583
        }
133
583
    }
134
135
30
    fn new_with_client(
136
30
        mut awaited_action_rx: watch::Receiver<AwaitedAction>,
137
30
        client_operation_id: OperationId,
138
30
        event_tx: mpsc::UnboundedSender<ActionEvent>,
139
30
        now_fn: NowFn,
140
30
    ) -> Self
141
30
    where
142
30
        NowFn: Fn() -> I,
143
30
    {
144
30
        awaited_action_rx.mark_changed();
145
30
        Self {
146
30
            awaited_action_rx,
147
30
            client_info: Some(ClientInfo {
148
30
                client_operation_id,
149
30
                last_keep_alive: I::from_secs(0),
150
30
                now_fn,
151
30
                event_tx,
152
30
            }),
153
30
        }
154
30
    }
155
}
156
157
impl<I, NowFn> AwaitedActionSubscriber for MemoryAwaitedActionSubscriber<I, NowFn>
158
where
159
    I: InstantWrapper,
160
    NowFn: Fn() -> I + Send + Sync + 'static,
161
{
162
58
    async fn changed(&mut self) -> Result<AwaitedAction, Error> {
163
44
        let client_operation_id = {
164
58
            let changed_fut = self.awaited_action_rx.changed().map(|r| {
165
44
                r.map_err(|e| {
166
0
                    make_err!(
167
0
                        Code::Internal,
168
0
                        "Failed to wait for awaited action to change {e:?}"
169
0
                    )
170
44
                })
171
58
            });
172
58
            let Some(client_info) = self.client_info.as_mut() else {
  Branch (172:17): [True: 0, False: 0]
  Branch (172:17): [Folded - Ignored]
  Branch (172:17): [True: 58, False: 0]
173
0
                changed_fut.await?;
174
0
                return Ok(self.awaited_action_rx.borrow().clone());
175
            };
176
58
            tokio::pin!(changed_fut);
177
            loop {
178
160
                if client_info.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION {
  Branch (178:20): [True: 0, False: 0]
  Branch (178:20): [Folded - Ignored]
  Branch (178:20): [True: 55, False: 105]
179
55
                    client_info.last_keep_alive = (client_info.now_fn)();
180
55
                    // Failing to send just means our receiver dropped.
181
55
                    let _ = client_info.event_tx.send(ActionEvent::ClientKeepAlive(
182
55
                        client_info.client_operation_id.clone(),
183
55
                    ));
184
105
                }
185
160
                let sleep_fut = (client_info.now_fn)().sleep(CLIENT_KEEPALIVE_DURATION);
186
160
                tokio::select! {
187
160
                    
result44
= &mut changed_fut => {
188
44
                        result
?0
;
189
44
                        break;
190
                    }
191
160
                    () = sleep_fut => {
192
102
                        // If we haven't received any updates for a while, we should
193
102
                        // let the database know that we are still listening to prevent
194
102
                        // the action from being dropped.
195
102
                    }
196
                }
197
            }
198
44
            client_info.client_operation_id.clone()
199
44
        };
200
44
        // At this stage we know that this event is a client request, so we need
201
44
        // to populate the client_operation_id.
202
44
        let mut awaited_action = self.awaited_action_rx.borrow().clone();
203
44
        awaited_action.set_client_operation_id(client_operation_id);
204
44
        Ok(awaited_action)
205
44
    }
206
207
668
    async fn borrow(&self) -> Result<AwaitedAction, Error> {
208
668
        let mut awaited_action = self.awaited_action_rx.borrow().clone();
209
668
        if let Some(
client_info16
) = self.client_info.as_ref() {
  Branch (209:16): [True: 0, False: 0]
  Branch (209:16): [Folded - Ignored]
  Branch (209:16): [True: 16, False: 652]
210
16
            awaited_action.set_client_operation_id(client_info.client_operation_id.clone());
211
652
        }
212
668
        Ok(awaited_action)
213
668
    }
214
}
215
216
/// A struct that is used to keep the devloper from trying to
217
/// return early from a function.
218
struct NoEarlyReturn;
219
220
#[derive(Default, MetricsComponent)]
221
struct SortedAwaitedActions {
222
    #[metric(group = "unknown")]
223
    unknown: BTreeSet<SortedAwaitedAction>,
224
    #[metric(group = "cache_check")]
225
    cache_check: BTreeSet<SortedAwaitedAction>,
226
    #[metric(group = "queued")]
227
    queued: BTreeSet<SortedAwaitedAction>,
228
    #[metric(group = "executing")]
229
    executing: BTreeSet<SortedAwaitedAction>,
230
    #[metric(group = "completed")]
231
    completed: BTreeSet<SortedAwaitedAction>,
232
}
233
234
impl SortedAwaitedActions {
235
39
    fn btree_for_state(&mut self, state: &ActionStage) -> &mut BTreeSet<SortedAwaitedAction> {
236
39
        match state {
237
0
            ActionStage::Unknown => &mut self.unknown,
238
0
            ActionStage::CacheCheck => &mut self.cache_check,
239
26
            ActionStage::Queued => &mut self.queued,
240
13
            ActionStage::Executing => &mut self.executing,
241
0
            ActionStage::Completed(_) | ActionStage::CompletedFromCache(_) => &mut self.completed,
242
        }
243
39
    }
244
245
63
    fn insert_sort_map_for_stage(
246
63
        &mut self,
247
63
        stage: &ActionStage,
248
63
        sorted_awaited_action: &SortedAwaitedAction,
249
63
    ) -> Result<(), Error> {
250
63
        let newly_inserted = match stage {
251
0
            ActionStage::Unknown => self.unknown.insert(sorted_awaited_action.clone()),
252
0
            ActionStage::CacheCheck => self.cache_check.insert(sorted_awaited_action.clone()),
253
31
            ActionStage::Queued => self.queued.insert(sorted_awaited_action.clone()),
254
26
            ActionStage::Executing => self.executing.insert(sorted_awaited_action.clone()),
255
6
            ActionStage::Completed(_) => self.completed.insert(sorted_awaited_action.clone()),
256
            ActionStage::CompletedFromCache(_) => {
257
0
                self.completed.insert(sorted_awaited_action.clone())
258
            }
259
        };
260
63
        if !newly_inserted {
  Branch (260:12): [True: 0, False: 63]
  Branch (260:12): [Folded - Ignored]
261
0
            return Err(make_err!(
262
0
                Code::Internal,
263
0
                "Tried to insert an action that was already in the sorted map. This should never happen. {:?} - {:?}",
264
0
                stage,
265
0
                sorted_awaited_action
266
0
            ));
267
63
        }
268
63
        Ok(())
269
63
    }
270
271
39
    fn process_state_changes(
272
39
        &mut self,
273
39
        old_awaited_action: &AwaitedAction,
274
39
        new_awaited_action: &AwaitedAction,
275
39
    ) -> Result<(), Error> {
276
39
        let btree = self.btree_for_state(&old_awaited_action.state().stage);
277
39
        let maybe_sorted_awaited_action = btree.take(&SortedAwaitedAction {
278
39
            sort_key: old_awaited_action.sort_key(),
279
39
            operation_id: new_awaited_action.operation_id().clone(),
280
39
        });
281
282
39
        let Some(sorted_awaited_action) = maybe_sorted_awaited_action else {
  Branch (282:13): [True: 39, False: 0]
  Branch (282:13): [Folded - Ignored]
283
0
            return Err(make_err!(
284
0
                Code::Internal,
285
0
                "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync - {} - {:?}",
286
0
                new_awaited_action.operation_id(),
287
0
                new_awaited_action,
288
0
            ));
289
        };
290
291
39
        self.insert_sort_map_for_stage(&new_awaited_action.state().stage, &sorted_awaited_action)
292
39
            .err_tip(|| 
"In AwaitedActionDb::update_awaited_action"0
)
?0
;
293
39
        Ok(())
294
39
    }
295
}
296
297
/// The database for storing the state of all actions.
298
#[derive(MetricsComponent)]
299
pub struct AwaitedActionDbImpl<I: InstantWrapper, NowFn: Fn() -> I> {
300
    /// A lookup table to lookup the state of an action by its client operation id.
301
    #[metric(group = "client_operation_ids")]
302
    client_operation_to_awaited_action: EvictingMap<OperationId, Arc<ClientAwaitedAction>, I>,
303
304
    /// A lookup table to lookup the state of an action by its worker operation id.
305
    #[metric(group = "operation_ids")]
306
    operation_id_to_awaited_action: BTreeMap<OperationId, watch::Sender<AwaitedAction>>,
307
308
    /// A lookup table to lookup the state of an action by its unique qualifier.
309
    #[metric(group = "action_info_hash_key_to_awaited_action")]
310
    action_info_hash_key_to_awaited_action: HashMap<ActionUniqueKey, OperationId>,
311
312
    /// A sorted set of [`AwaitedAction`]s. A wrapper is used to perform sorting
313
    /// based on the [`AwaitedActionSortKey`] of the [`AwaitedAction`].
314
    ///
315
    /// See [`AwaitedActionSortKey`] for more information on the ordering.
316
    #[metric(group = "sorted_action_infos")]
317
    sorted_action_info_hash_keys: SortedAwaitedActions,
318
319
    /// The number of connected clients for each operation id.
320
    #[metric(group = "connected_clients_for_operation_id")]
321
    connected_clients_for_operation_id: HashMap<OperationId, usize>,
322
323
    /// Where to send notifications about important events related to actions.
324
    action_event_tx: mpsc::UnboundedSender<ActionEvent>,
325
326
    /// The function to get the current time.
327
    now_fn: NowFn,
328
}
329
330
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync> AwaitedActionDbImpl<I, NowFn> {
331
3
    async fn get_awaited_action_by_id(
332
3
        &self,
333
3
        client_operation_id: &OperationId,
334
3
    ) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
335
3
        let maybe_client_awaited_action = self
336
3
            .client_operation_to_awaited_action
337
3
            .get(client_operation_id)
338
3
            .await;
339
3
        let Some(client_awaited_action) = maybe_client_awaited_action else {
  Branch (339:13): [True: 0, False: 0]
  Branch (339:13): [Folded - Ignored]
  Branch (339:13): [True: 3, False: 0]
340
0
            return Ok(None);
341
        };
342
343
3
        self.operation_id_to_awaited_action
344
3
            .get(client_awaited_action.operation_id())
345
3
            .map(|tx| {
346
3
                Some(MemoryAwaitedActionSubscriber::new_with_client(
347
3
                    tx.subscribe(),
348
3
                    client_operation_id.clone(),
349
3
                    self.action_event_tx.clone(),
350
3
                    self.now_fn.clone(),
351
3
                ))
352
3
            })
353
3
            .ok_or_else(|| {
354
0
                make_err!(
355
0
                    Code::Internal,
356
0
                    "Failed to get client operation id {client_operation_id:?}"
357
0
                )
358
3
            })
359
3
    }
360
361
    /// Processes action events that need to be handled by the database.
362
55
    async fn handle_action_events(
363
55
        &mut self,
364
55
        action_events: impl IntoIterator<Item = ActionEvent>,
365
55
    ) -> NoEarlyReturn {
366
110
        for 
action55
in action_events {
367
55
            event!(Level::DEBUG, ?action, 
"Handling action"0
);
368
55
            match action {
369
1
                ActionEvent::ClientDroppedOperation(operation_id) => {
370
                    // Cleanup operation_id_to_awaited_action.
371
1
                    let Some(tx) = self.operation_id_to_awaited_action.remove(&operation_id) else {
  Branch (371:25): [True: 0, False: 0]
  Branch (371:25): [Folded - Ignored]
  Branch (371:25): [True: 1, False: 0]
372
0
                        event!(
373
0
                            Level::ERROR,
374
                            ?operation_id,
375
0
                            "operation_id_to_awaited_action does not have operation_id"
376
                        );
377
0
                        continue;
378
                    };
379
1
                    let connected_clients = if let Some(connected_clients) = self
  Branch (379:52): [True: 0, False: 0]
  Branch (379:52): [Folded - Ignored]
  Branch (379:52): [True: 1, False: 0]
380
1
                        .connected_clients_for_operation_id
381
1
                        .remove(&operation_id)
382
                    {
383
1
                        connected_clients - 1
384
                    } else {
385
0
                        event!(
386
0
                            Level::ERROR,
387
                            ?operation_id,
388
0
                            "connected_clients_for_operation_id does not have operation_id"
389
                        );
390
0
                        0
391
                    };
392
                    // Note: It is rare to have more than one client listening
393
                    // to the same action, so we assume that we are the last
394
                    // client and insert it back into the map if we detect that
395
                    // there are still clients listening (ie: the happy path
396
                    // is operation.connected_clients == 0).
397
1
                    if connected_clients != 0 {
  Branch (397:24): [True: 0, False: 0]
  Branch (397:24): [Folded - Ignored]
  Branch (397:24): [True: 1, False: 0]
398
1
                        self.operation_id_to_awaited_action
399
1
                            .insert(operation_id.clone(), tx);
400
1
                        self.connected_clients_for_operation_id
401
1
                            .insert(operation_id, connected_clients);
402
1
                        continue;
403
0
                    }
404
0
                    event!(
405
0
                        Level::DEBUG,
406
                        ?operation_id,
407
0
                        "Clearing operation from state manager"
408
                    );
409
0
                    let awaited_action = tx.borrow().clone();
410
0
                    // Cleanup action_info_hash_key_to_awaited_action if it was marked cached.
411
0
                    match &awaited_action.action_info().unique_qualifier {
412
0
                        ActionUniqueQualifier::Cachable(action_key) => {
413
0
                            let maybe_awaited_action = self
414
0
                                .action_info_hash_key_to_awaited_action
415
0
                                .remove(action_key);
416
0
                            if !awaited_action.state().stage.is_finished()
  Branch (416:32): [True: 0, False: 0]
  Branch (416:32): [Folded - Ignored]
  Branch (416:32): [True: 0, False: 0]
417
0
                                && maybe_awaited_action.is_none()
  Branch (417:36): [True: 0, False: 0]
  Branch (417:36): [Folded - Ignored]
  Branch (417:36): [True: 0, False: 0]
418
                            {
419
0
                                event!(
420
0
                                    Level::ERROR,
421
                                    ?operation_id,
422
                                    ?awaited_action,
423
                                    ?action_key,
424
0
                                    "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync",
425
                                );
426
0
                            }
427
                        }
428
0
                        ActionUniqueQualifier::Uncachable(_action_key) => {
429
0
                            // This Operation should not be in the hash_key map.
430
0
                        }
431
                    }
432
433
                    // Cleanup sorted_awaited_action.
434
0
                    let sort_key = awaited_action.sort_key();
435
0
                    let sort_btree_for_state = self
436
0
                        .sorted_action_info_hash_keys
437
0
                        .btree_for_state(&awaited_action.state().stage);
438
0
439
0
                    let maybe_sorted_awaited_action =
440
0
                        sort_btree_for_state.take(&SortedAwaitedAction {
441
0
                            sort_key,
442
0
                            operation_id: operation_id.clone(),
443
0
                        });
444
0
                    if maybe_sorted_awaited_action.is_none() {
  Branch (444:24): [True: 0, False: 0]
  Branch (444:24): [Folded - Ignored]
  Branch (444:24): [True: 0, False: 0]
445
0
                        event!(
446
0
                            Level::ERROR,
447
                            ?operation_id,
448
                            ?sort_key,
449
0
                            "Expected maybe_sorted_awaited_action to have {sort_key:?}",
450
                        );
451
0
                    }
452
                }
453
54
                ActionEvent::ClientKeepAlive(client_id) => {
454
54
                    if let Some(client_awaited_action) = self
  Branch (454:28): [True: 0, False: 0]
  Branch (454:28): [Folded - Ignored]
  Branch (454:28): [True: 54, False: 0]
455
54
                        .client_operation_to_awaited_action
456
54
                        .get(&client_id)
457
54
                        .await
458
                    {
459
54
                        if let Some(awaited_action_sender) = self
  Branch (459:32): [True: 0, False: 0]
  Branch (459:32): [Folded - Ignored]
  Branch (459:32): [True: 54, False: 0]
460
54
                            .operation_id_to_awaited_action
461
54
                            .get(&client_awaited_action.operation_id)
462
54
                        {
463
54
                            awaited_action_sender.send_if_modified(|awaited_action| {
464
54
                                awaited_action.update_client_keep_alive((self.now_fn)().now());
465
54
                                false
466
54
                            });
467
54
                        
}0
468
                    } else {
469
0
                        event!(
470
0
                            Level::ERROR,
471
                            ?client_id,
472
0
                            "client_operation_to_awaited_action does not have client_id",
473
                        );
474
                    }
475
                }
476
            }
477
        }
478
55
        NoEarlyReturn
479
55
    }
480
481
0
    fn get_awaited_actions_range(
482
0
        &self,
483
0
        start: Bound<&OperationId>,
484
0
        end: Bound<&OperationId>,
485
0
    ) -> impl Iterator<Item = (&'_ OperationId, MemoryAwaitedActionSubscriber<I, NowFn>)> {
486
0
        self.operation_id_to_awaited_action
487
0
            .range((start, end))
488
0
            .map(|(operation_id, tx)| {
489
0
                (
490
0
                    operation_id,
491
0
                    MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()),
492
0
                )
493
0
            })
494
0
    }
495
496
583
    fn get_by_operation_id(
497
583
        &self,
498
583
        operation_id: &OperationId,
499
583
    ) -> Option<MemoryAwaitedActionSubscriber<I, NowFn>> {
500
583
        self.operation_id_to_awaited_action
501
583
            .get(operation_id)
502
583
            .map(|tx| MemoryAwaitedActionSubscriber::<I, NowFn>::new(tx.subscribe()))
503
583
    }
504
505
1.13k
    fn get_range_of_actions<'a, 'b>(
506
1.13k
        &'a self,
507
1.13k
        state: SortedAwaitedActionState,
508
1.13k
        range: impl RangeBounds<SortedAwaitedAction> + 'b,
509
1.13k
    ) -> impl DoubleEndedIterator<
510
1.13k
        Item = Result<
511
1.13k
            (
512
1.13k
                &'a SortedAwaitedAction,
513
1.13k
                MemoryAwaitedActionSubscriber<I, NowFn>,
514
1.13k
            ),
515
1.13k
            Error,
516
1.13k
        >,
517
1.13k
    > + 'a {
518
1.13k
        let btree = match state {
519
0
            SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check,
520
1.13k
            SortedAwaitedActionState::Queued => &self.sorted_action_info_hash_keys.queued,
521
0
            SortedAwaitedActionState::Executing => &self.sorted_action_info_hash_keys.executing,
522
0
            SortedAwaitedActionState::Completed => &self.sorted_action_info_hash_keys.completed,
523
        };
524
1.13k
        btree.range(range).map(|sorted_awaited_action| {
525
543
            let operation_id = &sorted_awaited_action.operation_id;
526
543
            self.get_by_operation_id(operation_id)
527
543
                .ok_or_else(|| {
528
0
                    make_err!(
529
0
                        Code::Internal,
530
0
                        "Failed to get operation id {}",
531
0
                        operation_id
532
0
                    )
533
543
                })
534
543
                .map(|subscriber| (sorted_awaited_action, subscriber))
535
1.13k
        })
536
1.13k
    }
537
538
39
    fn process_state_changes_for_hash_key_map(
539
39
        action_info_hash_key_to_awaited_action: &mut HashMap<ActionUniqueKey, OperationId>,
540
39
        new_awaited_action: &AwaitedAction,
541
39
    ) {
542
39
        // Only process changes if the stage is not finished.
543
39
        if !new_awaited_action.state().stage.is_finished() {
  Branch (543:12): [True: 0, False: 0]
  Branch (543:12): [Folded - Ignored]
  Branch (543:12): [True: 33, False: 6]
544
33
            return;
545
6
        }
546
6
        match &new_awaited_action.action_info().unique_qualifier {
547
6
            ActionUniqueQualifier::Cachable(action_key) => {
548
6
                let maybe_awaited_action =
549
6
                    action_info_hash_key_to_awaited_action.remove(action_key);
550
6
                match maybe_awaited_action {
551
6
                    Some(removed_operation_id) => {
552
6
                        if &removed_operation_id != new_awaited_action.operation_id() {
  Branch (552:28): [True: 0, False: 0]
  Branch (552:28): [Folded - Ignored]
  Branch (552:28): [True: 0, False: 6]
553
0
                            event!(
554
0
                                Level::ERROR,
555
                                ?removed_operation_id,
556
                                ?new_awaited_action,
557
                                ?action_key,
558
0
                                "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync",
559
                            );
560
6
                        }
561
                    }
562
                    None => {
563
0
                        event!(
564
0
                            Level::ERROR,
565
                            ?new_awaited_action,
566
                            ?action_key,
567
0
                            "action_info_hash_key_to_awaited_action out of sync, it should have had the unique_key",
568
                        );
569
                    }
570
                }
571
            }
572
0
            ActionUniqueQualifier::Uncachable(_action_key) => {
573
0
                // If we are not cachable, the action should not be in the
574
0
                // hash_key map, so we don't need to process anything in
575
0
                // action_info_hash_key_to_awaited_action.
576
0
            }
577
        }
578
39
    }
579
580
39
    fn update_awaited_action(
581
39
        &mut self,
582
39
        mut new_awaited_action: AwaitedAction,
583
39
    ) -> Result<(), Error> {
584
39
        let tx = self
585
39
            .operation_id_to_awaited_action
586
39
            .get(new_awaited_action.operation_id())
587
39
            .ok_or_else(|| {
588
0
                make_err!(
589
0
                    Code::Internal,
590
0
                    "OperationId does not exist in map in AwaitedActionDb::update_awaited_action"
591
0
                )
592
39
            })
?0
;
593
        {
594
            // Note: It's important to drop old_awaited_action before we call
595
            // send_replace or we will have a deadlock.
596
39
            let old_awaited_action = tx.borrow();
597
39
598
39
            // Do not process changes if the action version is not in sync with
599
39
            // what the sender based the update on.
600
39
            if old_awaited_action.version() != new_awaited_action.version() {
  Branch (600:16): [True: 0, False: 0]
  Branch (600:16): [Folded - Ignored]
  Branch (600:16): [True: 0, False: 39]
601
0
                return Err(make_err!(
602
0
                    // From: https://grpc.github.io/grpc/core/md_doc_statuscodes.html
603
0
                    // Use ABORTED if the client should retry at a higher level
604
0
                    // (e.g., when a client-specified test-and-set fails,
605
0
                    // indicating the client should restart a read-modify-write
606
0
                    // sequence)
607
0
                    Code::Aborted,
608
0
                    "{} Expected {} but got {} for operation_id {:?} - {:?}",
609
0
                    "Tried to update an awaited action with an incorrect version.",
610
0
                    old_awaited_action.version(),
611
0
                    new_awaited_action.version(),
612
0
                    old_awaited_action,
613
0
                    new_awaited_action,
614
0
                ));
615
39
            }
616
39
            new_awaited_action.increment_version();
617
618
0
            error_if!(
619
39
                old_awaited_action.action_info().unique_qualifier
  Branch (619:17): [True: 0, False: 0]
  Branch (619:17): [Folded - Ignored]
  Branch (619:17): [True: 0, False: 39]
620
39
                    != new_awaited_action.action_info().unique_qualifier,
621
                "Unique key changed for operation_id {:?} - {:?} - {:?}",
622
0
                new_awaited_action.operation_id(),
623
0
                old_awaited_action.action_info(),
624
0
                new_awaited_action.action_info(),
625
            );
626
39
            let is_same_stage = old_awaited_action
627
39
                .state()
628
39
                .stage
629
39
                .is_same_stage(&new_awaited_action.state().stage);
630
39
631
39
            if !is_same_stage {
  Branch (631:16): [True: 0, False: 0]
  Branch (631:16): [Folded - Ignored]
  Branch (631:16): [True: 39, False: 0]
632
39
                self.sorted_action_info_hash_keys
633
39
                    .process_state_changes(&old_awaited_action, &new_awaited_action)
?0
;
634
39
                Self::process_state_changes_for_hash_key_map(
635
39
                    &mut self.action_info_hash_key_to_awaited_action,
636
39
                    &new_awaited_action,
637
39
                );
638
0
            }
639
        }
640
641
        // Notify all listeners of the new state and ignore if no one is listening.
642
        // Note: Do not use `.send()` as it will not update the state if all listeners
643
        // are dropped.
644
39
        let _ = tx.send_replace(new_awaited_action);
645
39
646
39
        Ok(())
647
39
    }
648
649
    /// Creates a new [`ClientAwaitedAction`] and a [`watch::Receiver`] to
650
    /// listen for changes. We don't do this in-line because it is important
651
    /// to ALWAYS construct a [`ClientAwaitedAction`] before inserting it into
652
    /// the map. Failing to do so may result in memory leaks. This is because
653
    /// [`ClientAwaitedAction`] implements a drop function that will trigger
654
    /// cleanup of the other maps on drop.
655
24
    fn make_client_awaited_action(
656
24
        &mut self,
657
24
        operation_id: &OperationId,
658
24
        awaited_action: AwaitedAction,
659
24
    ) -> (Arc<ClientAwaitedAction>, watch::Receiver<AwaitedAction>) {
660
24
        let (tx, rx) = watch::channel(awaited_action);
661
24
        let client_awaited_action = Arc::new(ClientAwaitedAction::new(
662
24
            operation_id.clone(),
663
24
            self.action_event_tx.clone(),
664
24
        ));
665
24
        self.operation_id_to_awaited_action
666
24
            .insert(operation_id.clone(), tx);
667
24
        self.connected_clients_for_operation_id
668
24
            .insert(operation_id.clone(), 1);
669
24
        (client_awaited_action, rx)
670
24
    }
671
672
27
    async fn add_action(
673
27
        &mut self,
674
27
        client_operation_id: OperationId,
675
27
        action_info: Arc<ActionInfo>,
676
27
    ) -> Result<MemoryAwaitedActionSubscriber<I, NowFn>, Error> {
677
        // Check to see if the action is already known and subscribe if it is.
678
27
        let subscription_result = self
679
27
            .try_subscribe(
680
27
                &client_operation_id,
681
27
                &action_info.unique_qualifier,
682
27
                action_info.priority,
683
27
            )
684
27
            .await
685
27
            .err_tip(|| 
"In AwaitedActionDb::subscribe_or_add_action"0
);
686
27
        match subscription_result {
687
0
            Err(err) => return Err(err),
688
3
            Ok(Some(subscription)) => return Ok(subscription),
689
24
            Ok(None) => { /* Add item to queue. */ }
690
        }
691
692
24
        let maybe_unique_key = match &action_info.unique_qualifier {
693
24
            ActionUniqueQualifier::Cachable(unique_key) => Some(unique_key.clone()),
694
0
            ActionUniqueQualifier::Uncachable(_unique_key) => None,
695
        };
696
24
        let operation_id = OperationId::default();
697
24
        let awaited_action =
698
24
            AwaitedAction::new(operation_id.clone(), action_info, (self.now_fn)().now());
699
24
        debug_assert!(
700
0
            ActionStage::Queued == awaited_action.state().stage,
701
0
            "Expected action to be queued"
702
        );
703
24
        let sort_key = awaited_action.sort_key();
704
24
705
24
        let (client_awaited_action, rx) =
706
24
            self.make_client_awaited_action(&operation_id.clone(), awaited_action);
707
24
708
24
        event!(
709
24
            Level::DEBUG,
710
            ?client_operation_id,
711
            ?operation_id,
712
            ?client_awaited_action,
713
0
            "Adding action"
714
        );
715
716
24
        self.client_operation_to_awaited_action
717
24
            .insert(client_operation_id.clone(), client_awaited_action)
718
24
            .await;
719
720
        // Note: We only put items in the map that are cachable.
721
24
        if let Some(unique_key) = maybe_unique_key {
  Branch (721:16): [True: 0, False: 0]
  Branch (721:16): [Folded - Ignored]
  Branch (721:16): [True: 24, False: 0]
722
24
            let old_value = self
723
24
                .action_info_hash_key_to_awaited_action
724
24
                .insert(unique_key, operation_id.clone());
725
24
            if let Some(
old_value0
) = old_value {
  Branch (725:20): [True: 0, False: 0]
  Branch (725:20): [Folded - Ignored]
  Branch (725:20): [True: 0, False: 24]
726
0
                event!(
727
0
                    Level::ERROR,
728
                    ?operation_id,
729
                    ?old_value,
730
0
                    "action_info_hash_key_to_awaited_action already has unique_key"
731
                );
732
24
            }
733
0
        }
734
735
24
        self.sorted_action_info_hash_keys
736
24
            .insert_sort_map_for_stage(
737
24
                &ActionStage::Queued,
738
24
                &SortedAwaitedAction {
739
24
                    sort_key,
740
24
                    operation_id,
741
24
                },
742
24
            )
743
24
            .err_tip(|| 
"In AwaitedActionDb::subscribe_or_add_action"0
)
?0
;
744
745
24
        Ok(MemoryAwaitedActionSubscriber::new_with_client(
746
24
            rx,
747
24
            client_operation_id,
748
24
            self.action_event_tx.clone(),
749
24
            self.now_fn.clone(),
750
24
        ))
751
27
    }
752
753
27
    async fn try_subscribe(
754
27
        &mut self,
755
27
        client_operation_id: &OperationId,
756
27
        unique_qualifier: &ActionUniqueQualifier,
757
27
        // TODO(allada) To simplify the scheduler 2024 refactor, we
758
27
        // removed the ability to upgrade priorities of actions.
759
27
        // we should add priority upgrades back in.
760
27
        _priority: i32,
761
27
    ) -> Result<Option<MemoryAwaitedActionSubscriber<I, NowFn>>, Error> {
762
27
        let unique_key = match unique_qualifier {
763
27
            ActionUniqueQualifier::Cachable(unique_key) => unique_key,
764
0
            ActionUniqueQualifier::Uncachable(_unique_key) => return Ok(None),
765
        };
766
767
27
        let Some(
operation_id3
) = self.action_info_hash_key_to_awaited_action.get(unique_key) else {
  Branch (767:13): [True: 0, False: 0]
  Branch (767:13): [Folded - Ignored]
  Branch (767:13): [True: 3, False: 24]
768
24
            return Ok(None); // Not currently running.
769
        };
770
771
3
        let Some(tx) = self.operation_id_to_awaited_action.get(operation_id) else {
  Branch (771:13): [True: 0, False: 0]
  Branch (771:13): [Folded - Ignored]
  Branch (771:13): [True: 3, False: 0]
772
0
            return Err(make_err!(
773
0
                Code::Internal,
774
0
                "operation_id_to_awaited_action and action_info_hash_key_to_awaited_action are out of sync for {unique_key:?} - {operation_id}"
775
0
            ));
776
        };
777
778
0
        error_if!(
779
3
            tx.borrow().state().stage.is_finished(),
  Branch (779:13): [True: 0, False: 0]
  Branch (779:13): [Folded - Ignored]
  Branch (779:13): [True: 0, False: 3]
780
            "Tried to subscribe to a completed action but it already finished. This should never happen. {:?}",
781
0
            tx.borrow()
782
        );
783
784
3
        let maybe_connected_clients = self
785
3
            .connected_clients_for_operation_id
786
3
            .get_mut(operation_id);
787
3
        let Some(connected_clients) = maybe_connected_clients else {
  Branch (787:13): [True: 0, False: 0]
  Branch (787:13): [Folded - Ignored]
  Branch (787:13): [True: 3, False: 0]
788
0
            return Err(make_err!(
789
0
                Code::Internal,
790
0
                "connected_clients_for_operation_id and operation_id_to_awaited_action are out of sync for {unique_key:?} - {operation_id}"
791
0
            ));
792
        };
793
3
        *connected_clients += 1;
794
3
795
3
        // Immediately mark the keep alive, we don't need to wake anyone
796
3
        // so we always fake that it was not actually changed.
797
3
        // Failing update the client could lead to the client connecting
798
3
        // then not updating the keep alive in time, resulting in the
799
3
        // operation timing out due to async behavior.
800
3
        tx.send_if_modified(|awaited_action| {
801
3
            awaited_action.update_client_keep_alive((self.now_fn)().now());
802
3
            false
803
3
        });
804
3
        let subscription = tx.subscribe();
805
3
806
3
        self.client_operation_to_awaited_action
807
3
            .insert(
808
3
                client_operation_id.clone(),
809
3
                Arc::new(ClientAwaitedAction::new(
810
3
                    operation_id.clone(),
811
3
                    self.action_event_tx.clone(),
812
3
                )),
813
3
            )
814
3
            .await;
815
816
3
        Ok(Some(MemoryAwaitedActionSubscriber::new_with_client(
817
3
            subscription,
818
3
            client_operation_id.clone(),
819
3
            self.action_event_tx.clone(),
820
3
            self.now_fn.clone(),
821
3
        )))
822
27
    }
823
}
824
825
#[derive(MetricsComponent)]
826
pub struct MemoryAwaitedActionDb<I: InstantWrapper, NowFn: Fn() -> I> {
827
    #[metric]
828
    inner: Arc<Mutex<AwaitedActionDbImpl<I, NowFn>>>,
829
    tasks_change_notify: Arc<Notify>,
830
    _handle_awaited_action_events: JoinHandleDropGuard<()>,
831
}
832
833
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static>
834
    MemoryAwaitedActionDb<I, NowFn>
835
{
836
20
    pub fn new(
837
20
        eviction_config: &EvictionPolicy,
838
20
        tasks_change_notify: Arc<Notify>,
839
20
        now_fn: NowFn,
840
20
    ) -> Self {
841
20
        let (action_event_tx, mut action_event_rx) = mpsc::unbounded_channel();
842
20
        let inner = Arc::new(Mutex::new(AwaitedActionDbImpl {
843
20
            client_operation_to_awaited_action: EvictingMap::new(eviction_config, (now_fn)()),
844
20
            operation_id_to_awaited_action: BTreeMap::new(),
845
20
            action_info_hash_key_to_awaited_action: HashMap::new(),
846
20
            sorted_action_info_hash_keys: SortedAwaitedActions::default(),
847
20
            connected_clients_for_operation_id: HashMap::new(),
848
20
            action_event_tx,
849
20
            now_fn,
850
20
        }));
851
20
        let weak_inner = Arc::downgrade(&inner);
852
20
        Self {
853
20
            inner,
854
20
            tasks_change_notify,
855
20
            _handle_awaited_action_events: spawn!("handle_awaited_action_events", async move {
856
20
                let mut dropped_operation_ids = Vec::with_capacity(MAX_ACTION_EVENTS_RX_PER_CYCLE);
857
                loop {
858
75
                    dropped_operation_ids.clear();
859
75
                    action_event_rx
860
75
                        .recv_many(&mut dropped_operation_ids, MAX_ACTION_EVENTS_RX_PER_CYCLE)
861
75
                        .await;
862
55
                    let Some(inner) = weak_inner.upgrade() else {
  Branch (862:25): [True: 0, False: 0]
  Branch (862:25): [Folded - Ignored]
  Branch (862:25): [True: 55, False: 0]
863
0
                        return; // Nothing to cleanup, our struct is dropped.
864
                    };
865
55
                    let mut inner = inner.lock().await;
866
55
                    inner
867
55
                        .handle_action_events(dropped_operation_ids.drain(..))
868
55
                        .await;
869
                }
870
20
            
}0
),
871
        }
872
20
    }
873
}
874
875
impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static> AwaitedActionDb
876
    for MemoryAwaitedActionDb<I, NowFn>
877
{
878
    type Subscriber = MemoryAwaitedActionSubscriber<I, NowFn>;
879
880
3
    async fn get_awaited_action_by_id(
881
3
        &self,
882
3
        client_operation_id: &OperationId,
883
3
    ) -> Result<Option<Self::Subscriber>, Error> {
884
3
        self.inner
885
3
            .lock()
886
3
            .await
887
3
            .get_awaited_action_by_id(client_operation_id)
888
3
            .await
889
3
    }
890
891
0
    async fn get_all_awaited_actions(
892
0
        &self,
893
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
894
0
        Ok(ChunkedStream::new(
895
0
            Bound::Unbounded,
896
0
            Bound::Unbounded,
897
0
            move |start, end, mut output| async move {
898
0
                let inner = self.inner.lock().await;
899
0
                let mut maybe_new_start = None;
900
901
0
                for (operation_id, item) in
902
0
                    inner.get_awaited_actions_range(start.as_ref(), end.as_ref())
903
0
                {
904
0
                    output.push_back(item);
905
0
                    maybe_new_start = Some(operation_id);
906
0
                }
907
908
0
                Ok(maybe_new_start
909
0
                    .map(|new_start| ((Bound::Excluded(new_start.clone()), end), output)))
910
0
            },
911
0
        ))
912
0
    }
913
914
40
    async fn get_by_operation_id(
915
40
        &self,
916
40
        operation_id: &OperationId,
917
40
    ) -> Result<Option<Self::Subscriber>, Error> {
918
40
        Ok(self.inner.lock().await.get_by_operation_id(operation_id))
919
40
    }
920
921
590
    async fn get_range_of_actions(
922
590
        &self,
923
590
        state: SortedAwaitedActionState,
924
590
        start: Bound<SortedAwaitedAction>,
925
590
        end: Bound<SortedAwaitedAction>,
926
590
        desc: bool,
927
590
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
928
590
        Ok(ChunkedStream::new(
929
590
            start,
930
590
            end,
931
1.13k
            move |start, end, mut output| async move {
932
1.13k
                let inner = self.inner.lock().await;
933
1.13k
                let mut done = true;
934
1.13k
                let mut new_start = start.as_ref();
935
1.13k
                let mut new_end = end.as_ref();
936
1.13k
937
1.13k
                let iterator = inner
938
1.13k
                    .get_range_of_actions(state, (start.as_ref(), end.as_ref()))
939
1.13k
                    .map(|res| 
res.err_tip(543
||
"In AwaitedActionDb::get_range_of_actions"0
)543
);
940
1.13k
941
1.13k
                // TODO(allada) This should probably use the `.left()/right()` pattern,
942
1.13k
                // but that doesn't exist in the std or any libraries we use.
943
1.13k
                if desc {
  Branch (943:20): [True: 0, False: 0]
  Branch (943:20): [Folded - Ignored]
  Branch (943:20): [True: 130, False: 1.00k]
944
130
                    for 
result43
in iterator.rev() {
945
43
                        let (sorted_awaited_action, item) =
946
43
                            result.err_tip(|| 
"In AwaitedActionDb::get_range_of_actions"0
)
?0
;
947
43
                        output.push_back(item);
948
43
                        new_end = Bound::Excluded(sorted_awaited_action);
949
43
                        done = false;
950
                    }
951
                } else {
952
1.50k
                    for 
result500
in iterator {
953
500
                        let (sorted_awaited_action, item) =
954
500
                            result.err_tip(|| 
"In AwaitedActionDb::get_range_of_actions"0
)
?0
;
955
500
                        output.push_back(item);
956
500
                        new_start = Bound::Excluded(sorted_awaited_action);
957
500
                        done = false;
958
                    }
959
                }
960
1.13k
                if done {
  Branch (960:20): [True: 0, False: 0]
  Branch (960:20): [Folded - Ignored]
  Branch (960:20): [True: 590, False: 540]
961
590
                    return Ok(None);
962
540
                }
963
540
                Ok(Some(((new_start.cloned(), new_end.cloned()), output)))
964
2.26k
            },
965
590
        ))
966
590
    }
967
968
39
    async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
969
39
        self.inner
970
39
            .lock()
971
39
            .await
972
39
            .update_awaited_action(new_awaited_action)
?0
;
973
39
        self.tasks_change_notify.notify_one();
974
39
        Ok(())
975
39
    }
976
977
27
    async fn add_action(
978
27
        &self,
979
27
        client_operation_id: OperationId,
980
27
        action_info: Arc<ActionInfo>,
981
27
    ) -> Result<Self::Subscriber, Error> {
982
27
        let subscriber = self
983
27
            .inner
984
27
            .lock()
985
27
            .await
986
27
            .add_action(client_operation_id, action_info)
987
27
            .await
?0
;
988
27
        self.tasks_change_notify.notify_one();
989
27
        Ok(subscriber)
990
27
    }
991
}