/build/source/nativelink-scheduler/src/store_awaited_action_db.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::ops::Bound; |
16 | | use core::sync::atomic::{AtomicU64, Ordering}; |
17 | | use core::time::Duration; |
18 | | use std::borrow::Cow; |
19 | | use std::sync::{Arc, Weak}; |
20 | | |
21 | | use bytes::Bytes; |
22 | | use futures::{Stream, TryStreamExt}; |
23 | | use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err}; |
24 | | use nativelink_metric::MetricsComponent; |
25 | | use nativelink_util::action_messages::{ |
26 | | ActionInfo, ActionStage, ActionUniqueQualifier, OperationId, |
27 | | }; |
28 | | use nativelink_util::instant_wrapper::InstantWrapper; |
29 | | use nativelink_util::spawn; |
30 | | use nativelink_util::store_trait::{ |
31 | | FalseValue, SchedulerCurrentVersionProvider, SchedulerIndexProvider, SchedulerStore, |
32 | | SchedulerStoreDataProvider, SchedulerStoreDecodeTo, SchedulerStoreKeyProvider, |
33 | | SchedulerSubscription, SchedulerSubscriptionManager, StoreKey, TrueValue, |
34 | | }; |
35 | | use nativelink_util::task::JoinHandleDropGuard; |
36 | | use tokio::sync::Notify; |
37 | | use tracing::error; |
38 | | |
39 | | use crate::awaited_action_db::{ |
40 | | AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, CLIENT_KEEPALIVE_DURATION, |
41 | | SortedAwaitedAction, SortedAwaitedActionState, |
42 | | }; |
43 | | |
44 | | type ClientOperationId = OperationId; |
45 | | |
46 | | /// Maximum number of retries to update client keep alive. |
47 | | const MAX_RETRIES_FOR_CLIENT_KEEPALIVE: u32 = 8; |
48 | | |
49 | | enum OperationSubscriberState<Sub> { |
50 | | Unsubscribed, |
51 | | Subscribed(Sub), |
52 | | } |
53 | | |
54 | | pub struct OperationSubscriber<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I> { |
55 | | maybe_client_operation_id: Option<ClientOperationId>, |
56 | | subscription_key: OperationIdToAwaitedAction<'static>, |
57 | | weak_store: Weak<S>, |
58 | | state: OperationSubscriberState< |
59 | | <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription, |
60 | | >, |
61 | | last_known_keepalive_ts: AtomicU64, |
62 | | now_fn: NowFn, |
63 | | } |
64 | | |
65 | | impl<S: SchedulerStore, I: InstantWrapper, NowFn: Fn() -> I + core::fmt::Debug> core::fmt::Debug |
66 | | for OperationSubscriber<S, I, NowFn> |
67 | | where |
68 | | OperationSubscriberState< |
69 | | <S::SubscriptionManager as SchedulerSubscriptionManager>::Subscription, |
70 | | >: core::fmt::Debug, |
71 | | { |
72 | 0 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
73 | 0 | f.debug_struct("OperationSubscriber") |
74 | 0 | .field("maybe_client_operation_id", &self.maybe_client_operation_id) |
75 | 0 | .field("subscription_key", &self.subscription_key) |
76 | 0 | .field("weak_store", &self.weak_store) |
77 | 0 | .field("state", &self.state) |
78 | 0 | .field("last_known_keepalive_ts", &self.last_known_keepalive_ts) |
79 | 0 | .field("now_fn", &self.now_fn) |
80 | 0 | .finish() |
81 | 0 | } |
82 | | } |
83 | | impl<S, I, NowFn> OperationSubscriber<S, I, NowFn> |
84 | | where |
85 | | S: SchedulerStore, |
86 | | I: InstantWrapper, |
87 | | NowFn: Fn() -> I, |
88 | | { |
89 | 19 | const fn new( |
90 | 19 | maybe_client_operation_id: Option<ClientOperationId>, |
91 | 19 | subscription_key: OperationIdToAwaitedAction<'static>, |
92 | 19 | weak_store: Weak<S>, |
93 | 19 | now_fn: NowFn, |
94 | 19 | ) -> Self { |
95 | 19 | Self { |
96 | 19 | maybe_client_operation_id, |
97 | 19 | subscription_key, |
98 | 19 | weak_store, |
99 | 19 | last_known_keepalive_ts: AtomicU64::new(0), |
100 | 19 | state: OperationSubscriberState::Unsubscribed, |
101 | 19 | now_fn, |
102 | 19 | } |
103 | 19 | } |
104 | | |
105 | 31 | async fn inner_get_awaited_action( |
106 | 31 | store: &S, |
107 | 31 | key: OperationIdToAwaitedAction<'_>, |
108 | 31 | maybe_client_operation_id: Option<ClientOperationId>, |
109 | 31 | last_known_keepalive_ts: &AtomicU64, |
110 | 31 | ) -> Result<AwaitedAction, Error> { |
111 | 31 | let mut awaited_action = store |
112 | 31 | .get_and_decode(key.borrow()) |
113 | 31 | .await |
114 | 31 | .err_tip(|| format!("In OperationSubscriber::get_awaited_action {key:?}"0 ))?0 |
115 | 31 | .ok_or_else(|| {0 |
116 | 0 | make_err!( |
117 | 0 | Code::NotFound, |
118 | | "Could not find AwaitedAction for the given operation id {key:?}", |
119 | | ) |
120 | 0 | })?; |
121 | 31 | if let Some(client_operation_id9 ) = maybe_client_operation_id { Branch (121:16): [True: 0, False: 0]
Branch (121:16): [Folded - Ignored]
Branch (121:16): [True: 9, False: 22]
|
122 | 9 | awaited_action.set_client_operation_id(client_operation_id); |
123 | 22 | } |
124 | 31 | last_known_keepalive_ts.store( |
125 | 31 | awaited_action |
126 | 31 | .last_client_keepalive_timestamp() |
127 | 31 | .unix_timestamp(), |
128 | 31 | Ordering::Release, |
129 | | ); |
130 | 31 | Ok(awaited_action) |
131 | 31 | } |
132 | | |
133 | | #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this |
134 | 27 | async fn get_awaited_action(&self) -> Result<AwaitedAction, Error> { |
135 | 27 | let store = self |
136 | 27 | .weak_store |
137 | 27 | .upgrade() |
138 | 27 | .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")?0 ; |
139 | 27 | Self::inner_get_awaited_action( |
140 | 27 | store.as_ref(), |
141 | 27 | self.subscription_key.borrow(), |
142 | 27 | self.maybe_client_operation_id.clone(), |
143 | 27 | &self.last_known_keepalive_ts, |
144 | 27 | ) |
145 | 27 | .await |
146 | 27 | } |
147 | | } |
148 | | |
149 | | impl<S, I, NowFn> AwaitedActionSubscriber for OperationSubscriber<S, I, NowFn> |
150 | | where |
151 | | S: SchedulerStore, |
152 | | I: InstantWrapper, |
153 | | NowFn: Fn() -> I + Send + Sync + 'static, |
154 | | { |
155 | 4 | async fn changed(&mut self) -> Result<AwaitedAction, Error> { |
156 | 4 | let store = self |
157 | 4 | .weak_store |
158 | 4 | .upgrade() |
159 | 4 | .err_tip(|| "Store gone in OperationSubscriber::get_awaited_action")?0 ; |
160 | 4 | let subscription = match &mut self.state { |
161 | | OperationSubscriberState::Unsubscribed => { |
162 | 2 | let subscription = store |
163 | 2 | .subscription_manager() |
164 | 2 | .err_tip(|| "In OperationSubscriber::changed::subscription_manager")?0 |
165 | 2 | .subscribe(self.subscription_key.borrow()) |
166 | 2 | .err_tip(|| "In OperationSubscriber::changed::subscribe")?0 ; |
167 | 2 | self.state = OperationSubscriberState::Subscribed(subscription); |
168 | 2 | let OperationSubscriberState::Subscribed(subscription) = &mut self.state else { Branch (168:21): [True: 0, False: 0]
Branch (168:21): [Folded - Ignored]
Branch (168:21): [True: 2, False: 0]
|
169 | 0 | unreachable!("Subscription should be in Subscribed state"); |
170 | | }; |
171 | 2 | subscription |
172 | | } |
173 | 2 | OperationSubscriberState::Subscribed(subscription) => subscription, |
174 | | }; |
175 | | |
176 | 4 | let changed_fut = subscription.changed(); |
177 | 4 | tokio::pin!(changed_fut); |
178 | | loop { |
179 | 4 | let mut retries = 0; |
180 | | loop { |
181 | 4 | let last_known_keepalive_ts = self.last_known_keepalive_ts.load(Ordering::Acquire); |
182 | 4 | if I::from_secs(last_known_keepalive_ts).elapsed() <= CLIENT_KEEPALIVE_DURATION { Branch (182:20): [True: 0, False: 0]
Branch (182:20): [Folded - Ignored]
Branch (182:20): [True: 4, False: 0]
|
183 | 4 | break; // We are still within the keep alive duration. |
184 | 0 | } |
185 | 0 | if retries > MAX_RETRIES_FOR_CLIENT_KEEPALIVE { Branch (185:20): [True: 0, False: 0]
Branch (185:20): [Folded - Ignored]
Branch (185:20): [True: 0, False: 0]
|
186 | 0 | return Err(make_err!( |
187 | 0 | Code::Aborted, |
188 | 0 | "Could not update client keep alive for AwaitedAction", |
189 | 0 | )); |
190 | 0 | } |
191 | 0 | let mut awaited_action = Self::inner_get_awaited_action( |
192 | 0 | store.as_ref(), |
193 | 0 | self.subscription_key.borrow(), |
194 | 0 | self.maybe_client_operation_id.clone(), |
195 | 0 | &self.last_known_keepalive_ts, |
196 | 0 | ) |
197 | 0 | .await |
198 | 0 | .err_tip(|| "In OperationSubscriber::changed")?; |
199 | 0 | awaited_action.update_client_keep_alive((self.now_fn)().now()); |
200 | 0 | let update_res = inner_update_awaited_action(store.as_ref(), awaited_action) |
201 | 0 | .await |
202 | 0 | .err_tip(|| "In OperationSubscriber::changed"); |
203 | 0 | if update_res.is_ok() { Branch (203:20): [True: 0, False: 0]
Branch (203:20): [Folded - Ignored]
Branch (203:20): [True: 0, False: 0]
|
204 | 0 | break; |
205 | 0 | } |
206 | 0 | retries += 1; |
207 | | // Wait a tick before retrying. |
208 | 0 | (self.now_fn)().sleep(Duration::from_millis(100)).await; |
209 | | } |
210 | 4 | let sleep_fut = (self.now_fn)().sleep(CLIENT_KEEPALIVE_DURATION); |
211 | 4 | tokio::select! { |
212 | 4 | result = &mut changed_fut => { |
213 | 4 | result?0 ; |
214 | 4 | break; |
215 | | } |
216 | 4 | () = sleep_fut => { |
217 | 0 | // If we haven't received any updates for a while, we should |
218 | 0 | // let the database know that we are still listening to prevent |
219 | 0 | // the action from being dropped. |
220 | 0 | } |
221 | | } |
222 | | } |
223 | | |
224 | 4 | Self::inner_get_awaited_action( |
225 | 4 | store.as_ref(), |
226 | 4 | self.subscription_key.borrow(), |
227 | 4 | self.maybe_client_operation_id.clone(), |
228 | 4 | &self.last_known_keepalive_ts, |
229 | 4 | ) |
230 | 4 | .await |
231 | 4 | .err_tip(|| "In OperationSubscriber::changed") |
232 | 4 | } |
233 | | |
234 | 27 | async fn borrow(&self) -> Result<AwaitedAction, Error> { |
235 | 27 | self.get_awaited_action() |
236 | 27 | .await |
237 | 27 | .err_tip(|| "In OperationSubscriber::borrow") |
238 | 27 | } |
239 | | } |
240 | | |
241 | 38 | fn awaited_action_decode(version: i64, data: &Bytes) -> Result<AwaitedAction, Error> { |
242 | 38 | let mut awaited_action: AwaitedAction = serde_json::from_slice(data) |
243 | 38 | .map_err(|e| make_input_err!("In AwaitedAction::decode - {e:?}"))?0 ; |
244 | 38 | awaited_action.set_version(version); |
245 | 38 | Ok(awaited_action) |
246 | 38 | } |
247 | | |
248 | | const OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX: &str = "aa_"; |
249 | | const CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX: &str = "cid_"; |
250 | | |
251 | | #[derive(Debug)] |
252 | | struct OperationIdToAwaitedAction<'a>(Cow<'a, OperationId>); |
253 | | impl OperationIdToAwaitedAction<'_> { |
254 | 64 | fn borrow(&self) -> OperationIdToAwaitedAction<'_> { |
255 | 64 | OperationIdToAwaitedAction(Cow::Borrowed(self.0.as_ref())) |
256 | 64 | } |
257 | | } |
258 | | impl SchedulerStoreKeyProvider for OperationIdToAwaitedAction<'_> { |
259 | | type Versioned = TrueValue; |
260 | 48 | fn get_key(&self) -> StoreKey<'static> { |
261 | 48 | StoreKey::Str(Cow::Owned(format!( |
262 | 48 | "{OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX}{}", |
263 | 48 | self.0 |
264 | 48 | ))) |
265 | 48 | } |
266 | | } |
267 | | impl SchedulerStoreDecodeTo for OperationIdToAwaitedAction<'_> { |
268 | | type DecodeOutput = AwaitedAction; |
269 | 31 | fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> { |
270 | 31 | awaited_action_decode(version, &data) |
271 | 31 | } |
272 | | } |
273 | | |
274 | | struct ClientIdToOperationId<'a>(&'a OperationId); |
275 | | impl SchedulerStoreKeyProvider for ClientIdToOperationId<'_> { |
276 | | type Versioned = FalseValue; |
277 | 6 | fn get_key(&self) -> StoreKey<'static> { |
278 | 6 | StoreKey::Str(Cow::Owned(format!( |
279 | 6 | "{CLIENT_ID_TO_OPERATION_ID_KEY_PREFIX}{}", |
280 | 6 | self.0 |
281 | 6 | ))) |
282 | 6 | } |
283 | | } |
284 | | impl SchedulerStoreDecodeTo for ClientIdToOperationId<'_> { |
285 | | type DecodeOutput = OperationId; |
286 | 2 | fn decode(_version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> { |
287 | 2 | serde_json::from_slice(&data).map_err(|e| {0 |
288 | 0 | make_input_err!( |
289 | | "In ClientIdToOperationId::decode - {e:?} (data: {:02x?})", |
290 | | data |
291 | | ) |
292 | 0 | }) |
293 | 2 | } |
294 | | } |
295 | | |
296 | | // TODO(palfrey) We only need operation_id here, it would be nice if we had a way |
297 | | // to tell the decoder we only care about specific fields. |
298 | | struct SearchUniqueQualifierToAwaitedAction<'a>(&'a ActionUniqueQualifier); |
299 | | impl SchedulerIndexProvider for SearchUniqueQualifierToAwaitedAction<'_> { |
300 | | const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX; |
301 | | const INDEX_NAME: &'static str = "unique_qualifier"; |
302 | | type Versioned = TrueValue; |
303 | 4 | fn index_value(&self) -> Cow<'_, str> { |
304 | 4 | Cow::Owned(format!("{}", self.0)) |
305 | 4 | } |
306 | | } |
307 | | impl SchedulerStoreDecodeTo for SearchUniqueQualifierToAwaitedAction<'_> { |
308 | | type DecodeOutput = AwaitedAction; |
309 | 2 | fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> { |
310 | 2 | awaited_action_decode(version, &data) |
311 | 2 | } |
312 | | } |
313 | | |
314 | | struct SearchStateToAwaitedAction(&'static str); |
315 | | impl SchedulerIndexProvider for SearchStateToAwaitedAction { |
316 | | const KEY_PREFIX: &'static str = OPERATION_ID_TO_AWAITED_ACTION_KEY_PREFIX; |
317 | | const INDEX_NAME: &'static str = "state"; |
318 | | const MAYBE_SORT_KEY: Option<&'static str> = Some("sort_key"); |
319 | | type Versioned = TrueValue; |
320 | 17 | fn index_value(&self) -> Cow<'_, str> { |
321 | 17 | Cow::Borrowed(self.0) |
322 | 17 | } |
323 | | } |
324 | | impl SchedulerStoreDecodeTo for SearchStateToAwaitedAction { |
325 | | type DecodeOutput = AwaitedAction; |
326 | 5 | fn decode(version: i64, data: Bytes) -> Result<Self::DecodeOutput, Error> { |
327 | 5 | awaited_action_decode(version, &data) |
328 | 5 | } |
329 | | } |
330 | | |
331 | 30 | const fn get_state_prefix(state: SortedAwaitedActionState) -> &'static str { |
332 | 30 | match state { |
333 | 0 | SortedAwaitedActionState::CacheCheck => "cache_check", |
334 | 24 | SortedAwaitedActionState::Queued => "queued", |
335 | 5 | SortedAwaitedActionState::Executing => "executing", |
336 | 1 | SortedAwaitedActionState::Completed => "completed", |
337 | | } |
338 | 30 | } |
339 | | |
340 | | struct UpdateOperationIdToAwaitedAction(AwaitedAction); |
341 | | impl SchedulerCurrentVersionProvider for UpdateOperationIdToAwaitedAction { |
342 | 13 | fn current_version(&self) -> i64 { |
343 | 13 | self.0.version() |
344 | 13 | } |
345 | | } |
346 | | impl SchedulerStoreKeyProvider for UpdateOperationIdToAwaitedAction { |
347 | | type Versioned = TrueValue; |
348 | 13 | fn get_key(&self) -> StoreKey<'static> { |
349 | 13 | OperationIdToAwaitedAction(Cow::Borrowed(self.0.operation_id())).get_key() |
350 | 13 | } |
351 | | } |
352 | | impl SchedulerStoreDataProvider for UpdateOperationIdToAwaitedAction { |
353 | 13 | fn try_into_bytes(self) -> Result<Bytes, Error> { |
354 | 13 | serde_json::to_string(&self.0) |
355 | 13 | .map(Bytes::from) |
356 | 13 | .map_err(|e| make_input_err!("Could not convert AwaitedAction to json - {e:?}")) |
357 | 13 | } |
358 | 13 | fn get_indexes(&self) -> Result<Vec<(&'static str, Bytes)>, Error> { |
359 | 13 | let unique_qualifier = &self.0.action_info().unique_qualifier; |
360 | 13 | let maybe_unique_qualifier = match &unique_qualifier { |
361 | 13 | ActionUniqueQualifier::Cacheable(_) => Some(unique_qualifier), |
362 | 0 | ActionUniqueQualifier::Uncacheable(_) => None, |
363 | | }; |
364 | 13 | let mut output = Vec::with_capacity(2 + maybe_unique_qualifier.map_or(0, |_| 1)); |
365 | 13 | if maybe_unique_qualifier.is_some() { Branch (365:12): [True: 13, False: 0]
Branch (365:12): [Folded - Ignored]
|
366 | 13 | output.push(( |
367 | 13 | "unique_qualifier", |
368 | 13 | Bytes::from(unique_qualifier.to_string()), |
369 | 13 | )); |
370 | 13 | }0 |
371 | | { |
372 | 13 | let state = SortedAwaitedActionState::try_from(&self.0.state().stage) |
373 | 13 | .err_tip(|| "In UpdateOperationIdToAwaitedAction::get_index")?0 ; |
374 | 13 | output.push(("state", Bytes::from(get_state_prefix(state)))); |
375 | 13 | let sorted_awaited_action = SortedAwaitedAction::from(&self.0); |
376 | 13 | output.push(( |
377 | 13 | "sort_key", |
378 | 13 | // We encode to hex to ensure that the sort key is lexicographically sorted. |
379 | 13 | Bytes::from(format!("{:016x}", sorted_awaited_action.sort_key.as_u64())), |
380 | 13 | )); |
381 | | } |
382 | 13 | Ok(output) |
383 | 13 | } |
384 | | } |
385 | | |
386 | | struct UpdateClientIdToOperationId { |
387 | | client_operation_id: ClientOperationId, |
388 | | operation_id: OperationId, |
389 | | } |
390 | | impl SchedulerStoreKeyProvider for UpdateClientIdToOperationId { |
391 | | type Versioned = FalseValue; |
392 | 4 | fn get_key(&self) -> StoreKey<'static> { |
393 | 4 | ClientIdToOperationId(&self.client_operation_id).get_key() |
394 | 4 | } |
395 | | } |
396 | | impl SchedulerStoreDataProvider for UpdateClientIdToOperationId { |
397 | 4 | fn try_into_bytes(self) -> Result<Bytes, Error> { |
398 | 4 | serde_json::to_string(&self.operation_id) |
399 | 4 | .map(Bytes::from) |
400 | 4 | .map_err(|e| make_input_err!("Could not convert OperationId to json - {e:?}")) |
401 | 4 | } |
402 | | } |
403 | | |
404 | 9 | async fn inner_update_awaited_action( |
405 | 9 | store: &impl SchedulerStore, |
406 | 9 | mut new_awaited_action: AwaitedAction, |
407 | 9 | ) -> Result<(), Error> { |
408 | 9 | let operation_id = new_awaited_action.operation_id().clone(); |
409 | 9 | if new_awaited_action.state().client_operation_id != operation_id { Branch (409:8): [True: 0, False: 0]
Branch (409:8): [Folded - Ignored]
Branch (409:8): [True: 0, False: 9]
|
410 | 0 | // Just in case the client_operation_id was set to something else |
411 | 0 | // we put it back to the underlying operation_id. |
412 | 0 | new_awaited_action.set_client_operation_id(operation_id.clone()); |
413 | 9 | } |
414 | 9 | let maybe_version = store |
415 | 9 | .update_data(UpdateOperationIdToAwaitedAction(new_awaited_action)) |
416 | 9 | .await |
417 | 9 | .err_tip(|| "In RedisAwaitedActionDb::update_awaited_action")?0 ; |
418 | 9 | if maybe_version.is_none() { Branch (418:8): [True: 0, False: 0]
Branch (418:8): [Folded - Ignored]
Branch (418:8): [True: 0, False: 9]
|
419 | 0 | return Err(make_err!( |
420 | 0 | Code::Aborted, |
421 | 0 | "Could not update AwaitedAction because the version did not match for {operation_id:?}", |
422 | 0 | )); |
423 | 9 | } |
424 | 9 | Ok(()) |
425 | 9 | } |
426 | | |
427 | | #[derive(Debug, MetricsComponent)] |
428 | | pub struct StoreAwaitedActionDb<S, F, I, NowFn> |
429 | | where |
430 | | S: SchedulerStore, |
431 | | F: Fn() -> OperationId, |
432 | | I: InstantWrapper, |
433 | | NowFn: Fn() -> I, |
434 | | { |
435 | | store: Arc<S>, |
436 | | now_fn: NowFn, |
437 | | operation_id_creator: F, |
438 | | _pull_task_change_subscriber_spawn: JoinHandleDropGuard<()>, |
439 | | } |
440 | | |
441 | | impl<S, F, I, NowFn> StoreAwaitedActionDb<S, F, I, NowFn> |
442 | | where |
443 | | S: SchedulerStore, |
444 | | F: Fn() -> OperationId, |
445 | | I: InstantWrapper, |
446 | | NowFn: Fn() -> I + Send + Sync + Clone + 'static, |
447 | | { |
448 | 2 | pub fn new( |
449 | 2 | store: Arc<S>, |
450 | 2 | task_change_publisher: Arc<Notify>, |
451 | 2 | now_fn: NowFn, |
452 | 2 | operation_id_creator: F, |
453 | 2 | ) -> Result<Self, Error> { |
454 | 2 | let mut subscription = store |
455 | 2 | .subscription_manager() |
456 | 2 | .err_tip(|| "In RedisAwaitedActionDb::new")?0 |
457 | 2 | .subscribe(OperationIdToAwaitedAction(Cow::Owned(OperationId::String( |
458 | 2 | String::new(), |
459 | 2 | )))) |
460 | 2 | .err_tip(|| "In RedisAwaitedActionDb::new")?0 ; |
461 | 2 | let pull_task_change_subscriber = spawn!( |
462 | | "redis_awaited_action_db_pull_task_change_subscriber", |
463 | 2 | async move { |
464 | | loop { |
465 | 18 | let changed_res16 = subscription |
466 | 18 | .changed() |
467 | 18 | .await |
468 | 16 | .err_tip(|| "In RedisAwaitedActionDb::new"); |
469 | 16 | if let Err(err0 ) = changed_res { Branch (469:28): [True: 0, False: 0]
Branch (469:28): [Folded - Ignored]
Branch (469:28): [True: 0, False: 4]
Branch (469:28): [True: 0, False: 12]
|
470 | 0 | error!( |
471 | 0 | "Error waiting for pull task change subscriber in RedisAwaitedActionDb::new - {err:?}" |
472 | | ); |
473 | | // Sleep for a second to avoid a busy loop, then trigger the notify |
474 | | // so if a reconnect happens we let local resources know that things |
475 | | // might have changed. |
476 | 0 | tokio::time::sleep(Duration::from_secs(1)).await; |
477 | 16 | } |
478 | 16 | task_change_publisher.as_ref().notify_one(); |
479 | | } |
480 | | } |
481 | | ); |
482 | 2 | Ok(Self { |
483 | 2 | store, |
484 | 2 | now_fn, |
485 | 2 | operation_id_creator, |
486 | 2 | _pull_task_change_subscriber_spawn: pull_task_change_subscriber, |
487 | 2 | }) |
488 | 2 | } |
489 | | |
490 | | #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this |
491 | 4 | async fn try_subscribe( |
492 | 4 | &self, |
493 | 4 | client_operation_id: &ClientOperationId, |
494 | 4 | unique_qualifier: &ActionUniqueQualifier, |
495 | 4 | no_event_action_timeout: Duration, |
496 | 4 | // TODO(palfrey) To simplify the scheduler 2024 refactor, we |
497 | 4 | // removed the ability to upgrade priorities of actions. |
498 | 4 | // we should add priority upgrades back in. |
499 | 4 | _priority: i32, |
500 | 4 | ) -> Result<Option<AwaitedAction>, Error> { |
501 | 4 | match unique_qualifier { |
502 | 4 | ActionUniqueQualifier::Cacheable(_) => {} |
503 | 0 | ActionUniqueQualifier::Uncacheable(_) => return Ok(None), |
504 | | } |
505 | 4 | let stream = self |
506 | 4 | .store |
507 | 4 | .search_by_index_prefix(SearchUniqueQualifierToAwaitedAction(unique_qualifier)) |
508 | 4 | .await |
509 | 4 | .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")?0 ; |
510 | 4 | tokio::pin!(stream); |
511 | 4 | let maybe_awaited_action = stream |
512 | 4 | .try_next() |
513 | 4 | .await |
514 | 4 | .err_tip(|| "In RedisAwaitedActionDb::try_subscribe")?0 ; |
515 | 4 | match maybe_awaited_action { |
516 | 2 | Some(awaited_action) => { |
517 | | // TODO(palfrey) We don't support joining completed jobs because we |
518 | | // need to also check that all the data is still in the cache. |
519 | | // If the existing job failed then we need to set back to queued or we get |
520 | | // a version mismatch. Equally we need to check the timeout as the job |
521 | | // may be abandoned in the store. |
522 | 2 | let worker_should_update_before = (awaited_action.state().stage |
523 | 2 | == ActionStage::Executing) |
524 | 2 | .then_some(()) |
525 | 2 | .map(|()| awaited_action0 .last_worker_updated_timestamp0 ()) |
526 | 2 | .and_then(|last_worker_updated| {0 |
527 | 0 | last_worker_updated.checked_add(no_event_action_timeout) |
528 | 0 | }); |
529 | 2 | let awaited_action = if awaited_action.state().stage.is_finished() Branch (529:41): [True: 0, False: 0]
Branch (529:41): [Folded - Ignored]
Branch (529:41): [True: 0, False: 0]
Branch (529:41): [True: 1, False: 1]
|
530 | 1 | || worker_should_update_before Branch (530:24): [True: 0, False: 0]
Branch (530:24): [Folded - Ignored]
Branch (530:24): [True: 0, False: 0]
Branch (530:24): [True: 0, False: 1]
|
531 | 1 | .is_some_and(|timestamp| timestamp0 < (self.now_fn)().now()0 ) |
532 | | { |
533 | 1 | tracing::debug!( |
534 | 1 | "Recreating action {:?} for operation {client_operation_id}", |
535 | 1 | awaited_action.action_info().digest() |
536 | | ); |
537 | | // The version is reset because we have a new operation ID. |
538 | 1 | AwaitedAction::new( |
539 | 1 | (self.operation_id_creator)(), |
540 | 1 | awaited_action.action_info().clone(), |
541 | 1 | (self.now_fn)().now(), |
542 | | ) |
543 | | } else { |
544 | 1 | tracing::debug!( |
545 | 1 | "Subscribing to existing action {:?} for operation {client_operation_id}", |
546 | 1 | awaited_action.action_info().digest() |
547 | | ); |
548 | 1 | awaited_action |
549 | | }; |
550 | 2 | Ok(Some(awaited_action)) |
551 | | } |
552 | 2 | None => Ok(None), |
553 | | } |
554 | 4 | } |
555 | | |
556 | | #[expect(clippy::future_not_send)] // TODO(jhpratt) remove this |
557 | 2 | async fn inner_get_awaited_action_by_id( |
558 | 2 | &self, |
559 | 2 | client_operation_id: &ClientOperationId, |
560 | 2 | ) -> Result<Option<OperationSubscriber<S, I, NowFn>>, Error> { |
561 | 2 | let maybe_operation_id = self |
562 | 2 | .store |
563 | 2 | .get_and_decode(ClientIdToOperationId(client_operation_id)) |
564 | 2 | .await |
565 | 2 | .err_tip(|| "In RedisAwaitedActionDb::get_awaited_action_by_id")?0 ; |
566 | 2 | let Some(operation_id) = maybe_operation_id else { Branch (566:13): [True: 0, False: 0]
Branch (566:13): [Folded - Ignored]
Branch (566:13): [True: 1, False: 0]
Branch (566:13): [True: 1, False: 0]
|
567 | 0 | return Ok(None); |
568 | | }; |
569 | 2 | Ok(Some(OperationSubscriber::new( |
570 | 2 | Some(client_operation_id.clone()), |
571 | 2 | OperationIdToAwaitedAction(Cow::Owned(operation_id)), |
572 | 2 | Arc::downgrade(&self.store), |
573 | 2 | self.now_fn.clone(), |
574 | 2 | ))) |
575 | 2 | } |
576 | | } |
577 | | |
578 | | impl<S, F, I, NowFn> AwaitedActionDb for StoreAwaitedActionDb<S, F, I, NowFn> |
579 | | where |
580 | | S: SchedulerStore, |
581 | | F: Fn() -> OperationId + Send + Sync + Unpin + 'static, |
582 | | I: InstantWrapper, |
583 | | NowFn: Fn() -> I + Send + Sync + Unpin + Clone + 'static, |
584 | | { |
585 | | type Subscriber = OperationSubscriber<S, I, NowFn>; |
586 | | |
587 | 2 | async fn get_awaited_action_by_id( |
588 | 2 | &self, |
589 | 2 | client_operation_id: &ClientOperationId, |
590 | 2 | ) -> Result<Option<Self::Subscriber>, Error> { |
591 | 2 | self.inner_get_awaited_action_by_id(client_operation_id) |
592 | 2 | .await |
593 | 2 | } |
594 | | |
595 | 8 | async fn get_by_operation_id( |
596 | 8 | &self, |
597 | 8 | operation_id: &OperationId, |
598 | 8 | ) -> Result<Option<Self::Subscriber>, Error> { |
599 | 8 | Ok(Some(OperationSubscriber::new( |
600 | 8 | None, |
601 | 8 | OperationIdToAwaitedAction(Cow::Owned(operation_id.clone())), |
602 | 8 | Arc::downgrade(&self.store), |
603 | 8 | self.now_fn.clone(), |
604 | 8 | ))) |
605 | 8 | } |
606 | | |
607 | 9 | async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> { |
608 | 9 | inner_update_awaited_action(self.store.as_ref(), new_awaited_action).await |
609 | 9 | } |
610 | | |
611 | 4 | async fn add_action( |
612 | 4 | &self, |
613 | 4 | client_operation_id: ClientOperationId, |
614 | 4 | action_info: Arc<ActionInfo>, |
615 | 4 | no_event_action_timeout: Duration, |
616 | 4 | ) -> Result<Self::Subscriber, Error> { |
617 | | loop { |
618 | | // Check to see if the action is already known and subscribe if it is. |
619 | 4 | let mut awaited_action = self |
620 | 4 | .try_subscribe( |
621 | 4 | &client_operation_id, |
622 | 4 | &action_info.unique_qualifier, |
623 | 4 | no_event_action_timeout, |
624 | 4 | action_info.priority, |
625 | 4 | ) |
626 | 4 | .await |
627 | 4 | .err_tip(|| "In RedisAwaitedActionDb::add_action")?0 |
628 | 4 | .unwrap_or_else(|| {2 |
629 | 2 | tracing::debug!( |
630 | 2 | "Creating new action {:?} for operation {client_operation_id}", |
631 | 2 | action_info.digest() |
632 | | ); |
633 | 2 | AwaitedAction::new( |
634 | 2 | (self.operation_id_creator)(), |
635 | 2 | action_info.clone(), |
636 | 2 | (self.now_fn)().now(), |
637 | | ) |
638 | 2 | }); |
639 | | |
640 | 4 | debug_assert!( |
641 | 0 | ActionStage::Queued == awaited_action.state().stage, |
642 | 0 | "Expected action to be queued" |
643 | | ); |
644 | | |
645 | 4 | let operation_id = awaited_action.operation_id().clone(); |
646 | 4 | if awaited_action.state().client_operation_id != operation_id { Branch (646:16): [True: 0, False: 0]
Branch (646:16): [Folded - Ignored]
Branch (646:16): [True: 0, False: 1]
Branch (646:16): [True: 0, False: 3]
|
647 | 0 | // Just in case the client_operation_id was set to something else |
648 | 0 | // we put it back to the underlying operation_id. |
649 | 0 | awaited_action.set_client_operation_id(operation_id.clone()); |
650 | 4 | } |
651 | 4 | awaited_action.update_client_keep_alive((self.now_fn)().now()); |
652 | | |
653 | 4 | let version = awaited_action.version(); |
654 | 4 | if self Branch (654:16): [True: 0, False: 0]
Branch (654:16): [Folded - Ignored]
Branch (654:16): [True: 0, False: 1]
Branch (654:16): [True: 0, False: 3]
|
655 | 4 | .store |
656 | 4 | .update_data(UpdateOperationIdToAwaitedAction(awaited_action)) |
657 | 4 | .await |
658 | 4 | .err_tip(|| "In RedisAwaitedActionDb::update_awaited_action")?0 |
659 | 4 | .is_none() |
660 | | { |
661 | | // The version was out of date, try again. |
662 | 0 | tracing::debug!( |
663 | 0 | "Version out of date for {:?} {operation_id} {version}, retrying.", |
664 | 0 | action_info.digest() |
665 | | ); |
666 | 0 | continue; |
667 | 4 | } |
668 | | |
669 | | // Add the client_operation_id to operation_id mapping |
670 | 4 | self.store |
671 | 4 | .update_data(UpdateClientIdToOperationId { |
672 | 4 | client_operation_id: client_operation_id.clone(), |
673 | 4 | operation_id: operation_id.clone(), |
674 | 4 | }) |
675 | 4 | .await |
676 | 4 | .err_tip(|| "In RedisAwaitedActionDb::try_subscribe while adding client mapping")?0 ; |
677 | | |
678 | 4 | return Ok(OperationSubscriber::new( |
679 | 4 | Some(client_operation_id), |
680 | 4 | OperationIdToAwaitedAction(Cow::Owned(operation_id)), |
681 | 4 | Arc::downgrade(&self.store), |
682 | 4 | self.now_fn.clone(), |
683 | 4 | )); |
684 | | } |
685 | 4 | } |
686 | | |
687 | 17 | async fn get_range_of_actions( |
688 | 17 | &self, |
689 | 17 | state: SortedAwaitedActionState, |
690 | 17 | start: Bound<SortedAwaitedAction>, |
691 | 17 | end: Bound<SortedAwaitedAction>, |
692 | 17 | desc: bool, |
693 | 17 | ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> { |
694 | 17 | if !matches!0 (start, Bound::Unbounded) { Branch (694:12): [True: 0, False: 0]
Branch (694:12): [Folded - Ignored]
Branch (694:12): [True: 0, False: 17]
|
695 | 0 | return Err(make_err!( |
696 | 0 | Code::Unimplemented, |
697 | 0 | "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions", |
698 | 0 | )); |
699 | 17 | } |
700 | 17 | if !matches!0 (end, Bound::Unbounded) { Branch (700:12): [True: 0, False: 0]
Branch (700:12): [Folded - Ignored]
Branch (700:12): [True: 0, False: 17]
|
701 | 0 | return Err(make_err!( |
702 | 0 | Code::Unimplemented, |
703 | 0 | "Start bound is not supported in RedisAwaitedActionDb::get_range_of_actions", |
704 | 0 | )); |
705 | 17 | } |
706 | | // TODO(palfrey) This API is not difficult to implement, but there is no code path |
707 | | // that uses it, so no reason to implement it yet. |
708 | 17 | if !desc { Branch (708:12): [True: 0, False: 0]
Branch (708:12): [Folded - Ignored]
Branch (708:12): [True: 0, False: 17]
|
709 | 0 | return Err(make_err!( |
710 | 0 | Code::Unimplemented, |
711 | 0 | "Descending order is not supported in RedisAwaitedActionDb::get_range_of_actions", |
712 | 0 | )); |
713 | 17 | } |
714 | 17 | Ok(self |
715 | 17 | .store |
716 | 17 | .search_by_index_prefix(SearchStateToAwaitedAction(get_state_prefix(state))) |
717 | 17 | .await |
718 | 17 | .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")?0 |
719 | 17 | .map_ok(move |awaited_action| {5 |
720 | 5 | OperationSubscriber::new( |
721 | 5 | None, |
722 | 5 | OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())), |
723 | 5 | Arc::downgrade(&self.store), |
724 | 5 | self.now_fn.clone(), |
725 | | ) |
726 | 5 | })) |
727 | 17 | } |
728 | | |
729 | 0 | async fn get_all_awaited_actions( |
730 | 0 | &self, |
731 | 0 | ) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> { |
732 | 0 | Ok(self |
733 | 0 | .store |
734 | 0 | .search_by_index_prefix(SearchStateToAwaitedAction("")) |
735 | 0 | .await |
736 | 0 | .err_tip(|| "In RedisAwaitedActionDb::get_range_of_actions")? |
737 | 0 | .map_ok(move |awaited_action| { |
738 | 0 | OperationSubscriber::new( |
739 | 0 | None, |
740 | 0 | OperationIdToAwaitedAction(Cow::Owned(awaited_action.operation_id().clone())), |
741 | 0 | Arc::downgrade(&self.store), |
742 | 0 | self.now_fn.clone(), |
743 | | ) |
744 | 0 | })) |
745 | 0 | } |
746 | | } |