Coverage Report

Created: 2025-03-08 07:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/awaited_action_db/mod.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::cmp;
16
use std::ops::Bound;
17
use std::sync::Arc;
18
use std::time::Duration;
19
20
pub use awaited_action::{AwaitedAction, AwaitedActionSortKey};
21
use futures::{Future, Stream};
22
use nativelink_error::{make_input_err, Error, ResultExt};
23
use nativelink_metric::MetricsComponent;
24
use nativelink_util::action_messages::{ActionInfo, ActionStage, OperationId};
25
use serde::{Deserialize, Serialize};
26
27
mod awaited_action;
28
29
/// Duration to wait before sending client keep alive messages.
30
pub const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10);
31
32
/// A simple enum to represent the state of an `AwaitedAction`.
33
#[derive(Debug, Clone, Copy)]
34
pub enum SortedAwaitedActionState {
35
    CacheCheck,
36
    Queued,
37
    Executing,
38
    Completed,
39
}
40
41
impl TryFrom<&ActionStage> for SortedAwaitedActionState {
42
    type Error = Error;
43
2
    fn try_from(value: &ActionStage) -> Result<Self, Error> {
44
2
        match value {
45
0
            ActionStage::CacheCheck => Ok(Self::CacheCheck),
46
1
            ActionStage::Executing => Ok(Self::Executing),
47
0
            ActionStage::Completed(_) => Ok(Self::Completed),
48
1
            ActionStage::Queued => Ok(Self::Queued),
49
0
            _ => Err(make_input_err!("Invalid State")),
50
        }
51
2
    }
52
}
53
54
impl TryFrom<ActionStage> for SortedAwaitedActionState {
55
    type Error = Error;
56
0
    fn try_from(value: ActionStage) -> Result<Self, Error> {
57
0
        Self::try_from(&value)
58
0
    }
59
}
60
61
/// A struct pointing to an `AwaitedAction` that can be sorted.
62
#[derive(Debug, Clone, Serialize, Deserialize, MetricsComponent)]
63
pub struct SortedAwaitedAction {
64
    #[metric(help = "The sort key of the AwaitedAction")]
65
    pub sort_key: AwaitedActionSortKey,
66
    #[metric(help = "The operation id")]
67
    pub operation_id: OperationId,
68
}
69
70
impl PartialEq for SortedAwaitedAction {
71
0
    fn eq(&self, other: &Self) -> bool {
72
0
        self.sort_key == other.sort_key && self.operation_id == other.operation_id
  Branch (72:9): [True: 0, False: 0]
  Branch (72:9): [Folded - Ignored]
73
0
    }
74
}
75
76
impl Eq for SortedAwaitedAction {}
77
78
impl PartialOrd for SortedAwaitedAction {
79
0
    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
80
0
        Some(self.cmp(other))
81
0
    }
82
}
83
84
impl Ord for SortedAwaitedAction {
85
565
    fn cmp(&self, other: &Self) -> cmp::Ordering {
86
565
        self.sort_key
87
565
            .cmp(&other.sort_key)
88
565
            .then_with(|| 
self.operation_id.cmp(&other.operation_id)557
)
89
565
    }
90
}
91
92
impl std::fmt::Display for SortedAwaitedAction {
93
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94
0
        std::fmt::write(
95
0
            f,
96
0
            format_args!("{}-{}", self.sort_key.as_u64(), self.operation_id),
97
0
        )
98
0
    }
99
}
100
101
impl From<&AwaitedAction> for SortedAwaitedAction {
102
2
    fn from(value: &AwaitedAction) -> Self {
103
2
        Self {
104
2
            operation_id: value.operation_id().clone(),
105
2
            sort_key: value.sort_key(),
106
2
        }
107
2
    }
108
}
109
110
impl From<AwaitedAction> for SortedAwaitedAction {
111
0
    fn from(value: AwaitedAction) -> Self {
112
0
        Self::from(&value)
113
0
    }
114
}
115
116
impl TryInto<Vec<u8>> for SortedAwaitedAction {
117
    type Error = Error;
118
0
    fn try_into(self) -> Result<Vec<u8>, Self::Error> {
119
0
        serde_json::to_vec(&self)
120
0
            .map_err(|e| make_input_err!("{}", e.to_string()))
121
0
            .err_tip(|| "In SortedAwaitedAction::TryInto::<Vec<u8>>")
122
0
    }
123
}
124
125
impl TryFrom<&[u8]> for SortedAwaitedAction {
126
    type Error = Error;
127
0
    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
128
0
        serde_json::from_slice(value)
129
0
            .map_err(|e| make_input_err!("{}", e.to_string()))
130
0
            .err_tip(|| "In AwaitedAction::TryFrom::&[u8]")
131
0
    }
132
}
133
134
/// Subscriber that can be used to monitor when `AwaitedActions` change.
135
pub trait AwaitedActionSubscriber: Send + Sync + Sized + 'static {
136
    /// Wait for `AwaitedAction` to change.
137
    fn changed(&mut self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send;
138
139
    /// Get the current awaited action.
140
    fn borrow(&self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send;
141
}
142
143
/// A trait that defines the interface for an `AwaitedActionDb`.
144
pub trait AwaitedActionDb: Send + Sync + MetricsComponent + Unpin + 'static {
145
    type Subscriber: AwaitedActionSubscriber;
146
147
    /// Get the `AwaitedAction` by the client operation id.
148
    fn get_awaited_action_by_id(
149
        &self,
150
        client_operation_id: &OperationId,
151
    ) -> impl Future<Output = Result<Option<Self::Subscriber>, Error>> + Send;
152
153
    /// Get all `AwaitedActions`. This call should be avoided as much as possible.
154
    fn get_all_awaited_actions(
155
        &self,
156
    ) -> impl Future<
157
        Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>,
158
    > + Send;
159
160
    /// Get the `AwaitedAction` by the operation id.
161
    fn get_by_operation_id(
162
        &self,
163
        operation_id: &OperationId,
164
    ) -> impl Future<Output = Result<Option<Self::Subscriber>, Error>> + Send;
165
166
    /// Get a range of `AwaitedActions` of a specific state in sorted order.
167
    fn get_range_of_actions(
168
        &self,
169
        state: SortedAwaitedActionState,
170
        start: Bound<SortedAwaitedAction>,
171
        end: Bound<SortedAwaitedAction>,
172
        desc: bool,
173
    ) -> impl Future<
174
        Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>,
175
    > + Send;
176
177
    /// Process a change changed `AwaitedAction` and notify any listeners.
178
    fn update_awaited_action(
179
        &self,
180
        new_awaited_action: AwaitedAction,
181
    ) -> impl Future<Output = Result<(), Error>> + Send;
182
183
    /// Add (or join) an action to the `AwaitedActionDb` and subscribe
184
    /// to changes.
185
    fn add_action(
186
        &self,
187
        client_operation_id: OperationId,
188
        action_info: Arc<ActionInfo>,
189
    ) -> impl Future<Output = Result<Self::Subscriber, Error>> + Send;
190
}