Coverage Report

Created: 2025-12-16 15:31

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/store_awaited_action_db.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::mem::Discriminant;
16
use core::ops::Bound;
17
use core::sync::atomic::{AtomicU64, Ordering};
18
use core::time::Duration;
19
use std::borrow::Cow;
20
use std::sync::{Arc, Weak};
21
use std::time::UNIX_EPOCH;
22
23
use bytes::Bytes;
24
use futures::{Stream, TryStreamExt};
25
use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err};
26
use nativelink_metric::MetricsComponent;
27
use nativelink_util::action_messages::{
28
    ActionInfo, ActionStage, ActionUniqueQualifier, OperationId,
29
};
30
use nativelink_util::instant_wrapper::InstantWrapper;
31
use nativelink_util::spawn;
32
use nativelink_util::store_trait::{
33
    FalseValue, SchedulerCurrentVersionProvider, SchedulerIndexProvider, SchedulerStore,
34
    SchedulerStoreDataProvider, SchedulerStoreDecodeTo, SchedulerStoreKeyProvider,
35
    SchedulerSubscription, SchedulerSubscriptionManager, StoreKey, TrueValue,
36
};
37
use nativelink_util::task::JoinHandleDropGuard;
38
use tokio::sync::Notify;
39
use tracing::{error, warn};
40
41
use crate::awaited_action_db::{
42
    AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, CLIENT_KEEPALIVE_DURATION,
43
    SortedAwaitedAction, SortedAwaitedActionState,
44
};
45
46
type ClientOperationId = OperationId;
47
48
/// Maximum number of retries to update client keep alive.
49
const MAX_RETRIES_FOR_CLIENT_KEEPALIVE: u32 = 8;
50
51
/// Use separate non-versioned Redis key for client keepalives.
52
const USE_SEPARATE_CLIENT_KEEPALIVE_KEY: bool = true;
53
54
enum OperationSubscriberState<Sub> {
55
    Unsubscribed,
56
    Subscribed(Sub),
57
}
58
59
pub struct OperationSubscriber<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I> {
60
    maybe_client_operation_id: Option<ClientOperationId>,
61
    subscription_key: OperationIdToAwaitedAction<'static>,
62
    weak_store: Weak<S>,
63
    state: OperationSubscriberState<
64
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
65
    >,
66
    last_known_keepalive_ts: AtomicU64,
67
    now_fn: NowFn,
68
    // If the SchedulerSubscriptionManager is not reliable, then this is populated
69
    // when the state is set to subscribed.  When set it causes the state to be polled
70
    // as well as listening for the publishing.
71
    maybe_last_stage: Option<Discriminant<ActionStage>>,
72
}
73
74
impl<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I + core::fmt::Debug> core::fmt::Debug
75
    for OperationSubscriber<S, I, NowFn>
76
where
77
    OperationSubscriberState<
78
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
79
    >: core::fmt::Debug,
80
{
81
0
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
82
0
        f.debug_struct("OperationSubscriber")
83
0
            .field("maybe_client_operation_id", &self.maybe_client_operation_id)
84
0
            .field("subscription_key", &self.subscription_key)
85
0
            .field("weak_store", &self.weak_store)
86
0
            .field("state", &self.state)
87
0
            .field("last_known_keepalive_ts", &self.last_known_keepalive_ts)
88
0
            .field("now_fn", &self.now_fn)
89
0
            .finish()
90
0
    }
91
}
92
impl<S, I, NowFn> OperationSubscriber<S, I, NowFn>
93
where
94
    S: SchedulerStore,
95
    I: InstantWrapper,
96
    NowFn: Fn() -> I,
97
{
98
17
    const fn new(
99
17
        maybe_client_operation_id: Option<ClientOperationId>,
100
17
        subscription_key: OperationIdToAwaitedAction<'static>,
101
17
        weak_store: Weak<S>,
102
17
        now_fn: NowFn,
103
17
    ) -> Self {
104
17
        Self {
105
17
            maybe_client_operation_id,
106
17
            subscription_key,
107
17
            weak_store,
108
17
            last_known_keepalive_ts: AtomicU64::new(0),
109
17
            state: OperationSubscriberState::Unsubscribed,
110
17
            now_fn,
111
17
            maybe_last_stage: None,
112
17
        }
113
17
    }
114
115
28
    async fn inner_get_awaited_action(
116
28
        store: &S,
117
28
        key: OperationIdToAwaitedAction<'_>,
118
28
        maybe_client_operation_id: Option<ClientOperationId>,
119
28
        last_known_keepalive_ts: &AtomicU64,
120
28
    ) -> Result<AwaitedAction, Error> {
121
28
        let mut awaited_action = store
122
28
            .get_and_decode(key.borrow())
123
28
            .await
124
28
            .err_tip(|| format!(
"In OperationSubscriber::get_awaited_action {key:?}"0
))
?0
125
28
            .ok_or_else(|| 
{0
126
0
                make_err!(
127
0
                    Code::NotFound,
128
                    "Could not find AwaitedAction for the given operation id {key:?}",
129
                )
130
0
            })?;
131
28
        if let Some(
client_operation_id6
) = maybe_client_operation_id {
  Branch (131:16): [True: 0, False: 0]
  Branch (131:16): [Folded - Ignored]
  Branch (131:16): [True: 6, False: 22]
132
6
            awaited_action.set_client_operation_id(client_operation_id);
133
22
        }
134
135
        // Helper to convert SystemTime to unix timestamp
136
28
        let to_unix_ts = |t: std::time::SystemTime| -> u64 {
137
28
            t.duration_since(UNIX_EPOCH)
138
28
                .map(|d| d.as_secs())
139
28
                .unwrap_or(0)
140
28
        };
141
142
        // Check the separate keepalive key for the most recent timestamp.
143
28
        let keepalive_ts = if USE_SEPARATE_CLIENT_KEEPALIVE_KEY {
  Branch (143:31): [True: 0, Folded]
  Branch (143:31): [Folded - Ignored]
  Branch (143:31): [True: 28, Folded]
144
28
            let operation_id = key.0.as_ref();
145
28
            match store.get_and_decode(ClientKeepaliveKey(operation_id)).await {
146
0
                Ok(Some(ts)) => {
147
0
                    let awaited_ts = to_unix_ts(awaited_action.last_client_keepalive_timestamp());
148
0
                    if ts > awaited_ts {
  Branch (148:24): [True: 0, False: 0]
  Branch (148:24): [Folded - Ignored]
  Branch (148:24): [True: 0, False: 0]
149
0
                        let timestamp = UNIX_EPOCH + Duration::from_secs(ts);
150
0
                        awaited_action.update_client_keep_alive(timestamp);
151
0
                        ts
152
                    } else {
153
0
                        awaited_ts
154
                    }
155
                }
156
28
                Ok(None) | Err(_) => to_unix_ts(awaited_action.last_client_keepalive_timestamp()),
157
            }
158
        } else {
159
0
            to_unix_ts(awaited_action.last_client_keepalive_timestamp())
160
        };
161
162
28
        last_known_keepalive_ts.store(keepalive_ts, Ordering::Release);
163
28
        Ok(awaited_action)
164
28
    }
165
166
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
167
26
    async fn get_awaited_action(&self) -> Result<AwaitedAction, Error> {
168
26
        let store = self
169
26
            .weak_store
170
26
            .upgrade()
171
26
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")
?0
;
172
26
        Self::inner_get_awaited_action(
173
26
            store.as_ref(),
174
26
            self.subscription_key.borrow(),
175
26
            self.maybe_client_operation_id.clone(),
176
26
            &self.last_known_keepalive_ts,
177
26
        )
178
26
        .await
179
26
    }
180
}
181
182
impl<S, I, NowFn> AwaitedActionSubscriber for OperationSubscriber<S, I, NowFn>
183
where
184
    S: SchedulerStore,
185
    I: InstantWrapper,
186
    NowFn: Fn() -> I + Send + Sync + 'static,
187
{
188
2
    async fn changed(&mut self) -> Result<AwaitedAction, Error> {
189
2
        let store = self
190
2
            .weak_store
191
2
            .upgrade()
192
2
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")
?0
;
193
2
        let 
subscription1
= match &mut self.state {
194
1
            OperationSubscriberState::Subscribed(subscription) => subscription,
195
            OperationSubscriberState::Unsubscribed => {
196
1
                let subscription = store
197
1
                    .subscription_manager()
198
1
                    .err_tip(|| "In OperationSubscriber::changed::subscription_manager")
?0
199
1
                    .subscribe(self.subscription_key.borrow())
200
1
                    .err_tip(|| "In OperationSubscriber::changed::subscribe")
?0
;
201
1
                self.state = OperationSubscriberState::Subscribed(subscription);
202
                // When we've just subscribed, there may have been changes before now.
203
1
                let action = Self::inner_get_awaited_action(
204
1
                    store.as_ref(),
205
1
                    self.subscription_key.borrow(),
206
1
                    self.maybe_client_operation_id.clone(),
207
1
                    &self.last_known_keepalive_ts,
208
1
                )
209
1
                .await
210
1
                .err_tip(|| "In OperationSubscriber::changed")
?0
;
211
1
                if !<S as SchedulerStore>::SubscriptionManager::is_reliable() {
  Branch (211:20): [True: 0, False: 0]
  Branch (211:20): [Folded - Ignored]
  Branch (211:20): [True: 1, False: 0]
212
1
                    self.maybe_last_stage = Some(core::mem::discriminant(&action.state().stage));
213
1
                
}0
214
                // Existing changes are only interesting if the state is past queued.
215
1
                if !matches!(action.state().stage, ActionStage::Queued) {
  Branch (215:20): [True: 0, False: 0]
  Branch (215:20): [Folded - Ignored]
  Branch (215:20): [True: 1, False: 0]
216
1
                    return Ok(action);
217
0
                }
218
0
                let OperationSubscriberState::Subscribed(subscription) = &mut self.state else {
  Branch (218:21): [True: 0, False: 0]
  Branch (218:21): [Folded - Ignored]
  Branch (218:21): [True: 0, False: 0]
219
0
                    unreachable!("Subscription should be in Subscribed state");
220
                };
221
0
                subscription
222
            }
223
        };
224
225
1
        let changed_fut = subscription.changed();
226
1
        tokio::pin!(changed_fut);
227
        loop {
228
            // This is set if the maybe_last_state doesn't match the state in the store.
229
1
            let mut maybe_changed_action = None;
230
231
1
            let last_known_keepalive_ts = self.last_known_keepalive_ts.load(Ordering::Acquire);
232
1
            if I::from_secs(last_known_keepalive_ts).elapsed() > CLIENT_KEEPALIVE_DURATION {
  Branch (232:16): [True: 0, False: 0]
  Branch (232:16): [Folded - Ignored]
  Branch (232:16): [True: 0, False: 1]
233
0
                let now = (self.now_fn)().now();
234
0
                let now_ts = now
235
0
                    .duration_since(UNIX_EPOCH)
236
0
                    .map(|d| d.as_secs())
237
0
                    .unwrap_or(0);
238
239
0
                if USE_SEPARATE_CLIENT_KEEPALIVE_KEY {
  Branch (239:20): [True: 0, Folded]
  Branch (239:20): [Folded - Ignored]
  Branch (239:20): [True: 0, Folded]
240
0
                    let operation_id = self.subscription_key.0.as_ref();
241
0
                    let update_result = store
242
0
                        .update_data(UpdateClientKeepalive {
243
0
                            operation_id,
244
0
                            timestamp: now_ts,
245
0
                        })
246
0
                        .await;
247
248
0
                    if let Err(e) = update_result {
  Branch (248:28): [True: 0, False: 0]
  Branch (248:28): [Folded - Ignored]
  Branch (248:28): [True: 0, False: 0]
249
0
                        warn!(
250
                            ?self.subscription_key,
251
                            ?e,
252
0
                            "Failed to update client keepalive (non-versioned)"
253
                        );
254
0
                    }
255
256
                    // Update local timestamp
257
0
                    self.last_known_keepalive_ts
258
0
                        .store(now_ts, Ordering::Release);
259
260
                    // Check if state changed (for unreliable subscription managers)
261
0
                    if self.maybe_last_stage.is_some() {
  Branch (261:24): [True: 0, False: 0]
  Branch (261:24): [Folded - Ignored]
  Branch (261:24): [True: 0, False: 0]
262
0
                        let awaited_action = Self::inner_get_awaited_action(
263
0
                            store.as_ref(),
264
0
                            self.subscription_key.borrow(),
265
0
                            self.maybe_client_operation_id.clone(),
266
0
                            &self.last_known_keepalive_ts,
267
0
                        )
268
0
                        .await
269
0
                        .err_tip(|| "In OperationSubscriber::changed")?;
270
271
0
                        if self.maybe_last_stage.as_ref().is_some_and(|last_stage| {
  Branch (271:28): [True: 0, False: 0]
  Branch (271:28): [Folded - Ignored]
  Branch (271:28): [True: 0, False: 0]
272
0
                            *last_stage != core::mem::discriminant(&awaited_action.state().stage)
273
0
                        }) {
274
0
                            maybe_changed_action = Some(awaited_action);
275
0
                        }
276
0
                    }
277
                } else {
278
0
                    for attempt in 1..=MAX_RETRIES_FOR_CLIENT_KEEPALIVE {
279
0
                        if attempt > 1 {
  Branch (279:28): [True: 0, False: 0]
  Branch (279:28): [Folded - Ignored]
  Branch (279:28): [True: 0, False: 0]
280
0
                            (self.now_fn)().sleep(Duration::from_millis(100)).await;
281
0
                            warn!(
282
                                ?self.subscription_key,
283
                                attempt,
284
0
                                "Client keepalive retry due to version conflict"
285
                            );
286
0
                        }
287
0
                        let mut awaited_action = Self::inner_get_awaited_action(
288
0
                            store.as_ref(),
289
0
                            self.subscription_key.borrow(),
290
0
                            self.maybe_client_operation_id.clone(),
291
0
                            &self.last_known_keepalive_ts,
292
0
                        )
293
0
                        .await
294
0
                        .err_tip(|| "In OperationSubscriber::changed")?;
295
0
                        awaited_action.update_client_keep_alive(now);
296
0
                        maybe_changed_action = self
297
0
                            .maybe_last_stage
298
0
                            .as_ref()
299
0
                            .is_some_and(|last_stage| {
300
0
                                *last_stage
301
0
                                    != core::mem::discriminant(&awaited_action.state().stage)
302
0
                            })
303
0
                            .then(|| awaited_action.clone());
304
0
                        match inner_update_awaited_action(store.as_ref(), awaited_action).await {
305
0
                            Ok(()) => break,
306
0
                            err if attempt == MAX_RETRIES_FOR_CLIENT_KEEPALIVE => {
  Branch (306:36): [True: 0, False: 0]
  Branch (306:36): [Folded - Ignored]
  Branch (306:36): [True: 0, False: 0]
307
0
                                err.err_tip_with_code(|_| {
308
0
                                    (Code::Aborted, "Could not update client keep alive")
309
0
                                })?;
310
                            }
311
0
                            _ => (),
312
                        }
313
                    }
314
                }
315
1
            }
316
317
            // If the polling shows that it's changed state then publish now.
318
1
            if let Some(
changed_action0
) = maybe_changed_action {
  Branch (318:20): [True: 0, False: 0]
  Branch (318:20): [Folded - Ignored]
  Branch (318:20): [True: 0, False: 1]
319
0
                self.maybe_last_stage =
320
0
                    Some(core::mem::discriminant(&changed_action.state().stage));
321
0
                return Ok(changed_action);
322
1
            }
323
            // Determine the sleep time based on the last client keep alive.
324
1
            let sleep_time = CLIENT_KEEPALIVE_DURATION
325
1
                .checked_sub(
326
1
                    I::from_secs(self.last_known_keepalive_ts.load(Ordering::Acquire)).elapsed(),
327
                )
328
1
                .unwrap_or(Duration::from_millis(100));
329
1
            tokio::select! {
330
1
                result = &mut changed_fut => {
331
1
                    result
?0
;
332
1
                    break;
333
                }
334
1
                () = (self.now_fn)().sleep(sleep_time) => {
335
0
                    // If we haven't received any updates for a while, we should
336
0
                    // let the database know that we are still listening to prevent
337
0
                    // the action from being dropped.  Also poll for updates if the
338
0
                    // subscription manager is unreliable.
339
0
                }
340
            }
341
        }
342
343
1
        let awaited_action = Self::inner_get_awaited_action(
344
1
            store.as_ref(),
345
1
            self.subscription_key.borrow(),
346
1
            self.maybe_client_operation_id.clone(),
347
1
            &self.last_known_keepalive_ts,
348
1
        )
349
1
        .await
350
1
        .err_tip(|| "In OperationSubscriber::changed")
?0
;
351
1
        if self.maybe_last_stage.is_some() {
  Branch (351:12): [True: 0, False: 0]
  Branch (351:12): [Folded - Ignored]
  Branch (351:12): [True: 1, False: 0]
352
1
            self.maybe_last_stage = Some(core::mem::discriminant(&awaited_action.state().stage));
353
1
        
}0
354
1
        Ok(awaited_action)
355
2
    }
356
357
26
    async fn borrow(&self) -> Result<AwaitedAction, Error> {
358
26
        self.get_awaited_action()
359
26
            .await
360
26
            .err_tip(|| "In OperationSubscriber::borrow")
361
26
    }
362
}
363
364
36
fn awaited_action_decode(version: i64, data: &Bytes) -> Result<AwaitedAction, Error> {
365
36
    let mut awaited_action: AwaitedAction = serde_json::from_slice(data)
366
36
        .map_err(|e| make_input_err!("In AwaitedAction::decode - {e:?}"))
?0
;
367
36
    awaited_action.set_version(version);
368
36
    Ok(awaited_action)
369
36
}
370
371
const OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX: &str = "aa_";
372
const CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX: &str = "cid_";
373
/// Phase 2: Separate key prefix for client keepalives (non-versioned).
374
const CLIENT_KEEPALIVE_KEY_PREFIX: &str = "ck_";
375
376
#[derive(Debug)]
377
struct OperationIdToAwaitedAction<'a>(Cow<'a, OperationId>);
378
impl OperationIdToAwaitedAction<'_> {
379
57
    fn borrow(&self) -> OperationIdToAwaitedAction<'_> {
380
57
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.as_ref()))
381
57
    }
382
}
383
impl SchedulerStoreKeyProvider for OperationIdToAwaitedAction<'_> {
384
    type Versioned = TrueValue;
385
47
    fn get_key(&self) -> StoreKey<'static> {
386
47
        StoreKey::Str(Cow::Owned(format!(
387
47
            "{OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX}{}",
388
47
            self.0
389
47
        )))
390
47
    }
391
}
392
impl SchedulerStoreDecodeTo for OperationIdToAwaitedAction<'_> {
393
    type DecodeOutput = AwaitedAction;
394
29
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
395
29
        awaited_action_decode(version, &data)
396
29
    }
397
}
398
399
struct ClientIdToOperationId<'a>(&'a OperationId);
400
impl SchedulerStoreKeyProvider for ClientIdToOperationId<'_> {
401
    type Versioned = FalseValue;
402
5
    fn get_key(&self) -> StoreKey<'static> {
403
5
        StoreKey::Str(Cow::Owned(format!(
404
5
            "{CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX}{}",
405
5
            self.0
406
5
        )))
407
5
    }
408
}
409
impl SchedulerStoreDecodeTo for ClientIdToOperationId<'_> {
410
    type DecodeOutput = OperationId;
411
2
    fn decode(_version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
412
2
        serde_json::from_slice(&data).map_err(|e| 
{0
413
0
            make_input_err!(
414
                "In ClientIdToOperationId::decode - {e:?} (data: {:02x?})",
415
                data
416
            )
417
0
        })
418
2
    }
419
}
420
421
struct ClientKeepaliveKey<'a>(&'a OperationId);
422
impl SchedulerStoreKeyProvider for ClientKeepaliveKey<'_> {
423
    type Versioned = FalseValue;
424
28
    fn get_key(&self) -> StoreKey<'static> {
425
28
        StoreKey::Str(Cow::Owned(format!(
426
28
            "{CLIENT_KEEPALIVE_KEY_PREFIX}{}",
427
28
            self.0
428
28
        )))
429
28
    }
430
}
431
impl SchedulerStoreDecodeTo for ClientKeepaliveKey<'_> {
432
    type DecodeOutput = u64;
433
0
    fn decode(_version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
434
0
        let s = core::str::from_utf8(&data)
435
0
            .map_err(|e| make_input_err!("In ClientKeepaliveKey::decode utf8 - {e:?}"))?;
436
0
        s.parse::<u64>()
437
0
            .map_err(|e| make_input_err!("In ClientKeepaliveKey::decode parse - {e:?}"))
438
0
    }
439
}
440
441
struct UpdateClientKeepalive<'a> {
442
    operation_id: &'a OperationId,
443
    timestamp: u64,
444
}
445
impl SchedulerStoreKeyProvider for UpdateClientKeepalive<'_> {
446
    type Versioned = FalseValue;
447
0
    fn get_key(&self) -> StoreKey<'static> {
448
0
        ClientKeepaliveKey(self.operation_id).get_key()
449
0
    }
450
}
451
impl SchedulerStoreDataProvider for UpdateClientKeepalive<'_> {
452
0
    fn try_into_bytes(self) -> Result<Bytes, Error> {
453
0
        Ok(Bytes::from(self.timestamp.to_string()))
454
0
    }
455
}
456
457
// TODO(palfrey) We only need operation_id here, it would be nice if we had a way
458
// to tell the decoder we only care about specific fields.
459
struct SearchUniqueQualifierToAwaitedAction<'a>(&'a ActionUniqueQualifier);
460
impl SchedulerIndexProvider for SearchUniqueQualifierToAwaitedAction<'_> {
461
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
462
    const INDEX_NAME: &'static str = "unique_qualifier";
463
    type Versioned = TrueValue;
464
3
    fn index_value(&self) -> Cow<'_, str> {
465
3
        Cow::Owned(format!("{}", self.0))
466
3
    }
467
}
468
impl SchedulerStoreDecodeTo for SearchUniqueQualifierToAwaitedAction<'_> {
469
    type DecodeOutput = AwaitedAction;
470
2
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
471
2
        awaited_action_decode(version, &data)
472
2
    }
473
}
474
475
struct SearchStateToAwaitedAction(&'static str);
476
impl SchedulerIndexProvider for SearchStateToAwaitedAction {
477
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
478
    const INDEX_NAME: &'static str = "state";
479
    const MAYBE_SORT_KEY: Option<&'static str> = Some("sort_key");
480
    type Versioned = TrueValue;
481
17
    fn index_value(&self) -> Cow<'_, str> {
482
17
        Cow::Borrowed(self.0)
483
17
    }
484
}
485
impl SchedulerStoreDecodeTo for SearchStateToAwaitedAction {
486
    type DecodeOutput = AwaitedAction;
487
5
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
488
5
        awaited_action_decode(version, &data)
489
5
    }
490
}
491
492
30
const fn get_state_prefix(state: SortedAwaitedActionState) -> &'static str {
493
30
    match state {
494
0
        SortedAwaitedActionState::CacheCheck => "cache_check",
495
25
        SortedAwaitedActionState::Queued => "queued",
496
4
        SortedAwaitedActionState::Executing => "executing",
497
1
        SortedAwaitedActionState::Completed => "completed",
498
    }
499
30
}
500
501
struct UpdateOperationIdToAwaitedAction(AwaitedAction);
502
impl SchedulerCurrentVersionProvider for UpdateOperationIdToAwaitedAction {
503
13
    fn current_version(&self) -> i64 {
504
13
        self.0.version()
505
13
    }
506
}
507
impl SchedulerStoreKeyProvider for UpdateOperationIdToAwaitedAction {
508
    type Versioned = TrueValue;
509
13
    fn get_key(&self) -> StoreKey<'static> {
510
13
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.operation_id())).get_key()
511
13
    }
512
}
513
impl SchedulerStoreDataProvider for UpdateOperationIdToAwaitedAction {
514
13
    fn try_into_bytes(self) -> Result<Bytes, Error> {
515
13
        serde_json::to_string(&self.0)
516
13
            .map(Bytes::from)
517
13
            .map_err(|e| make_input_err!("Could not convert AwaitedAction to json - {e:?}"))
518
13
    }
519
13
    fn get_indexes(&self) -> Result<Vec<(&'static str, Bytes)>, Error> {
520
13
        let unique_qualifier = &self.0.action_info().unique_qualifier;
521
13
        let maybe_unique_qualifier = match &unique_qualifier {
522
13
            ActionUniqueQualifier::Cacheable(_) => Some(unique_qualifier),
523
0
            ActionUniqueQualifier::Uncacheable(_) => None,
524
        };
525
13
        let mut output = Vec::with_capacity(2 + maybe_unique_qualifier.map_or(0, |_| 1));
526
13
        if maybe_unique_qualifier.is_some() {
  Branch (526:12): [True: 13, False: 0]
  Branch (526:12): [Folded - Ignored]
527
13
            output.push((
528
13
                "unique_qualifier",
529
13
                Bytes::from(unique_qualifier.to_string()),
530
13
            ));
531
13
        
}0
532
        {
533
13
            let state = SortedAwaitedActionState::try_from(&self.0.state().stage)
534
13
                .err_tip(|| "In UpdateOperationIdToAwaitedAction::get_index")
?0
;
535
13
            output.push(("state", Bytes::from(get_state_prefix(state))));
536
13
            let sorted_awaited_action = SortedAwaitedAction::from(&self.0);
537
13
            output.push((
538
13
                "sort_key",
539
13
                // We encode to hex to ensure that the sort key is lexicographically sorted.
540
13
                Bytes::from(format!("{:016x}", sorted_awaited_action.sort_key.as_u64())),
541
13
            ));
542
        }
543
13
        Ok(output)
544
13
    }
545
}
546
547
struct UpdateClientIdToOperationId {
548
    client_operation_id: ClientOperationId,
549
    operation_id: OperationId,
550
}
551
impl SchedulerStoreKeyProvider for UpdateClientIdToOperationId {
552
    type Versioned = FalseValue;
553
3
    fn get_key(&self) -> StoreKey<'static> {
554
3
        ClientIdToOperationId(&self.client_operation_id).get_key()
555
3
    }
556
}
557
impl SchedulerStoreDataProvider for UpdateClientIdToOperationId {
558
3
    fn try_into_bytes(self) -> Result<Bytes, Error> {
559
3
        serde_json::to_string(&self.operation_id)
560
3
            .map(Bytes::from)
561
3
            .map_err(|e| make_input_err!("Could not convert OperationId to json - {e:?}"))
562
3
    }
563
}
564
565
10
async fn inner_update_awaited_action(
566
10
    store: &impl SchedulerStore,
567
10
    mut new_awaited_action: AwaitedAction,
568
10
) -> Result<(), Error> {
569
10
    let operation_id = new_awaited_action.operation_id().clone();
570
10
    if new_awaited_action.state().client_operation_id != operation_id {
  Branch (570:8): [True: 0, False: 0]
  Branch (570:8): [Folded - Ignored]
  Branch (570:8): [True: 0, False: 10]
571
0
        new_awaited_action.set_client_operation_id(operation_id.clone());
572
10
    }
573
574
10
    let _is_finished = new_awaited_action.state().stage.is_finished();
575
576
10
    let maybe_version = store
577
10
        .update_data(UpdateOperationIdToAwaitedAction(new_awaited_action))
578
10
        .await
579
10
        .err_tip(|| "In RedisAwaitedActionDb::update_awaited_action")
?0
;
580
581
10
    if maybe_version.is_none() {
  Branch (581:8): [True: 0, False: 0]
  Branch (581:8): [Folded - Ignored]
  Branch (581:8): [True: 1, False: 9]
582
1
        warn!(
583
            %operation_id,
584
1
            "Could not update AwaitedAction because the version did not match"
585
        );
586
1
        return Err(make_err!(
587
1
            Code::Aborted,
588
1
            "Could not update AwaitedAction because the version did not match for {operation_id}",
589
1
        ));
590
9
    }
591
592
9
    Ok(())
593
10
}
594
595
#[derive(Debug, MetricsComponent)]
596
pub struct StoreAwaitedActionDb<S, F, I, NowFn>
597
where
598
    S: SchedulerStore,
599
    F: Fn() -> OperationId,
600
    I: InstantWrapper,
601
    NowFn: Fn() -> I,
602
{
603
    store: Arc<S>,
604
    now_fn: NowFn,
605
    operation_id_creator: F,
606
    _pull_task_change_subscriber_spawn: JoinHandleDropGuard<()>,
607
}
608
609
impl<S, F, I, NowFn> StoreAwaitedActionDb<S, F, I, NowFn>
610
where
611
    S: SchedulerStore,
612
    F: Fn() -> OperationId,
613
    I: InstantWrapper,
614
    NowFn: Fn() -> I + Send + Sync + Clone + 'static,
615
{
616
3
    pub fn new(
617
3
        store: Arc<S>,
618
3
        task_change_publisher: Arc<Notify>,
619
3
        now_fn: NowFn,
620
3
        operation_id_creator: F,
621
3
    ) -> Result<Self, Error> {
622
3
        let mut subscription = store
623
3
            .subscription_manager()
624
3
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
625
3
            .subscribe(OperationIdToAwaitedAction(Cow::Owned(OperationId::String(
626
3
                String::new(),
627
3
            ))))
628
3
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
;
629
3
        let pull_task_change_subscriber = spawn!(
630
            "redis_awaited_action_db_pull_task_change_subscriber",
631
3
            async move {
632
                loop {
633
17
                    let 
changed_res14
= subscription
634
17
                        .changed()
635
17
                        .await
636
14
                        .err_tip(|| "In RedisAwaitedActionDb::new");
637
14
                    if let Err(
err0
) = changed_res {
  Branch (637:28): [True: 0, False: 0]
  Branch (637:28): [Folded - Ignored]
  Branch (637:28): [True: 0, False: 0]
  Branch (637:28): [True: 0, False: 1]
  Branch (637:28): [True: 0, False: 12]
  Branch (637:28): [True: 0, False: 1]
638
0
                        error!(
639
0
                            "Error waiting for pull task change subscriber in RedisAwaitedActionDb::new  - {err:?}"
640
                        );
641
                        // Sleep for a second to avoid a busy loop, then trigger the notify
642
                        // so if a reconnect happens we let local resources know that things
643
                        // might have changed.
644
0
                        tokio::time::sleep(Duration::from_secs(1)).await;
645
14
                    }
646
14
                    task_change_publisher.as_ref().notify_one();
647
                }
648
            }
649
        );
650
3
        Ok(Self {
651
3
            store,
652
3
            now_fn,
653
3
            operation_id_creator,
654
3
            _pull_task_change_subscriber_spawn: pull_task_change_subscriber,
655
3
        })
656
3
    }
657
658
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
659
3
    async fn try_subscribe(
660
3
        &self,
661
3
        client_operation_id: &ClientOperationId,
662
3
        unique_qualifier: &ActionUniqueQualifier,
663
3
        no_event_action_timeout: Duration,
664
3
        // TODO(palfrey) To simplify the scheduler 2024 refactor, we
665
3
        // removed the ability to upgrade priorities of actions.
666
3
        // we should add priority upgrades back in.
667
3
        _priority: i32,
668
3
    ) -> Result<Option<AwaitedAction>, Error> {
669
3
        match unique_qualifier {
670
3
            ActionUniqueQualifier::Cacheable(_) => {}
671
0
            ActionUniqueQualifier::Uncacheable(_) => return Ok(None),
672
        }
673
3
        let stream = self
674
3
            .store
675
3
            .search_by_index_prefix(SearchUniqueQualifierToAwaitedAction(unique_qualifier))
676
3
            .await
677
3
            .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
678
3
        tokio::pin!(stream);
679
3
        let maybe_awaited_action = stream
680
3
            .try_next()
681
3
            .await
682
3
            .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
683
3
        match maybe_awaited_action {
684
2
            Some(awaited_action) => {
685
                // TODO(palfrey) We don't support joining completed jobs because we
686
                // need to also check that all the data is still in the cache.
687
                // If the existing job failed then we need to set back to queued or we get
688
                // a version mismatch.  Equally we need to check the timeout as the job
689
                // may be abandoned in the store.
690
2
                let worker_should_update_before = (awaited_action.state().stage
691
2
                    == ActionStage::Executing)
692
2
                    .then_some(())
693
2
                    .map(|()| 
awaited_action0
.
last_worker_updated_timestamp0
())
694
2
                    .and_then(|last_worker_updated| 
{0
695
0
                        last_worker_updated.checked_add(no_event_action_timeout)
696
0
                    });
697
2
                let awaited_action = if awaited_action.state().stage.is_finished()
  Branch (697:41): [True: 0, False: 0]
  Branch (697:41): [Folded - Ignored]
  Branch (697:41): [True: 0, False: 0]
  Branch (697:41): [True: 1, False: 1]
698
1
                    || worker_should_update_before
  Branch (698:24): [True: 0, False: 0]
  Branch (698:24): [Folded - Ignored]
  Branch (698:24): [True: 0, False: 0]
  Branch (698:24): [True: 0, False: 1]
699
1
                        .is_some_and(|timestamp| 
timestamp0
<
(self.now_fn)().now()0
)
700
                {
701
1
                    tracing::debug!(
702
1
                        "Recreating action {:?} for operation {client_operation_id}",
703
1
                        awaited_action.action_info().digest()
704
                    );
705
                    // The version is reset because we have a new operation ID.
706
1
                    AwaitedAction::new(
707
1
                        (self.operation_id_creator)(),
708
1
                        awaited_action.action_info().clone(),
709
1
                        (self.now_fn)().now(),
710
                    )
711
                } else {
712
1
                    tracing::debug!(
713
1
                        "Subscribing to existing action {:?} for operation {client_operation_id}",
714
1
                        awaited_action.action_info().digest()
715
                    );
716
1
                    awaited_action
717
                };
718
2
                Ok(Some(awaited_action))
719
            }
720
1
            None => Ok(None),
721
        }
722
3
    }
723
724
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
725
2
    async fn inner_get_awaited_action_by_id(
726
2
        &self,
727
2
        client_operation_id: &ClientOperationId,
728
2
    ) -> Result<Option<OperationSubscriber<S, I, NowFn>>, Error> {
729
2
        let maybe_operation_id = self
730
2
            .store
731
2
            .get_and_decode(ClientIdToOperationId(client_operation_id))
732
2
            .await
733
2
            .err_tip(|| "In RedisAwaitedActionDb::get_awaited_action_by_id")
?0
;
734
2
        let Some(operation_id) = maybe_operation_id else {
  Branch (734:13): [True: 0, False: 0]
  Branch (734:13): [Folded - Ignored]
  Branch (734:13): [True: 0, False: 0]
  Branch (734:13): [True: 1, False: 0]
  Branch (734:13): [True: 1, False: 0]
735
0
            return Ok(None);
736
        };
737
738
        // Validate that the internal operation actually exists.
739
        // If it doesn't, this is an orphaned client operation mapping that should be cleaned up.
740
        // This can happen when an operation is deleted (completed/timed out) but the
741
        // client_id -> operation_id mapping persists in the store.
742
2
        let maybe_awaited_action = match self
743
2
            .store
744
2
            .get_and_decode(OperationIdToAwaitedAction(Cow::Borrowed(&operation_id)))
745
2
            .await
746
        {
747
2
            Ok(maybe_action) => maybe_action,
748
0
            Err(err) if err.code == Code::NotFound => {
  Branch (748:25): [True: 0, False: 0]
  Branch (748:25): [Folded - Ignored]
  Branch (748:25): [True: 0, False: 0]
  Branch (748:25): [True: 0, False: 0]
  Branch (748:25): [True: 0, False: 0]
749
0
                tracing::warn!(
750
0
                    "Orphaned client operation mapping detected: client_id={} maps to operation_id={}, \
751
0
                    but the operation does not exist in the store (NotFound). This typically happens when \
752
0
                    an operation completes or times out but the client mapping persists.",
753
                    client_operation_id,
754
                    operation_id
755
                );
756
0
                None
757
            }
758
0
            Err(err) => {
759
                // Some other error occurred
760
0
                return Err(err).err_tip(
761
                    || "In RedisAwaitedActionDb::get_awaited_action_by_id::validate_operation",
762
                );
763
            }
764
        };
765
766
2
        if maybe_awaited_action.is_none() {
  Branch (766:12): [True: 0, False: 0]
  Branch (766:12): [Folded - Ignored]
  Branch (766:12): [True: 0, False: 0]
  Branch (766:12): [True: 1, False: 0]
  Branch (766:12): [True: 0, False: 1]
767
1
            tracing::warn!(
768
1
                "Found orphaned client operation mapping: client_id={} -> operation_id={}, \
769
1
                but operation no longer exists. Returning None to prevent client from polling \
770
1
                a non-existent operation.",
771
                client_operation_id,
772
                operation_id
773
            );
774
1
            return Ok(None);
775
1
        }
776
777
1
        Ok(Some(OperationSubscriber::new(
778
1
            Some(client_operation_id.clone()),
779
1
            OperationIdToAwaitedAction(Cow::Owned(operation_id)),
780
1
            Arc::downgrade(&self.store),
781
1
            self.now_fn.clone(),
782
1
        )))
783
2
    }
784
}
785
786
impl<S, F, I, NowFn> AwaitedActionDb for StoreAwaitedActionDb<S, F, I, NowFn>
787
where
788
    S: SchedulerStore,
789
    F: Fn() -> OperationId + Send + Sync + Unpin + 'static,
790
    I: InstantWrapper,
791
    NowFn: Fn() -> I + Send + Sync + Unpin + Clone + 'static,
792
{
793
    type Subscriber = OperationSubscriber<S, I, NowFn>;
794
795
2
    async fn get_awaited_action_by_id(
796
2
        &self,
797
2
        client_operation_id: &ClientOperationId,
798
2
    ) -> Result<Option<Self::Subscriber>, Error> {
799
2
        self.inner_get_awaited_action_by_id(client_operation_id)
800
2
            .await
801
2
    }
802
803
8
    async fn get_by_operation_id(
804
8
        &self,
805
8
        operation_id: &OperationId,
806
8
    ) -> Result<Option<Self::Subscriber>, Error> {
807
8
        Ok(Some(OperationSubscriber::new(
808
8
            None,
809
8
            OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())),
810
8
            Arc::downgrade(&self.store),
811
8
            self.now_fn.clone(),
812
8
        )))
813
8
    }
814
815
10
    async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
816
10
        inner_update_awaited_action(self.store.as_ref(), new_awaited_action).await
817
10
    }
818
819
3
    async fn add_action(
820
3
        &self,
821
3
        client_operation_id: ClientOperationId,
822
3
        action_info: Arc<ActionInfo>,
823
3
        no_event_action_timeout: Duration,
824
3
    ) -> Result<Self::Subscriber, Error> {
825
        loop {
826
            // Check to see if the action is already known and subscribe if it is.
827
3
            let mut awaited_action = self
828
3
                .try_subscribe(
829
3
                    &client_operation_id,
830
3
                    &action_info.unique_qualifier,
831
3
                    no_event_action_timeout,
832
3
                    action_info.priority,
833
3
                )
834
3
                .await
835
3
                .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
836
3
                .unwrap_or_else(|| 
{1
837
1
                    tracing::debug!(
838
1
                        "Creating new action {:?} for operation {client_operation_id}",
839
1
                        action_info.digest()
840
                    );
841
1
                    AwaitedAction::new(
842
1
                        (self.operation_id_creator)(),
843
1
                        action_info.clone(),
844
1
                        (self.now_fn)().now(),
845
                    )
846
1
                });
847
848
3
            debug_assert!(
849
0
                ActionStage::Queued == awaited_action.state().stage,
850
0
                "Expected action to be queued"
851
            );
852
853
3
            let operation_id = awaited_action.operation_id().clone();
854
3
            if awaited_action.state().client_operation_id != operation_id {
  Branch (854:16): [True: 0, False: 0]
  Branch (854:16): [Folded - Ignored]
  Branch (854:16): [True: 0, False: 0]
  Branch (854:16): [True: 0, False: 3]
855
0
                // Just in case the client_operation_id was set to something else
856
0
                // we put it back to the underlying operation_id.
857
0
                awaited_action.set_client_operation_id(operation_id.clone());
858
3
            }
859
3
            awaited_action.update_client_keep_alive((self.now_fn)().now());
860
861
3
            let version = awaited_action.version();
862
3
            if self
  Branch (862:16): [True: 0, False: 0]
  Branch (862:16): [Folded - Ignored]
  Branch (862:16): [True: 0, False: 0]
  Branch (862:16): [True: 0, False: 3]
863
3
                .store
864
3
                .update_data(UpdateOperationIdToAwaitedAction(awaited_action))
865
3
                .await
866
3
                .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
867
3
                .is_none()
868
            {
869
                // The version was out of date, try again.
870
0
                tracing::info!(
871
0
                    "Version out of date for {:?} {operation_id} {version}, retrying.",
872
0
                    action_info.digest()
873
                );
874
0
                continue;
875
3
            }
876
877
            // Add the client_operation_id to operation_id mapping
878
3
            self.store
879
3
                .update_data(UpdateClientIdToOperationId {
880
3
                    client_operation_id: client_operation_id.clone(),
881
3
                    operation_id: operation_id.clone(),
882
3
                })
883
3
                .await
884
3
                .err_tip(|| "In RedisAwaitedActionDb::add_action while adding client mapping")
?0
;
885
886
3
            return Ok(OperationSubscriber::new(
887
3
                Some(client_operation_id),
888
3
                OperationIdToAwaitedAction(Cow::Owned(operation_id)),
889
3
                Arc::downgrade(&self.store),
890
3
                self.now_fn.clone(),
891
3
            ));
892
        }
893
3
    }
894
895
17
    async fn get_range_of_actions(
896
17
        &self,
897
17
        state: SortedAwaitedActionState,
898
17
        start: Bound<SortedAwaitedAction>,
899
17
        end: Bound<SortedAwaitedAction>,
900
17
        desc: bool,
901
17
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
902
17
        if !
matches!0
(start, Bound::Unbounded) {
  Branch (902:12): [True: 0, False: 0]
  Branch (902:12): [Folded - Ignored]
  Branch (902:12): [True: 0, False: 17]
903
0
            return Err(make_err!(
904
0
                Code::Unimplemented,
905
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
906
0
            ));
907
17
        }
908
17
        if !
matches!0
(end, Bound::Unbounded) {
  Branch (908:12): [True: 0, False: 0]
  Branch (908:12): [Folded - Ignored]
  Branch (908:12): [True: 0, False: 17]
909
0
            return Err(make_err!(
910
0
                Code::Unimplemented,
911
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
912
0
            ));
913
17
        }
914
        // TODO(palfrey) This API is not difficult to implement, but there is no code path
915
        // that uses it, so no reason to implement it yet.
916
17
        if !desc {
  Branch (916:12): [True: 0, False: 0]
  Branch (916:12): [Folded - Ignored]
  Branch (916:12): [True: 0, False: 17]
917
0
            return Err(make_err!(
918
0
                Code::Unimplemented,
919
0
                "Descending order is not supported in RedisAwaitedActionDb::get_range_of_actions",
920
0
            ));
921
17
        }
922
17
        Ok(self
923
17
            .store
924
17
            .search_by_index_prefix(SearchStateToAwaitedAction(get_state_prefix(state)))
925
17
            .await
926
17
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")
?0
927
17
            .map_ok(move |awaited_action| 
{5
928
5
                OperationSubscriber::new(
929
5
                    None,
930
5
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
931
5
                    Arc::downgrade(&self.store),
932
5
                    self.now_fn.clone(),
933
                )
934
5
            }))
935
17
    }
936
937
0
    async fn get_all_awaited_actions(
938
0
        &self,
939
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
940
0
        Ok(self
941
0
            .store
942
0
            .search_by_index_prefix(SearchStateToAwaitedAction(""))
943
0
            .await
944
0
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")?
945
0
            .map_ok(move |awaited_action| {
946
0
                OperationSubscriber::new(
947
0
                    None,
948
0
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
949
0
                    Arc::downgrade(&self.store),
950
0
                    self.now_fn.clone(),
951
                )
952
0
            }))
953
0
    }
954
}