Coverage Report

Created: 2025-10-30 00:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/store_awaited_action_db.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::mem::Discriminant;
16
use core::ops::Bound;
17
use core::sync::atomic::{AtomicU64, Ordering};
18
use core::time::Duration;
19
use std::borrow::Cow;
20
use std::sync::{Arc, Weak};
21
22
use bytes::Bytes;
23
use futures::{Stream, TryStreamExt};
24
use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err};
25
use nativelink_metric::MetricsComponent;
26
use nativelink_util::action_messages::{
27
    ActionInfo, ActionStage, ActionUniqueQualifier, OperationId,
28
};
29
use nativelink_util::instant_wrapper::InstantWrapper;
30
use nativelink_util::spawn;
31
use nativelink_util::store_trait::{
32
    FalseValue, SchedulerCurrentVersionProvider, SchedulerIndexProvider, SchedulerStore,
33
    SchedulerStoreDataProvider, SchedulerStoreDecodeTo, SchedulerStoreKeyProvider,
34
    SchedulerSubscription, SchedulerSubscriptionManager, StoreKey, TrueValue,
35
};
36
use nativelink_util::task::JoinHandleDropGuard;
37
use tokio::sync::Notify;
38
use tracing::error;
39
40
use crate::awaited_action_db::{
41
    AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, CLIENT_KEEPALIVE_DURATION,
42
    SortedAwaitedAction, SortedAwaitedActionState,
43
};
44
45
type ClientOperationId = OperationId;
46
47
/// Maximum number of retries to update client keep alive.
48
const MAX_RETRIES_FOR_CLIENT_KEEPALIVE: u32 = 8;
49
50
enum OperationSubscriberState<Sub> {
51
    Unsubscribed,
52
    Subscribed(Sub),
53
}
54
55
pub struct OperationSubscriber<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I> {
56
    maybe_client_operation_id: Option<ClientOperationId>,
57
    subscription_key: OperationIdToAwaitedAction<'static>,
58
    weak_store: Weak<S>,
59
    state: OperationSubscriberState<
60
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
61
    >,
62
    last_known_keepalive_ts: AtomicU64,
63
    now_fn: NowFn,
64
    // If the SchedulerSubscriptionManager is not reliable, then this is populated
65
    // when the state is set to subscribed.  When set it causes the state to be polled
66
    // as well as listening for the publishing.
67
    maybe_last_stage: Option<Discriminant<ActionStage>>,
68
}
69
70
impl<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I + core::fmt::Debug> core::fmt::Debug
71
    for OperationSubscriber<S, I, NowFn>
72
where
73
    OperationSubscriberState<
74
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
75
    >: core::fmt::Debug,
76
{
77
0
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78
0
        f.debug_struct("OperationSubscriber")
79
0
            .field("maybe_client_operation_id", &self.maybe_client_operation_id)
80
0
            .field("subscription_key", &self.subscription_key)
81
0
            .field("weak_store", &self.weak_store)
82
0
            .field("state", &self.state)
83
0
            .field("last_known_keepalive_ts", &self.last_known_keepalive_ts)
84
0
            .field("now_fn", &self.now_fn)
85
0
            .finish()
86
0
    }
87
}
88
impl<S, I, NowFn> OperationSubscriber<S, I, NowFn>
89
where
90
    S: SchedulerStore,
91
    I: InstantWrapper,
92
    NowFn: Fn() -> I,
93
{
94
19
    const fn new(
95
19
        maybe_client_operation_id: Option<ClientOperationId>,
96
19
        subscription_key: OperationIdToAwaitedAction<'static>,
97
19
        weak_store: Weak<S>,
98
19
        now_fn: NowFn,
99
19
    ) -> Self {
100
19
        Self {
101
19
            maybe_client_operation_id,
102
19
            subscription_key,
103
19
            weak_store,
104
19
            last_known_keepalive_ts: AtomicU64::new(0),
105
19
            state: OperationSubscriberState::Unsubscribed,
106
19
            now_fn,
107
19
            maybe_last_stage: None,
108
19
        }
109
19
    }
110
111
32
    async fn inner_get_awaited_action(
112
32
        store: &S,
113
32
        key: OperationIdToAwaitedAction<'_>,
114
32
        maybe_client_operation_id: Option<ClientOperationId>,
115
32
        last_known_keepalive_ts: &AtomicU64,
116
32
    ) -> Result<AwaitedAction, Error> {
117
32
        let mut awaited_action = store
118
32
            .get_and_decode(key.borrow())
119
32
            .await
120
32
            .err_tip(|| format!(
"In OperationSubscriber::get_awaited_action {key:?}"0
))
?0
121
32
            .ok_or_else(|| 
{0
122
0
                make_err!(
123
0
                    Code::NotFound,
124
                    "Could not find AwaitedAction for the given operation id {key:?}",
125
                )
126
0
            })?;
127
32
        if let Some(
client_operation_id10
) = maybe_client_operation_id {
  Branch (127:16): [True: 0, False: 0]
  Branch (127:16): [Folded - Ignored]
  Branch (127:16): [True: 10, False: 22]
128
10
            awaited_action.set_client_operation_id(client_operation_id);
129
22
        }
130
32
        last_known_keepalive_ts.store(
131
32
            awaited_action
132
32
                .last_client_keepalive_timestamp()
133
32
                .unix_timestamp(),
134
32
            Ordering::Release,
135
        );
136
32
        Ok(awaited_action)
137
32
    }
138
139
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
140
27
    async fn get_awaited_action(&self) -> Result<AwaitedAction, Error> {
141
27
        let store = self
142
27
            .weak_store
143
27
            .upgrade()
144
27
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")
?0
;
145
27
        Self::inner_get_awaited_action(
146
27
            store.as_ref(),
147
27
            self.subscription_key.borrow(),
148
27
            self.maybe_client_operation_id.clone(),
149
27
            &self.last_known_keepalive_ts,
150
27
        )
151
27
        .await
152
27
    }
153
}
154
155
impl<S, I, NowFn> AwaitedActionSubscriber for OperationSubscriber<S, I, NowFn>
156
where
157
    S: SchedulerStore,
158
    I: InstantWrapper,
159
    NowFn: Fn() -> I + Send + Sync + 'static,
160
{
161
4
    async fn changed(&mut self) -> Result<AwaitedAction, Error> {
162
4
        let store = self
163
4
            .weak_store
164
4
            .upgrade()
165
4
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")
?0
;
166
4
        let 
subscription3
= match &mut self.state {
167
2
            OperationSubscriberState::Subscribed(subscription) => subscription,
168
            OperationSubscriberState::Unsubscribed => {
169
2
                let subscription = store
170
2
                    .subscription_manager()
171
2
                    .err_tip(|| "In OperationSubscriber::changed::subscription_manager")
?0
172
2
                    .subscribe(self.subscription_key.borrow())
173
2
                    .err_tip(|| "In OperationSubscriber::changed::subscribe")
?0
;
174
2
                self.state = OperationSubscriberState::Subscribed(subscription);
175
                // When we've just subscribed, there may have been changes before now.
176
2
                let action = Self::inner_get_awaited_action(
177
2
                    store.as_ref(),
178
2
                    self.subscription_key.borrow(),
179
2
                    self.maybe_client_operation_id.clone(),
180
2
                    &self.last_known_keepalive_ts,
181
2
                )
182
2
                .await
183
2
                .err_tip(|| "In OperationSubscriber::changed")
?0
;
184
2
                if !<S as SchedulerStore>::SubscriptionManager::is_reliable() {
  Branch (184:20): [True: 0, False: 0]
  Branch (184:20): [Folded - Ignored]
  Branch (184:20): [True: 2, False: 0]
185
2
                    self.maybe_last_stage = Some(core::mem::discriminant(&action.state().stage));
186
2
                
}0
187
                // Existing changes are only interesting if the state is past queued.
188
2
                if !
matches!1
(action.state().stage, ActionStage::Queued) {
  Branch (188:20): [True: 0, False: 0]
  Branch (188:20): [Folded - Ignored]
  Branch (188:20): [True: 1, False: 1]
189
1
                    return Ok(action);
190
1
                }
191
1
                let OperationSubscriberState::Subscribed(subscription) = &mut self.state else {
  Branch (191:21): [True: 0, False: 0]
  Branch (191:21): [Folded - Ignored]
  Branch (191:21): [True: 1, False: 0]
192
0
                    unreachable!("Subscription should be in Subscribed state");
193
                };
194
1
                subscription
195
            }
196
        };
197
198
3
        let changed_fut = subscription.changed();
199
3
        tokio::pin!(changed_fut);
200
        loop {
201
            // This is set if the maybe_last_state doesn't match the state in the store.
202
3
            let mut maybe_changed_action = None;
203
3
            for attempt in 1..=MAX_RETRIES_FOR_CLIENT_KEEPALIVE {
204
3
                let last_known_keepalive_ts = self.last_known_keepalive_ts.load(Ordering::Acquire);
205
3
                if I::from_secs(last_known_keepalive_ts).elapsed() <= CLIENT_KEEPALIVE_DURATION {
  Branch (205:20): [True: 0, False: 0]
  Branch (205:20): [Folded - Ignored]
  Branch (205:20): [True: 3, False: 0]
206
3
                    break; // We are still within the keep alive duration.
207
0
                }
208
0
                if attempt > 1 {
  Branch (208:20): [True: 0, False: 0]
  Branch (208:20): [Folded - Ignored]
  Branch (208:20): [True: 0, False: 0]
209
                    // Wait a tick before retrying.
210
0
                    (self.now_fn)().sleep(Duration::from_millis(100)).await;
211
0
                }
212
0
                let mut awaited_action = Self::inner_get_awaited_action(
213
0
                    store.as_ref(),
214
0
                    self.subscription_key.borrow(),
215
0
                    self.maybe_client_operation_id.clone(),
216
0
                    &self.last_known_keepalive_ts,
217
0
                )
218
0
                .await
219
0
                .err_tip(|| "In OperationSubscriber::changed")?;
220
0
                awaited_action.update_client_keep_alive((self.now_fn)().now());
221
                // If this is set to Some then the action changed without being published.
222
0
                maybe_changed_action = self
223
0
                    .maybe_last_stage
224
0
                    .as_ref()
225
0
                    .is_some_and(|last_stage| {
226
0
                        *last_stage != core::mem::discriminant(&awaited_action.state().stage)
227
0
                    })
228
0
                    .then(|| awaited_action.clone());
229
0
                match inner_update_awaited_action(store.as_ref(), awaited_action).await {
230
0
                    Ok(()) => break,
231
0
                    err if attempt == MAX_RETRIES_FOR_CLIENT_KEEPALIVE => {
  Branch (231:28): [True: 0, False: 0]
  Branch (231:28): [Folded - Ignored]
  Branch (231:28): [True: 0, False: 0]
232
0
                        err.err_tip_with_code(|_| {
233
0
                            (Code::Aborted, "Could not update client keep alive")
234
0
                        })?;
235
                    }
236
0
                    _ => (),
237
                }
238
            }
239
            // If the polling shows that it's changed state then publish now.
240
3
            if let Some(
changed_action0
) = maybe_changed_action {
  Branch (240:20): [True: 0, False: 0]
  Branch (240:20): [Folded - Ignored]
  Branch (240:20): [True: 0, False: 3]
241
0
                self.maybe_last_stage =
242
0
                    Some(core::mem::discriminant(&changed_action.state().stage));
243
0
                return Ok(changed_action);
244
3
            }
245
            // Determine the sleep time based on the last client keep alive.
246
3
            let sleep_time = CLIENT_KEEPALIVE_DURATION
247
3
                .checked_sub(
248
3
                    I::from_secs(self.last_known_keepalive_ts.load(Ordering::Acquire)).elapsed(),
249
                )
250
3
                .unwrap_or(Duration::from_millis(100));
251
3
            tokio::select! {
252
3
                result = &mut changed_fut => {
253
3
                    result
?0
;
254
3
                    break;
255
                }
256
3
                () = (self.now_fn)().sleep(sleep_time) => {
257
0
                    // If we haven't received any updates for a while, we should
258
0
                    // let the database know that we are still listening to prevent
259
0
                    // the action from being dropped.  Also poll for updates if the
260
0
                    // subscription manager is unreliable.
261
0
                }
262
            }
263
        }
264
265
3
        let awaited_action = Self::inner_get_awaited_action(
266
3
            store.as_ref(),
267
3
            self.subscription_key.borrow(),
268
3
            self.maybe_client_operation_id.clone(),
269
3
            &self.last_known_keepalive_ts,
270
3
        )
271
3
        .await
272
3
        .err_tip(|| "In OperationSubscriber::changed")
?0
;
273
3
        if self.maybe_last_stage.is_some() {
  Branch (273:12): [True: 0, False: 0]
  Branch (273:12): [Folded - Ignored]
  Branch (273:12): [True: 3, False: 0]
274
3
            self.maybe_last_stage = Some(core::mem::discriminant(&awaited_action.state().stage));
275
3
        
}0
276
3
        Ok(awaited_action)
277
4
    }
278
279
27
    async fn borrow(&self) -> Result<AwaitedAction, Error> {
280
27
        self.get_awaited_action()
281
27
            .await
282
27
            .err_tip(|| "In OperationSubscriber::borrow")
283
27
    }
284
}
285
286
41
fn awaited_action_decode(version: i64, data: &Bytes) -> Result<AwaitedAction, Error> {
287
41
    let mut awaited_action: AwaitedAction = serde_json::from_slice(data)
288
41
        .map_err(|e| make_input_err!("In AwaitedAction::decode - {e:?}"))
?0
;
289
41
    awaited_action.set_version(version);
290
41
    Ok(awaited_action)
291
41
}
292
293
const OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX: &str = "aa_";
294
const CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX: &str = "cid_";
295
296
#[derive(Debug)]
297
struct OperationIdToAwaitedAction<'a>(Cow<'a, OperationId>);
298
impl OperationIdToAwaitedAction<'_> {
299
66
    fn borrow(&self) -> OperationIdToAwaitedAction<'_> {
300
66
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.as_ref()))
301
66
    }
302
}
303
impl SchedulerStoreKeyProvider for OperationIdToAwaitedAction<'_> {
304
    type Versioned = TrueValue;
305
56
    fn get_key(&self) -> StoreKey<'static> {
306
56
        StoreKey::Str(Cow::Owned(format!(
307
56
            "{OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX}{}",
308
56
            self.0
309
56
        )))
310
56
    }
311
}
312
impl SchedulerStoreDecodeTo for OperationIdToAwaitedAction<'_> {
313
    type DecodeOutput = AwaitedAction;
314
34
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
315
34
        awaited_action_decode(version, &data)
316
34
    }
317
}
318
319
struct ClientIdToOperationId<'a>(&'a OperationId);
320
impl SchedulerStoreKeyProvider for ClientIdToOperationId<'_> {
321
    type Versioned = FalseValue;
322
7
    fn get_key(&self) -> StoreKey<'static> {
323
7
        StoreKey::Str(Cow::Owned(format!(
324
7
            "{CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX}{}",
325
7
            self.0
326
7
        )))
327
7
    }
328
}
329
impl SchedulerStoreDecodeTo for ClientIdToOperationId<'_> {
330
    type DecodeOutput = OperationId;
331
3
    fn decode(_version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
332
3
        serde_json::from_slice(&data).map_err(|e| 
{0
333
0
            make_input_err!(
334
                "In ClientIdToOperationId::decode - {e:?} (data: {:02x?})",
335
                data
336
            )
337
0
        })
338
3
    }
339
}
340
341
// TODO(palfrey) We only need operation_id here, it would be nice if we had a way
342
// to tell the decoder we only care about specific fields.
343
struct SearchUniqueQualifierToAwaitedAction<'a>(&'a ActionUniqueQualifier);
344
impl SchedulerIndexProvider for SearchUniqueQualifierToAwaitedAction<'_> {
345
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
346
    const INDEX_NAME: &'static str = "unique_qualifier";
347
    type Versioned = TrueValue;
348
4
    fn index_value(&self) -> Cow<'_, str> {
349
4
        Cow::Owned(format!("{}", self.0))
350
4
    }
351
}
352
impl SchedulerStoreDecodeTo for SearchUniqueQualifierToAwaitedAction<'_> {
353
    type DecodeOutput = AwaitedAction;
354
2
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
355
2
        awaited_action_decode(version, &data)
356
2
    }
357
}
358
359
struct SearchStateToAwaitedAction(&'static str);
360
impl SchedulerIndexProvider for SearchStateToAwaitedAction {
361
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
362
    const INDEX_NAME: &'static str = "state";
363
    const MAYBE_SORT_KEY: Option<&'static str> = Some("sort_key");
364
    type Versioned = TrueValue;
365
17
    fn index_value(&self) -> Cow<'_, str> {
366
17
        Cow::Borrowed(self.0)
367
17
    }
368
}
369
impl SchedulerStoreDecodeTo for SearchStateToAwaitedAction {
370
    type DecodeOutput = AwaitedAction;
371
5
    fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
372
5
        awaited_action_decode(version, &data)
373
5
    }
374
}
375
376
32
const fn get_state_prefix(state: SortedAwaitedActionState) -> &'static str {
377
32
    match state {
378
0
        SortedAwaitedActionState::CacheCheck => "cache_check",
379
26
        SortedAwaitedActionState::Queued => "queued",
380
5
        SortedAwaitedActionState::Executing => "executing",
381
1
        SortedAwaitedActionState::Completed => "completed",
382
    }
383
32
}
384
385
struct UpdateOperationIdToAwaitedAction(AwaitedAction);
386
impl SchedulerCurrentVersionProvider for UpdateOperationIdToAwaitedAction {
387
15
    fn current_version(&self) -> i64 {
388
15
        self.0.version()
389
15
    }
390
}
391
impl SchedulerStoreKeyProvider for UpdateOperationIdToAwaitedAction {
392
    type Versioned = TrueValue;
393
15
    fn get_key(&self) -> StoreKey<'static> {
394
15
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.operation_id())).get_key()
395
15
    }
396
}
397
impl SchedulerStoreDataProvider for UpdateOperationIdToAwaitedAction {
398
15
    fn try_into_bytes(self) -> Result<Bytes, Error> {
399
15
        serde_json::to_string(&self.0)
400
15
            .map(Bytes::from)
401
15
            .map_err(|e| make_input_err!("Could not convert AwaitedAction to json - {e:?}"))
402
15
    }
403
15
    fn get_indexes(&self) -> Result<Vec<(&'static str, Bytes)>, Error> {
404
15
        let unique_qualifier = &self.0.action_info().unique_qualifier;
405
15
        let maybe_unique_qualifier = match &unique_qualifier {
406
15
            ActionUniqueQualifier::Cacheable(_) => Some(unique_qualifier),
407
0
            ActionUniqueQualifier::Uncacheable(_) => None,
408
        };
409
15
        let mut output = Vec::with_capacity(2 + maybe_unique_qualifier.map_or(0, |_| 1));
410
15
        if maybe_unique_qualifier.is_some() {
  Branch (410:12): [True: 15, False: 0]
  Branch (410:12): [Folded - Ignored]
411
15
            output.push((
412
15
                "unique_qualifier",
413
15
                Bytes::from(unique_qualifier.to_string()),
414
15
            ));
415
15
        
}0
416
        {
417
15
            let state = SortedAwaitedActionState::try_from(&self.0.state().stage)
418
15
                .err_tip(|| "In UpdateOperationIdToAwaitedAction::get_index")
?0
;
419
15
            output.push(("state", Bytes::from(get_state_prefix(state))));
420
15
            let sorted_awaited_action = SortedAwaitedAction::from(&self.0);
421
15
            output.push((
422
15
                "sort_key",
423
15
                // We encode to hex to ensure that the sort key is lexicographically sorted.
424
15
                Bytes::from(format!("{:016x}", sorted_awaited_action.sort_key.as_u64())),
425
15
            ));
426
        }
427
15
        Ok(output)
428
15
    }
429
}
430
431
struct UpdateClientIdToOperationId {
432
    client_operation_id: ClientOperationId,
433
    operation_id: OperationId,
434
}
435
impl SchedulerStoreKeyProvider for UpdateClientIdToOperationId {
436
    type Versioned = FalseValue;
437
4
    fn get_key(&self) -> StoreKey<'static> {
438
4
        ClientIdToOperationId(&self.client_operation_id).get_key()
439
4
    }
440
}
441
impl SchedulerStoreDataProvider for UpdateClientIdToOperationId {
442
4
    fn try_into_bytes(self) -> Result<Bytes, Error> {
443
4
        serde_json::to_string(&self.operation_id)
444
4
            .map(Bytes::from)
445
4
            .map_err(|e| make_input_err!("Could not convert OperationId to json - {e:?}"))
446
4
    }
447
}
448
449
11
async fn inner_update_awaited_action(
450
11
    store: &impl SchedulerStore,
451
11
    mut new_awaited_action: AwaitedAction,
452
11
) -> Result<(), Error> {
453
11
    let operation_id = new_awaited_action.operation_id().clone();
454
11
    if new_awaited_action.state().client_operation_id != operation_id {
  Branch (454:8): [True: 0, False: 0]
  Branch (454:8): [Folded - Ignored]
  Branch (454:8): [True: 0, False: 11]
455
0
        // Just in case the client_operation_id was set to something else
456
0
        // we put it back to the underlying operation_id.
457
0
        new_awaited_action.set_client_operation_id(operation_id.clone());
458
11
    }
459
11
    let maybe_version = store
460
11
        .update_data(UpdateOperationIdToAwaitedAction(new_awaited_action))
461
11
        .await
462
11
        .err_tip(|| "In RedisAwaitedActionDb::update_awaited_action")
?0
;
463
11
    if maybe_version.is_none() {
  Branch (463:8): [True: 0, False: 0]
  Branch (463:8): [Folded - Ignored]
  Branch (463:8): [True: 1, False: 10]
464
1
        tracing::warn!(
465
1
            "Could not update AwaitedAction because the version did not match for {operation_id}"
466
        );
467
1
        return Err(make_err!(
468
1
            Code::Aborted,
469
1
            "Could not update AwaitedAction because the version did not match for {operation_id}",
470
1
        ));
471
10
    }
472
10
    Ok(())
473
11
}
474
475
#[derive(Debug, MetricsComponent)]
476
pub struct StoreAwaitedActionDb<S, F, I, NowFn>
477
where
478
    S: SchedulerStore,
479
    F: Fn() -> OperationId,
480
    I: InstantWrapper,
481
    NowFn: Fn() -> I,
482
{
483
    store: Arc<S>,
484
    now_fn: NowFn,
485
    operation_id_creator: F,
486
    _pull_task_change_subscriber_spawn: JoinHandleDropGuard<()>,
487
}
488
489
impl<S, F, I, NowFn> StoreAwaitedActionDb<S, F, I, NowFn>
490
where
491
    S: SchedulerStore,
492
    F: Fn() -> OperationId,
493
    I: InstantWrapper,
494
    NowFn: Fn() -> I + Send + Sync + Clone + 'static,
495
{
496
4
    pub fn new(
497
4
        store: Arc<S>,
498
4
        task_change_publisher: Arc<Notify>,
499
4
        now_fn: NowFn,
500
4
        operation_id_creator: F,
501
4
    ) -> Result<Self, Error> {
502
4
        let mut subscription = store
503
4
            .subscription_manager()
504
4
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
505
4
            .subscribe(OperationIdToAwaitedAction(Cow::Owned(OperationId::String(
506
4
                String::new(),
507
4
            ))))
508
4
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
;
509
4
        let pull_task_change_subscriber = spawn!(
510
            "redis_awaited_action_db_pull_task_change_subscriber",
511
4
            async move {
512
                loop {
513
22
                    let 
changed_res18
= subscription
514
22
                        .changed()
515
22
                        .await
516
18
                        .err_tip(|| "In RedisAwaitedActionDb::new");
517
18
                    if let Err(
err0
) = changed_res {
  Branch (517:28): [True: 0, False: 0]
  Branch (517:28): [Folded - Ignored]
  Branch (517:28): [True: 0, False: 4]
  Branch (517:28): [True: 0, False: 1]
  Branch (517:28): [True: 0, False: 12]
  Branch (517:28): [True: 0, False: 1]
518
0
                        error!(
519
0
                            "Error waiting for pull task change subscriber in RedisAwaitedActionDb::new  - {err:?}"
520
                        );
521
                        // Sleep for a second to avoid a busy loop, then trigger the notify
522
                        // so if a reconnect happens we let local resources know that things
523
                        // might have changed.
524
0
                        tokio::time::sleep(Duration::from_secs(1)).await;
525
18
                    }
526
18
                    task_change_publisher.as_ref().notify_one();
527
                }
528
            }
529
        );
530
4
        Ok(Self {
531
4
            store,
532
4
            now_fn,
533
4
            operation_id_creator,
534
4
            _pull_task_change_subscriber_spawn: pull_task_change_subscriber,
535
4
        })
536
4
    }
537
538
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
539
4
    async fn try_subscribe(
540
4
        &self,
541
4
        client_operation_id: &ClientOperationId,
542
4
        unique_qualifier: &ActionUniqueQualifier,
543
4
        no_event_action_timeout: Duration,
544
4
        // TODO(palfrey) To simplify the scheduler 2024 refactor, we
545
4
        // removed the ability to upgrade priorities of actions.
546
4
        // we should add priority upgrades back in.
547
4
        _priority: i32,
548
4
    ) -> Result<Option<AwaitedAction>, Error> {
549
4
        match unique_qualifier {
550
4
            ActionUniqueQualifier::Cacheable(_) => {}
551
0
            ActionUniqueQualifier::Uncacheable(_) => return Ok(None),
552
        }
553
4
        let stream = self
554
4
            .store
555
4
            .search_by_index_prefix(SearchUniqueQualifierToAwaitedAction(unique_qualifier))
556
4
            .await
557
4
            .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
558
4
        tokio::pin!(stream);
559
4
        let maybe_awaited_action = stream
560
4
            .try_next()
561
4
            .await
562
4
            .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
563
4
        match maybe_awaited_action {
564
2
            Some(awaited_action) => {
565
                // TODO(palfrey) We don't support joining completed jobs because we
566
                // need to also check that all the data is still in the cache.
567
                // If the existing job failed then we need to set back to queued or we get
568
                // a version mismatch.  Equally we need to check the timeout as the job
569
                // may be abandoned in the store.
570
2
                let worker_should_update_before = (awaited_action.state().stage
571
2
                    == ActionStage::Executing)
572
2
                    .then_some(())
573
2
                    .map(|()| 
awaited_action0
.
last_worker_updated_timestamp0
())
574
2
                    .and_then(|last_worker_updated| 
{0
575
0
                        last_worker_updated.checked_add(no_event_action_timeout)
576
0
                    });
577
2
                let awaited_action = if awaited_action.state().stage.is_finished()
  Branch (577:41): [True: 0, False: 0]
  Branch (577:41): [Folded - Ignored]
  Branch (577:41): [True: 0, False: 0]
  Branch (577:41): [True: 1, False: 1]
578
1
                    || worker_should_update_before
  Branch (578:24): [True: 0, False: 0]
  Branch (578:24): [Folded - Ignored]
  Branch (578:24): [True: 0, False: 0]
  Branch (578:24): [True: 0, False: 1]
579
1
                        .is_some_and(|timestamp| 
timestamp0
<
(self.now_fn)().now()0
)
580
                {
581
1
                    tracing::debug!(
582
1
                        "Recreating action {:?} for operation {client_operation_id}",
583
1
                        awaited_action.action_info().digest()
584
                    );
585
                    // The version is reset because we have a new operation ID.
586
1
                    AwaitedAction::new(
587
1
                        (self.operation_id_creator)(),
588
1
                        awaited_action.action_info().clone(),
589
1
                        (self.now_fn)().now(),
590
                    )
591
                } else {
592
1
                    tracing::debug!(
593
1
                        "Subscribing to existing action {:?} for operation {client_operation_id}",
594
1
                        awaited_action.action_info().digest()
595
                    );
596
1
                    awaited_action
597
                };
598
2
                Ok(Some(awaited_action))
599
            }
600
2
            None => Ok(None),
601
        }
602
4
    }
603
604
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
605
3
    async fn inner_get_awaited_action_by_id(
606
3
        &self,
607
3
        client_operation_id: &ClientOperationId,
608
3
    ) -> Result<Option<OperationSubscriber<S, I, NowFn>>, Error> {
609
3
        let maybe_operation_id = self
610
3
            .store
611
3
            .get_and_decode(ClientIdToOperationId(client_operation_id))
612
3
            .await
613
3
            .err_tip(|| "In RedisAwaitedActionDb::get_awaited_action_by_id")
?0
;
614
3
        let Some(operation_id) = maybe_operation_id else {
  Branch (614:13): [True: 0, False: 0]
  Branch (614:13): [Folded - Ignored]
  Branch (614:13): [True: 1, False: 0]
  Branch (614:13): [True: 1, False: 0]
  Branch (614:13): [True: 1, False: 0]
615
0
            return Ok(None);
616
        };
617
618
        // Validate that the internal operation actually exists.
619
        // If it doesn't, this is an orphaned client operation mapping that should be cleaned up.
620
        // This can happen when an operation is deleted (completed/timed out) but the
621
        // client_id -> operation_id mapping persists in the store.
622
3
        let maybe_awaited_action = match self
623
3
            .store
624
3
            .get_and_decode(OperationIdToAwaitedAction(Cow::Borrowed(&operation_id)))
625
3
            .await
626
        {
627
2
            Ok(maybe_action) => maybe_action,
628
1
            Err(err) if err.code == Code::NotFound => {
  Branch (628:25): [True: 0, False: 0]
  Branch (628:25): [Folded - Ignored]
  Branch (628:25): [True: 0, False: 0]
  Branch (628:25): [True: 1, False: 0]
  Branch (628:25): [True: 0, False: 0]
629
1
                tracing::warn!(
630
1
                    "Orphaned client operation mapping detected: client_id={} maps to operation_id={}, \
631
1
                    but the operation does not exist in the store (NotFound). This typically happens when \
632
1
                    an operation completes or times out but the client mapping persists.",
633
                    client_operation_id,
634
                    operation_id
635
                );
636
1
                None
637
            }
638
0
            Err(err) => {
639
                // Some other error occurred
640
0
                return Err(err).err_tip(
641
                    || "In RedisAwaitedActionDb::get_awaited_action_by_id::validate_operation",
642
                );
643
            }
644
        };
645
646
3
        if maybe_awaited_action.is_none() {
  Branch (646:12): [True: 0, False: 0]
  Branch (646:12): [Folded - Ignored]
  Branch (646:12): [True: 0, False: 1]
  Branch (646:12): [True: 1, False: 0]
  Branch (646:12): [True: 0, False: 1]
647
1
            tracing::warn!(
648
1
                "Found orphaned client operation mapping: client_id={} -> operation_id={}, \
649
1
                but operation no longer exists. Returning None to prevent client from polling \
650
1
                a non-existent operation.",
651
                client_operation_id,
652
                operation_id
653
            );
654
1
            return Ok(None);
655
2
        }
656
657
2
        Ok(Some(OperationSubscriber::new(
658
2
            Some(client_operation_id.clone()),
659
2
            OperationIdToAwaitedAction(Cow::Owned(operation_id)),
660
2
            Arc::downgrade(&self.store),
661
2
            self.now_fn.clone(),
662
2
        )))
663
3
    }
664
}
665
666
impl<S, F, I, NowFn> AwaitedActionDb for StoreAwaitedActionDb<S, F, I, NowFn>
667
where
668
    S: SchedulerStore,
669
    F: Fn() -> OperationId + Send + Sync + Unpin + 'static,
670
    I: InstantWrapper,
671
    NowFn: Fn() -> I + Send + Sync + Unpin + Clone + 'static,
672
{
673
    type Subscriber = OperationSubscriber<S, I, NowFn>;
674
675
3
    async fn get_awaited_action_by_id(
676
3
        &self,
677
3
        client_operation_id: &ClientOperationId,
678
3
    ) -> Result<Option<Self::Subscriber>, Error> {
679
3
        self.inner_get_awaited_action_by_id(client_operation_id)
680
3
            .await
681
3
    }
682
683
8
    async fn get_by_operation_id(
684
8
        &self,
685
8
        operation_id: &OperationId,
686
8
    ) -> Result<Option<Self::Subscriber>, Error> {
687
8
        Ok(Some(OperationSubscriber::new(
688
8
            None,
689
8
            OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())),
690
8
            Arc::downgrade(&self.store),
691
8
            self.now_fn.clone(),
692
8
        )))
693
8
    }
694
695
11
    async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
696
11
        inner_update_awaited_action(self.store.as_ref(), new_awaited_action).await
697
11
    }
698
699
4
    async fn add_action(
700
4
        &self,
701
4
        client_operation_id: ClientOperationId,
702
4
        action_info: Arc<ActionInfo>,
703
4
        no_event_action_timeout: Duration,
704
4
    ) -> Result<Self::Subscriber, Error> {
705
        loop {
706
            // Check to see if the action is already known and subscribe if it is.
707
4
            let mut awaited_action = self
708
4
                .try_subscribe(
709
4
                    &client_operation_id,
710
4
                    &action_info.unique_qualifier,
711
4
                    no_event_action_timeout,
712
4
                    action_info.priority,
713
4
                )
714
4
                .await
715
4
                .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
716
4
                .unwrap_or_else(|| 
{2
717
2
                    tracing::debug!(
718
2
                        "Creating new action {:?} for operation {client_operation_id}",
719
2
                        action_info.digest()
720
                    );
721
2
                    AwaitedAction::new(
722
2
                        (self.operation_id_creator)(),
723
2
                        action_info.clone(),
724
2
                        (self.now_fn)().now(),
725
                    )
726
2
                });
727
728
4
            debug_assert!(
729
0
                ActionStage::Queued == awaited_action.state().stage,
730
0
                "Expected action to be queued"
731
            );
732
733
4
            let operation_id = awaited_action.operation_id().clone();
734
4
            if awaited_action.state().client_operation_id != operation_id {
  Branch (734:16): [True: 0, False: 0]
  Branch (734:16): [Folded - Ignored]
  Branch (734:16): [True: 0, False: 1]
  Branch (734:16): [True: 0, False: 3]
735
0
                // Just in case the client_operation_id was set to something else
736
0
                // we put it back to the underlying operation_id.
737
0
                awaited_action.set_client_operation_id(operation_id.clone());
738
4
            }
739
4
            awaited_action.update_client_keep_alive((self.now_fn)().now());
740
741
4
            let version = awaited_action.version();
742
4
            if self
  Branch (742:16): [True: 0, False: 0]
  Branch (742:16): [Folded - Ignored]
  Branch (742:16): [True: 0, False: 1]
  Branch (742:16): [True: 0, False: 3]
743
4
                .store
744
4
                .update_data(UpdateOperationIdToAwaitedAction(awaited_action))
745
4
                .await
746
4
                .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
747
4
                .is_none()
748
            {
749
                // The version was out of date, try again.
750
0
                tracing::info!(
751
0
                    "Version out of date for {:?} {operation_id} {version}, retrying.",
752
0
                    action_info.digest()
753
                );
754
0
                continue;
755
4
            }
756
757
            // Add the client_operation_id to operation_id mapping
758
4
            self.store
759
4
                .update_data(UpdateClientIdToOperationId {
760
4
                    client_operation_id: client_operation_id.clone(),
761
4
                    operation_id: operation_id.clone(),
762
4
                })
763
4
                .await
764
4
                .err_tip(|| "In RedisAwaitedActionDb::add_action while adding client mapping")
?0
;
765
766
4
            return Ok(OperationSubscriber::new(
767
4
                Some(client_operation_id),
768
4
                OperationIdToAwaitedAction(Cow::Owned(operation_id)),
769
4
                Arc::downgrade(&self.store),
770
4
                self.now_fn.clone(),
771
4
            ));
772
        }
773
4
    }
774
775
17
    async fn get_range_of_actions(
776
17
        &self,
777
17
        state: SortedAwaitedActionState,
778
17
        start: Bound<SortedAwaitedAction>,
779
17
        end: Bound<SortedAwaitedAction>,
780
17
        desc: bool,
781
17
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
782
17
        if !
matches!0
(start, Bound::Unbounded) {
  Branch (782:12): [True: 0, False: 0]
  Branch (782:12): [Folded - Ignored]
  Branch (782:12): [True: 0, False: 17]
783
0
            return Err(make_err!(
784
0
                Code::Unimplemented,
785
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
786
0
            ));
787
17
        }
788
17
        if !
matches!0
(end, Bound::Unbounded) {
  Branch (788:12): [True: 0, False: 0]
  Branch (788:12): [Folded - Ignored]
  Branch (788:12): [True: 0, False: 17]
789
0
            return Err(make_err!(
790
0
                Code::Unimplemented,
791
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
792
0
            ));
793
17
        }
794
        // TODO(palfrey) This API is not difficult to implement, but there is no code path
795
        // that uses it, so no reason to implement it yet.
796
17
        if !desc {
  Branch (796:12): [True: 0, False: 0]
  Branch (796:12): [Folded - Ignored]
  Branch (796:12): [True: 0, False: 17]
797
0
            return Err(make_err!(
798
0
                Code::Unimplemented,
799
0
                "Descending order is not supported in RedisAwaitedActionDb::get_range_of_actions",
800
0
            ));
801
17
        }
802
17
        Ok(self
803
17
            .store
804
17
            .search_by_index_prefix(SearchStateToAwaitedAction(get_state_prefix(state)))
805
17
            .await
806
17
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")
?0
807
17
            .map_ok(move |awaited_action| 
{5
808
5
                OperationSubscriber::new(
809
5
                    None,
810
5
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
811
5
                    Arc::downgrade(&self.store),
812
5
                    self.now_fn.clone(),
813
                )
814
5
            }))
815
17
    }
816
817
0
    async fn get_all_awaited_actions(
818
0
        &self,
819
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
820
0
        Ok(self
821
0
            .store
822
0
            .search_by_index_prefix(SearchStateToAwaitedAction(""))
823
0
            .await
824
0
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")?
825
0
            .map_ok(move |awaited_action| {
826
0
                OperationSubscriber::new(
827
0
                    None,
828
0
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
829
0
                    Arc::downgrade(&self.store),
830
0
                    self.now_fn.clone(),
831
                )
832
0
            }))
833
0
    }
834
}