Coverage Report

Created: 2025-04-19 16:54

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/worker.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::HashMap;
16
use std::hash::{Hash, Hasher};
17
use std::sync::Arc;
18
use std::time::{SystemTime, UNIX_EPOCH};
19
20
use nativelink_error::{Code, Error, ResultExt, make_err};
21
use nativelink_metric::MetricsComponent;
22
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
23
    ConnectionResult, StartExecute, UpdateForWorker, update_for_worker,
24
};
25
use nativelink_util::action_messages::{ActionInfo, OperationId, WorkerId};
26
use nativelink_util::metrics_utils::{AsyncCounterWrapper, CounterWithTime, FuncCounterWrapper};
27
use nativelink_util::origin_event::OriginEventContext;
28
use nativelink_util::platform_properties::{PlatformProperties, PlatformPropertyValue};
29
use tokio::sync::mpsc::UnboundedSender;
30
31
pub type WorkerTimestamp = u64;
32
33
/// Represents the action info and the platform properties of the action.
34
/// These platform properties have the type of the properties as well as
35
/// the value of the properties, unlike `ActionInfo`, which only has the
36
/// string value of the properties.
37
#[derive(Clone, Debug, MetricsComponent)]
38
pub struct ActionInfoWithProps {
39
    /// The action info of the action.
40
    #[metric(group = "action_info")]
41
    pub inner: Arc<ActionInfo>,
42
    /// The platform properties of the action.
43
    #[metric(group = "platform_properties")]
44
    pub platform_properties: PlatformProperties,
45
}
46
47
/// Notifications to send worker about a requested state change.
48
#[derive(Debug)]
49
pub enum WorkerUpdate {
50
    /// Requests that the worker begin executing this action.
51
    RunAction((OperationId, ActionInfoWithProps)),
52
53
    /// Request that the worker is no longer in the pool and may discard any jobs.
54
    Disconnect,
55
}
56
57
#[derive(Debug, MetricsComponent)]
58
pub struct PendingActionInfoData {
59
    #[metric]
60
    pub action_info: ActionInfoWithProps,
61
    ctx: OriginEventContext<StartExecute>,
62
}
63
64
/// Represents a connection to a worker and used as the medium to
65
/// interact with the worker from the client/scheduler.
66
#[derive(Debug, MetricsComponent)]
67
pub struct Worker {
68
    /// Unique identifier of the worker.
69
    #[metric(help = "The unique identifier of the worker.")]
70
    pub id: WorkerId,
71
72
    /// Properties that describe the capabilities of this worker.
73
    #[metric(group = "platform_properties")]
74
    pub platform_properties: PlatformProperties,
75
76
    /// Channel to send commands from scheduler to worker.
77
    pub tx: UnboundedSender<UpdateForWorker>,
78
79
    /// The action info of the running actions on the worker.
80
    #[metric(group = "running_action_infos")]
81
    pub running_action_infos: HashMap<OperationId, PendingActionInfoData>,
82
83
    /// Timestamp of last time this worker had been communicated with.
84
    // Warning: Do not update this timestamp without updating the placement of the worker in
85
    // the LRUCache in the Workers struct.
86
    #[metric(help = "Last time this worker was communicated with.")]
87
    pub last_update_timestamp: WorkerTimestamp,
88
89
    /// Whether the worker rejected the last action due to back pressure.
90
    #[metric(help = "If the worker is paused.")]
91
    pub is_paused: bool,
92
93
    /// Whether the worker is draining.
94
    #[metric(help = "If the worker is draining.")]
95
    pub is_draining: bool,
96
97
    /// Stats about the worker.
98
    #[metric]
99
    metrics: Arc<Metrics>,
100
}
101
102
66
fn send_msg_to_worker(
103
66
    tx: &mut UnboundedSender<UpdateForWorker>,
104
66
    msg: update_for_worker::Update,
105
66
) -> Result<(), Error> {
106
66
    tx.send(UpdateForWorker { update: Some(msg) })
107
66
        .map_err(|_| 
make_err!(Code::Internal, "Worker disconnected")3
)
108
66
}
109
110
/// Reduces the platform properties available on the worker based on the platform properties provided.
111
/// This is used because we allow more than 1 job to run on a worker at a time, and this is how the
112
/// scheduler knows if more jobs can run on a given worker.
113
28
fn reduce_platform_properties(
114
28
    parent_props: &mut PlatformProperties,
115
28
    reduction_props: &PlatformProperties,
116
28
) {
117
28
    debug_assert!(
reduction_props.is_satisfied_by(parent_props0
));
118
32
    for (
property, prop_value4
) in &reduction_props.properties {
119
4
        if let PlatformPropertyValue::Minimum(
value3
) = prop_value {
  Branch (119:16): [True: 3, False: 1]
  Branch (119:16): [Folded - Ignored]
120
3
            let worker_props = &mut parent_props.properties;
121
3
            if let &mut PlatformPropertyValue::Minimum(worker_value) =
  Branch (121:20): [True: 3, False: 0]
  Branch (121:20): [Folded - Ignored]
122
3
                &mut worker_props.get_mut(property).unwrap()
123
3
            {
124
3
                *worker_value -= value;
125
3
            
}0
126
1
        }
127
    }
128
28
}
129
130
impl Worker {
131
31
    pub fn new(
132
31
        id: WorkerId,
133
31
        platform_properties: PlatformProperties,
134
31
        tx: UnboundedSender<UpdateForWorker>,
135
31
        timestamp: WorkerTimestamp,
136
31
    ) -> Self {
137
31
        Self {
138
31
            id,
139
31
            platform_properties,
140
31
            tx,
141
31
            running_action_infos: HashMap::new(),
142
31
            last_update_timestamp: timestamp,
143
31
            is_paused: false,
144
31
            is_draining: false,
145
31
            metrics: Arc::new(Metrics {
146
31
                connected_timestamp: SystemTime::now()
147
31
                    .duration_since(UNIX_EPOCH)
148
31
                    .unwrap()
149
31
                    .as_secs(),
150
31
                actions_completed: CounterWithTime::default(),
151
31
                run_action: AsyncCounterWrapper::default(),
152
31
                keep_alive: FuncCounterWrapper::default(),
153
31
                notify_disconnect: CounterWithTime::default(),
154
31
            }),
155
31
        }
156
31
    }
157
158
    /// Sends the initial connection information to the worker. This generally is just meta info.
159
    /// This should only be sent once and should always be the first item in the stream.
160
31
    pub fn send_initial_connection_result(&mut self) -> Result<(), Error> {
161
31
        send_msg_to_worker(
162
31
            &mut self.tx,
163
31
            update_for_worker::Update::ConnectionResult(ConnectionResult {
164
31
                worker_id: self.id.clone().into(),
165
31
            }),
166
31
        )
167
31
        .err_tip(|| 
format!("Failed to send ConnectionResult to worker : {}", self.id)0
)
168
31
    }
169
170
    /// Notifies the worker of a requested state change.
171
34
    pub async fn notify_update(&mut self, worker_update: WorkerUpdate) -> Result<(), Error> {
172
34
        match worker_update {
173
28
            WorkerUpdate::RunAction((operation_id, action_info)) => {
174
28
                self.run_action(operation_id, action_info).await
175
            }
176
            WorkerUpdate::Disconnect => {
177
6
                self.metrics.notify_disconnect.inc();
178
6
                send_msg_to_worker(&mut self.tx, update_for_worker::Update::Disconnect(()))
179
            }
180
        }
181
34
    }
182
183
1
    pub fn keep_alive(&mut self) -> Result<(), Error> {
184
1
        let tx = &mut self.tx;
185
1
        let id = &self.id;
186
1
        self.metrics.keep_alive.wrap(move || {
187
1
            send_msg_to_worker(tx, update_for_worker::Update::KeepAlive(()))
188
1
                .err_tip(|| 
format!("Failed to send KeepAlive to worker : {id}")0
)
189
1
        })
190
1
    }
191
192
28
    async fn run_action(
193
28
        &mut self,
194
28
        operation_id: OperationId,
195
28
        action_info: ActionInfoWithProps,
196
28
    ) -> Result<(), Error> {
197
28
        let tx = &mut self.tx;
198
28
        let worker_platform_properties = &mut self.platform_properties;
199
28
        let running_action_infos = &mut self.running_action_infos;
200
28
        let worker_id = self.id.clone().into();
201
28
        self.metrics
202
28
            .run_action
203
28
            .wrap(async move {
204
28
                let action_info_clone = action_info.clone();
205
28
                let operation_id_string = operation_id.to_string();
206
28
                let start_execute = StartExecute {
207
28
                    execute_request: Some(action_info_clone.inner.as_ref().into()),
208
28
                    operation_id: operation_id_string,
209
28
                    queued_timestamp: Some(action_info.inner.insert_timestamp.into()),
210
28
                    platform: Some((&action_info.platform_properties).into()),
211
28
                    worker_id,
212
28
                };
213
28
                reduce_platform_properties(
214
28
                    worker_platform_properties,
215
28
                    &action_info.platform_properties,
216
                );
217
218
28
                let ctx = OriginEventContext::new(|| 
&start_execute0
).await;
219
28
                running_action_infos
220
28
                    .insert(operation_id, PendingActionInfoData { action_info, ctx });
221
28
222
28
                send_msg_to_worker(tx, update_for_worker::Update::StartAction(start_execute))
223
28
            })
224
28
            .await
225
28
    }
226
227
9
    pub(crate) async fn complete_action(
228
9
        &mut self,
229
9
        operation_id: &OperationId,
230
9
    ) -> Result<(), Error> {
231
9
        let pending_action_info = self.running_action_infos.remove(operation_id).err_tip(|| {
232
0
            format!(
233
0
                "Worker {} tried to complete operation {} that was not running",
234
0
                self.id, operation_id
235
0
            )
236
0
        })?;
237
9
        pending_action_info.ctx.emit(|| 
&()0
).await;
238
9
        self.restore_platform_properties(&pending_action_info.action_info.platform_properties);
239
9
        self.is_paused = false;
240
9
        self.metrics.actions_completed.inc();
241
9
        Ok(())
242
9
    }
243
244
0
    pub fn has_actions(&self) -> bool {
245
0
        !self.running_action_infos.is_empty()
246
0
    }
247
248
9
    fn restore_platform_properties(&mut self, props: &PlatformProperties) {
249
11
        for (
property, prop_value2
) in &props.properties {
250
2
            if let PlatformPropertyValue::Minimum(value) = prop_value {
  Branch (250:20): [True: 2, False: 0]
  Branch (250:20): [Folded - Ignored]
251
2
                let worker_props = &mut self.platform_properties.properties;
252
2
                if let PlatformPropertyValue::Minimum(worker_value) =
  Branch (252:24): [True: 2, False: 0]
  Branch (252:24): [Folded - Ignored]
253
2
                    worker_props.get_mut(property).unwrap()
254
2
                {
255
2
                    *worker_value += value;
256
2
                
}0
257
0
            }
258
        }
259
9
    }
260
261
43
    pub const fn can_accept_work(&self) -> bool {
262
43
        !self.is_paused && !self.is_draining
  Branch (262:9): [True: 43, False: 0]
  Branch (262:9): [Folded - Ignored]
263
43
    }
264
}
265
266
impl PartialEq for Worker {
267
0
    fn eq(&self, other: &Self) -> bool {
268
0
        self.id == other.id
269
0
    }
270
}
271
272
impl Eq for Worker {}
273
274
impl Hash for Worker {
275
0
    fn hash<H: Hasher>(&self, state: &mut H) {
276
0
        self.id.hash(state);
277
0
    }
278
}
279
280
#[derive(Debug, Default, MetricsComponent)]
281
struct Metrics {
282
    #[metric(help = "The timestamp of when this worker connected.")]
283
    connected_timestamp: u64,
284
    #[metric(help = "The number of actions completed for this worker.")]
285
    actions_completed: CounterWithTime,
286
    #[metric(help = "The number of actions started for this worker.")]
287
    run_action: AsyncCounterWrapper,
288
    #[metric(help = "The number of keep_alive sent to this worker.")]
289
    keep_alive: FuncCounterWrapper,
290
    #[metric(help = "The number of notify_disconnect sent to this worker.")]
291
    notify_disconnect: CounterWithTime,
292
}