Coverage Report

Created: 2024-11-22 20:17

/build/source/nativelink-scheduler/src/store_awaited_action_db.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::borrow::Cow;
16
use std::ops::Bound;
17
use std::sync::atomic::{AtomicU64, Ordering};
18
use std::sync::{Arc, Weak};
19
use std::time::Duration;
20
21
use bytes::Bytes;
22
use futures::{Stream, TryStreamExt};
23
use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt};
24
use nativelink_metric::MetricsComponent;
25
use nativelink_util::action_messages::{
26
    ActionInfo, ActionStage, ActionUniqueQualifier, OperationId,
27
};
28
use nativelink_util::instant_wrapper::InstantWrapper;
29
use nativelink_util::spawn;
30
use nativelink_util::store_trait::{
31
    FalseValue, SchedulerCurrentVersionProvider, SchedulerIndexProvider, SchedulerStore,
32
    SchedulerStoreDataProvider, SchedulerStoreDecodeTo, SchedulerStoreKeyProvider,
33
    SchedulerSubscription, SchedulerSubscriptionManager, StoreKey, TrueValue,
34
};
35
use nativelink_util::task::JoinHandleDropGuard;
36
use tokio::sync::Notify;
37
use tracing::{event, Level};
38
39
use crate::awaited_action_db::{
40
    AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedAction,
41
    SortedAwaitedActionState,
42
};
43
44
type ClientOperationId = OperationId;
45
46
/// Duration to wait before sending client keep alive messages.
47
const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10);
48
49
/// Maximum number of retries to update client keep alive.
50
const MAX_RETRIES_FOR_CLIENT_KEEPALIVE: u32 = 8;
51
52
enum OperationSubscriberState<Sub> {
53
    Unsubscribed,
54
    Subscribed(Sub),
55
}
56
57
pub struct OperationSubscriber<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I> {
58
    maybe_client_operation_id: Option<ClientOperationId>,
59
    subscription_key: OperationIdToAwaitedAction<'static>,
60
    weak_store: Weak<S>,
61
    state: OperationSubscriberState<
62
        <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription,
63
    >,
64
    last_known_keepalive_ts: AtomicU64,
65
    now_fn: NowFn,
66
}
67
impl<S, I, NowFn> OperationSubscriber<S, I, NowFn>
68
where
69
    S: SchedulerStore,
70
    I: InstantWrapper,
71
    NowFn: Fn() -> I,
72
{
73
1
    fn new(
74
1
        maybe_client_operation_id: Option<ClientOperationId>,
75
1
        subscription_key: OperationIdToAwaitedAction<'static>,
76
1
        weak_store: Weak<S>,
77
1
        now_fn: NowFn,
78
1
    ) -> Self {
79
1
        Self {
80
1
            maybe_client_operation_id,
81
1
            subscription_key,
82
1
            weak_store,
83
1
            last_known_keepalive_ts: AtomicU64::new(0),
84
1
            state: OperationSubscriberState::Unsubscribed,
85
1
            now_fn,
86
1
        }
87
1
    }
88
89
2
    async fn inner_get_awaited_action(
90
2
        store: &S,
91
2
        key: OperationIdToAwaitedAction<'_>,
92
2
        maybe_client_operation_id: Option<ClientOperationId>,
93
2
        last_known_keepalive_ts: &AtomicU64,
94
2
    ) -> Result<AwaitedAction, Error> {
95
2
        let mut awaited_action = store
96
2
            .get_and_decode(key.borrow())
97
2
            .await
98
2
            .err_tip(|| 
format!("In OperationSubscriber::get_awaited_action {key:?}")0
)
?0
99
2
            .ok_or_else(|| {
100
0
                make_err!(
101
0
                    Code::NotFound,
102
0
                    "Could not find AwaitedAction for the given operation id {key:?}",
103
0
                )
104
2
            })
?0
;
105
2
        if let Some(client_operation_id) = maybe_client_operation_id {
  Branch (105:16): [True: 0, False: 0]
  Branch (105:16): [Folded - Ignored]
  Branch (105:16): [True: 2, False: 0]
106
2
            awaited_action.set_client_operation_id(client_operation_id);
107
2
        }
0
108
2
        last_known_keepalive_ts.store(
109
2
            awaited_action
110
2
                .last_client_keepalive_timestamp()
111
2
                .unix_timestamp(),
112
2
            Ordering::Release,
113
2
        );
114
2
        Ok(awaited_action)
115
2
    }
116
117
0
    async fn get_awaited_action(&self) -> Result<AwaitedAction, Error> {
118
0
        let store = self
119
0
            .weak_store
120
0
            .upgrade()
121
0
            .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")?;
122
0
        Self::inner_get_awaited_action(
123
0
            store.as_ref(),
124
0
            self.subscription_key.borrow(),
125
0
            self.maybe_client_operation_id.clone(),
126
0
            &self.last_known_keepalive_ts,
127
0
        )
128
0
        .await
129
0
    }
130
}
131
132
impl<S, I, NowFn> AwaitedActionSubscriber for OperationSubscriber<S, I, NowFn>
133
where
134
    S: SchedulerStore,
135
    I: InstantWrapper,
136
    NowFn: Fn() -> I + Send + Sync + 'static,
137
{
138
2
    async fn changed(&mut self) -> Result<AwaitedAction, Error> {
139
2
        let store = self
140
2
            .weak_store
141
2
            .upgrade()
142
2
            .err_tip(|| 
"Store gone in OperationSubscriber::get_awaited_action"0
)
?0
;
143
2
        let subscription = match &mut self.state {
144
            OperationSubscriberState::Unsubscribed => {
145
1
                let subscription = store
146
1
                    .subscription_manager()
147
1
                    .err_tip(|| 
"In OperationSubscriber::changed::subscription_manager"0
)
?0
148
1
                    .subscribe(self.subscription_key.borrow())
149
1
                    .err_tip(|| 
"In OperationSubscriber::changed::subscribe"0
)
?0
;
150
1
                self.state = OperationSubscriberState::Subscribed(subscription);
151
1
                let OperationSubscriberState::Subscribed(subscription) = &mut self.state else {
  Branch (151:21): [True: 0, False: 0]
  Branch (151:21): [Folded - Ignored]
  Branch (151:21): [True: 1, False: 0]
152
0
                    unreachable!("Subscription should be in Subscribed state");
153
                };
154
1
                subscription
155
            }
156
1
            OperationSubscriberState::Subscribed(subscription) => subscription,
157
        };
158
159
2
        let changed_fut = subscription.changed();
160
2
        tokio::pin!(changed_fut);
161
        loop {
162
2
            let mut retries = 0;
163
            loop {
164
2
                let last_known_keepalive_ts = self.last_known_keepalive_ts.load(Ordering::Acquire);
165
2
                if I::from_secs(last_known_keepalive_ts).elapsed() <= CLIENT_KEEPALIVE_DURATION {
  Branch (165:20): [True: 0, False: 0]
  Branch (165:20): [Folded - Ignored]
  Branch (165:20): [True: 2, False: 0]
166
2
                    break; // We are still within the keep alive duration.
167
0
                }
168
0
                if retries > MAX_RETRIES_FOR_CLIENT_KEEPALIVE {
  Branch (168:20): [True: 0, False: 0]
  Branch (168:20): [Folded - Ignored]
  Branch (168:20): [True: 0, False: 0]
169
0
                    return Err(make_err!(
170
0
                        Code::Aborted,
171
0
                        "Could not update client keep alive for AwaitedAction",
172
0
                    ));
173
0
                }
174
0
                let mut awaited_action = Self::inner_get_awaited_action(
175
0
                    store.as_ref(),
176
0
                    self.subscription_key.borrow(),
177
0
                    self.maybe_client_operation_id.clone(),
178
0
                    &self.last_known_keepalive_ts,
179
0
                )
180
0
                .await
181
0
                .err_tip(|| "In OperationSubscriber::changed")?;
182
0
                awaited_action.update_client_keep_alive((self.now_fn)().now());
183
0
                let update_res = inner_update_awaited_action(store.as_ref(), awaited_action)
184
0
                    .await
185
0
                    .err_tip(|| "In OperationSubscriber::changed");
186
0
                if update_res.is_ok() {
  Branch (186:20): [True: 0, False: 0]
  Branch (186:20): [Folded - Ignored]
  Branch (186:20): [True: 0, False: 0]
187
0
                    break;
188
0
                }
189
0
                retries += 1;
190
0
                // Wait a tick before retrying.
191
0
                (self.now_fn)().sleep(Duration::from_millis(100)).await;
192
            }
193
2
            let sleep_fut = (self.now_fn)().sleep(CLIENT_KEEPALIVE_DURATION);
194
2
            tokio::select! {
195
2
                result = &mut changed_fut => {
196
2
                    result
?0
;
197
2
                    break;
198
                }
199
2
                _ = sleep_fut => {
200
0
                    // If we haven't received any updates for a while, we should
201
0
                    // let the database know that we are still listening to prevent
202
0
                    // the action from being dropped.
203
0
                }
204
            }
205
        }
206
207
2
        Self::inner_get_awaited_action(
208
2
            store.as_ref(),
209
2
            self.subscription_key.borrow(),
210
2
            self.maybe_client_operation_id.clone(),
211
2
            &self.last_known_keepalive_ts,
212
2
        )
213
2
        .await
214
2
        .err_tip(|| 
"In OperationSubscriber::changed"0
)
215
2
    }
216
217
0
    async fn borrow(&self) -> Result<AwaitedAction, Error> {
218
0
        self.get_awaited_action()
219
0
            .await
220
0
            .err_tip(|| "In OperationSubscriber::borrow")
221
0
    }
222
}
223
224
2
fn awaited_action_decode(version: u64, data: &Bytes) -> Result<AwaitedAction, Error> {
225
2
    let mut awaited_action: AwaitedAction = serde_json::from_slice(data)
226
2
        .map_err(|e| 
make_input_err!("In AwaitedAction::decode - {e:?}")0
)
?0
;
227
2
    awaited_action.set_version(version);
228
2
    Ok(awaited_action)
229
2
}
230
231
const OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX: &str = "aa_";
232
const CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX: &str = "cid_";
233
234
#[derive(Debug)]
235
struct OperationIdToAwaitedAction<'a>(Cow<'a, OperationId>);
236
impl OperationIdToAwaitedAction<'_> {
237
5
    fn borrow(&self) -> OperationIdToAwaitedAction<'_> {
238
5
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.as_ref()))
239
5
    }
240
}
241
impl SchedulerStoreKeyProvider for OperationIdToAwaitedAction<'_> {
242
    type Versioned = TrueValue;
243
6
    fn get_key(&self) -> StoreKey<'static> {
244
6
        StoreKey::Str(Cow::Owned(format!(
245
6
            "{OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX}{}",
246
6
            self.0
247
6
        )))
248
6
    }
249
}
250
impl SchedulerStoreDecodeTo for OperationIdToAwaitedAction<'_> {
251
    type DecodeOutput = AwaitedAction;
252
2
    fn decode(version: u64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
253
2
        awaited_action_decode(version, &data)
254
2
    }
255
}
256
257
struct ClientIdToOperationId<'a>(&'a OperationId);
258
impl SchedulerStoreKeyProvider for ClientIdToOperationId<'_> {
259
    type Versioned = FalseValue;
260
1
    fn get_key(&self) -> StoreKey<'static> {
261
1
        StoreKey::Str(Cow::Owned(format!(
262
1
            "{CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX}{}",
263
1
            self.0
264
1
        )))
265
1
    }
266
}
267
impl SchedulerStoreDecodeTo for ClientIdToOperationId<'_> {
268
    type DecodeOutput = OperationId;
269
0
    fn decode(_version: u64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
270
0
        OperationId::try_from(data).err_tip(|| "In ClientIdToOperationId::decode")
271
0
    }
272
}
273
274
// TODO(allada) We only need operation_id here, it would be nice if we had a way
275
// to tell the decoder we only care about specific fields.
276
struct SearchUniqueQualifierToAwaitedAction<'a>(&'a ActionUniqueQualifier);
277
impl SchedulerIndexProvider for SearchUniqueQualifierToAwaitedAction<'_> {
278
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
279
    const INDEX_NAME: &'static str = "unique_qualifier";
280
    type Versioned = TrueValue;
281
1
    fn index_value(&self) -> Cow<'_, str> {
282
1
        Cow::Owned(format!("{}", self.0))
283
1
    }
284
}
285
impl SchedulerStoreDecodeTo for SearchUniqueQualifierToAwaitedAction<'_> {
286
    type DecodeOutput = AwaitedAction;
287
0
    fn decode(version: u64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
288
0
        awaited_action_decode(version, &data)
289
0
    }
290
}
291
292
struct SearchStateToAwaitedAction(&'static str);
293
impl SchedulerIndexProvider for SearchStateToAwaitedAction {
294
    const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX;
295
    const INDEX_NAME: &'static str = "state";
296
    const MAYBE_SORT_KEY: Option<&'static str> = Some("sort_key");
297
    type Versioned = TrueValue;
298
0
    fn index_value(&self) -> Cow<'_, str> {
299
0
        Cow::Borrowed(self.0)
300
0
    }
301
}
302
impl SchedulerStoreDecodeTo for SearchStateToAwaitedAction {
303
    type DecodeOutput = AwaitedAction;
304
0
    fn decode(version: u64, data: Bytes) -> Result<Self::DecodeOutput, Error> {
305
0
        awaited_action_decode(version, &data)
306
0
    }
307
}
308
309
2
fn get_state_prefix(state: SortedAwaitedActionState) -> &'static str {
310
2
    match state {
311
0
        SortedAwaitedActionState::CacheCheck => "cache_check",
312
1
        SortedAwaitedActionState::Queued => "queued",
313
1
        SortedAwaitedActionState::Executing => "executing",
314
0
        SortedAwaitedActionState::Completed => "completed",
315
    }
316
2
}
317
318
struct UpdateOperationIdToAwaitedAction(AwaitedAction);
319
impl SchedulerCurrentVersionProvider for UpdateOperationIdToAwaitedAction {
320
2
    fn current_version(&self) -> u64 {
321
2
        self.0.version()
322
2
    }
323
}
324
impl SchedulerStoreKeyProvider for UpdateOperationIdToAwaitedAction {
325
    type Versioned = TrueValue;
326
2
    fn get_key(&self) -> StoreKey<'static> {
327
2
        OperationIdToAwaitedAction(Cow::Borrowed(self.0.operation_id())).get_key()
328
2
    }
329
}
330
impl SchedulerStoreDataProvider for UpdateOperationIdToAwaitedAction {
331
2
    fn try_into_bytes(self) -> Result<Bytes, Error> {
332
2
        serde_json::to_string(&self.0)
333
2
            .map(Bytes::from)
334
2
            .map_err(|e| 
make_input_err!("Could not convert AwaitedAction to json - {e:?}")0
)
335
2
    }
336
2
    fn get_indexes(&self) -> Result<Vec<(&'static str, Bytes)>, Error> {
337
2
        let unique_qualifier = &self.0.action_info().unique_qualifier;
338
2
        let maybe_unique_qualifier = match &unique_qualifier {
339
2
            ActionUniqueQualifier::Cachable(_) => Some(unique_qualifier),
340
0
            ActionUniqueQualifier::Uncachable(_) => None,
341
        };
342
2
        let mut output = Vec::with_capacity(1 + maybe_unique_qualifier.map_or(0, |_| 1));
343
2
        if maybe_unique_qualifier.is_some() {
  Branch (343:12): [True: 2, False: 0]
  Branch (343:12): [Folded - Ignored]
344
2
            output.push((
345
2
                "unique_qualifier",
346
2
                Bytes::from(unique_qualifier.to_string()),
347
2
            ));
348
2
        }
0
349
        {
350
2
            let state = SortedAwaitedActionState::try_from(&self.0.state().stage)
351
2
                .err_tip(|| 
"In UpdateOperationIdToAwaitedAction::get_index"0
)
?0
;
352
2
            output.push(("state", Bytes::from(get_state_prefix(state))));
353
2
            let sorted_awaited_action = SortedAwaitedAction::from(&self.0);
354
2
            output.push((
355
2
                "sort_key",
356
2
                // We encode to hex to ensure that the sort key is lexicographically sorted.
357
2
                Bytes::from(format!("{:016x}", sorted_awaited_action.sort_key.as_u64())),
358
2
            ));
359
2
        }
360
2
        Ok(output)
361
2
    }
362
}
363
364
struct UpdateClientIdToOperationId {
365
    client_operation_id: ClientOperationId,
366
    operation_id: OperationId,
367
}
368
impl SchedulerStoreKeyProvider for UpdateClientIdToOperationId {
369
    type Versioned = FalseValue;
370
1
    fn get_key(&self) -> StoreKey<'static> {
371
1
        ClientIdToOperationId(&self.client_operation_id).get_key()
372
1
    }
373
}
374
impl SchedulerStoreDataProvider for UpdateClientIdToOperationId {
375
1
    fn try_into_bytes(self) -> Result<Bytes, Error> {
376
1
        serde_json::to_string(&self.operation_id)
377
1
            .map(Bytes::from)
378
1
            .map_err(|e| 
make_input_err!("Could not convert OperationId to json - {e:?}")0
)
379
1
    }
380
}
381
382
1
async fn inner_update_awaited_action(
383
1
    store: &impl SchedulerStore,
384
1
    mut new_awaited_action: AwaitedAction,
385
1
) -> Result<(), Error> {
386
1
    let operation_id = new_awaited_action.operation_id().clone();
387
1
    if new_awaited_action.state().client_operation_id != operation_id {
  Branch (387:8): [True: 0, False: 0]
  Branch (387:8): [Folded - Ignored]
  Branch (387:8): [True: 0, False: 1]
388
0
        // Just in case the client_operation_id was set to something else
389
0
        // we put it back to the underlying operation_id.
390
0
        new_awaited_action.set_client_operation_id(operation_id.clone());
391
1
    }
392
1
    let maybe_version = store
393
1
        .update_data(UpdateOperationIdToAwaitedAction(new_awaited_action))
394
2
        .await
395
1
        .err_tip(|| 
"In RedisAwaitedActionDb::update_awaited_action"0
)
?0
;
396
1
    if maybe_version.is_none() {
  Branch (396:8): [True: 0, False: 0]
  Branch (396:8): [Folded - Ignored]
  Branch (396:8): [True: 0, False: 1]
397
0
        return Err(make_err!(
398
0
            Code::Aborted,
399
0
            "Could not update AwaitedAction because the version did not match for {operation_id:?}",
400
0
        ));
401
1
    }
402
1
    Ok(())
403
1
}
404
405
0
#[derive(MetricsComponent)]
406
pub struct StoreAwaitedActionDb<S, F, I, NowFn>
407
where
408
    S: SchedulerStore,
409
    F: Fn() -> OperationId,
410
    I: InstantWrapper,
411
    NowFn: Fn() -> I,
412
{
413
    store: Arc<S>,
414
    now_fn: NowFn,
415
    operation_id_creator: F,
416
    _pull_task_change_subscriber_spawn: JoinHandleDropGuard<()>,
417
}
418
419
impl<S, F, I, NowFn> StoreAwaitedActionDb<S, F, I, NowFn>
420
where
421
    S: SchedulerStore,
422
    F: Fn() -> OperationId,
423
    I: InstantWrapper,
424
    NowFn: Fn() -> I + Send + Sync + Clone + 'static,
425
{
426
1
    pub fn new(
427
1
        store: Arc<S>,
428
1
        task_change_publisher: Arc<Notify>,
429
1
        now_fn: NowFn,
430
1
        operation_id_creator: F,
431
1
    ) -> Result<Self, Error> {
432
1
        let mut subscription = store
433
1
            .subscription_manager()
434
1
            .err_tip(|| 
"In RedisAwaitedActionDb::new"0
)
?0
435
1
            .subscribe(OperationIdToAwaitedAction(Cow::Owned(OperationId::String(
436
1
                String::new(),
437
1
            ))))
438
1
            .err_tip(|| 
"In RedisAwaitedActionDb::new"0
)
?0
;
439
1
        let pull_task_change_subscriber = spawn!(
440
1
            "redis_awaited_action_db_pull_task_change_subscriber",
441
1
            async move {
442
                loop {
443
5
                    let 
changed_res4
= subscription
444
5
                        .changed()
445
3
                        .await
446
4
                        .err_tip(|| 
"In RedisAwaitedActionDb::new"0
);
447
4
                    if let Err(
err0
) = changed_res {
  Branch (447:28): [True: 0, False: 0]
  Branch (447:28): [Folded - Ignored]
  Branch (447:28): [True: 0, False: 4]
448
0
                        event!(
449
0
                            Level::ERROR,
450
0
                            "Error waiting for pull task change subscriber in RedisAwaitedActionDb::new  - {err:?}"
451
                        );
452
                        // Sleep for a second to avoid a busy loop, then trigger the notify
453
                        // so if a reconnect happens we let local resources know that things
454
                        // might have changed.
455
0
                        tokio::time::sleep(Duration::from_secs(1)).await;
456
4
                    }
457
4
                    task_change_publisher.as_ref().notify_one();
458
                }
459
1
            }
460
1
        );
461
1
        Ok(Self {
462
1
            store,
463
1
            now_fn,
464
1
            operation_id_creator,
465
1
            _pull_task_change_subscriber_spawn: pull_task_change_subscriber,
466
1
        })
467
1
    }
468
469
1
    async fn try_subscribe(
470
1
        &self,
471
1
        client_operation_id: &ClientOperationId,
472
1
        unique_qualifier: &ActionUniqueQualifier,
473
1
        // TODO(allada) To simplify the scheduler 2024 refactor, we
474
1
        // removed the ability to upgrade priorities of actions.
475
1
        // we should add priority upgrades back in.
476
1
        _priority: i32,
477
1
    ) -> Result<Option<OperationSubscriber<S, I, NowFn>>, Error> {
478
1
        match unique_qualifier {
479
1
            ActionUniqueQualifier::Cachable(_) => {}
480
0
            ActionUniqueQualifier::Uncachable(_) => return Ok(None),
481
        }
482
1
        let stream = self
483
1
            .store
484
1
            .search_by_index_prefix(SearchUniqueQualifierToAwaitedAction(unique_qualifier))
485
3
            .await
486
1
            .err_tip(|| 
"In RedisAwaitedActionDb::try_subscribe"0
)
?0
;
487
1
        tokio::pin!(stream);
488
1
        let maybe_awaited_action = stream
489
1
            .try_next()
490
0
            .await
491
1
            .err_tip(|| 
"In RedisAwaitedActionDb::try_subscribe"0
)
?0
;
492
1
        match maybe_awaited_action {
493
0
            Some(awaited_action) => {
494
0
                // TODO(allada) We don't support joining completed jobs because we
495
0
                // need to also check that all the data is still in the cache.
496
0
                if awaited_action.state().stage.is_finished() {
  Branch (496:20): [True: 0, False: 0]
  Branch (496:20): [Folded - Ignored]
  Branch (496:20): [True: 0, False: 0]
497
0
                    return Ok(None);
498
0
                }
499
0
                // TODO(allada) We only care about the operation_id here, we should
500
0
                // have a way to tell the decoder we only care about specific fields.
501
0
                let operation_id = awaited_action.operation_id();
502
0
                Ok(Some(OperationSubscriber::new(
503
0
                    Some(client_operation_id.clone()),
504
0
                    OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())),
505
0
                    Arc::downgrade(&self.store),
506
0
                    self.now_fn.clone(),
507
0
                )))
508
            }
509
1
            None => Ok(None),
510
        }
511
1
    }
512
513
0
    async fn inner_get_awaited_action_by_id(
514
0
        &self,
515
0
        client_operation_id: &ClientOperationId,
516
0
    ) -> Result<Option<OperationSubscriber<S, I, NowFn>>, Error> {
517
0
        let maybe_operation_id = self
518
0
            .store
519
0
            .get_and_decode(ClientIdToOperationId(client_operation_id))
520
0
            .await
521
0
            .err_tip(|| "In RedisAwaitedActionDb::get_awaited_action_by_id")?;
522
0
        let Some(operation_id) = maybe_operation_id else {
  Branch (522:13): [True: 0, False: 0]
  Branch (522:13): [Folded - Ignored]
523
0
            return Ok(None);
524
        };
525
0
        Ok(Some(OperationSubscriber::new(
526
0
            Some(client_operation_id.clone()),
527
0
            OperationIdToAwaitedAction(Cow::Owned(operation_id)),
528
0
            Arc::downgrade(&self.store),
529
0
            self.now_fn.clone(),
530
0
        )))
531
0
    }
532
}
533
534
impl<S, F, I, NowFn> AwaitedActionDb for StoreAwaitedActionDb<S, F, I, NowFn>
535
where
536
    S: SchedulerStore,
537
    F: Fn() -> OperationId + Send + Sync + Unpin + 'static,
538
    I: InstantWrapper,
539
    NowFn: Fn() -> I + Send + Sync + Unpin + Clone + 'static,
540
{
541
    type Subscriber = OperationSubscriber<S, I, NowFn>;
542
543
0
    async fn get_awaited_action_by_id(
544
0
        &self,
545
0
        client_operation_id: &ClientOperationId,
546
0
    ) -> Result<Option<Self::Subscriber>, Error> {
547
0
        self.inner_get_awaited_action_by_id(client_operation_id)
548
0
            .await
549
0
    }
550
551
0
    async fn get_by_operation_id(
552
0
        &self,
553
0
        operation_id: &OperationId,
554
0
    ) -> Result<Option<Self::Subscriber>, Error> {
555
0
        Ok(Some(OperationSubscriber::new(
556
0
            None,
557
0
            OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())),
558
0
            Arc::downgrade(&self.store),
559
0
            self.now_fn.clone(),
560
0
        )))
561
0
    }
562
563
1
    async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
564
2
        inner_update_awaited_action(self.store.as_ref(), new_awaited_action).await
565
1
    }
566
567
1
    async fn add_action(
568
1
        &self,
569
1
        client_operation_id: ClientOperationId,
570
1
        action_info: Arc<ActionInfo>,
571
1
    ) -> Result<Self::Subscriber, Error> {
572
        // Check to see if the action is already known and subscribe if it is.
573
1
        let subscription = self
574
1
            .try_subscribe(
575
1
                &client_operation_id,
576
1
                &action_info.unique_qualifier,
577
1
                action_info.priority,
578
1
            )
579
3
            .await
580
1
            .err_tip(|| 
"In RedisAwaitedActionDb::add_action"0
)
?0
;
581
1
        if let Some(
sub0
) = subscription {
  Branch (581:16): [True: 0, False: 0]
  Branch (581:16): [Folded - Ignored]
  Branch (581:16): [True: 0, False: 1]
582
0
            return Ok(sub);
583
1
        }
584
1
585
1
        let new_operation_id = (self.operation_id_creator)();
586
1
        let awaited_action =
587
1
            AwaitedAction::new(new_operation_id.clone(), action_info, (self.now_fn)().now());
588
1
        debug_assert!(
589
0
            ActionStage::Queued == awaited_action.state().stage,
590
0
            "Expected action to be queued"
591
        );
592
593
        // Note: Version is not needed with this api.
594
1
        let _version = self
595
1
            .store
596
1
            .update_data(UpdateOperationIdToAwaitedAction(awaited_action))
597
2
            .await
598
1
            .err_tip(|| 
"In RedisAwaitedActionDb::add_action"0
)
?0
599
1
            .err_tip(|| {
600
0
                "Version match failed for new action insert in RedisAwaitedActionDb::add_action"
601
1
            })
?0
;
602
603
1
        self.store
604
1
            .update_data(UpdateClientIdToOperationId {
605
1
                client_operation_id: client_operation_id.clone(),
606
1
                operation_id: new_operation_id.clone(),
607
1
            })
608
2
            .await
609
1
            .err_tip(|| 
"In RedisAwaitedActionDb::add_action"0
)
?0
;
610
611
1
        Ok(OperationSubscriber::new(
612
1
            Some(client_operation_id),
613
1
            OperationIdToAwaitedAction(Cow::Owned(new_operation_id)),
614
1
            Arc::downgrade(&self.store),
615
1
            self.now_fn.clone(),
616
1
        ))
617
1
    }
618
619
0
    async fn get_range_of_actions(
620
0
        &self,
621
0
        state: SortedAwaitedActionState,
622
0
        start: Bound<SortedAwaitedAction>,
623
0
        end: Bound<SortedAwaitedAction>,
624
0
        desc: bool,
625
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
626
0
        if !matches!(start, Bound::Unbounded) {
  Branch (626:12): [True: 0, False: 0]
  Branch (626:12): [Folded - Ignored]
627
0
            return Err(make_err!(
628
0
                Code::Unimplemented,
629
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
630
0
            ));
631
0
        }
632
0
        if !matches!(end, Bound::Unbounded) {
  Branch (632:12): [True: 0, False: 0]
  Branch (632:12): [Folded - Ignored]
633
0
            return Err(make_err!(
634
0
                Code::Unimplemented,
635
0
                "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions",
636
0
            ));
637
0
        }
638
0
        // TODO(allada) This API is not difficult to implement, but there is no code path
639
0
        // that uses it, so no reason to implement it yet.
640
0
        if !desc {
  Branch (640:12): [True: 0, False: 0]
  Branch (640:12): [Folded - Ignored]
641
0
            return Err(make_err!(
642
0
                Code::Unimplemented,
643
0
                "Descending order is not supported in RedisAwaitedActionDb::get_range_of_actions",
644
0
            ));
645
0
        }
646
0
        Ok(self
647
0
            .store
648
0
            .search_by_index_prefix(SearchStateToAwaitedAction(get_state_prefix(state)))
649
0
            .await
650
0
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")?
651
0
            .map_ok(move |awaited_action| {
652
0
                OperationSubscriber::new(
653
0
                    None,
654
0
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
655
0
                    Arc::downgrade(&self.store),
656
0
                    self.now_fn.clone(),
657
0
                )
658
0
            }))
659
0
    }
660
661
0
    async fn get_all_awaited_actions(
662
0
        &self,
663
0
    ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
664
0
        Ok(self
665
0
            .store
666
0
            .search_by_index_prefix(SearchStateToAwaitedAction(""))
667
0
            .await
668
0
            .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")?
669
0
            .map_ok(move |awaited_action| {
670
0
                OperationSubscriber::new(
671
0
                    None,
672
0
                    OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())),
673
0
                    Arc::downgrade(&self.store),
674
0
                    self.now_fn.clone(),
675
0
                )
676
0
            }))
677
0
    }
678
}