Coverage Report

Created: 2025-06-24 08:57

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/store_awaited_action_db.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::ops::Bound;
16
use core::sync::atomic::{AtomicU64, Ordering};
17
use core::time::Duration;
18
use std::borrow::Cow;
19
use std::sync::{Arc, Weak};
20
21
use bytes::Bytes;
22
use futures::{Stream, TryStreamExt};
23
use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err};
24
use nativelink_metric::MetricsComponent;
25
use nativelink_util::action_messages::{
26
    ActionInfo, ActionStage, ActionUniqueQualifier, OperationId,
27
};
28
use nativelink_util::instant_wrapper::InstantWrapper;
29
use nativelink_util::spawn;
30
use nativelink_util::store_trait::{
31
    FalseValue, SchedulerCurrentVersionProvider, SchedulerIndexProvider, SchedulerStore,
32
    SchedulerStoreDataProvider, SchedulerStoreDecodeTo, SchedulerStoreKeyProvider,
33
    SchedulerSubscription, SchedulerSubscriptionManager, StoreKey, TrueValue,
34
};
35
use nativelink_util::task::JoinHandleDropGuard;
36
use tokio::sync::Notify;
37
use tracing::{error, warn};
38
39
use crate::awaited_action_db::{
40
    AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, CLIENT_KEEPALIVE_DURATION,
41
    SortedAwaitedAction, SortedAwaitedActionState,
42
};
43
44
type ClientOperationId = OperationId;
45
46
/// Maximum number of retries to update client keep alive.
47
const MAX_RETRIES_FOR_CLIENT_KEEPALIVE: u32 = 8;
48
49
enum OperationSubscriberState<Sub> {
50
    Unsubscribed,
51
    Subscribed(Sub),
52
}
53
54
pub struct OperationSubscriber<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I> {
55
    maybe_client_operation_id: Option<ClientOperationId>,
56
    subscription_key: OperationIdToAwaitedAction<'static>,
57
    weak_store: Weak<S>,
58
    state: OperationSubscriberState<
59
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
60
    >,
61
    last_known_keepalive_ts: AtomicU64,
62
    now_fn: NowFn,
63
}
64
65
impl<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I + core::fmt::Debug> core::fmt::Debug
66
    for OperationSubscriber<S, I, NowFn>
67
where
68
    OperationSubscriberState<
69
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
70
    >: core::fmt::Debug,
71
{
72
0
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
73
0
        f.debug_struct("OperationSubscriber")
74
0
            .field("maybe_client_operation_id", &self.maybe_client_operation_id)
75
0
            .field("subscription_key", &self.subscription_key)
76
0
            .field("weak_store", &self.weak_store)
77
0
            .field("state", &self.state)
78
0
            .field("last_known_keepalive_ts", &self.last_known_keepalive_ts)
79
0
            .field("now_fn", &self.now_fn)
80
0
            .finish()
81
0
    }
82
}
83
impl<S, I, NowFn> OperationSubscriber<S, I, NowFn>
84
where
85
    S: SchedulerStore,
86
    I: InstantWrapper,
87
    NowFn: Fn() -> I,
88
{
89
1
    const fn new(
90
1
        maybe_client_operation_id: Option<ClientOperationId>,
91
1
        subscription_key: OperationIdToAwaitedAction<'static>,
92
1
        weak_store: Weak<S>,
93
1
        now_fn: NowFn,
94
1
    ) -> Self {
95
1
        Self {
96
1
            maybe_client_operation_id,
97
1
            subscription_key,
98
1
            weak_store,
99
1
            last_known_keepalive_ts: AtomicU64::new(0),
100
1
            state: OperationSubscriberState::Unsubscribed,
101
1
            now_fn,
102
1
        }
103
1
    }
104
105
2
    async fn inner_get_awaited_action(
106
2
        store: &S,
107
2
        key: OperationIdToAwaitedAction<'_>,
108
2
        maybe_client_operation_id: Option<ClientOperationId>,
109
2
        last_known_keepalive_ts: &AtomicU64,
110
2
    ) -> Result<AwaitedAction, Error> {
111
2
        let mut awaited_action = store
112
2
            .get_and_decode(key.borrow())
113
2
            .await
114
2
            .err_tip(|| format!(
"In OperationSubscriber::get_awaited_action {key:?}"0
))
?0
115
2
            .ok_or_else(|| 
{0
116
0
                make_err!(
117
0
                    Code::NotFound,
118
                    "Could not find AwaitedAction for the given operation id {key:?}",
119
                )
120
0
            })?;
121
2
        if let Some(client_operation_id) = maybe_client_operation_id {
  Branch (121:16): [True: 0, False: 0]
  Branch (121:16): [Folded - Ignored]
  Branch (121:16): [True: 2, False: 0]
122
2
            awaited_action.set_client_operation_id(client_operation_id);
123
2
        
}0
124
2
        last_known_keepalive_ts.store(
125
2
            awaited_action
126
2
                .last_client_keepalive_timestamp()
127
2
                .unix_timestamp(),
128
2
            Ordering::Release,
129
        );
130
2
        Ok(awaited_action)
131
2
    }
132
133
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
134
0
    async fn get_awaited_action(&self) -> Result<AwaitedAction, Error> {
135
0
        let store = self
136
0
            .weak_store
137
0
            .upgrade()
138
0
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")?;
139
0
        Self::inner_get_awaited_action(
140
0
            store.as_ref(),
141
0
            self.subscription_key.borrow(),
142
0
            self.maybe_client_operation_id.clone(),
143
0
            &self.last_known_keepalive_ts,
144
0
        )
145
0
        .await
146
0
    }
147
}
148
149
impl<S, I, NowFn> AwaitedActionSubscriber for OperationSubscriber<S, I, NowFn>
150
where
151
    S: SchedulerStore,
152
    I: InstantWrapper,
153
    NowFn: Fn() -> I + Send + Sync + 'static,
154
{
155
2
    async fn changed(&mut self) -> Result<AwaitedAction, Error> {
156
2
        let store = self
157
2
            .weak_store
158
2
            .upgrade()
159
2
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")
?0
;
160
2
        let subscription = match &mut self.state {
161
            OperationSubscriberState::Unsubscribed => {
162
1
                let subscription = store
163
1
                    .subscription_manager()
164
1
                    .err_tip(|| "In OperationSubscriber::changed::subscription_manager")
?0
165
1
                    .subscribe(self.subscription_key.borrow())
166
1
                    .err_tip(|| "In OperationSubscriber::changed::subscribe")
?0
;
167
1
                self.state = OperationSubscriberState::Subscribed(subscription);
168
1
                let OperationSubscriberState::Subscribed(subscription) = &mut self.state else {
  Branch (168:21): [True: 0, False: 0]
  Branch (168:21): [Folded - Ignored]
  Branch (168:21): [True: 1, False: 0]
169
0
                    unreachable!("Subscription should be in Subscribed state");
170
                };
171
1
                subscription
172
            }
173
1
            OperationSubscriberState::Subscribed(subscription) => subscription,
174
        };
175
176
2
        let changed_fut = subscription.changed();
177
2
        tokio::pin!(changed_fut);
178
        loop {
179
2
            let mut retries = 0;
180
            loop {
181
2
                let last_known_keepalive_ts = self.last_known_keepalive_ts.load(Ordering::Acquire);
182
2
                if I::from_secs(last_known_keepalive_ts).elapsed() <= CLIENT_KEEPALIVE_DURATION {
  Branch (182:20): [True: 0, False: 0]
  Branch (182:20): [Folded - Ignored]
  Branch (182:20): [True: 2, False: 0]
183
2
                    break; // We are still within the keep alive duration.
184
0
                }
185
0
                if retries > MAX_RETRIES_FOR_CLIENT_KEEPALIVE {
  Branch (185:20): [True: 0, False: 0]
  Branch (185:20): [Folded - Ignored]
  Branch (185:20): [True: 0, False: 0]
186
0
                    return Err(make_err!(
187
0
                        Code::Aborted,
188
0
                        "Could not update client keep alive for AwaitedAction",
189
0
                    ));
190
0
                }
191
0
                let mut awaited_action = Self::inner_get_awaited_action(
192
0
                    store.as_ref(),
193
0
                    self.subscription_key.borrow(),
194
0
                    self.maybe_client_operation_id.clone(),
195
0
                    &self.last_known_keepalive_ts,
196
0
                )
197
0
                .await
198
0
                .err_tip(|| "In OperationSubscriber::changed")?;
199
0
                awaited_action.update_client_keep_alive((self.now_fn)().now());
200
0
                let update_res = inner_update_awaited_action(store.as_ref(), awaited_action)
201
0
                    .await
202
0
                    .err_tip(|| "In OperationSubscriber::changed");
203
0
                if update_res.is_ok() {
  Branch (203:20): [True: 0, False: 0]
  Branch (203:20): [Folded - Ignored]
  Branch (203:20): [True: 0, False: 0]
204
0
                    break;
205
0
                }
206
0
                retries += 1;
207
                // Wait a tick before retrying.
208
0
                (self.now_fn)().sleep(Duration::from_millis(100)).await;
209
            }
210
2
            let sleep_fut = (self.now_fn)().sleep(CLIENT_KEEPALIVE_DURATION);
211
2
            tokio::select! {
212
2
                result = &mut changed_fut => {
213
2
                    result
?0
;
214
2
                    break;
215
                }
216
2
                () = sleep_fut => {
217
0
                    // If we haven't received any updates for a while, we should
218
0
                    // let the database know that we are still listening to prevent
219
0
                    // the action from being dropped.
220
0
                }
221
            }
222
        }
223
224
2
        Self::inner_get_awaited_action(
225
2
            store.as_ref(),
226
2
            self.subscription_key.borrow(),
227
2
            self.maybe_client_operation_id.clone(),
228
2
            &self.last_known_keepalive_ts,
229
2
        )
230
2
        .await
231
2
        .err_tip(|| "In OperationSubscriber::changed")
232
2
    }
233
234
0
    async fn borrow(&self) -> Result<AwaitedAction, Error> {
235
0
        self.get_awaited_action()
236
0
            .await
237
0
            .err_tip(|| "In OperationSubscriber::borrow")
238
0
    }
239
}
240
241
2
fn awaited_action_decode(version: u64, data: &Bytes) -> Result<AwaitedAction, Error> {
242
2
    let mut awaited_action: AwaitedAction = serde_json::from_slice(data)
243
2
        .map_err(|e| make_input_err!("In AwaitedAction::decode - {e:?}"))
?0
;
244
2
    awaited_action.set_version(version);
245
2
    Ok(awaited_action)
246
2
}
247
248
const OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX: &str = "aa_";
249
const CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX: &str = "cid_";
250
251
#[derive(Debug)]
252
struct OperationIdToAwaitedAction<'a>(Cow<'a, OperationId>);
253
impl OperationIdToAwaitedAction<'_> {
254
5
    fn borrow(&self) -> OperationIdToAwaitedAction<'_> {
255
5
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.as_ref()))
256
5
    }
257
}
258
impl SchedulerStoreKeyProvider for OperationIdToAwaitedAction<'_> {
259
    type Versioned = TrueValue;
260
6
    fn get_key(&self) -> StoreKey<'static> {
261
6
        StoreKey::Str(Cow::Owned(format!(
262
6
            "{OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX}{}",
263
6
            self.0
264
6
        )))
265
6
    }
266
}
267
impl SchedulerStoreDecodeTo for OperationIdToAwaitedAction<'_> {
268
    type DecodeOutput = AwaitedAction;
269
2
    fn decode(version: u64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
270
2
        awaited_action_decode(version, &data)
271
2
    }
272
}
273
274
struct ClientIdToOperationId<'a>(&'a OperationId);
275
impl SchedulerStoreKeyProvider for ClientIdToOperationId<'_> {
276
    type Versioned = FalseValue;
277
1
    fn get_key(&self) -> StoreKey<'static> {
278
1
        StoreKey::Str(Cow::Owned(format!(
279
1
            "{CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX}{}",
280
1
            self.0
281
1
        )))
282
1
    }
283
}
284
impl SchedulerStoreDecodeTo for ClientIdToOperationId<'_> {
285
    type DecodeOutput = OperationId;
286
0
    fn decode(_version: u64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
287
0
        OperationId::try_from(data).err_tip(|| "In ClientIdToOperationId::decode")
288
0
    }
289
}
290
291
// TODO(aaronmondal) We only need operation_id here, it would be nice if we had a way
292
// to tell the decoder we only care about specific fields.
293
struct SearchUniqueQualifierToAwaitedAction<'a>(&'a ActionUniqueQualifier);
294
impl SchedulerIndexProvider for SearchUniqueQualifierToAwaitedAction<'_> {
295
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
296
    const INDEX_NAME: &'static str = "unique_qualifier";
297
    type Versioned = TrueValue;
298
1
    fn index_value(&self) -> Cow<'_, str> {
299
1
        Cow::Owned(format!("{}", self.0))
300
1
    }
301
}
302
impl SchedulerStoreDecodeTo for SearchUniqueQualifierToAwaitedAction<'_> {
303
    type DecodeOutput = AwaitedAction;
304
0
    fn decode(version: u64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
305
0
        awaited_action_decode(version, &data)
306
0
    }
307
}
308
309
struct SearchStateToAwaitedAction(&'static str);
310
impl SchedulerIndexProvider for SearchStateToAwaitedAction {
311
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
312
    const INDEX_NAME: &'static str = "state";
313
    const MAYBE_SORT_KEY: Option<&'static str> = Some("sort_key");
314
    type Versioned = TrueValue;
315
0
    fn index_value(&self) -> Cow<'_, str> {
316
0
        Cow::Borrowed(self.0)
317
0
    }
318
}
319
impl SchedulerStoreDecodeTo for SearchStateToAwaitedAction {
320
    type DecodeOutput = AwaitedAction;
321
0
    fn decode(version: u64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
322
0
        awaited_action_decode(version, &data)
323
0
    }
324
}
325
326
2
const fn get_state_prefix(state: SortedAwaitedActionState) -> &'static str {
327
2
    match state {
328
0
        SortedAwaitedActionState::CacheCheck => "cache_check",
329
1
        SortedAwaitedActionState::Queued => "queued",
330
1
        SortedAwaitedActionState::Executing => "executing",
331
0
        SortedAwaitedActionState::Completed => "completed",
332
    }
333
2
}
334
335
struct UpdateOperationIdToAwaitedAction(AwaitedAction);
336
impl SchedulerCurrentVersionProvider for UpdateOperationIdToAwaitedAction {
337
2
    fn current_version(&self) -> u64 {
338
2
        self.0.version()
339
2
    }
340
}
341
impl SchedulerStoreKeyProvider for UpdateOperationIdToAwaitedAction {
342
    type Versioned = TrueValue;
343
2
    fn get_key(&self) -> StoreKey<'static> {
344
2
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.operation_id())).get_key()
345
2
    }
346
}
347
impl SchedulerStoreDataProvider for UpdateOperationIdToAwaitedAction {
348
2
    fn try_into_bytes(self) -> Result<Bytes, Error> {
349
2
        serde_json::to_string(&self.0)
350
2
            .map(Bytes::from)
351
2
            .map_err(|e| make_input_err!("Could not convert AwaitedAction to json - {e:?}"))
352
2
    }
353
2
    fn get_indexes(&self) -> Result<Vec<(&'static str, Bytes)>, Error> {
354
2
        let unique_qualifier = &self.0.action_info().unique_qualifier;
355
2
        let maybe_unique_qualifier = match &unique_qualifier {
356
2
            ActionUniqueQualifier::Cacheable(_) => Some(unique_qualifier),
357
0
            ActionUniqueQualifier::Uncacheable(_) => None,
358
        };
359
2
        let mut output = Vec::with_capacity(2 + maybe_unique_qualifier.map_or(0, |_| 1));
360
2
        if maybe_unique_qualifier.is_some() {
  Branch (360:12): [True: 2, False: 0]
  Branch (360:12): [Folded - Ignored]
361
2
            output.push((
362
2
                "unique_qualifier",
363
2
                Bytes::from(unique_qualifier.to_string()),
364
2
            ));
365
2
        
}0
366
        {
367
2
            let state = SortedAwaitedActionState::try_from(&self.0.state().stage)
368
2
                .err_tip(|| "In UpdateOperationIdToAwaitedAction::get_index")
?0
;
369
2
            output.push(("state", Bytes::from(get_state_prefix(state))));
370
2
            let sorted_awaited_action = SortedAwaitedAction::from(&self.0);
371
2
            output.push((
372
2
                "sort_key",
373
2
                // We encode to hex to ensure that the sort key is lexicographically sorted.
374
2
                Bytes::from(format!("{:016x}", sorted_awaited_action.sort_key.as_u64())),
375
2
            ));
376
        }
377
2
        Ok(output)
378
2
    }
379
}
380
381
struct UpdateClientIdToOperationId {
382
    client_operation_id: ClientOperationId,
383
    operation_id: OperationId,
384
}
385
impl SchedulerStoreKeyProvider for UpdateClientIdToOperationId {
386
    type Versioned = FalseValue;
387
1
    fn get_key(&self) -> StoreKey<'static> {
388
1
        ClientIdToOperationId(&self.client_operation_id).get_key()
389
1
    }
390
}
391
impl SchedulerStoreDataProvider for UpdateClientIdToOperationId {
392
1
    fn try_into_bytes(self) -> Result<Bytes, Error> {
393
1
        serde_json::to_string(&self.operation_id)
394
1
            .map(Bytes::from)
395
1
            .map_err(|e| make_input_err!("Could not convert OperationId to json - {e:?}"))
396
1
    }
397
}
398
399
1
async fn inner_update_awaited_action(
400
1
    store: &impl SchedulerStore,
401
1
    mut new_awaited_action: AwaitedAction,
402
1
) -> Result<(), Error> {
403
1
    let operation_id = new_awaited_action.operation_id().clone();
404
1
    if new_awaited_action.state().client_operation_id != operation_id {
  Branch (404:8): [True: 0, False: 0]
  Branch (404:8): [Folded - Ignored]
  Branch (404:8): [True: 0, False: 1]
405
0
        // Just in case the client_operation_id was set to something else
406
0
        // we put it back to the underlying operation_id.
407
0
        new_awaited_action.set_client_operation_id(operation_id.clone());
408
1
    }
409
1
    let maybe_version = store
410
1
        .update_data(UpdateOperationIdToAwaitedAction(new_awaited_action))
411
1
        .await
412
1
        .err_tip(|| "In RedisAwaitedActionDb::update_awaited_action")
?0
;
413
1
    if maybe_version.is_none() {
  Branch (413:8): [True: 0, False: 0]
  Branch (413:8): [Folded - Ignored]
  Branch (413:8): [True: 0, False: 1]
414
0
        return Err(make_err!(
415
0
            Code::Aborted,
416
0
            "Could not update AwaitedAction because the version did not match for {operation_id:?}",
417
0
        ));
418
1
    }
419
1
    Ok(())
420
1
}
421
422
#[derive(Debug, MetricsComponent)]
423
pub struct StoreAwaitedActionDb<S, F, I, NowFn>
424
where
425
    S: SchedulerStore,
426
    F: Fn() -> OperationId,
427
    I: InstantWrapper,
428
    NowFn: Fn() -> I,
429
{
430
    store: Arc<S>,
431
    now_fn: NowFn,
432
    operation_id_creator: F,
433
    _pull_task_change_subscriber_spawn: JoinHandleDropGuard<()>,
434
}
435
436
impl<S, F, I, NowFn> StoreAwaitedActionDb<S, F, I, NowFn>
437
where
438
    S: SchedulerStore,
439
    F: Fn() -> OperationId,
440
    I: InstantWrapper,
441
    NowFn: Fn() -> I + Send + Sync + Clone + 'static,
442
{
443
1
    pub fn new(
444
1
        store: Arc<S>,
445
1
        task_change_publisher: Arc<Notify>,
446
1
        now_fn: NowFn,
447
1
        operation_id_creator: F,
448
1
    ) -> Result<Self, Error> {
449
1
        let mut subscription = store
450
1
            .subscription_manager()
451
1
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
452
1
            .subscribe(OperationIdToAwaitedAction(Cow::Owned(OperationId::String(
453
1
                String::new(),
454
1
            ))))
455
1
            .err_tip(|| "In RedisAwaitedActionDb::new")
?0
;
456
1
        let pull_task_change_subscriber = spawn!(
457
            "redis_awaited_action_db_pull_task_change_subscriber",
458
1
            async move {
459
                loop {
460
5
                    let 
changed_res4
= subscription
461
5
                        .changed()
462
5
                        .await
463
4
                        .err_tip(|| "In RedisAwaitedActionDb::new");
464
4
                    if let Err(
err0
) = changed_res {
  Branch (464:28): [True: 0, False: 0]
  Branch (464:28): [Folded - Ignored]
  Branch (464:28): [True: 0, False: 4]
465
0
                        error!(
466
0
                            "Error waiting for pull task change subscriber in RedisAwaitedActionDb::new  - {err:?}"
467
                        );
468
                        // Sleep for a second to avoid a busy loop, then trigger the notify
469
                        // so if a reconnect happens we let local resources know that things
470
                        // might have changed.
471
0
                        tokio::time::sleep(Duration::from_secs(1)).await;
472
4
                    }
473
4
                    task_change_publisher.as_ref().notify_one();
474
                }
475
            }
476
        );
477
1
        Ok(Self {
478
1
            store,
479
1
            now_fn,
480
1
            operation_id_creator,
481
1
            _pull_task_change_subscriber_spawn: pull_task_change_subscriber,
482
1
        })
483
1
    }
484
485
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
486
1
    async fn try_subscribe(
487
1
        &self,
488
1
        client_operation_id: &ClientOperationId,
489
1
        unique_qualifier: &ActionUniqueQualifier,
490
1
        // TODO(aaronmondal) To simplify the scheduler 2024 refactor, we
491
1
        // removed the ability to upgrade priorities of actions.
492
1
        // we should add priority upgrades back in.
493
1
        _priority: i32,
494
1
    ) -> Result<Option<OperationSubscriber<S, I, NowFn>>, Error> {
495
1
        match unique_qualifier {
496
1
            ActionUniqueQualifier::Cacheable(_) => {}
497
0
            ActionUniqueQualifier::Uncacheable(_) => return Ok(None),
498
        }
499
1
        let stream = self
500
1
            .store
501
1
            .search_by_index_prefix(SearchUniqueQualifierToAwaitedAction(unique_qualifier))
502
1
            .await
503
1
            .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
504
1
        tokio::pin!(stream);
505
1
        let maybe_awaited_action = stream
506
1
            .try_next()
507
1
            .await
508
1
            .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")
?0
;
509
1
        match maybe_awaited_action {
510
0
            Some(mut awaited_action) => {
511
                // TODO(aaronmondal) We don't support joining completed jobs because we
512
                // need to also check that all the data is still in the cache.
513
0
                if awaited_action.state().stage.is_finished() {
  Branch (513:20): [True: 0, False: 0]
  Branch (513:20): [Folded - Ignored]
  Branch (513:20): [True: 0, False: 0]
514
0
                    return Ok(None);
515
0
                }
516
                // TODO(aaronmondal) We only care about the operation_id here, we should
517
                // have a way to tell the decoder we only care about specific fields.
518
0
                let operation_id = awaited_action.operation_id().clone();
519
520
0
                awaited_action.update_client_keep_alive((self.now_fn)().now());
521
0
                let update_res = inner_update_awaited_action(self.store.as_ref(), awaited_action)
522
0
                    .await
523
0
                    .err_tip(|| "In OperationSubscriber::changed");
524
0
                if let Err(err) = update_res {
  Branch (524:24): [True: 0, False: 0]
  Branch (524:24): [Folded - Ignored]
  Branch (524:24): [True: 0, False: 0]
525
0
                    warn!(
526
0
                        "Error updating client keep alive in RedisAwaitedActionDb::try_subscribe - {err:?} - This is not a critical error, but we did decide to create a new action instead of joining an existing one."
527
                    );
528
0
                    return Ok(None);
529
0
                }
530
531
0
                Ok(Some(OperationSubscriber::new(
532
0
                    Some(client_operation_id.clone()),
533
0
                    OperationIdToAwaitedAction(Cow::Owned(operation_id)),
534
0
                    Arc::downgrade(&self.store),
535
0
                    self.now_fn.clone(),
536
0
                )))
537
            }
538
1
            None => Ok(None),
539
        }
540
1
    }
541
542
    #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this
543
0
    async fn inner_get_awaited_action_by_id(
544
0
        &self,
545
0
        client_operation_id: &ClientOperationId,
546
0
    ) -> Result<Option<OperationSubscriber<S, I, NowFn>>, Error> {
547
0
        let maybe_operation_id = self
548
0
            .store
549
0
            .get_and_decode(ClientIdToOperationId(client_operation_id))
550
0
            .await
551
0
            .err_tip(|| "In RedisAwaitedActionDb::get_awaited_action_by_id")?;
552
0
        let Some(operation_id) = maybe_operation_id else {
  Branch (552:13): [True: 0, False: 0]
  Branch (552:13): [Folded - Ignored]
553
0
            return Ok(None);
554
        };
555
0
        Ok(Some(OperationSubscriber::new(
556
0
            Some(client_operation_id.clone()),
557
0
            OperationIdToAwaitedAction(Cow::Owned(operation_id)),
558
0
            Arc::downgrade(&self.store),
559
0
            self.now_fn.clone(),
560
0
        )))
561
0
    }
562
}
563
564
impl<S, F, I, NowFn> AwaitedActionDb for StoreAwaitedActionDb<S, F, I, NowFn>
565
where
566
    S: SchedulerStore,
567
    F: Fn() -> OperationId + Send + Sync + Unpin + 'static,
568
    I: InstantWrapper,
569
    NowFn: Fn() -> I + Send + Sync + Unpin + Clone + 'static,
570
{
571
    type Subscriber = OperationSubscriber<S, I, NowFn>;
572
573
0
    async fn get_awaited_action_by_id(
574
0
        &self,
575
0
        client_operation_id: &ClientOperationId,
576
0
    ) -> Result<Option<Self::Subscriber>, Error> {
577
0
        self.inner_get_awaited_action_by_id(client_operation_id)
578
0
            .await
579
0
    }
580
581
0
    async fn get_by_operation_id(
582
0
        &self,
583
0
        operation_id: &OperationId,
584
0
    ) -> Result<Option<Self::Subscriber>, Error> {
585
0
        Ok(Some(OperationSubscriber::new(
586
0
            None,
587
0
            OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())),
588
0
            Arc::downgrade(&self.store),
589
0
            self.now_fn.clone(),
590
0
        )))
591
0
    }
592
593
1
    async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
594
1
        inner_update_awaited_action(self.store.as_ref(), new_awaited_action).await
595
1
    }
596
597
1
    async fn add_action(
598
1
        &self,
599
1
        client_operation_id: ClientOperationId,
600
1
        action_info: Arc<ActionInfo>,
601
1
    ) -> Result<Self::Subscriber, Error> {
602
        // Check to see if the action is already known and subscribe if it is.
603
1
        let subscription = self
604
1
            .try_subscribe(
605
1
                &client_operation_id,
606
1
                &action_info.unique_qualifier,
607
1
                action_info.priority,
608
1
            )
609
1
            .await
610
1
            .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
;
611
1
        if let Some(
sub0
) = subscription {
  Branch (611:16): [True: 0, False: 0]
  Branch (611:16): [Folded - Ignored]
  Branch (611:16): [True: 0, False: 1]
612
0
            return Ok(sub);
613
1
        }
614
615
1
        let new_operation_id = (self.operation_id_creator)();
616
1
        let awaited_action =
617
1
            AwaitedAction::new(new_operation_id.clone(), action_info, (self.now_fn)().now());
618
1
        debug_assert!(
619
0
            ActionStage::Queued == awaited_action.state().stage,
620
0
            "Expected action to be queued"
621
        );
622
623
        // Note: Version is not needed with this api.
624
1
        let _version = self
625
1
            .store
626
1
            .update_data(UpdateOperationIdToAwaitedAction(awaited_action))
627
1
            .await
628
1
            .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
629
1
            .err_tip(
630
                || "Version match failed for new action insert in RedisAwaitedActionDb::add_action",
631
0
            )?;
632
633
1
        self.store
634
1
            .update_data(UpdateClientIdToOperationId {
635
1
                client_operation_id: client_operation_id.clone(),
636
1
                operation_id: new_operation_id.clone(),
637
1
            })
638
1
            .await
639
1
            .err_tip(|| "In RedisAwaitedActionDb::add_action")
?0
;
640
641
1
        Ok(OperationSubscriber::new(
642
1
            Some(client_operation_id),
643
1
            OperationIdToAwaitedAction(Cow::Owned(new_operation_id)),
644
1
            Arc::downgrade(&self.store),
645
1
            self.now_fn.clone(),
646
1
        ))
647
1
    }
648
649
0
    async fn get_range_of_actions(
650
0
        &self,
651
0
        state: SortedAwaitedActionState,
652
0
        start: Bound<SortedAwaitedAction>,
653
0
        end: Bound<SortedAwaitedAction>,
654
0
        desc: bool,
655
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
656
0
        if !matches!(start, Bound::Unbounded) {
  Branch (656:12): [True: 0, False: 0]
  Branch (656:12): [Folded - Ignored]
657
0
            return Err(make_err!(
658
0
                Code::Unimplemented,
659
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
660
0
            ));
661
0
        }
662
0
        if !matches!(end, Bound::Unbounded) {
  Branch (662:12): [True: 0, False: 0]
  Branch (662:12): [Folded - Ignored]
663
0
            return Err(make_err!(
664
0
                Code::Unimplemented,
665
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
666
0
            ));
667
0
        }
668
        // TODO(aaronmondal) This API is not difficult to implement, but there is no code path
669
        // that uses it, so no reason to implement it yet.
670
0
        if !desc {
  Branch (670:12): [True: 0, False: 0]
  Branch (670:12): [Folded - Ignored]
671
0
            return Err(make_err!(
672
0
                Code::Unimplemented,
673
0
                "Descending order is not supported in RedisAwaitedActionDb::get_range_of_actions",
674
0
            ));
675
0
        }
676
0
        Ok(self
677
0
            .store
678
0
            .search_by_index_prefix(SearchStateToAwaitedAction(get_state_prefix(state)))
679
0
            .await
680
0
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")?
681
0
            .map_ok(move |awaited_action| {
682
0
                OperationSubscriber::new(
683
0
                    None,
684
0
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
685
0
                    Arc::downgrade(&self.store),
686
0
                    self.now_fn.clone(),
687
                )
688
0
            }))
689
0
    }
690
691
0
    async fn get_all_awaited_actions(
692
0
        &self,
693
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
694
0
        Ok(self
695
0
            .store
696
0
            .search_by_index_prefix(SearchStateToAwaitedAction(""))
697
0
            .await
698
0
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")?
699
0
            .map_ok(move |awaited_action| {
700
0
                OperationSubscriber::new(
701
0
                    None,
702
0
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
703
0
                    Arc::downgrade(&self.store),
704
0
                    self.now_fn.clone(),
705
                )
706
0
            }))
707
0
    }
708
}