Coverage Report

Created: 2025-03-08 07:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/worker.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::HashMap;
16
use std::hash::{Hash, Hasher};
17
use std::sync::Arc;
18
use std::time::{SystemTime, UNIX_EPOCH};
19
20
use nativelink_error::{make_err, Code, Error, ResultExt};
21
use nativelink_metric::MetricsComponent;
22
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
23
    update_for_worker, ConnectionResult, StartExecute, UpdateForWorker,
24
};
25
use nativelink_util::action_messages::{ActionInfo, OperationId, WorkerId};
26
use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime, FuncCounterWrapper};
27
use nativelink_util::origin_event::OriginEventContext;
28
use nativelink_util::platform_properties::{PlatformProperties, PlatformPropertyValue};
29
use tokio::sync::mpsc::UnboundedSender;
30
31
pub type WorkerTimestamp = u64;
32
33
/// Represents the action info and the platform properties of the action.
34
/// These platform properties have the type of the properties as well as
35
/// the value of the properties, unlike `ActionInfo`, which only has the
36
/// string value of the properties.
37
#[derive(Clone, Debug, MetricsComponent)]
38
pub struct ActionInfoWithProps {
39
    /// The action info of the action.
40
    #[metric(group = "action_info")]
41
    pub inner: Arc<ActionInfo>,
42
    /// The platform properties of the action.
43
    #[metric(group = "platform_properties")]
44
    pub platform_properties: PlatformProperties,
45
}
46
47
/// Notifications to send worker about a requested state change.
48
pub enum WorkerUpdate {
49
    /// Requests that the worker begin executing this action.
50
    RunAction((OperationId, ActionInfoWithProps)),
51
52
    /// Request that the worker is no longer in the pool and may discard any jobs.
53
    Disconnect,
54
}
55
56
#[derive(MetricsComponent)]
57
pub struct PendingActionInfoData {
58
    #[metric]
59
    pub action_info: ActionInfoWithProps,
60
    ctx: OriginEventContext<StartExecute>,
61
}
62
63
/// Represents a connection to a worker and used as the medium to
64
/// interact with the worker from the client/scheduler.
65
#[derive(MetricsComponent)]
66
pub struct Worker {
67
    /// Unique identifier of the worker.
68
    #[metric(help = "The unique identifier of the worker.")]
69
    pub id: WorkerId,
70
71
    /// Properties that describe the capabilities of this worker.
72
    #[metric(group = "platform_properties")]
73
    pub platform_properties: PlatformProperties,
74
75
    /// Channel to send commands from scheduler to worker.
76
    pub tx: UnboundedSender<UpdateForWorker>,
77
78
    /// The action info of the running actions on the worker.
79
    #[metric(group = "running_action_infos")]
80
    pub running_action_infos: HashMap<OperationId, PendingActionInfoData>,
81
82
    /// Timestamp of last time this worker had been communicated with.
83
    // Warning: Do not update this timestamp without updating the placement of the worker in
84
    // the LRUCache in the Workers struct.
85
    #[metric(help = "Last time this worker was communicated with.")]
86
    pub last_update_timestamp: WorkerTimestamp,
87
88
    /// Whether the worker rejected the last action due to back pressure.
89
    #[metric(help = "If the worker is paused.")]
90
    pub is_paused: bool,
91
92
    /// Whether the worker is draining.
93
    #[metric(help = "If the worker is draining.")]
94
    pub is_draining: bool,
95
96
    /// Stats about the worker.
97
    #[metric]
98
    metrics: Arc<Metrics>,
99
}
100
101
66
fn send_msg_to_worker(
102
66
    tx: &mut UnboundedSender<UpdateForWorker>,
103
66
    msg: update_for_worker::Update,
104
66
) -> Result<(), Error> {
105
66
    tx.send(UpdateForWorker { update: Some(msg) })
106
66
        .map_err(|_| 
make_err!(Code::Internal, "Worker disconnected")3
)
107
66
}
108
109
/// Reduces the platform properties available on the worker based on the platform properties provided.
110
/// This is used because we allow more than 1 job to run on a worker at a time, and this is how the
111
/// scheduler knows if more jobs can run on a given worker.
112
28
fn reduce_platform_properties(
113
28
    parent_props: &mut PlatformProperties,
114
28
    reduction_props: &PlatformProperties,
115
28
) {
116
28
    debug_assert!(
reduction_props.is_satisfied_by(parent_props)0
);
117
32
    for (
property, prop_value4
) in &reduction_props.properties {
118
4
        if let PlatformPropertyValue::Minimum(
value3
) = prop_value {
  Branch (118:16): [True: 3, False: 1]
  Branch (118:16): [Folded - Ignored]
119
3
            let worker_props = &mut parent_props.properties;
120
3
            if let &mut PlatformPropertyValue::Minimum(worker_value) =
  Branch (120:20): [True: 3, False: 0]
  Branch (120:20): [Folded - Ignored]
121
3
                &mut worker_props.get_mut(property).unwrap()
122
3
            {
123
3
                *worker_value -= value;
124
3
            
}0
125
1
        }
126
    }
127
28
}
128
129
impl Worker {
130
31
    pub fn new(
131
31
        id: WorkerId,
132
31
        platform_properties: PlatformProperties,
133
31
        tx: UnboundedSender<UpdateForWorker>,
134
31
        timestamp: WorkerTimestamp,
135
31
    ) -> Self {
136
31
        Self {
137
31
            id,
138
31
            platform_properties,
139
31
            tx,
140
31
            running_action_infos: HashMap::new(),
141
31
            last_update_timestamp: timestamp,
142
31
            is_paused: false,
143
31
            is_draining: false,
144
31
            metrics: Arc::new(Metrics {
145
31
                connected_timestamp: SystemTime::now()
146
31
                    .duration_since(UNIX_EPOCH)
147
31
                    .unwrap()
148
31
                    .as_secs(),
149
31
                actions_completed: CounterWithTime::default(),
150
31
                run_action: AsyncCounterWrapper::default(),
151
31
                keep_alive: FuncCounterWrapper::default(),
152
31
                notify_disconnect: CounterWithTime::default(),
153
31
            }),
154
31
        }
155
31
    }
156
157
    /// Sends the initial connection information to the worker. This generally is just meta info.
158
    /// This should only be sent once and should always be the first item in the stream.
159
31
    pub fn send_initial_connection_result(&mut self) -> Result<(), Error> {
160
31
        send_msg_to_worker(
161
31
            &mut self.tx,
162
31
            update_for_worker::Update::ConnectionResult(ConnectionResult {
163
31
                worker_id: self.id.clone().into(),
164
31
            }),
165
31
        )
166
31
        .err_tip(|| 
format!("Failed to send ConnectionResult to worker : {}", self.id)0
)
167
31
    }
168
169
    /// Notifies the worker of a requested state change.
170
34
    pub async fn notify_update(&mut self, worker_update: WorkerUpdate) -> Result<(), Error> {
171
34
        match worker_update {
172
28
            WorkerUpdate::RunAction((operation_id, action_info)) => {
173
28
                self.run_action(operation_id, action_info).await
174
            }
175
            WorkerUpdate::Disconnect => {
176
6
                self.metrics.notify_disconnect.inc();
177
6
                send_msg_to_worker(&mut self.tx, update_for_worker::Update::Disconnect(()))
178
            }
179
        }
180
34
    }
181
182
1
    pub fn keep_alive(&mut self) -> Result<(), Error> {
183
1
        let tx = &mut self.tx;
184
1
        let id = &self.id;
185
1
        self.metrics.keep_alive.wrap(move || {
186
1
            send_msg_to_worker(tx, update_for_worker::Update::KeepAlive(()))
187
1
                .err_tip(|| 
format!("Failed to send KeepAlive to worker : {id}")0
)
188
1
        })
189
1
    }
190
191
28
    async fn run_action(
192
28
        &mut self,
193
28
        operation_id: OperationId,
194
28
        action_info: ActionInfoWithProps,
195
28
    ) -> Result<(), Error> {
196
28
        let tx = &mut self.tx;
197
28
        let worker_platform_properties = &mut self.platform_properties;
198
28
        let running_action_infos = &mut self.running_action_infos;
199
28
        let worker_id = self.id.clone().into();
200
28
        self.metrics
201
28
            .run_action
202
28
            .wrap(async move {
203
28
                let action_info_clone = action_info.clone();
204
28
                let operation_id_string = operation_id.to_string();
205
28
                let start_execute = StartExecute {
206
28
                    execute_request: Some(action_info_clone.inner.as_ref().into()),
207
28
                    operation_id: operation_id_string,
208
28
                    queued_timestamp: Some(action_info.inner.insert_timestamp.into()),
209
28
                    platform: Some((&action_info.platform_properties).into()),
210
28
                    worker_id,
211
28
                };
212
28
                reduce_platform_properties(
213
28
                    worker_platform_properties,
214
28
                    &action_info.platform_properties,
215
28
                );
216
217
28
                let ctx = OriginEventContext::new(|| 
&start_execute0
).await;
218
28
                running_action_infos
219
28
                    .insert(operation_id, PendingActionInfoData { action_info, ctx });
220
28
221
28
                send_msg_to_worker(tx, update_for_worker::Update::StartAction(start_execute))
222
28
            })
223
28
            .await
224
28
    }
225
226
9
    pub(crate) async fn complete_action(
227
9
        &mut self,
228
9
        operation_id: &OperationId,
229
9
    ) -> Result<(), Error> {
230
9
        let pending_action_info = self.running_action_infos.remove(operation_id).err_tip(|| {
231
0
            format!(
232
0
                "Worker {} tried to complete operation {} that was not running",
233
0
                self.id, operation_id
234
0
            )
235
9
        })
?0
;
236
9
        pending_action_info.ctx.emit(|| 
&()0
).await;
237
9
        self.restore_platform_properties(&pending_action_info.action_info.platform_properties);
238
9
        self.is_paused = false;
239
9
        self.metrics.actions_completed.inc();
240
9
        Ok(())
241
9
    }
242
243
0
    pub fn has_actions(&self) -> bool {
244
0
        !self.running_action_infos.is_empty()
245
0
    }
246
247
9
    fn restore_platform_properties(&mut self, props: &PlatformProperties) {
248
11
        for (
property, prop_value2
) in &props.properties {
249
2
            if let PlatformPropertyValue::Minimum(value) = prop_value {
  Branch (249:20): [True: 2, False: 0]
  Branch (249:20): [Folded - Ignored]
250
2
                let worker_props = &mut self.platform_properties.properties;
251
2
                if let PlatformPropertyValue::Minimum(worker_value) =
  Branch (251:24): [True: 2, False: 0]
  Branch (251:24): [Folded - Ignored]
252
2
                    worker_props.get_mut(property).unwrap()
253
2
                {
254
2
                    *worker_value += value;
255
2
                
}0
256
0
            }
257
        }
258
9
    }
259
260
43
    pub const fn can_accept_work(&self) -> bool {
261
43
        !self.is_paused && !self.is_draining
  Branch (261:9): [True: 43, False: 0]
  Branch (261:9): [Folded - Ignored]
262
43
    }
263
}
264
265
impl PartialEq for Worker {
266
0
    fn eq(&self, other: &Self) -> bool {
267
0
        self.id == other.id
268
0
    }
269
}
270
271
impl Eq for Worker {}
272
273
impl Hash for Worker {
274
0
    fn hash<H: Hasher>(&self, state: &mut H) {
275
0
        self.id.hash(state);
276
0
    }
277
}
278
279
#[derive(Default, MetricsComponent)]
280
struct Metrics {
281
    #[metric(help = "The timestamp of when this worker connected.")]
282
    connected_timestamp: u64,
283
    #[metric(help = "The number of actions completed for this worker.")]
284
    actions_completed: CounterWithTime,
285
    #[metric(help = "The number of actions started for this worker.")]
286
    run_action: AsyncCounterWrapper,
287
    #[metric(help = "The number of keep_alive sent to this worker.")]
288
    keep_alive: FuncCounterWrapper,
289
    #[metric(help = "The number of notify_disconnect sent to this worker.")]
290
    notify_disconnect: CounterWithTime,
291
}