/build/source/nativelink-scheduler/src/awaited_action_db/mod.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::cmp; |
16 | | use std::ops::Bound; |
17 | | use std::sync::Arc; |
18 | | use std::time::Duration; |
19 | | |
20 | | pub use awaited_action::{AwaitedAction, AwaitedActionSortKey}; |
21 | | use futures::{Future, Stream}; |
22 | | use nativelink_error::{make_input_err, Error, ResultExt}; |
23 | | use nativelink_metric::MetricsComponent; |
24 | | use nativelink_util::action_messages::{ActionInfo, ActionStage, OperationId}; |
25 | | use serde::{Deserialize, Serialize}; |
26 | | |
27 | | mod awaited_action; |
28 | | |
29 | | /// Duration to wait before sending client keep alive messages. |
30 | | pub const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10); |
31 | | |
32 | | /// A simple enum to represent the state of an `AwaitedAction`. |
33 | | #[derive(Debug, Clone, Copy)] |
34 | | pub enum SortedAwaitedActionState { |
35 | | CacheCheck, |
36 | | Queued, |
37 | | Executing, |
38 | | Completed, |
39 | | } |
40 | | |
41 | | impl TryFrom<&ActionStage> for SortedAwaitedActionState { |
42 | | type Error = Error; |
43 | 2 | fn try_from(value: &ActionStage) -> Result<Self, Error> { |
44 | 2 | match value { |
45 | 0 | ActionStage::CacheCheck => Ok(Self::CacheCheck), |
46 | 1 | ActionStage::Executing => Ok(Self::Executing), |
47 | 0 | ActionStage::Completed(_) => Ok(Self::Completed), |
48 | 1 | ActionStage::Queued => Ok(Self::Queued), |
49 | 0 | _ => Err(make_input_err!("Invalid State")), |
50 | | } |
51 | 2 | } |
52 | | } |
53 | | |
54 | | impl TryFrom<ActionStage> for SortedAwaitedActionState { |
55 | | type Error = Error; |
56 | 0 | fn try_from(value: ActionStage) -> Result<Self, Error> { |
57 | 0 | Self::try_from(&value) |
58 | 0 | } |
59 | | } |
60 | | |
61 | | /// A struct pointing to an `AwaitedAction` that can be sorted. |
62 | | #[derive(Debug, Clone, Serialize, Deserialize, MetricsComponent)] |
63 | | pub struct SortedAwaitedAction { |
64 | | #[metric(help = "The sort key of the AwaitedAction")] |
65 | | pub sort_key: AwaitedActionSortKey, |
66 | | #[metric(help = "The operation id")] |
67 | | pub operation_id: OperationId, |
68 | | } |
69 | | |
70 | | impl PartialEq for SortedAwaitedAction { |
71 | 0 | fn eq(&self, other: &Self) -> bool { |
72 | 0 | self.sort_key == other.sort_key && self.operation_id == other.operation_id Branch (72:9): [True: 0, False: 0]
Branch (72:9): [Folded - Ignored]
|
73 | 0 | } |
74 | | } |
75 | | |
76 | | impl Eq for SortedAwaitedAction {} |
77 | | |
78 | | impl PartialOrd for SortedAwaitedAction { |
79 | 0 | fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> { |
80 | 0 | Some(self.cmp(other)) |
81 | 0 | } |
82 | | } |
83 | | |
84 | | impl Ord for SortedAwaitedAction { |
85 | 565 | fn cmp(&self, other: &Self) -> cmp::Ordering { |
86 | 565 | self.sort_key |
87 | 565 | .cmp(&other.sort_key) |
88 | 565 | .then_with(|| self.operation_id.cmp(&other.operation_id)557 ) |
89 | 565 | } |
90 | | } |
91 | | |
92 | | impl std::fmt::Display for SortedAwaitedAction { |
93 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
94 | 0 | std::fmt::write( |
95 | 0 | f, |
96 | 0 | format_args!("{}-{}", self.sort_key.as_u64(), self.operation_id), |
97 | 0 | ) |
98 | 0 | } |
99 | | } |
100 | | |
101 | | impl From<&AwaitedAction> for SortedAwaitedAction { |
102 | 2 | fn from(value: &AwaitedAction) -> Self { |
103 | 2 | Self { |
104 | 2 | operation_id: value.operation_id().clone(), |
105 | 2 | sort_key: value.sort_key(), |
106 | 2 | } |
107 | 2 | } |
108 | | } |
109 | | |
110 | | impl From<AwaitedAction> for SortedAwaitedAction { |
111 | 0 | fn from(value: AwaitedAction) -> Self { |
112 | 0 | Self::from(&value) |
113 | 0 | } |
114 | | } |
115 | | |
116 | | impl TryInto<Vec<u8>> for SortedAwaitedAction { |
117 | | type Error = Error; |
118 | 0 | fn try_into(self) -> Result<Vec<u8>, Self::Error> { |
119 | 0 | serde_json::to_vec(&self) |
120 | 0 | .map_err(|e| make_input_err!("{}", e.to_string())) |
121 | 0 | .err_tip(|| "In SortedAwaitedAction::TryInto::<Vec<u8>>") |
122 | 0 | } |
123 | | } |
124 | | |
125 | | impl TryFrom<&[u8]> for SortedAwaitedAction { |
126 | | type Error = Error; |
127 | 0 | fn try_from(value: &[u8]) -> Result<Self, Self::Error> { |
128 | 0 | serde_json::from_slice(value) |
129 | 0 | .map_err(|e| make_input_err!("{}", e.to_string())) |
130 | 0 | .err_tip(|| "In AwaitedAction::TryFrom::&[u8]") |
131 | 0 | } |
132 | | } |
133 | | |
134 | | /// Subscriber that can be used to monitor when `AwaitedActions` change. |
135 | | pub trait AwaitedActionSubscriber: Send + Sync + Sized + 'static { |
136 | | /// Wait for `AwaitedAction` to change. |
137 | | fn changed(&mut self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send; |
138 | | |
139 | | /// Get the current awaited action. |
140 | | fn borrow(&self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send; |
141 | | } |
142 | | |
143 | | /// A trait that defines the interface for an `AwaitedActionDb`. |
144 | | pub trait AwaitedActionDb: Send + Sync + MetricsComponent + Unpin + 'static { |
145 | | type Subscriber: AwaitedActionSubscriber; |
146 | | |
147 | | /// Get the `AwaitedAction` by the client operation id. |
148 | | fn get_awaited_action_by_id( |
149 | | &self, |
150 | | client_operation_id: &OperationId, |
151 | | ) -> impl Future<Output = Result<Option<Self::Subscriber>, Error>> + Send; |
152 | | |
153 | | /// Get all `AwaitedActions`. This call should be avoided as much as possible. |
154 | | fn get_all_awaited_actions( |
155 | | &self, |
156 | | ) -> impl Future< |
157 | | Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>, |
158 | | > + Send; |
159 | | |
160 | | /// Get the `AwaitedAction` by the operation id. |
161 | | fn get_by_operation_id( |
162 | | &self, |
163 | | operation_id: &OperationId, |
164 | | ) -> impl Future<Output = Result<Option<Self::Subscriber>, Error>> + Send; |
165 | | |
166 | | /// Get a range of `AwaitedActions` of a specific state in sorted order. |
167 | | fn get_range_of_actions( |
168 | | &self, |
169 | | state: SortedAwaitedActionState, |
170 | | start: Bound<SortedAwaitedAction>, |
171 | | end: Bound<SortedAwaitedAction>, |
172 | | desc: bool, |
173 | | ) -> impl Future< |
174 | | Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>, |
175 | | > + Send; |
176 | | |
177 | | /// Process a change changed `AwaitedAction` and notify any listeners. |
178 | | fn update_awaited_action( |
179 | | &self, |
180 | | new_awaited_action: AwaitedAction, |
181 | | ) -> impl Future<Output = Result<(), Error>> + Send; |
182 | | |
183 | | /// Add (or join) an action to the `AwaitedActionDb` and subscribe |
184 | | /// to changes. |
185 | | fn add_action( |
186 | | &self, |
187 | | client_operation_id: OperationId, |
188 | | action_info: Arc<ActionInfo>, |
189 | | ) -> impl Future<Output = Result<Self::Subscriber, Error>> + Send; |
190 | | } |