Coverage Report

Created: 2026-06-04 10:48

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/store_awaited_action_db.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::mem::Discriminant;
16
use core::ops::Bound;
17
use core::sync::atomic::{AtomicU64, Ordering};
18
use core::time::Duration;
19
use std::borrow::Cow;
20
use std::sync::{Arc, Weak};
21
use std::time::UNIX_EPOCH;
22
23
use bytes::Bytes;
24
use futures::{Stream, TryStreamExt};
25
use nativelink_error::{Code, Error, ResultExt, make_err};
26
use nativelink_metric::MetricsComponent;
27
use nativelink_util::action_messages::{
28
    ActionInfo, ActionStage, ActionUniqueQualifier, OperationId,
29
};
30
use nativelink_util::instant_wrapper::InstantWrapper;
31
use nativelink_util::spawn;
32
use nativelink_util::store_trait::{
33
    FalseValue, SchedulerCurrentVersionProvider, SchedulerIndexProvider, SchedulerStore,
34
    SchedulerStoreDataProvider, SchedulerStoreDecodeTo, SchedulerStoreKeyProvider,
35
    SchedulerSubscription, SchedulerSubscriptionManager, StoreKey, TrueValue,
36
};
37
use nativelink_util::task::JoinHandleDropGuard;
38
use tokio::sync::Notify;
39
use tracing::{error, warn};
40
41
use crate::awaited_action_db::{
42
    AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, CLIENT_KEEPALIVE_DURATION,
43
    SortedAwaitedAction, SortedAwaitedActionState,
44
};
45
46
type ClientOperationId = OperationId;
47
48
/// Maximum number of retries to update client keep alive.
49
const MAX_RETRIES_FOR_CLIENT_KEEPALIVE: u32 = 8;
50
51
/// Use separate non-versioned Redis key for client keepalives.
52
const USE_SEPARATE_CLIENT_KEEPALIVE_KEY: bool = true;
53
54
enum OperationSubscriberState<Sub> {
55
    Unsubscribed,
56
    Subscribed(Sub),
57
}
58
59
pub struct OperationSubscriber<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I> {
60
    maybe_client_operation_id: Option<ClientOperationId>,
61
    subscription_key: OperationIdToAwaitedAction<'static>,
62
    weak_store: Weak<S>,
63
    state: OperationSubscriberState<
64
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
65
    >,
66
    last_known_keepalive_ts: AtomicU64,
67
    now_fn: NowFn,
68
    // If the SchedulerSubscriptionManager is not reliable, then this is populated
69
    // when the state is set to subscribed.  When set it causes the state to be polled
70
    // as well as listening for the publishing.
71
    maybe_last_stage: Option<Discriminant<ActionStage>>,
72
    retain_completed_for: Duration,
73
}
74
75
impl<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I + core::fmt::Debug> core::fmt::Debug
76
    for OperationSubscriber<S, I, NowFn>
77
where
78
    OperationSubscriberState<
79
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
80
    >: core::fmt::Debug,
81
{
82
0
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
83
0
        f.debug_struct("OperationSubscriber")
84
0
            .field("maybe_client_operation_id", &self.maybe_client_operation_id)
85
0
            .field("subscription_key", &self.subscription_key)
86
0
            .field("weak_store", &self.weak_store)
87
0
            .field("state", &self.state)
88
0
            .field("last_known_keepalive_ts", &self.last_known_keepalive_ts)
89
0
            .field("now_fn", &self.now_fn)
90
0
            .finish()
91
0
    }
92
}
93
impl<S, I, NowFn> OperationSubscriber<S, I, NowFn>
94
where
95
    S: SchedulerStore,
96
    I: InstantWrapper,
97
    NowFn: Fn() -> I,
98
{
99
27
    const fn new(
100
27
        maybe_client_operation_id: Option<ClientOperationId>,
101
27
        subscription_key: OperationIdToAwaitedAction<'static>,
102
27
        weak_store: Weak<S>,
103
27
        now_fn: NowFn,
104
27
        retain_completed_for: Duration,
105
27
    ) -> Self {
106
27
        Self {
107
27
            maybe_client_operation_id,
108
27
            subscription_key,
109
27
            weak_store,
110
27
            last_known_keepalive_ts: AtomicU64::new(0),
111
27
            state: OperationSubscriberState::Unsubscribed,
112
27
            now_fn,
113
27
            maybe_last_stage: None,
114
27
            retain_completed_for,
115
27
        }
116
27
    }
117
118
45
    async fn inner_get_awaited_action(
119
45
        store: &S,
120
45
        key: OperationIdToAwaitedAction<'_>,
121
45
        maybe_client_operation_id: Option<ClientOperationId>,
122
45
        last_known_keepalive_ts: &AtomicU64,
123
45
    ) -> Result<AwaitedAction, Error> {
124
45
        let mut awaited_action = store
125
45
            .get_and_decode(key.borrow())
126
45
            .await
127
45
            .err_tip(|| 
format!0
("In OperationSubscriber::get_awaited_action {key:?}"))
?0
128
45
            .ok_or_else(|| 
{0
129
0
                make_err!(
130
0
                    Code::NotFound,
131
                    "Could not find AwaitedAction for the given operation id {key:?}",
132
                )
133
0
            })?;
134
45
        if let Some(
client_operation_id6
) = maybe_client_operation_id {
135
6
            awaited_action.set_client_operation_id(client_operation_id);
136
39
        }
137
138
        // Helper to convert SystemTime to unix timestamp
139
45
        let to_unix_ts = |t: std::time::SystemTime| -> u64 {
140
45
            t.duration_since(UNIX_EPOCH).map_or(0, |d| d.as_secs())
141
45
        };
142
143
        // Check the separate keepalive key for the most recent timestamp.
144
45
        let keepalive_ts = if USE_SEPARATE_CLIENT_KEEPALIVE_KEY {
145
45
            let operation_id = key.0.as_ref();
146
45
            match store.get_and_decode(ClientKeepaliveKey(operation_id)).await {
147
0
                Ok(Some(ts)) => {
148
0
                    let awaited_ts = to_unix_ts(awaited_action.last_client_keepalive_timestamp());
149
0
                    if ts > awaited_ts {
150
0
                        let timestamp = UNIX_EPOCH + Duration::from_secs(ts);
151
0
                        awaited_action.update_client_keep_alive(timestamp);
152
0
                        ts
153
                    } else {
154
0
                        awaited_ts
155
                    }
156
                }
157
45
                Ok(None) | Err(_) => to_unix_ts(awaited_action.last_client_keepalive_timestamp()),
158
            }
159
        } else {
160
0
            to_unix_ts(awaited_action.last_client_keepalive_timestamp())
161
        };
162
163
45
        last_known_keepalive_ts.store(keepalive_ts, Ordering::Release);
164
45
        Ok(awaited_action)
165
45
    }
166
167
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
168
43
    async fn get_awaited_action(&self) -> Result<AwaitedAction, Error> {
169
43
        let store = self
170
43
            .weak_store
171
43
            .upgrade()
172
43
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")
?0
;
173
43
        Self::inner_get_awaited_action(
174
43
            store.as_ref(),
175
43
            self.subscription_key.borrow(),
176
43
            self.maybe_client_operation_id.clone(),
177
43
            &self.last_known_keepalive_ts,
178
43
        )
179
43
        .await
180
43
    }
181
}
182
183
impl<S, I, NowFn> AwaitedActionSubscriber for OperationSubscriber<S, I, NowFn>
184
where
185
    S: SchedulerStore,
186
    I: InstantWrapper,
187
    NowFn: Fn() -> I + Send + Sync + 'static,
188
{
189
2
    async fn changed(&mut self) -> Result<AwaitedAction, Error> {
190
2
        let store = self
191
2
            .weak_store
192
2
            .upgrade()
193
2
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")
?0
;
194
2
        let 
subscription1
= match &mut self.state {
195
1
            OperationSubscriberState::Subscribed(subscription) => subscription,
196
            OperationSubscriberState::Unsubscribed => {
197
1
                let subscription = store
198
1
                    .subscription_manager()
199
1
                    .await
200
1
                    .err_tip(|| "In OperationSubscriber::changed::subscription_manager")
?0
201
1
                    .subscribe(self.subscription_key.borrow())
202
1
                    .err_tip(|| "In OperationSubscriber::changed::subscribe")
?0
;
203
1
                self.state = OperationSubscriberState::Subscribed(subscription);
204
                // When we've just subscribed, there may have been changes before now.
205
1
                let action = Self::inner_get_awaited_action(
206
1
                    store.as_ref(),
207
1
                    self.subscription_key.borrow(),
208
1
                    self.maybe_client_operation_id.clone(),
209
1
                    &self.last_known_keepalive_ts,
210
1
                )
211
1
                .await
212
1
                .err_tip(|| "In OperationSubscriber::changed")
?0
;
213
1
                if !<S as SchedulerStore>::SubscriptionManager::is_reliable() {
214
1
                    self.maybe_last_stage = Some(core::mem::discriminant(&action.state().stage));
215
1
                
}0
216
                // Existing changes are only interesting if the state is past queued.
217
1
                if !matches!(action.state().stage, ActionStage::Queued) {
218
1
                    return Ok(action);
219
0
                }
220
0
                let OperationSubscriberState::Subscribed(subscription) = &mut self.state else {
221
0
                    unreachable!("Subscription should be in Subscribed state");
222
                };
223
0
                subscription
224
            }
225
        };
226
227
1
        let changed_fut = subscription.changed();
228
1
        tokio::pin!(changed_fut);
229
        loop {
230
            // This is set if the maybe_last_state doesn't match the state in the store.
231
1
            let mut maybe_changed_action = None;
232
233
1
            let last_known_keepalive_ts = self.last_known_keepalive_ts.load(Ordering::Acquire);
234
1
            if I::from_secs(last_known_keepalive_ts).elapsed() > CLIENT_KEEPALIVE_DURATION {
235
0
                let now = (self.now_fn)().now();
236
0
                let now_ts = now.duration_since(UNIX_EPOCH).map_or(0, |d| d.as_secs());
237
238
0
                if USE_SEPARATE_CLIENT_KEEPALIVE_KEY {
239
0
                    let operation_id = self.subscription_key.0.as_ref();
240
0
                    let update_result = store
241
0
                        .update_data(
242
0
                            UpdateClientKeepalive {
243
0
                                operation_id,
244
0
                                timestamp: now_ts,
245
0
                            },
246
0
                            None,
247
0
                        )
248
0
                        .await;
249
250
0
                    if let Err(e) = update_result {
251
0
                        warn!(
252
                            ?self.subscription_key,
253
                            ?e,
254
                            "Failed to update client keepalive (non-versioned)"
255
                        );
256
0
                    }
257
258
                    // Update local timestamp
259
0
                    self.last_known_keepalive_ts
260
0
                        .store(now_ts, Ordering::Release);
261
262
                    // Check if state changed (for unreliable subscription managers)
263
0
                    if self.maybe_last_stage.is_some() {
264
0
                        let awaited_action = Self::inner_get_awaited_action(
265
0
                            store.as_ref(),
266
0
                            self.subscription_key.borrow(),
267
0
                            self.maybe_client_operation_id.clone(),
268
0
                            &self.last_known_keepalive_ts,
269
0
                        )
270
0
                        .await
271
0
                        .err_tip(|| "In OperationSubscriber::changed")?;
272
273
0
                        if self.maybe_last_stage.as_ref().is_some_and(|last_stage| {
274
0
                            *last_stage != core::mem::discriminant(&awaited_action.state().stage)
275
0
                        }) {
276
0
                            maybe_changed_action = Some(awaited_action);
277
0
                        }
278
0
                    }
279
                } else {
280
0
                    for attempt in 1..=MAX_RETRIES_FOR_CLIENT_KEEPALIVE {
281
0
                        if attempt > 1 {
282
0
                            (self.now_fn)().sleep(Duration::from_millis(100)).await;
283
0
                            warn!(
284
                                ?self.subscription_key,
285
                                attempt,
286
                                "Client keepalive retry due to version conflict"
287
                            );
288
0
                        }
289
0
                        let mut awaited_action = Self::inner_get_awaited_action(
290
0
                            store.as_ref(),
291
0
                            self.subscription_key.borrow(),
292
0
                            self.maybe_client_operation_id.clone(),
293
0
                            &self.last_known_keepalive_ts,
294
0
                        )
295
0
                        .await
296
0
                        .err_tip(|| "In OperationSubscriber::changed")?;
297
0
                        awaited_action.update_client_keep_alive(now);
298
0
                        maybe_changed_action = self
299
0
                            .maybe_last_stage
300
0
                            .as_ref()
301
0
                            .is_some_and(|last_stage| {
302
0
                                *last_stage
303
0
                                    != core::mem::discriminant(&awaited_action.state().stage)
304
0
                            })
305
0
                            .then(|| awaited_action.clone());
306
0
                        let expiry = if awaited_action.is_complete() {
307
0
                            Some(self.retain_completed_for)
308
                        } else {
309
0
                            None
310
                        };
311
0
                        match inner_update_awaited_action(store.as_ref(), awaited_action, expiry)
312
0
                            .await
313
                        {
314
0
                            Ok(()) => break,
315
0
                            err if attempt == MAX_RETRIES_FOR_CLIENT_KEEPALIVE => {
316
0
                                err.err_tip_with_code(|_| {
317
0
                                    (Code::Aborted, "Could not update client keep alive")
318
0
                                })?;
319
                            }
320
0
                            _ => (),
321
                        }
322
                    }
323
                }
324
1
            }
325
326
            // If the polling shows that it's changed state then publish now.
327
1
            if let Some(
changed_action0
) = maybe_changed_action {
328
0
                self.maybe_last_stage =
329
0
                    Some(core::mem::discriminant(&changed_action.state().stage));
330
0
                return Ok(changed_action);
331
1
            }
332
            // Determine the sleep time based on the last client keep alive.
333
1
            let sleep_time = CLIENT_KEEPALIVE_DURATION
334
1
                .checked_sub(
335
1
                    I::from_secs(self.last_known_keepalive_ts.load(Ordering::Acquire)).elapsed(),
336
                )
337
1
                .unwrap_or(Duration::from_millis(100));
338
1
            tokio::select! {
339
1
                result = &mut changed_fut => {
340
1
                    result
?0
;
341
1
                    break;
342
                }
343
1
                () = (self.now_fn)().sleep(sleep_time) => {
344
0
                    // If we haven't received any updates for a while, we should
345
0
                    // let the database know that we are still listening to prevent
346
0
                    // the action from being dropped.  Also poll for updates if the
347
0
                    // subscription manager is unreliable.
348
0
                }
349
            }
350
        }
351
352
1
        let awaited_action = Self::inner_get_awaited_action(
353
1
            store.as_ref(),
354
1
            self.subscription_key.borrow(),
355
1
            self.maybe_client_operation_id.clone(),
356
1
            &self.last_known_keepalive_ts,
357
1
        )
358
1
        .await
359
1
        .err_tip(|| "In OperationSubscriber::changed")
?0
;
360
1
        if self.maybe_last_stage.is_some() {
361
1
            self.maybe_last_stage = Some(core::mem::discriminant(&awaited_action.state().stage));
362
1
        
}0
363
1
        Ok(awaited_action)
364
2
    }
365
366
43
    async fn borrow(&self) -> Result<AwaitedAction, Error> {
367
43
        self.get_awaited_action()
368
43
            .await
369
43
            .err_tip(|| "In OperationSubscriber::borrow")
370
43
    }
371
}
372
373
57
fn awaited_action_decode(version: i64, data: &Bytes) -> Result<AwaitedAction, Error> {
374
57
    let mut awaited_action: AwaitedAction = serde_json::from_slice(data).map_err(|e| 
{0
375
0
        Error::from_std_err(Code::InvalidArgument, &e).append("In AwaitedAction::decode")
376
0
    })?;
377
57
    awaited_action.set_version(version);
378
57
    Ok(awaited_action)
379
57
}
380
381
const OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX: &str = "aa_";
382
const CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX: &str = "cid_";
383
/// Phase 2: Separate key prefix for client keepalives (non-versioned).
384
const CLIENT_KEEPALIVE_KEY_PREFIX: &str = "ck_";
385
386
#[derive(Debug)]
387
struct OperationIdToAwaitedAction<'a>(Cow<'a, OperationId>);
388
impl OperationIdToAwaitedAction<'_> {
389
91
    fn borrow(&self) -> OperationIdToAwaitedAction<'_> {
390
91
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.as_ref()))
391
91
    }
392
}
393
impl SchedulerStoreKeyProvider for OperationIdToAwaitedAction<'_> {
394
    type Versioned = TrueValue;
395
67
    fn get_key(&self) -> StoreKey<'static> {
396
67
        StoreKey::Str(Cow::Owned(format!(
397
67
            "{OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX}{}",
398
67
            self.0
399
67
        )))
400
67
    }
401
}
402
impl SchedulerStoreDecodeTo for OperationIdToAwaitedAction<'_> {
403
    type DecodeOutput = AwaitedAction;
404
46
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
405
46
        awaited_action_decode(version, &data)
406
46
    }
407
}
408
409
struct ClientIdToOperationId<'a>(&'a OperationId);
410
impl SchedulerStoreKeyProvider for ClientIdToOperationId<'_> {
411
    type Versioned = FalseValue;
412
5
    fn get_key(&self) -> StoreKey<'static> {
413
5
        StoreKey::Str(Cow::Owned(format!(
414
5
            "{CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX}{}",
415
5
            self.0
416
5
        )))
417
5
    }
418
}
419
impl SchedulerStoreDecodeTo for ClientIdToOperationId<'_> {
420
    type DecodeOutput = OperationId;
421
2
    fn decode(_version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
422
2
        serde_json::from_slice(&data).map_err(|e| 
{0
423
0
            Error::from_std_err(Code::InvalidArgument, &e).append(format!(
424
                "In ClientIdToOperationId::decode (data: {data:02x?})",
425
            ))
426
0
        })
427
2
    }
428
}
429
430
struct ClientKeepaliveKey<'a>(&'a OperationId);
431
impl SchedulerStoreKeyProvider for ClientKeepaliveKey<'_> {
432
    type Versioned = FalseValue;
433
45
    fn get_key(&self) -> StoreKey<'static> {
434
45
        StoreKey::Str(Cow::Owned(format!(
435
45
            "{CLIENT_KEEPALIVE_KEY_PREFIX}{}",
436
45
            self.0
437
45
        )))
438
45
    }
439
}
440
impl SchedulerStoreDecodeTo for ClientKeepaliveKey<'_> {
441
    type DecodeOutput = u64;
442
0
    fn decode(_version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
443
0
        let s = core::str::from_utf8(&data).map_err(|e| {
444
0
            Error::from_std_err(Code::InvalidArgument, &e)
445
0
                .append("In ClientKeepaliveKey::decode utf8")
446
0
        })?;
447
0
        s.parse::<u64>().map_err(|e| {
448
0
            Error::from_std_err(Code::InvalidArgument, &e)
449
0
                .append("In ClientKeepaliveKey::decode parse")
450
0
        })
451
0
    }
452
}
453
454
struct UpdateClientKeepalive<'a> {
455
    operation_id: &'a OperationId,
456
    timestamp: u64,
457
}
458
impl SchedulerStoreKeyProvider for UpdateClientKeepalive<'_> {
459
    type Versioned = FalseValue;
460
0
    fn get_key(&self) -> StoreKey<'static> {
461
0
        ClientKeepaliveKey(self.operation_id).get_key()
462
0
    }
463
}
464
impl SchedulerStoreDataProvider for UpdateClientKeepalive<'_> {
465
0
    fn try_into_bytes(self) -> Result<Bytes, Error> {
466
0
        Ok(Bytes::from(self.timestamp.to_string()))
467
0
    }
468
}
469
470
// TODO(palfrey) We only need operation_id here, it would be nice if we had a way
471
// to tell the decoder we only care about specific fields.
472
struct SearchUniqueQualifierToAwaitedAction<'a>(&'a ActionUniqueQualifier);
473
impl SchedulerIndexProvider for SearchUniqueQualifierToAwaitedAction<'_> {
474
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
475
    const INDEX_NAME: &'static str = "unique_qualifier";
476
    type Versioned = TrueValue;
477
4
    fn index_value(&self) -> Cow<'_, str> {
478
4
        Cow::Owned(format!("{}", self.0))
479
4
    }
480
}
481
impl SchedulerStoreDecodeTo for SearchUniqueQualifierToAwaitedAction<'_> {
482
    type DecodeOutput = AwaitedAction;
483
3
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
484
3
        awaited_action_decode(version, &data)
485
3
    }
486
}
487
488
struct SearchStateToAwaitedAction(&'static str);
489
impl SchedulerIndexProvider for SearchStateToAwaitedAction {
490
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
491
    const INDEX_NAME: &'static str = "state";
492
    const MAYBE_SORT_KEY: Option<&'static str> = Some("sort_key");
493
    type Versioned = TrueValue;
494
17
    fn index_value(&self) -> Cow<'_, str> {
495
17
        Cow::Borrowed(self.0)
496
17
    }
497
}
498
impl SchedulerStoreDecodeTo for SearchStateToAwaitedAction {
499
    type DecodeOutput = AwaitedAction;
500
8
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
501
8
        awaited_action_decode(version, &data)
502
8
    }
503
}
504
505
33
const fn get_state_prefix(state: SortedAwaitedActionState) -> &'static str {
506
33
    match state {
507
0
        SortedAwaitedActionState::CacheCheck => "cache_check",
508
25
        SortedAwaitedActionState::Queued => "queued",
509
7
        SortedAwaitedActionState::Executing => "executing",
510
1
        SortedAwaitedActionState::Completed => "completed",
511
    }
512
33
}
513
514
#[derive(Debug)]
515
pub struct UpdateOperationIdToAwaitedAction(AwaitedAction);
516
impl SchedulerCurrentVersionProvider for UpdateOperationIdToAwaitedAction {
517
16
    fn current_version(&self) -> i64 {
518
16
        self.0.version()
519
16
    }
520
}
521
impl SchedulerStoreKeyProvider for UpdateOperationIdToAwaitedAction {
522
    type Versioned = TrueValue;
523
16
    fn get_key(&self) -> StoreKey<'static> {
524
16
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.operation_id())).get_key()
525
16
    }
526
}
527
impl SchedulerStoreDataProvider for UpdateOperationIdToAwaitedAction {
528
17
    fn try_into_bytes(self) -> Result<Bytes, Error> {
529
17
        serde_json::to_string(&self.0)
530
17
            .map(Bytes::from)
531
17
            .map_err(|e| 
{0
532
0
                Error::from_std_err(Code::InvalidArgument, &e)
533
0
                    .append("Could not convert AwaitedAction to json")
534
0
            })
535
17
    }
536
16
    fn get_indexes(&self) -> Result<Vec<(&'static str, Bytes)>, Error> {
537
16
        let unique_qualifier = &self.0.action_info().unique_qualifier;
538
16
        let maybe_unique_qualifier = match &unique_qualifier {
539
16
            ActionUniqueQualifier::Cacheable(_) => Some(unique_qualifier),
540
0
            ActionUniqueQualifier::Uncacheable(_) => None,
541
        };
542
16
        let mut output = Vec::with_capacity(2 + maybe_unique_qualifier.map_or(0, |_| 1));
543
16
        if maybe_unique_qualifier.is_some() {
544
16
            output.push((
545
16
                "unique_qualifier",
546
16
                Bytes::from(unique_qualifier.to_string()),
547
16
            ));
548
16
        
}0
549
        {
550
16
            let state = SortedAwaitedActionState::try_from(&self.0.state().stage)
551
16
                .err_tip(|| "In UpdateOperationIdToAwaitedAction::get_index")
?0
;
552
16
            output.push(("state", Bytes::from(get_state_prefix(state))));
553
16
            let sorted_awaited_action = SortedAwaitedAction::from(&self.0);
554
16
            output.push((
555
16
                "sort_key",
556
16
                // We encode to hex to ensure that the sort key is lexicographically sorted.
557
16
                Bytes::from(format!("{:016x}", sorted_awaited_action.sort_key.as_u64())),
558
16
            ));
559
        }
560
16
        Ok(output)
561
16
    }
562
}
563
564
struct UpdateClientIdToOperationId {
565
    client_operation_id: ClientOperationId,
566
    operation_id: OperationId,
567
}
568
impl SchedulerStoreKeyProvider for UpdateClientIdToOperationId {
569
    type Versioned = FalseValue;
570
3
    fn get_key(&self) -> StoreKey<'static> {
571
3
        ClientIdToOperationId(&self.client_operation_id).get_key()
572
3
    }
573
}
574
impl SchedulerStoreDataProvider for UpdateClientIdToOperationId {
575
3
    fn try_into_bytes(self) -> Result<Bytes, Error> {
576
3
        serde_json::to_string(&self.operation_id)
577
3
            .map(Bytes::from)
578
3
            .map_err(|e| 
{0
579
0
                Error::from_std_err(Code::InvalidArgument, &e)
580
0
                    .append("Could not convert OperationId to json")
581
0
            })
582
3
    }
583
}
584
585
14
pub async fn inner_update_awaited_action(
586
14
    store: &impl SchedulerStore,
587
14
    mut new_awaited_action: AwaitedAction,
588
14
    expiry: Option<Duration>,
589
14
) -> Result<(), Error> {
590
14
    let operation_id = new_awaited_action.operation_id().clone();
591
14
    if new_awaited_action.state().client_operation_id != operation_id {
592
0
        new_awaited_action.set_client_operation_id(operation_id.clone());
593
14
    }
594
595
14
    let _is_finished = new_awaited_action.state().stage.is_finished();
596
597
14
    let maybe_version = store
598
14
        .update_data(UpdateOperationIdToAwaitedAction(new_awaited_action), expiry)
599
14
        .await
600
14
        .err_tip(|| "In RedisAwaitedActionDb::update_awaited_action")
?0
;
601
602
14
    if maybe_version.is_none() {
603
4
        warn!(
604
            %operation_id,
605
            "Could not update AwaitedAction because the version did not match"
606
        );
607
4
        return Err(make_err!(
608
4
            Code::Aborted,
609
4
            "Could not update AwaitedAction because the version did not match for {operation_id}",
610
4
        ));
611
10
    }
612
613
10
    Ok(())
614
14
}
615
616
#[derive(Debug, MetricsComponent)]
617
pub struct StoreAwaitedActionDb<S, F, I, NowFn>
618
where
619
    S: SchedulerStore,
620
    F: Fn() -> OperationId,
621
    I: InstantWrapper,
622
    NowFn: Fn() -> I,
623
{
624
    store: Arc<S>,
625
    now_fn: NowFn,
626
    operation_id_creator: F,
627
    _pull_task_change_subscriber_spawn: JoinHandleDropGuard<()>,
628
    retain_completed_for: Duration,
629
}
630
631
impl<S, F, I, NowFn> StoreAwaitedActionDb<S, F, I, NowFn>
632
where
633
    S: SchedulerStore,
634
    F: Fn() -> OperationId,
635
    I: InstantWrapper,
636
    NowFn: Fn() -> I + Send + Sync + Clone + 'static,
637
{
638
6
    pub async fn new(
639
6
        store: Arc<S>,
640
6
        task_change_publisher: Arc<Notify>,
641
6
        now_fn: NowFn,
642
6
        operation_id_creator: F,
643
6
        retain_completed_for_s: u32,
644
6
    ) -> Result<Self, Error> {
645
6
        let mut subscription = store
646
6
            .subscription_manager()
647
6
            .await
648
6
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
649
6
            .subscribe(OperationIdToAwaitedAction(Cow::Owned(OperationId::String(
650
6
                String::new(),
651
6
            ))))
652
6
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
;
653
6
        let pull_task_change_subscriber = spawn!(
654
            "redis_awaited_action_db_pull_task_change_subscriber",
655
5
            async move {
656
                loop {
657
19
                    let 
changed_res14
= subscription
658
19
                        .changed()
659
19
                        .await
660
14
                        .err_tip(|| "In RedisAwaitedActionDb::new");
661
14
                    if let Err(
err0
) = changed_res {
662
0
                        error!(
663
                            "Error waiting for pull task change subscriber in RedisAwaitedActionDb::new  - {err:?}"
664
                        );
665
                        // Sleep for a second to avoid a busy loop, then trigger the notify
666
                        // so if a reconnect happens we let local resources know that things
667
                        // might have changed.
668
0
                        tokio::time::sleep(Duration::from_secs(1)).await;
669
14
                    }
670
14
                    task_change_publisher.as_ref().notify_one();
671
                }
672
            }
673
        );
674
6
        Ok(Self {
675
6
            store,
676
6
            now_fn,
677
6
            operation_id_creator,
678
6
            _pull_task_change_subscriber_spawn: pull_task_change_subscriber,
679
6
            retain_completed_for: Duration::from_secs(retain_completed_for_s.into()),
680
6
        })
681
6
    }
682
683
    // `pub` so integration tests in `tests/` can drive this directly;
684
    // matches the precedent of `inner_update_awaited_action` below.
685
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
686
6
    pub async fn try_subscribe(
687
6
        &self,
688
6
        client_operation_id: &ClientOperationId,
689
6
        unique_qualifier: &ActionUniqueQualifier,
690
6
        no_event_action_timeout: Duration,
691
6
        // TODO(palfrey) To simplify the scheduler 2024 refactor, we
692
6
        // removed the ability to upgrade priorities of actions.
693
6
        // we should add priority upgrades back in.
694
6
        _priority: i32,
695
6
    ) -> Result<Option<AwaitedAction>, Error> {
696
        // Retry once on miss: closes the RediSearch index-visibility
697
        // window where two concurrent `add_action` calls can both see
698
        // empty and create duplicate scheduler operations.
699
        const SUBSCRIBE_RACE_RETRY_DELAY: Duration = Duration::from_millis(20);
700
6
        match unique_qualifier {
701
5
            ActionUniqueQualifier::Cacheable(_) => {}
702
1
            ActionUniqueQualifier::Uncacheable(_) => return Ok(None),
703
        }
704
5
        let mut maybe_awaited_action: Option<AwaitedAction> = None;
705
8
        for attempt in 
0..2_u325
{
706
8
            if attempt > 0 {
707
3
                tokio::time::sleep(SUBSCRIBE_RACE_RETRY_DELAY).await;
708
5
            }
709
8
            let stream = self
710
8
                .store
711
8
                .search_by_index_prefix(SearchUniqueQualifierToAwaitedAction(unique_qualifier))
712
8
                .await
713
8
                .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
714
8
            tokio::pin!(stream);
715
8
            maybe_awaited_action = stream
716
8
                .try_next()
717
8
                .await
718
8
                .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
719
8
            if maybe_awaited_action.is_some() {
720
3
                break;
721
5
            }
722
        }
723
5
        match maybe_awaited_action {
724
3
            Some(awaited_action) => {
725
                // TODO(palfrey) We don't support joining completed jobs because we
726
                // need to also check that all the data is still in the cache.
727
                // If the existing job failed then we need to set back to queued or we get
728
                // a version mismatch.  Equally we need to check the timeout as the job
729
                // may be abandoned in the store.
730
3
                let worker_should_update_before = (awaited_action.state().stage
731
3
                    == ActionStage::Executing)
732
3
                    .then_some(())
733
3
                    .map(|()| 
awaited_action0
.
last_worker_updated_timestamp0
())
734
3
                    .and_then(|last_worker_updated| 
{0
735
0
                        last_worker_updated.checked_add(no_event_action_timeout)
736
0
                    });
737
3
                let awaited_action = if awaited_action.state().stage.is_finished()
738
2
                    || worker_should_update_before
739
2
                        .is_some_and(|timestamp| 
timestamp0
<
(self.now_fn)().now()0
)
740
                {
741
1
                    tracing::debug!(
742
                        "Recreating action {:?} for operation {client_operation_id}",
743
1
                        awaited_action.action_info().digest()
744
                    );
745
                    // The version is reset because we have a new operation ID.
746
1
                    AwaitedAction::new(
747
1
                        (self.operation_id_creator)(),
748
1
                        awaited_action.action_info().clone(),
749
1
                        (self.now_fn)().now(),
750
                    )
751
                } else {
752
2
                    tracing::debug!(
753
                        "Subscribing to existing action {:?} for operation {client_operation_id}",
754
2
                        awaited_action.action_info().digest()
755
                    );
756
2
                    awaited_action
757
                };
758
3
                Ok(Some(awaited_action))
759
            }
760
2
            None => Ok(None),
761
        }
762
6
    }
763
764
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
765
2
    async fn inner_get_awaited_action_by_id(
766
2
        &self,
767
2
        client_operation_id: &ClientOperationId,
768
2
    ) -> Result<Option<OperationSubscriber<S, I, NowFn>>, Error> {
769
2
        let maybe_operation_id = self
770
2
            .store
771
2
            .get_and_decode(ClientIdToOperationId(client_operation_id))
772
2
            .await
773
2
            .err_tip(|| "In RedisAwaitedActionDb::get_awaited_action_by_id")
?0
;
774
2
        let Some(operation_id) = maybe_operation_id else {
775
0
            return Ok(None);
776
        };
777
778
        // Validate that the internal operation actually exists.
779
        // If it doesn't, this is an orphaned client operation mapping that should be cleaned up.
780
        // This can happen when an operation is deleted (completed/timed out) but the
781
        // client_id -> operation_id mapping persists in the store.
782
2
        let maybe_awaited_action = match self
783
2
            .store
784
2
            .get_and_decode(OperationIdToAwaitedAction(Cow::Borrowed(&operation_id)))
785
2
            .await
786
        {
787
2
            Ok(maybe_action) => maybe_action,
788
0
            Err(err) if err.code == Code::NotFound => {
789
0
                tracing::warn!(
790
                    "Orphaned client operation mapping detected: client_id={} maps to operation_id={}, \
791
                    but the operation does not exist in the store (NotFound). This typically happens when \
792
                    an operation completes or times out but the client mapping persists.",
793
                    client_operation_id,
794
                    operation_id
795
                );
796
0
                None
797
            }
798
0
            Err(err) => {
799
                // Some other error occurred
800
0
                return Err(err).err_tip(
801
                    || "In RedisAwaitedActionDb::get_awaited_action_by_id::validate_operation",
802
                );
803
            }
804
        };
805
806
2
        if maybe_awaited_action.is_none() {
807
1
            tracing::warn!(
808
                "Found orphaned client operation mapping: client_id={} -> operation_id={}, \
809
                but operation no longer exists. Returning None to prevent client from polling \
810
                a non-existent operation.",
811
                client_operation_id,
812
                operation_id
813
            );
814
1
            return Ok(None);
815
1
        }
816
817
1
        Ok(Some(OperationSubscriber::new(
818
1
            Some(client_operation_id.clone()),
819
1
            OperationIdToAwaitedAction(Cow::Owned(operation_id)),
820
1
            Arc::downgrade(&self.store),
821
1
            self.now_fn.clone(),
822
1
            self.retain_completed_for,
823
1
        )))
824
2
    }
825
}
826
827
impl<S, F, I, NowFn> AwaitedActionDb for StoreAwaitedActionDb<S, F, I, NowFn>
828
where
829
    S: SchedulerStore,
830
    F: Fn() -> OperationId + Send + Sync + Unpin + 'static,
831
    I: InstantWrapper,
832
    NowFn: Fn() -> I + Send + Sync + Unpin + Clone + 'static,
833
{
834
    type Subscriber = OperationSubscriber<S, I, NowFn>;
835
836
2
    async fn get_awaited_action_by_id(
837
2
        &self,
838
2
        client_operation_id: &ClientOperationId,
839
2
    ) -> Result<Option<Self::Subscriber>, Error> {
840
2
        self.inner_get_awaited_action_by_id(client_operation_id)
841
2
            .await
842
2
    }
843
844
15
    async fn get_by_operation_id(
845
15
        &self,
846
15
        operation_id: &OperationId,
847
15
    ) -> Result<Option<Self::Subscriber>, Error> {
848
15
        Ok(Some(OperationSubscriber::new(
849
15
            None,
850
15
            OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())),
851
15
            Arc::downgrade(&self.store),
852
15
            self.now_fn.clone(),
853
15
            self.retain_completed_for,
854
15
        )))
855
15
    }
856
857
13
    async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
858
13
        let expiry = if new_awaited_action.is_complete() {
859
1
            Some(self.retain_completed_for)
860
        } else {
861
12
            None
862
        };
863
13
        inner_update_awaited_action(self.store.as_ref(), new_awaited_action, expiry).await
864
13
    }
865
866
3
    async fn add_action(
867
3
        &self,
868
3
        client_operation_id: ClientOperationId,
869
3
        action_info: Arc<ActionInfo>,
870
3
        no_event_action_timeout: Duration,
871
3
    ) -> Result<Self::Subscriber, Error> {
872
        loop {
873
            // Check to see if the action is already known and subscribe if it is.
874
3
            let mut awaited_action = self
875
3
                .try_subscribe(
876
3
                    &client_operation_id,
877
3
                    &action_info.unique_qualifier,
878
3
                    no_event_action_timeout,
879
3
                    action_info.priority,
880
3
                )
881
3
                .await
882
3
                .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
883
3
                .unwrap_or_else(|| 
{1
884
1
                    tracing::debug!(
885
                        "Creating new action {:?} for operation {client_operation_id}",
886
1
                        action_info.digest()
887
                    );
888
1
                    AwaitedAction::new(
889
1
                        (self.operation_id_creator)(),
890
1
                        action_info.clone(),
891
1
                        (self.now_fn)().now(),
892
                    )
893
1
                });
894
895
3
            debug_assert!(
896
0
                ActionStage::Queued == awaited_action.state().stage,
897
                "Expected action to be queued"
898
            );
899
900
3
            let operation_id = awaited_action.operation_id().clone();
901
3
            if awaited_action.state().client_operation_id != operation_id {
902
0
                // Just in case the client_operation_id was set to something else
903
0
                // we put it back to the underlying operation_id.
904
0
                awaited_action.set_client_operation_id(operation_id.clone());
905
3
            }
906
3
            awaited_action.update_client_keep_alive((self.now_fn)().now());
907
908
3
            let version = awaited_action.version();
909
3
            let expiry = if awaited_action.is_complete() {
910
0
                Some(self.retain_completed_for)
911
            } else {
912
3
                None
913
            };
914
3
            if self
915
3
                .store
916
3
                .update_data(UpdateOperationIdToAwaitedAction(awaited_action), expiry)
917
3
                .await
918
3
                .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
919
3
                .is_none()
920
            {
921
                // The version was out of date, try again.
922
0
                tracing::info!(
923
                    "Version out of date for {:?} {operation_id} {version}, retrying.",
924
0
                    action_info.digest()
925
                );
926
0
                continue;
927
3
            }
928
929
            // Add the client_operation_id to operation_id mapping
930
3
            self.store
931
3
                .update_data(
932
3
                    UpdateClientIdToOperationId {
933
3
                        client_operation_id: client_operation_id.clone(),
934
3
                        operation_id: operation_id.clone(),
935
3
                    },
936
3
                    None,
937
3
                )
938
3
                .await
939
3
                .err_tip(|| "In RedisAwaitedActionDb::add_action while adding client mapping")
?0
;
940
941
3
            return Ok(OperationSubscriber::new(
942
3
                Some(client_operation_id),
943
3
                OperationIdToAwaitedAction(Cow::Owned(operation_id)),
944
3
                Arc::downgrade(&self.store),
945
3
                self.now_fn.clone(),
946
3
                self.retain_completed_for,
947
3
            ));
948
        }
949
3
    }
950
951
17
    async fn get_range_of_actions(
952
17
        &self,
953
17
        state: SortedAwaitedActionState,
954
17
        start: Bound<SortedAwaitedAction>,
955
17
        end: Bound<SortedAwaitedAction>,
956
17
        desc: bool,
957
17
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
958
17
        if !
matches!0
(start, Bound::Unbounded) {
959
0
            return Err(make_err!(
960
0
                Code::Unimplemented,
961
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
962
0
            ));
963
17
        }
964
17
        if !
matches!0
(end, Bound::Unbounded) {
965
0
            return Err(make_err!(
966
0
                Code::Unimplemented,
967
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
968
0
            ));
969
17
        }
970
        // TODO(palfrey) This API is not difficult to implement, but there is no code path
971
        // that uses it, so no reason to implement it yet.
972
17
        if !desc {
973
0
            return Err(make_err!(
974
0
                Code::Unimplemented,
975
0
                "Descending order is not supported in RedisAwaitedActionDb::get_range_of_actions",
976
0
            ));
977
17
        }
978
17
        Ok(self
979
17
            .store
980
17
            .search_by_index_prefix(SearchStateToAwaitedAction(get_state_prefix(state)))
981
17
            .await
982
17
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")
?0
983
17
            .map_ok(move |awaited_action| 
{8
984
8
                OperationSubscriber::new(
985
8
                    None,
986
8
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
987
8
                    Arc::downgrade(&self.store),
988
8
                    self.now_fn.clone(),
989
8
                    self.retain_completed_for,
990
                )
991
8
            }))
992
17
    }
993
994
0
    async fn get_all_awaited_actions(
995
0
        &self,
996
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
997
0
        Ok(self
998
0
            .store
999
0
            .search_by_index_prefix(SearchStateToAwaitedAction(""))
1000
0
            .await
1001
0
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")?
1002
0
            .map_ok(move |awaited_action| {
1003
0
                OperationSubscriber::new(
1004
0
                    None,
1005
0
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
1006
0
                    Arc::downgrade(&self.store),
1007
0
                    self.now_fn.clone(),
1008
0
                    self.retain_completed_for,
1009
                )
1010
0
            }))
1011
0
    }
1012
}