Coverage Report

Created: 2026-04-14 11:55

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/store_awaited_action_db.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::mem::Discriminant;
16
use core::ops::Bound;
17
use core::sync::atomic::{AtomicU64, Ordering};
18
use core::time::Duration;
19
use std::borrow::Cow;
20
use std::sync::{Arc, Weak};
21
use std::time::UNIX_EPOCH;
22
23
use bytes::Bytes;
24
use futures::{Stream, TryStreamExt};
25
use nativelink_error::{Code, Error, ResultExt, make_err};
26
use nativelink_metric::MetricsComponent;
27
use nativelink_util::action_messages::{
28
    ActionInfo, ActionStage, ActionUniqueQualifier, OperationId,
29
};
30
use nativelink_util::instant_wrapper::InstantWrapper;
31
use nativelink_util::spawn;
32
use nativelink_util::store_trait::{
33
    FalseValue, SchedulerCurrentVersionProvider, SchedulerIndexProvider, SchedulerStore,
34
    SchedulerStoreDataProvider, SchedulerStoreDecodeTo, SchedulerStoreKeyProvider,
35
    SchedulerSubscription, SchedulerSubscriptionManager, StoreKey, TrueValue,
36
};
37
use nativelink_util::task::JoinHandleDropGuard;
38
use tokio::sync::Notify;
39
use tracing::{error, warn};
40
41
use crate::awaited_action_db::{
42
    AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, CLIENT_KEEPALIVE_DURATION,
43
    SortedAwaitedAction, SortedAwaitedActionState,
44
};
45
46
type ClientOperationId = OperationId;
47
48
/// Maximum number of retries to update client keep alive.
49
const MAX_RETRIES_FOR_CLIENT_KEEPALIVE: u32 = 8;
50
51
/// Use separate non-versioned Redis key for client keepalives.
52
const USE_SEPARATE_CLIENT_KEEPALIVE_KEY: bool = true;
53
54
enum OperationSubscriberState<Sub> {
55
    Unsubscribed,
56
    Subscribed(Sub),
57
}
58
59
pub struct OperationSubscriber<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I> {
60
    maybe_client_operation_id: Option<ClientOperationId>,
61
    subscription_key: OperationIdToAwaitedAction<'static>,
62
    weak_store: Weak<S>,
63
    state: OperationSubscriberState<
64
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
65
    >,
66
    last_known_keepalive_ts: AtomicU64,
67
    now_fn: NowFn,
68
    // If the SchedulerSubscriptionManager is not reliable, then this is populated
69
    // when the state is set to subscribed.  When set it causes the state to be polled
70
    // as well as listening for the publishing.
71
    maybe_last_stage: Option<Discriminant<ActionStage>>,
72
}
73
74
impl<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I + core::fmt::Debug> core::fmt::Debug
75
    for OperationSubscriber<S, I, NowFn>
76
where
77
    OperationSubscriberState<
78
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
79
    >: core::fmt::Debug,
80
{
81
0
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
82
0
        f.debug_struct("OperationSubscriber")
83
0
            .field("maybe_client_operation_id", &self.maybe_client_operation_id)
84
0
            .field("subscription_key", &self.subscription_key)
85
0
            .field("weak_store", &self.weak_store)
86
0
            .field("state", &self.state)
87
0
            .field("last_known_keepalive_ts", &self.last_known_keepalive_ts)
88
0
            .field("now_fn", &self.now_fn)
89
0
            .finish()
90
0
    }
91
}
92
impl<S, I, NowFn> OperationSubscriber<S, I, NowFn>
93
where
94
    S: SchedulerStore,
95
    I: InstantWrapper,
96
    NowFn: Fn() -> I,
97
{
98
27
    const fn new(
99
27
        maybe_client_operation_id: Option<ClientOperationId>,
100
27
        subscription_key: OperationIdToAwaitedAction<'static>,
101
27
        weak_store: Weak<S>,
102
27
        now_fn: NowFn,
103
27
    ) -> Self {
104
27
        Self {
105
27
            maybe_client_operation_id,
106
27
            subscription_key,
107
27
            weak_store,
108
27
            last_known_keepalive_ts: AtomicU64::new(0),
109
27
            state: OperationSubscriberState::Unsubscribed,
110
27
            now_fn,
111
27
            maybe_last_stage: None,
112
27
        }
113
27
    }
114
115
45
    async fn inner_get_awaited_action(
116
45
        store: &S,
117
45
        key: OperationIdToAwaitedAction<'_>,
118
45
        maybe_client_operation_id: Option<ClientOperationId>,
119
45
        last_known_keepalive_ts: &AtomicU64,
120
45
    ) -> Result<AwaitedAction, Error> {
121
45
        let mut awaited_action = store
122
45
            .get_and_decode(key.borrow())
123
45
            .await
124
45
            .err_tip(|| 
format!0
("In OperationSubscriber::get_awaited_action {key:?}"))
?0
125
45
            .ok_or_else(|| 
{0
126
0
                make_err!(
127
0
                    Code::NotFound,
128
                    "Could not find AwaitedAction for the given operation id {key:?}",
129
                )
130
0
            })?;
131
45
        if let Some(
client_operation_id6
) = maybe_client_operation_id {
132
6
            awaited_action.set_client_operation_id(client_operation_id);
133
39
        }
134
135
        // Helper to convert SystemTime to unix timestamp
136
45
        let to_unix_ts = |t: std::time::SystemTime| -> u64 {
137
45
            t.duration_since(UNIX_EPOCH).map_or(0, |d| d.as_secs())
138
45
        };
139
140
        // Check the separate keepalive key for the most recent timestamp.
141
45
        let keepalive_ts = if USE_SEPARATE_CLIENT_KEEPALIVE_KEY {
142
45
            let operation_id = key.0.as_ref();
143
45
            match store.get_and_decode(ClientKeepaliveKey(operation_id)).await {
144
0
                Ok(Some(ts)) => {
145
0
                    let awaited_ts = to_unix_ts(awaited_action.last_client_keepalive_timestamp());
146
0
                    if ts > awaited_ts {
147
0
                        let timestamp = UNIX_EPOCH + Duration::from_secs(ts);
148
0
                        awaited_action.update_client_keep_alive(timestamp);
149
0
                        ts
150
                    } else {
151
0
                        awaited_ts
152
                    }
153
                }
154
45
                Ok(None) | Err(_) => to_unix_ts(awaited_action.last_client_keepalive_timestamp()),
155
            }
156
        } else {
157
0
            to_unix_ts(awaited_action.last_client_keepalive_timestamp())
158
        };
159
160
45
        last_known_keepalive_ts.store(keepalive_ts, Ordering::Release);
161
45
        Ok(awaited_action)
162
45
    }
163
164
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
165
43
    async fn get_awaited_action(&self) -> Result<AwaitedAction, Error> {
166
43
        let store = self
167
43
            .weak_store
168
43
            .upgrade()
169
43
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")
?0
;
170
43
        Self::inner_get_awaited_action(
171
43
            store.as_ref(),
172
43
            self.subscription_key.borrow(),
173
43
            self.maybe_client_operation_id.clone(),
174
43
            &self.last_known_keepalive_ts,
175
43
        )
176
43
        .await
177
43
    }
178
}
179
180
impl<S, I, NowFn> AwaitedActionSubscriber for OperationSubscriber<S, I, NowFn>
181
where
182
    S: SchedulerStore,
183
    I: InstantWrapper,
184
    NowFn: Fn() -> I + Send + Sync + 'static,
185
{
186
2
    async fn changed(&mut self) -> Result<AwaitedAction, Error> {
187
2
        let store = self
188
2
            .weak_store
189
2
            .upgrade()
190
2
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")
?0
;
191
2
        let 
subscription1
= match &mut self.state {
192
1
            OperationSubscriberState::Subscribed(subscription) => subscription,
193
            OperationSubscriberState::Unsubscribed => {
194
1
                let subscription = store
195
1
                    .subscription_manager()
196
1
                    .await
197
1
                    .err_tip(|| "In OperationSubscriber::changed::subscription_manager")
?0
198
1
                    .subscribe(self.subscription_key.borrow())
199
1
                    .err_tip(|| "In OperationSubscriber::changed::subscribe")
?0
;
200
1
                self.state = OperationSubscriberState::Subscribed(subscription);
201
                // When we've just subscribed, there may have been changes before now.
202
1
                let action = Self::inner_get_awaited_action(
203
1
                    store.as_ref(),
204
1
                    self.subscription_key.borrow(),
205
1
                    self.maybe_client_operation_id.clone(),
206
1
                    &self.last_known_keepalive_ts,
207
1
                )
208
1
                .await
209
1
                .err_tip(|| "In OperationSubscriber::changed")
?0
;
210
1
                if !<S as SchedulerStore>::SubscriptionManager::is_reliable() {
211
1
                    self.maybe_last_stage = Some(core::mem::discriminant(&action.state().stage));
212
1
                
}0
213
                // Existing changes are only interesting if the state is past queued.
214
1
                if !matches!(action.state().stage, ActionStage::Queued) {
215
1
                    return Ok(action);
216
0
                }
217
0
                let OperationSubscriberState::Subscribed(subscription) = &mut self.state else {
218
0
                    unreachable!("Subscription should be in Subscribed state");
219
                };
220
0
                subscription
221
            }
222
        };
223
224
1
        let changed_fut = subscription.changed();
225
1
        tokio::pin!(changed_fut);
226
        loop {
227
            // This is set if the maybe_last_state doesn't match the state in the store.
228
1
            let mut maybe_changed_action = None;
229
230
1
            let last_known_keepalive_ts = self.last_known_keepalive_ts.load(Ordering::Acquire);
231
1
            if I::from_secs(last_known_keepalive_ts).elapsed() > CLIENT_KEEPALIVE_DURATION {
232
0
                let now = (self.now_fn)().now();
233
0
                let now_ts = now.duration_since(UNIX_EPOCH).map_or(0, |d| d.as_secs());
234
235
0
                if USE_SEPARATE_CLIENT_KEEPALIVE_KEY {
236
0
                    let operation_id = self.subscription_key.0.as_ref();
237
0
                    let update_result = store
238
0
                        .update_data(UpdateClientKeepalive {
239
0
                            operation_id,
240
0
                            timestamp: now_ts,
241
0
                        })
242
0
                        .await;
243
244
0
                    if let Err(e) = update_result {
245
0
                        warn!(
246
                            ?self.subscription_key,
247
                            ?e,
248
                            "Failed to update client keepalive (non-versioned)"
249
                        );
250
0
                    }
251
252
                    // Update local timestamp
253
0
                    self.last_known_keepalive_ts
254
0
                        .store(now_ts, Ordering::Release);
255
256
                    // Check if state changed (for unreliable subscription managers)
257
0
                    if self.maybe_last_stage.is_some() {
258
0
                        let awaited_action = Self::inner_get_awaited_action(
259
0
                            store.as_ref(),
260
0
                            self.subscription_key.borrow(),
261
0
                            self.maybe_client_operation_id.clone(),
262
0
                            &self.last_known_keepalive_ts,
263
0
                        )
264
0
                        .await
265
0
                        .err_tip(|| "In OperationSubscriber::changed")?;
266
267
0
                        if self.maybe_last_stage.as_ref().is_some_and(|last_stage| {
268
0
                            *last_stage != core::mem::discriminant(&awaited_action.state().stage)
269
0
                        }) {
270
0
                            maybe_changed_action = Some(awaited_action);
271
0
                        }
272
0
                    }
273
                } else {
274
0
                    for attempt in 1..=MAX_RETRIES_FOR_CLIENT_KEEPALIVE {
275
0
                        if attempt > 1 {
276
0
                            (self.now_fn)().sleep(Duration::from_millis(100)).await;
277
0
                            warn!(
278
                                ?self.subscription_key,
279
                                attempt,
280
                                "Client keepalive retry due to version conflict"
281
                            );
282
0
                        }
283
0
                        let mut awaited_action = Self::inner_get_awaited_action(
284
0
                            store.as_ref(),
285
0
                            self.subscription_key.borrow(),
286
0
                            self.maybe_client_operation_id.clone(),
287
0
                            &self.last_known_keepalive_ts,
288
0
                        )
289
0
                        .await
290
0
                        .err_tip(|| "In OperationSubscriber::changed")?;
291
0
                        awaited_action.update_client_keep_alive(now);
292
0
                        maybe_changed_action = self
293
0
                            .maybe_last_stage
294
0
                            .as_ref()
295
0
                            .is_some_and(|last_stage| {
296
0
                                *last_stage
297
0
                                    != core::mem::discriminant(&awaited_action.state().stage)
298
0
                            })
299
0
                            .then(|| awaited_action.clone());
300
0
                        match inner_update_awaited_action(store.as_ref(), awaited_action).await {
301
0
                            Ok(()) => break,
302
0
                            err if attempt == MAX_RETRIES_FOR_CLIENT_KEEPALIVE => {
303
0
                                err.err_tip_with_code(|_| {
304
0
                                    (Code::Aborted, "Could not update client keep alive")
305
0
                                })?;
306
                            }
307
0
                            _ => (),
308
                        }
309
                    }
310
                }
311
1
            }
312
313
            // If the polling shows that it's changed state then publish now.
314
1
            if let Some(
changed_action0
) = maybe_changed_action {
315
0
                self.maybe_last_stage =
316
0
                    Some(core::mem::discriminant(&changed_action.state().stage));
317
0
                return Ok(changed_action);
318
1
            }
319
            // Determine the sleep time based on the last client keep alive.
320
1
            let sleep_time = CLIENT_KEEPALIVE_DURATION
321
1
                .checked_sub(
322
1
                    I::from_secs(self.last_known_keepalive_ts.load(Ordering::Acquire)).elapsed(),
323
                )
324
1
                .unwrap_or(Duration::from_millis(100));
325
1
            tokio::select! {
326
1
                result = &mut changed_fut => {
327
1
                    result
?0
;
328
1
                    break;
329
                }
330
1
                () = (self.now_fn)().sleep(sleep_time) => {
331
0
                    // If we haven't received any updates for a while, we should
332
0
                    // let the database know that we are still listening to prevent
333
0
                    // the action from being dropped.  Also poll for updates if the
334
0
                    // subscription manager is unreliable.
335
0
                }
336
            }
337
        }
338
339
1
        let awaited_action = Self::inner_get_awaited_action(
340
1
            store.as_ref(),
341
1
            self.subscription_key.borrow(),
342
1
            self.maybe_client_operation_id.clone(),
343
1
            &self.last_known_keepalive_ts,
344
1
        )
345
1
        .await
346
1
        .err_tip(|| "In OperationSubscriber::changed")
?0
;
347
1
        if self.maybe_last_stage.is_some() {
348
1
            self.maybe_last_stage = Some(core::mem::discriminant(&awaited_action.state().stage));
349
1
        
}0
350
1
        Ok(awaited_action)
351
2
    }
352
353
43
    async fn borrow(&self) -> Result<AwaitedAction, Error> {
354
43
        self.get_awaited_action()
355
43
            .await
356
43
            .err_tip(|| "In OperationSubscriber::borrow")
357
43
    }
358
}
359
360
56
fn awaited_action_decode(version: i64, data: &Bytes) -> Result<AwaitedAction, Error> {
361
56
    let mut awaited_action: AwaitedAction = serde_json::from_slice(data).map_err(|e| 
{0
362
0
        Error::from_std_err(Code::InvalidArgument, &e).append("In AwaitedAction::decode")
363
0
    })?;
364
56
    awaited_action.set_version(version);
365
56
    Ok(awaited_action)
366
56
}
367
368
const OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX: &str = "aa_";
369
const CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX: &str = "cid_";
370
/// Phase 2: Separate key prefix for client keepalives (non-versioned).
371
const CLIENT_KEEPALIVE_KEY_PREFIX: &str = "ck_";
372
373
#[derive(Debug)]
374
struct OperationIdToAwaitedAction<'a>(Cow<'a, OperationId>);
375
impl OperationIdToAwaitedAction<'_> {
376
91
    fn borrow(&self) -> OperationIdToAwaitedAction<'_> {
377
91
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.as_ref()))
378
91
    }
379
}
380
impl SchedulerStoreKeyProvider for OperationIdToAwaitedAction<'_> {
381
    type Versioned = TrueValue;
382
67
    fn get_key(&self) -> StoreKey<'static> {
383
67
        StoreKey::Str(Cow::Owned(format!(
384
67
            "{OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX}{}",
385
67
            self.0
386
67
        )))
387
67
    }
388
}
389
impl SchedulerStoreDecodeTo for OperationIdToAwaitedAction<'_> {
390
    type DecodeOutput = AwaitedAction;
391
46
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
392
46
        awaited_action_decode(version, &data)
393
46
    }
394
}
395
396
struct ClientIdToOperationId<'a>(&'a OperationId);
397
impl SchedulerStoreKeyProvider for ClientIdToOperationId<'_> {
398
    type Versioned = FalseValue;
399
5
    fn get_key(&self) -> StoreKey<'static> {
400
5
        StoreKey::Str(Cow::Owned(format!(
401
5
            "{CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX}{}",
402
5
            self.0
403
5
        )))
404
5
    }
405
}
406
impl SchedulerStoreDecodeTo for ClientIdToOperationId<'_> {
407
    type DecodeOutput = OperationId;
408
2
    fn decode(_version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
409
2
        serde_json::from_slice(&data).map_err(|e| 
{0
410
0
            Error::from_std_err(Code::InvalidArgument, &e).append(format!(
411
                "In ClientIdToOperationId::decode (data: {data:02x?})",
412
            ))
413
0
        })
414
2
    }
415
}
416
417
struct ClientKeepaliveKey<'a>(&'a OperationId);
418
impl SchedulerStoreKeyProvider for ClientKeepaliveKey<'_> {
419
    type Versioned = FalseValue;
420
45
    fn get_key(&self) -> StoreKey<'static> {
421
45
        StoreKey::Str(Cow::Owned(format!(
422
45
            "{CLIENT_KEEPALIVE_KEY_PREFIX}{}",
423
45
            self.0
424
45
        )))
425
45
    }
426
}
427
impl SchedulerStoreDecodeTo for ClientKeepaliveKey<'_> {
428
    type DecodeOutput = u64;
429
0
    fn decode(_version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
430
0
        let s = core::str::from_utf8(&data).map_err(|e| {
431
0
            Error::from_std_err(Code::InvalidArgument, &e)
432
0
                .append("In ClientKeepaliveKey::decode utf8")
433
0
        })?;
434
0
        s.parse::<u64>().map_err(|e| {
435
0
            Error::from_std_err(Code::InvalidArgument, &e)
436
0
                .append("In ClientKeepaliveKey::decode parse")
437
0
        })
438
0
    }
439
}
440
441
struct UpdateClientKeepalive<'a> {
442
    operation_id: &'a OperationId,
443
    timestamp: u64,
444
}
445
impl SchedulerStoreKeyProvider for UpdateClientKeepalive<'_> {
446
    type Versioned = FalseValue;
447
0
    fn get_key(&self) -> StoreKey<'static> {
448
0
        ClientKeepaliveKey(self.operation_id).get_key()
449
0
    }
450
}
451
impl SchedulerStoreDataProvider for UpdateClientKeepalive<'_> {
452
0
    fn try_into_bytes(self) -> Result<Bytes, Error> {
453
0
        Ok(Bytes::from(self.timestamp.to_string()))
454
0
    }
455
}
456
457
// TODO(palfrey) We only need operation_id here, it would be nice if we had a way
458
// to tell the decoder we only care about specific fields.
459
struct SearchUniqueQualifierToAwaitedAction<'a>(&'a ActionUniqueQualifier);
460
impl SchedulerIndexProvider for SearchUniqueQualifierToAwaitedAction<'_> {
461
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
462
    const INDEX_NAME: &'static str = "unique_qualifier";
463
    type Versioned = TrueValue;
464
3
    fn index_value(&self) -> Cow<'_, str> {
465
3
        Cow::Owned(format!("{}", self.0))
466
3
    }
467
}
468
impl SchedulerStoreDecodeTo for SearchUniqueQualifierToAwaitedAction<'_> {
469
    type DecodeOutput = AwaitedAction;
470
2
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
471
2
        awaited_action_decode(version, &data)
472
2
    }
473
}
474
475
struct SearchStateToAwaitedAction(&'static str);
476
impl SchedulerIndexProvider for SearchStateToAwaitedAction {
477
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
478
    const INDEX_NAME: &'static str = "state";
479
    const MAYBE_SORT_KEY: Option<&'static str> = Some("sort_key");
480
    type Versioned = TrueValue;
481
17
    fn index_value(&self) -> Cow<'_, str> {
482
17
        Cow::Borrowed(self.0)
483
17
    }
484
}
485
impl SchedulerStoreDecodeTo for SearchStateToAwaitedAction {
486
    type DecodeOutput = AwaitedAction;
487
8
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
488
8
        awaited_action_decode(version, &data)
489
8
    }
490
}
491
492
33
const fn get_state_prefix(state: SortedAwaitedActionState) -> &'static str {
493
33
    match state {
494
0
        SortedAwaitedActionState::CacheCheck => "cache_check",
495
25
        SortedAwaitedActionState::Queued => "queued",
496
7
        SortedAwaitedActionState::Executing => "executing",
497
1
        SortedAwaitedActionState::Completed => "completed",
498
    }
499
33
}
500
501
struct UpdateOperationIdToAwaitedAction(AwaitedAction);
502
impl SchedulerCurrentVersionProvider for UpdateOperationIdToAwaitedAction {
503
16
    fn current_version(&self) -> i64 {
504
16
        self.0.version()
505
16
    }
506
}
507
impl SchedulerStoreKeyProvider for UpdateOperationIdToAwaitedAction {
508
    type Versioned = TrueValue;
509
16
    fn get_key(&self) -> StoreKey<'static> {
510
16
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.operation_id())).get_key()
511
16
    }
512
}
513
impl SchedulerStoreDataProvider for UpdateOperationIdToAwaitedAction {
514
16
    fn try_into_bytes(self) -> Result<Bytes, Error> {
515
16
        serde_json::to_string(&self.0)
516
16
            .map(Bytes::from)
517
16
            .map_err(|e| 
{0
518
0
                Error::from_std_err(Code::InvalidArgument, &e)
519
0
                    .append("Could not convert AwaitedAction to json")
520
0
            })
521
16
    }
522
16
    fn get_indexes(&self) -> Result<Vec<(&'static str, Bytes)>, Error> {
523
16
        let unique_qualifier = &self.0.action_info().unique_qualifier;
524
16
        let maybe_unique_qualifier = match &unique_qualifier {
525
16
            ActionUniqueQualifier::Cacheable(_) => Some(unique_qualifier),
526
0
            ActionUniqueQualifier::Uncacheable(_) => None,
527
        };
528
16
        let mut output = Vec::with_capacity(2 + maybe_unique_qualifier.map_or(0, |_| 1));
529
16
        if maybe_unique_qualifier.is_some() {
530
16
            output.push((
531
16
                "unique_qualifier",
532
16
                Bytes::from(unique_qualifier.to_string()),
533
16
            ));
534
16
        
}0
535
        {
536
16
            let state = SortedAwaitedActionState::try_from(&self.0.state().stage)
537
16
                .err_tip(|| "In UpdateOperationIdToAwaitedAction::get_index")
?0
;
538
16
            output.push(("state", Bytes::from(get_state_prefix(state))));
539
16
            let sorted_awaited_action = SortedAwaitedAction::from(&self.0);
540
16
            output.push((
541
16
                "sort_key",
542
16
                // We encode to hex to ensure that the sort key is lexicographically sorted.
543
16
                Bytes::from(format!("{:016x}", sorted_awaited_action.sort_key.as_u64())),
544
16
            ));
545
        }
546
16
        Ok(output)
547
16
    }
548
}
549
550
struct UpdateClientIdToOperationId {
551
    client_operation_id: ClientOperationId,
552
    operation_id: OperationId,
553
}
554
impl SchedulerStoreKeyProvider for UpdateClientIdToOperationId {
555
    type Versioned = FalseValue;
556
3
    fn get_key(&self) -> StoreKey<'static> {
557
3
        ClientIdToOperationId(&self.client_operation_id).get_key()
558
3
    }
559
}
560
impl SchedulerStoreDataProvider for UpdateClientIdToOperationId {
561
3
    fn try_into_bytes(self) -> Result<Bytes, Error> {
562
3
        serde_json::to_string(&self.operation_id)
563
3
            .map(Bytes::from)
564
3
            .map_err(|e| 
{0
565
0
                Error::from_std_err(Code::InvalidArgument, &e)
566
0
                    .append("Could not convert OperationId to json")
567
0
            })
568
3
    }
569
}
570
571
13
async fn inner_update_awaited_action(
572
13
    store: &impl SchedulerStore,
573
13
    mut new_awaited_action: AwaitedAction,
574
13
) -> Result<(), Error> {
575
13
    let operation_id = new_awaited_action.operation_id().clone();
576
13
    if new_awaited_action.state().client_operation_id != operation_id {
577
0
        new_awaited_action.set_client_operation_id(operation_id.clone());
578
13
    }
579
580
13
    let _is_finished = new_awaited_action.state().stage.is_finished();
581
582
13
    let maybe_version = store
583
13
        .update_data(UpdateOperationIdToAwaitedAction(new_awaited_action))
584
13
        .await
585
13
        .err_tip(|| "In RedisAwaitedActionDb::update_awaited_action")
?0
;
586
587
13
    if maybe_version.is_none() {
588
4
        warn!(
589
            %operation_id,
590
            "Could not update AwaitedAction because the version did not match"
591
        );
592
4
        return Err(make_err!(
593
4
            Code::Aborted,
594
4
            "Could not update AwaitedAction because the version did not match for {operation_id}",
595
4
        ));
596
9
    }
597
598
9
    Ok(())
599
13
}
600
601
#[derive(Debug, MetricsComponent)]
602
pub struct StoreAwaitedActionDb<S, F, I, NowFn>
603
where
604
    S: SchedulerStore,
605
    F: Fn() -> OperationId,
606
    I: InstantWrapper,
607
    NowFn: Fn() -> I,
608
{
609
    store: Arc<S>,
610
    now_fn: NowFn,
611
    operation_id_creator: F,
612
    _pull_task_change_subscriber_spawn: JoinHandleDropGuard<()>,
613
}
614
615
impl<S, F, I, NowFn> StoreAwaitedActionDb<S, F, I, NowFn>
616
where
617
    S: SchedulerStore,
618
    F: Fn() -> OperationId,
619
    I: InstantWrapper,
620
    NowFn: Fn() -> I + Send + Sync + Clone + 'static,
621
{
622
3
    pub async fn new(
623
3
        store: Arc<S>,
624
3
        task_change_publisher: Arc<Notify>,
625
3
        now_fn: NowFn,
626
3
        operation_id_creator: F,
627
3
    ) -> Result<Self, Error> {
628
3
        let mut subscription = store
629
3
            .subscription_manager()
630
3
            .await
631
3
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
632
3
            .subscribe(OperationIdToAwaitedAction(Cow::Owned(OperationId::String(
633
3
                String::new(),
634
3
            ))))
635
3
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
;
636
3
        let pull_task_change_subscriber = spawn!(
637
            "redis_awaited_action_db_pull_task_change_subscriber",
638
3
            async move {
639
                loop {
640
17
                    let 
changed_res14
= subscription
641
17
                        .changed()
642
17
                        .await
643
14
                        .err_tip(|| "In RedisAwaitedActionDb::new");
644
14
                    if let Err(
err0
) = changed_res {
645
0
                        error!(
646
                            "Error waiting for pull task change subscriber in RedisAwaitedActionDb::new  - {err:?}"
647
                        );
648
                        // Sleep for a second to avoid a busy loop, then trigger the notify
649
                        // so if a reconnect happens we let local resources know that things
650
                        // might have changed.
651
0
                        tokio::time::sleep(Duration::from_secs(1)).await;
652
14
                    }
653
14
                    task_change_publisher.as_ref().notify_one();
654
                }
655
            }
656
        );
657
3
        Ok(Self {
658
3
            store,
659
3
            now_fn,
660
3
            operation_id_creator,
661
3
            _pull_task_change_subscriber_spawn: pull_task_change_subscriber,
662
3
        })
663
3
    }
664
665
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
666
3
    async fn try_subscribe(
667
3
        &self,
668
3
        client_operation_id: &ClientOperationId,
669
3
        unique_qualifier: &ActionUniqueQualifier,
670
3
        no_event_action_timeout: Duration,
671
3
        // TODO(palfrey) To simplify the scheduler 2024 refactor, we
672
3
        // removed the ability to upgrade priorities of actions.
673
3
        // we should add priority upgrades back in.
674
3
        _priority: i32,
675
3
    ) -> Result<Option<AwaitedAction>, Error> {
676
3
        match unique_qualifier {
677
3
            ActionUniqueQualifier::Cacheable(_) => {}
678
0
            ActionUniqueQualifier::Uncacheable(_) => return Ok(None),
679
        }
680
3
        let stream = self
681
3
            .store
682
3
            .search_by_index_prefix(SearchUniqueQualifierToAwaitedAction(unique_qualifier))
683
3
            .await
684
3
            .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
685
3
        tokio::pin!(stream);
686
3
        let maybe_awaited_action = stream
687
3
            .try_next()
688
3
            .await
689
3
            .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
690
3
        match maybe_awaited_action {
691
2
            Some(awaited_action) => {
692
                // TODO(palfrey) We don't support joining completed jobs because we
693
                // need to also check that all the data is still in the cache.
694
                // If the existing job failed then we need to set back to queued or we get
695
                // a version mismatch.  Equally we need to check the timeout as the job
696
                // may be abandoned in the store.
697
2
                let worker_should_update_before = (awaited_action.state().stage
698
2
                    == ActionStage::Executing)
699
2
                    .then_some(())
700
2
                    .map(|()| 
awaited_action0
.
last_worker_updated_timestamp0
())
701
2
                    .and_then(|last_worker_updated| 
{0
702
0
                        last_worker_updated.checked_add(no_event_action_timeout)
703
0
                    });
704
2
                let awaited_action = if awaited_action.state().stage.is_finished()
705
1
                    || worker_should_update_before
706
1
                        .is_some_and(|timestamp| 
timestamp0
<
(self.now_fn)().now()0
)
707
                {
708
1
                    tracing::debug!(
709
                        "Recreating action {:?} for operation {client_operation_id}",
710
1
                        awaited_action.action_info().digest()
711
                    );
712
                    // The version is reset because we have a new operation ID.
713
1
                    AwaitedAction::new(
714
1
                        (self.operation_id_creator)(),
715
1
                        awaited_action.action_info().clone(),
716
1
                        (self.now_fn)().now(),
717
                    )
718
                } else {
719
1
                    tracing::debug!(
720
                        "Subscribing to existing action {:?} for operation {client_operation_id}",
721
1
                        awaited_action.action_info().digest()
722
                    );
723
1
                    awaited_action
724
                };
725
2
                Ok(Some(awaited_action))
726
            }
727
1
            None => Ok(None),
728
        }
729
3
    }
730
731
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
732
2
    async fn inner_get_awaited_action_by_id(
733
2
        &self,
734
2
        client_operation_id: &ClientOperationId,
735
2
    ) -> Result<Option<OperationSubscriber<S, I, NowFn>>, Error> {
736
2
        let maybe_operation_id = self
737
2
            .store
738
2
            .get_and_decode(ClientIdToOperationId(client_operation_id))
739
2
            .await
740
2
            .err_tip(|| "In RedisAwaitedActionDb::get_awaited_action_by_id")
?0
;
741
2
        let Some(operation_id) = maybe_operation_id else {
742
0
            return Ok(None);
743
        };
744
745
        // Validate that the internal operation actually exists.
746
        // If it doesn't, this is an orphaned client operation mapping that should be cleaned up.
747
        // This can happen when an operation is deleted (completed/timed out) but the
748
        // client_id -> operation_id mapping persists in the store.
749
2
        let maybe_awaited_action = match self
750
2
            .store
751
2
            .get_and_decode(OperationIdToAwaitedAction(Cow::Borrowed(&operation_id)))
752
2
            .await
753
        {
754
2
            Ok(maybe_action) => maybe_action,
755
0
            Err(err) if err.code == Code::NotFound => {
756
0
                tracing::warn!(
757
                    "Orphaned client operation mapping detected: client_id={} maps to operation_id={}, \
758
                    but the operation does not exist in the store (NotFound). This typically happens when \
759
                    an operation completes or times out but the client mapping persists.",
760
                    client_operation_id,
761
                    operation_id
762
                );
763
0
                None
764
            }
765
0
            Err(err) => {
766
                // Some other error occurred
767
0
                return Err(err).err_tip(
768
                    || "In RedisAwaitedActionDb::get_awaited_action_by_id::validate_operation",
769
                );
770
            }
771
        };
772
773
2
        if maybe_awaited_action.is_none() {
774
1
            tracing::warn!(
775
                "Found orphaned client operation mapping: client_id={} -> operation_id={}, \
776
                but operation no longer exists. Returning None to prevent client from polling \
777
                a non-existent operation.",
778
                client_operation_id,
779
                operation_id
780
            );
781
1
            return Ok(None);
782
1
        }
783
784
1
        Ok(Some(OperationSubscriber::new(
785
1
            Some(client_operation_id.clone()),
786
1
            OperationIdToAwaitedAction(Cow::Owned(operation_id)),
787
1
            Arc::downgrade(&self.store),
788
1
            self.now_fn.clone(),
789
1
        )))
790
2
    }
791
}
792
793
impl<S, F, I, NowFn> AwaitedActionDb for StoreAwaitedActionDb<S, F, I, NowFn>
794
where
795
    S: SchedulerStore,
796
    F: Fn() -> OperationId + Send + Sync + Unpin + 'static,
797
    I: InstantWrapper,
798
    NowFn: Fn() -> I + Send + Sync + Unpin + Clone + 'static,
799
{
800
    type Subscriber = OperationSubscriber<S, I, NowFn>;
801
802
2
    async fn get_awaited_action_by_id(
803
2
        &self,
804
2
        client_operation_id: &ClientOperationId,
805
2
    ) -> Result<Option<Self::Subscriber>, Error> {
806
2
        self.inner_get_awaited_action_by_id(client_operation_id)
807
2
            .await
808
2
    }
809
810
15
    async fn get_by_operation_id(
811
15
        &self,
812
15
        operation_id: &OperationId,
813
15
    ) -> Result<Option<Self::Subscriber>, Error> {
814
15
        Ok(Some(OperationSubscriber::new(
815
15
            None,
816
15
            OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())),
817
15
            Arc::downgrade(&self.store),
818
15
            self.now_fn.clone(),
819
15
        )))
820
15
    }
821
822
13
    async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
823
13
        inner_update_awaited_action(self.store.as_ref(), new_awaited_action).await
824
13
    }
825
826
3
    async fn add_action(
827
3
        &self,
828
3
        client_operation_id: ClientOperationId,
829
3
        action_info: Arc<ActionInfo>,
830
3
        no_event_action_timeout: Duration,
831
3
    ) -> Result<Self::Subscriber, Error> {
832
        loop {
833
            // Check to see if the action is already known and subscribe if it is.
834
3
            let mut awaited_action = self
835
3
                .try_subscribe(
836
3
                    &client_operation_id,
837
3
                    &action_info.unique_qualifier,
838
3
                    no_event_action_timeout,
839
3
                    action_info.priority,
840
3
                )
841
3
                .await
842
3
                .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
843
3
                .unwrap_or_else(|| 
{1
844
1
                    tracing::debug!(
845
                        "Creating new action {:?} for operation {client_operation_id}",
846
1
                        action_info.digest()
847
                    );
848
1
                    AwaitedAction::new(
849
1
                        (self.operation_id_creator)(),
850
1
                        action_info.clone(),
851
1
                        (self.now_fn)().now(),
852
                    )
853
1
                });
854
855
3
            debug_assert!(
856
0
                ActionStage::Queued == awaited_action.state().stage,
857
                "Expected action to be queued"
858
            );
859
860
3
            let operation_id = awaited_action.operation_id().clone();
861
3
            if awaited_action.state().client_operation_id != operation_id {
862
0
                // Just in case the client_operation_id was set to something else
863
0
                // we put it back to the underlying operation_id.
864
0
                awaited_action.set_client_operation_id(operation_id.clone());
865
3
            }
866
3
            awaited_action.update_client_keep_alive((self.now_fn)().now());
867
868
3
            let version = awaited_action.version();
869
3
            if self
870
3
                .store
871
3
                .update_data(UpdateOperationIdToAwaitedAction(awaited_action))
872
3
                .await
873
3
                .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
874
3
                .is_none()
875
            {
876
                // The version was out of date, try again.
877
0
                tracing::info!(
878
                    "Version out of date for {:?} {operation_id} {version}, retrying.",
879
0
                    action_info.digest()
880
                );
881
0
                continue;
882
3
            }
883
884
            // Add the client_operation_id to operation_id mapping
885
3
            self.store
886
3
                .update_data(UpdateClientIdToOperationId {
887
3
                    client_operation_id: client_operation_id.clone(),
888
3
                    operation_id: operation_id.clone(),
889
3
                })
890
3
                .await
891
3
                .err_tip(|| "In RedisAwaitedActionDb::add_action while adding client mapping")
?0
;
892
893
3
            return Ok(OperationSubscriber::new(
894
3
                Some(client_operation_id),
895
3
                OperationIdToAwaitedAction(Cow::Owned(operation_id)),
896
3
                Arc::downgrade(&self.store),
897
3
                self.now_fn.clone(),
898
3
            ));
899
        }
900
3
    }
901
902
17
    async fn get_range_of_actions(
903
17
        &self,
904
17
        state: SortedAwaitedActionState,
905
17
        start: Bound<SortedAwaitedAction>,
906
17
        end: Bound<SortedAwaitedAction>,
907
17
        desc: bool,
908
17
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
909
17
        if !
matches!0
(start, Bound::Unbounded) {
910
0
            return Err(make_err!(
911
0
                Code::Unimplemented,
912
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
913
0
            ));
914
17
        }
915
17
        if !
matches!0
(end, Bound::Unbounded) {
916
0
            return Err(make_err!(
917
0
                Code::Unimplemented,
918
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
919
0
            ));
920
17
        }
921
        // TODO(palfrey) This API is not difficult to implement, but there is no code path
922
        // that uses it, so no reason to implement it yet.
923
17
        if !desc {
924
0
            return Err(make_err!(
925
0
                Code::Unimplemented,
926
0
                "Descending order is not supported in RedisAwaitedActionDb::get_range_of_actions",
927
0
            ));
928
17
        }
929
17
        Ok(self
930
17
            .store
931
17
            .search_by_index_prefix(SearchStateToAwaitedAction(get_state_prefix(state)))
932
17
            .await
933
17
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")
?0
934
17
            .map_ok(move |awaited_action| 
{8
935
8
                OperationSubscriber::new(
936
8
                    None,
937
8
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
938
8
                    Arc::downgrade(&self.store),
939
8
                    self.now_fn.clone(),
940
                )
941
8
            }))
942
17
    }
943
944
0
    async fn get_all_awaited_actions(
945
0
        &self,
946
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
947
0
        Ok(self
948
0
            .store
949
0
            .search_by_index_prefix(SearchStateToAwaitedAction(""))
950
0
            .await
951
0
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")?
952
0
            .map_ok(move |awaited_action| {
953
0
                OperationSubscriber::new(
954
0
                    None,
955
0
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
956
0
                    Arc::downgrade(&self.store),
957
0
                    self.now_fn.clone(),
958
                )
959
0
            }))
960
0
    }
961
}