Coverage Report

Created: 2024-11-20 10:13

/build/source/nativelink-scheduler/src/api_worker_scheduler.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::ops::{Deref, DerefMut};
16
use std::sync::Arc;
17
18
use async_lock::Mutex;
19
use lru::LruCache;
20
use nativelink_config::schedulers::WorkerAllocationStrategy;
21
use nativelink_error::{error_if, make_err, make_input_err, Code, Error, ResultExt};
22
use nativelink_metric::{
23
    group, MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent,
24
    RootMetricsComponent,
25
};
26
use nativelink_util::action_messages::{OperationId, WorkerId};
27
use nativelink_util::operation_state_manager::{UpdateOperationType, WorkerStateManager};
28
use nativelink_util::platform_properties::PlatformProperties;
29
use nativelink_util::spawn;
30
use nativelink_util::task::JoinHandleDropGuard;
31
use tokio::sync::mpsc::{self, UnboundedSender};
32
use tokio::sync::Notify;
33
use tonic::async_trait;
34
use tracing::{event, Level};
35
36
use crate::platform_property_manager::PlatformPropertyManager;
37
use crate::worker::{ActionInfoWithProps, Worker, WorkerTimestamp, WorkerUpdate};
38
use crate::worker_scheduler::WorkerScheduler;
39
40
struct Workers(LruCache<WorkerId, Worker>);
41
42
impl Deref for Workers {
43
    type Target = LruCache<WorkerId, Worker>;
44
45
54
    fn deref(&self) -> &Self::Target {
46
54
        &self.0
47
54
    }
48
}
49
50
impl DerefMut for Workers {
51
109
    fn deref_mut(&mut self) -> &mut Self::Target {
52
109
        &mut self.0
53
109
    }
54
}
55
56
// Note: This could not be a derive macro because this derive-macro
57
// does n ot support LruCache and nameless field structs.
58
impl MetricsComponent for Workers {
59
0
    fn publish(
60
0
        &self,
61
0
        _kind: MetricKind,
62
0
        _field_metadata: MetricFieldData,
63
0
    ) -> Result<MetricPublishKnownKindData, nativelink_metric::Error> {
64
0
        let _enter = group!("workers").entered();
65
0
        for (worker_id, worker) in self.iter() {
66
0
            let _enter = group!(worker_id).entered();
67
0
            worker.publish(MetricKind::Component, MetricFieldData::default())?;
68
        }
69
0
        Ok(MetricPublishKnownKindData::Component)
70
0
    }
71
}
72
73
/// A collection of workers that are available to run tasks.
74
0
#[derive(MetricsComponent)]
75
struct ApiWorkerSchedulerImpl {
76
    /// A `LruCache` of workers availabled based on `allocation_strategy`.
77
    #[metric(group = "workers")]
78
    workers: Workers,
79
80
    /// The worker state manager.
81
    #[metric(group = "worker_state_manager")]
82
    worker_state_manager: Arc<dyn WorkerStateManager>,
83
    /// The allocation strategy for workers.
84
    allocation_strategy: WorkerAllocationStrategy,
85
    /// A channel to notify the matching engine that the worker pool has changed.
86
    worker_change_notify: Arc<Notify>,
87
    /// A channel to notify that an operation is still alive.
88
    operation_keep_alive_tx: UnboundedSender<(OperationId, WorkerId)>,
89
}
90
91
impl ApiWorkerSchedulerImpl {
92
    /// Refreshes the lifetime of the worker with the given timestamp.
93
2
    fn refresh_lifetime(
94
2
        &mut self,
95
2
        worker_id: &WorkerId,
96
2
        timestamp: WorkerTimestamp,
97
2
    ) -> Result<(), Error> {
98
2
        let worker = self.workers.0.peek_mut(worker_id).ok_or_else(|| {
99
0
            make_input_err!(
100
0
                "Worker not found in worker map in refresh_lifetime() {}",
101
0
                worker_id
102
0
            )
103
2
        })
?0
;
104
0
        error_if!(
105
2
            worker.last_update_timestamp > timestamp,
  Branch (105:13): [True: 0, False: 2]
  Branch (105:13): [Folded - Ignored]
106
            "Worker already had a timestamp of {}, but tried to update it with {}",
107
            worker.last_update_timestamp,
108
            timestamp
109
        );
110
2
        worker.last_update_timestamp = timestamp;
111
2
        for 
operation_id0
in worker.running_action_infos.keys() {
112
0
            if self
  Branch (112:16): [True: 0, False: 0]
  Branch (112:16): [Folded - Ignored]
113
0
                .operation_keep_alive_tx
114
0
                .send((operation_id.clone(), *worker_id))
115
0
                .is_err()
116
            {
117
0
                event!(
118
0
                    Level::ERROR,
119
                    ?operation_id,
120
                    ?worker_id,
121
0
                    "OperationKeepAliveTx stream closed"
122
                );
123
0
            }
124
        }
125
2
        Ok(())
126
2
    }
127
128
    /// Adds a worker to the pool.
129
    /// Note: This function will not do any task matching.
130
31
    fn add_worker(&mut self, worker: Worker) -> Result<(), Error> {
131
31
        let worker_id = worker.id;
132
31
        self.workers.put(worker_id, worker);
133
31
134
31
        // Worker is not cloneable, and we do not want to send the initial connection results until
135
31
        // we have added it to the map, or we might get some strange race conditions due to the way
136
31
        // the multi-threaded runtime works.
137
31
        let worker = self.workers.peek_mut(&worker_id).unwrap();
138
31
        let res = worker
139
31
            .send_initial_connection_result()
140
31
            .err_tip(|| 
"Failed to send initial connection result to worker"0
);
141
31
        if let Err(
err0
) = &res {
  Branch (141:16): [True: 0, False: 31]
  Branch (141:16): [Folded - Ignored]
142
0
            event!(
143
0
                Level::ERROR,
144
                ?worker_id,
145
                ?err,
146
0
                "Worker connection appears to have been closed while adding to pool"
147
            );
148
31
        }
149
31
        self.worker_change_notify.notify_one();
150
31
        res
151
31
    }
152
153
    /// Removes worker from pool.
154
    /// Note: The caller is responsible for any rescheduling of any tasks that might be
155
    /// running.
156
6
    fn remove_worker(&mut self, worker_id: &WorkerId) -> Option<Worker> {
157
6
        let result = self.workers.pop(worker_id);
158
6
        self.worker_change_notify.notify_one();
159
6
        result
160
6
    }
161
162
    /// Sets if the worker is draining or not.
163
2
    async fn set_drain_worker(
164
2
        &mut self,
165
2
        worker_id: &WorkerId,
166
2
        is_draining: bool,
167
2
    ) -> Result<(), Error> {
168
2
        let worker = self
169
2
            .workers
170
2
            .get_mut(worker_id)
171
2
            .err_tip(|| 
format!("Worker {worker_id} doesn't exist in the pool")0
)
?0
;
172
2
        worker.is_draining = is_draining;
173
2
        self.worker_change_notify.notify_one();
174
2
        Ok(())
175
2
    }
176
177
42
    fn inner_find_worker_for_action(
178
42
        &self,
179
42
        platform_properties: &PlatformProperties,
180
42
    ) -> Option<WorkerId> {
181
42
        let mut workers_iter = self.workers.iter();
182
42
        let workers_iter = match self.allocation_strategy {
183
            // Use rfind to get the least recently used that satisfies the properties.
184
42
            WorkerAllocationStrategy::least_recently_used => workers_iter.rfind(|(_, w)| {
185
34
                w.can_accept_work() && 
platform_properties.is_satisfied_by(&w.platform_properties)33
  Branch (185:17): [True: 33, False: 1]
  Branch (185:17): [Folded - Ignored]
186
42
            
}34
),
187
            // Use find to get the most recently used that satisfies the properties.
188
0
            WorkerAllocationStrategy::most_recently_used => workers_iter.find(|(_, w)| {
189
0
                w.can_accept_work() && platform_properties.is_satisfied_by(&w.platform_properties)
  Branch (189:17): [True: 0, False: 0]
  Branch (189:17): [Folded - Ignored]
190
0
            }),
191
        };
192
42
        workers_iter.map(|(_, w)| 
&w.id28
).copied()
193
42
    }
194
195
10
    async fn update_action(
196
10
        &mut self,
197
10
        worker_id: &WorkerId,
198
10
        operation_id: &OperationId,
199
10
        update: UpdateOperationType,
200
10
    ) -> Result<(), Error> {
201
10
        let worker = self.workers.get_mut(worker_id).err_tip(|| {
202
0
            format!("Worker {worker_id} does not exist in SimpleScheduler::update_action")
203
10
        })
?0
;
204
205
        // Ensure the worker is supposed to be running the operation.
206
10
        if !worker.running_action_infos.contains_key(operation_id) {
  Branch (206:12): [True: 1, False: 9]
  Branch (206:12): [Folded - Ignored]
207
1
            let err = make_err!(
208
1
                Code::Internal,
209
1
                "Operation {operation_id} should not be running on worker {worker_id} in SimpleScheduler::update_action"
210
1
            );
211
1
            return Result::<(), _>::Err(err.clone())
212
1
                .merge(self.immediate_evict_worker(worker_id, err).
await0
);
213
9
        }
214
215
9
        let (is_finished, due_to_backpressure) = match &update {
216
6
            UpdateOperationType::UpdateWithActionStage(action_stage) => {
217
6
                (action_stage.is_finished(), false)
218
            }
219
0
            UpdateOperationType::KeepAlive => (false, false),
220
3
            UpdateOperationType::UpdateWithError(err) => {
221
3
                (true, err.code == Code::ResourceExhausted)
222
            }
223
        };
224
225
        // Update the operation in the worker state manager.
226
        {
227
9
            let update_operation_res = self
228
9
                .worker_state_manager
229
9
                .update_operation(operation_id, worker_id, update)
230
1
                .await
231
9
                .err_tip(|| 
"in update_operation on SimpleScheduler::update_action"0
);
232
9
            if let Err(
err0
) = update_operation_res {
  Branch (232:20): [True: 0, False: 9]
  Branch (232:20): [Folded - Ignored]
233
0
                event!(
234
0
                    Level::ERROR,
235
                    ?operation_id,
236
                    ?worker_id,
237
                    ?err,
238
0
                    "Failed to update_operation on update_action"
239
                );
240
0
                return Err(err);
241
9
            }
242
9
        }
243
9
244
9
        if !is_finished {
  Branch (244:12): [True: 0, False: 9]
  Branch (244:12): [Folded - Ignored]
245
0
            return Ok(());
246
9
        }
247
248
        // Clear this action from the current worker if finished.
249
9
        let complete_action_res = {
250
9
            let was_paused = !worker.can_accept_work();
251
9
252
9
            // Note: We need to run this before dealing with backpressure logic.
253
9
            let complete_action_res = worker.complete_action(operation_id);
254
9
255
9
            // Only pause if there's an action still waiting that will unpause.
256
9
            if (was_paused || due_to_backpressure) && 
worker.has_actions()0
{
  Branch (256:17): [True: 0, False: 9]
  Branch (256:31): [True: 0, False: 9]
  Branch (256:55): [True: 0, False: 0]
  Branch (256:17): [Folded - Ignored]
  Branch (256:31): [Folded - Ignored]
  Branch (256:55): [Folded - Ignored]
257
0
                worker.is_paused = true;
258
9
            }
259
9
            complete_action_res
260
9
        };
261
9
262
9
        self.worker_change_notify.notify_one();
263
9
264
9
        complete_action_res
265
10
    }
266
267
    /// Notifies the specified worker to run the given action and handles errors by evicting
268
    /// the worker if the notification fails.
269
28
    async fn worker_notify_run_action(
270
28
        &mut self,
271
28
        worker_id: WorkerId,
272
28
        operation_id: OperationId,
273
28
        action_info: ActionInfoWithProps,
274
28
    ) -> Result<(), Error> {
275
28
        if let Some(worker) = self.workers.get_mut(&worker_id) {
  Branch (275:16): [True: 0, False: 0]
  Branch (275:16): [Folded - Ignored]
  Branch (275:16): [True: 27, False: 0]
  Branch (275:16): [True: 1, False: 0]
276
28
            let notify_worker_result =
277
28
                worker.notify_update(WorkerUpdate::RunAction((operation_id, action_info.clone())));
278
28
279
28
            if notify_worker_result.is_err() {
  Branch (279:16): [True: 0, False: 0]
  Branch (279:16): [Folded - Ignored]
  Branch (279:16): [True: 1, False: 26]
  Branch (279:16): [True: 0, False: 1]
280
1
                event!(
281
1
                    Level::WARN,
282
                    ?worker_id,
283
                    ?action_info,
284
                    ?notify_worker_result,
285
1
                    "Worker command failed, removing worker",
286
                );
287
288
1
                let err = make_err!(
289
1
                    Code::Internal,
290
1
                    "Worker command failed, removing worker {worker_id} -- {notify_worker_result:?}",
291
1
                );
292
1
293
1
                return Result::<(), _>::Err(err.clone())
294
1
                    .merge(self.immediate_evict_worker(&worker_id, err).
await0
);
295
27
            }
296
        } else {
297
0
            event!(
298
0
                Level::WARN,
299
                ?worker_id,
300
                ?operation_id,
301
                ?action_info,
302
0
                "Worker not found in worker map in worker_notify_run_action"
303
            );
304
        }
305
27
        Ok(())
306
28
    }
307
308
    /// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it.
309
6
    async fn immediate_evict_worker(
310
6
        &mut self,
311
6
        worker_id: &WorkerId,
312
6
        err: Error,
313
6
    ) -> Result<(), Error> {
314
6
        let mut result = Ok(());
315
6
        if let Some(mut worker) = self.remove_worker(worker_id) {
  Branch (315:16): [True: 5, False: 0]
  Branch (315:16): [Folded - Ignored]
  Branch (315:16): [True: 1, False: 0]
  Branch (315:16): [True: 0, False: 0]
316
            // We don't care if we fail to send message to worker, this is only a best attempt.
317
6
            let _ = worker.notify_update(WorkerUpdate::Disconnect);
318
6
            for (
operation_id4
, _) in worker.running_action_infos.drain() {
319
4
                result = result.merge(
320
4
                    self.worker_state_manager
321
4
                        .update_operation(
322
4
                            &operation_id,
323
4
                            worker_id,
324
4
                            UpdateOperationType::UpdateWithError(err.clone()),
325
4
                        )
326
0
                        .await,
327
                );
328
            }
329
0
        }
330
        // Note: Calling this many time is very cheap, it'll only trigger `do_try_match` once.
331
        // TODO(allada) This should be moved to inside the Workers struct.
332
6
        self.worker_change_notify.notify_one();
333
6
        result
334
6
    }
335
}
336
337
0
#[derive(MetricsComponent)]
338
pub struct ApiWorkerScheduler {
339
    #[metric]
340
    inner: Mutex<ApiWorkerSchedulerImpl>,
341
    #[metric(group = "platform_property_manager")]
342
    platform_property_manager: Arc<PlatformPropertyManager>,
343
344
    #[metric(
345
        help = "Timeout of how long to evict workers if no response in this given amount of time in seconds."
346
    )]
347
    worker_timeout_s: u64,
348
    _operation_keep_alive_spawn: JoinHandleDropGuard<()>,
349
}
350
351
impl ApiWorkerScheduler {
352
27
    pub fn new(
353
27
        worker_state_manager: Arc<dyn WorkerStateManager>,
354
27
        platform_property_manager: Arc<PlatformPropertyManager>,
355
27
        allocation_strategy: WorkerAllocationStrategy,
356
27
        worker_change_notify: Arc<Notify>,
357
27
        worker_timeout_s: u64,
358
27
    ) -> Arc<Self> {
359
27
        let (operation_keep_alive_tx, mut operation_keep_alive_rx) = mpsc::unbounded_channel();
360
27
        Arc::new(Self {
361
27
            inner: Mutex::new(ApiWorkerSchedulerImpl {
362
27
                workers: Workers(LruCache::unbounded()),
363
27
                worker_state_manager: worker_state_manager.clone(),
364
27
                allocation_strategy,
365
27
                worker_change_notify,
366
27
                operation_keep_alive_tx,
367
27
            }),
368
27
            platform_property_manager,
369
27
            worker_timeout_s,
370
27
            _operation_keep_alive_spawn: spawn!(
371
27
                "simple_scheduler_operation_keep_alive",
372
27
                async move 
{22
373
                    const RECV_MANY_LIMIT: usize = 256;
374
22
                    let mut messages = Vec::with_capacity(RECV_MANY_LIMIT);
375
                    loop {
376
22
                        messages.clear();
377
22
                        operation_keep_alive_rx
378
22
                            .recv_many(&mut messages, RECV_MANY_LIMIT)
379
0
                            .await;
380
0
                        if messages.is_empty() {
  Branch (380:28): [True: 0, False: 0]
  Branch (380:28): [Folded - Ignored]
381
0
                            return; // Looks like our sender has been dropped.
382
0
                        }
383
0
                        for (operation_id, worker_id) in messages.drain(..) {
384
0
                            let update_operation_res = worker_state_manager
385
0
                                .update_operation(
386
0
                                    &operation_id,
387
0
                                    &worker_id,
388
0
                                    UpdateOperationType::KeepAlive,
389
0
                                )
390
0
                                .await;
391
0
                            if let Err(err) = update_operation_res {
  Branch (391:36): [True: 0, False: 0]
  Branch (391:36): [Folded - Ignored]
392
0
                                event!(Level::WARN, ?err, "Error while running worker_keep_alive_received, maybe job is done?");
393
0
                            }
394
                        }
395
                    }
396
27
                
}0
397
27
            ),
398
        })
399
27
    }
400
401
28
    pub async fn worker_notify_run_action(
402
28
        &self,
403
28
        worker_id: WorkerId,
404
28
        operation_id: OperationId,
405
28
        action_info: ActionInfoWithProps,
406
28
    ) -> Result<(), Error> {
407
28
        let mut inner = self.inner.lock().
await0
;
408
28
        inner
409
28
            .worker_notify_run_action(worker_id, operation_id, action_info)
410
0
            .await
411
28
    }
412
413
    /// Attempts to find a worker that is capable of running this action.
414
    // TODO(blaise.bruer) This algorithm is not very efficient. Simple testing using a tree-like
415
    // structure showed worse performance on a 10_000 worker * 7 properties * 1000 queued tasks
416
    // simulation of worst cases in a single threaded environment.
417
42
    pub async fn find_worker_for_action(
418
42
        &self,
419
42
        platform_properties: &PlatformProperties,
420
42
    ) -> Option<WorkerId> {
421
42
        let inner = self.inner.lock().
await0
;
422
42
        inner.inner_find_worker_for_action(platform_properties)
423
42
    }
424
425
    /// Checks to see if the worker exists in the worker pool. Should only be used in unit tests.
426
    #[must_use]
427
7
    pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool {
428
7
        let inner = self.inner.lock().
await0
;
429
7
        inner.workers.contains(worker_id)
430
7
    }
431
432
    /// A unit test function used to send the keep alive message to the worker from the server.
433
1
    pub async fn send_keep_alive_to_worker_for_test(
434
1
        &self,
435
1
        worker_id: &WorkerId,
436
1
    ) -> Result<(), Error> {
437
1
        let mut inner = self.inner.lock().
await0
;
438
1
        let worker = inner.workers.get_mut(worker_id).ok_or_else(|| {
439
0
            make_input_err!("WorkerId '{}' does not exist in workers map", worker_id)
440
1
        })
?0
;
441
1
        worker.keep_alive()
442
1
    }
443
}
444
445
#[async_trait]
446
impl WorkerScheduler for ApiWorkerScheduler {
447
1
    fn get_platform_property_manager(&self) -> &PlatformPropertyManager {
448
1
        self.platform_property_manager.as_ref()
449
1
    }
450
451
31
    async fn add_worker(&self, worker: Worker) -> Result<(), Error> {
452
31
        let mut inner = self.inner.lock().
await0
;
453
31
        let worker_id = worker.id;
454
31
        let result = inner
455
31
            .add_worker(worker)
456
31
            .err_tip(|| 
"Error while adding worker, removing from pool"0
);
457
31
        if let Err(
err0
) = result {
  Branch (457:16): [True: 0, False: 31]
  Branch (457:16): [Folded - Ignored]
458
0
            return Result::<(), _>::Err(err.clone())
459
0
                .merge(inner.immediate_evict_worker(&worker_id, err).await);
460
31
        }
461
31
        Ok(())
462
62
    }
463
464
    async fn update_action(
465
        &self,
466
        worker_id: &WorkerId,
467
        operation_id: &OperationId,
468
        update: UpdateOperationType,
469
10
    ) -> Result<(), Error> {
470
10
        let mut inner = self.inner.lock().
await0
;
471
10
        inner.update_action(worker_id, operation_id, update).
await1
472
20
    }
473
474
    async fn worker_keep_alive_received(
475
        &self,
476
        worker_id: &WorkerId,
477
        timestamp: WorkerTimestamp,
478
2
    ) -> Result<(), Error> {
479
2
        let mut inner = self.inner.lock().
await0
;
480
2
        inner
481
2
            .refresh_lifetime(worker_id, timestamp)
482
2
            .err_tip(|| 
"Error refreshing lifetime in worker_keep_alive_received()"0
)
483
4
    }
484
485
2
    async fn remove_worker(&self, worker_id: &WorkerId) -> Result<(), Error> {
486
2
        let mut inner = self.inner.lock().
await0
;
487
2
        inner
488
2
            .immediate_evict_worker(
489
2
                worker_id,
490
2
                make_err!(Code::Internal, "Received request to remove worker"),
491
2
            )
492
0
            .await
493
4
    }
494
495
5
    async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error> {
496
5
        let mut inner = self.inner.lock().
await0
;
497
498
5
        let mut result = Ok(());
499
5
        // Items should be sorted based on last_update_timestamp, so we don't need to iterate the entire
500
5
        // map most of the time.
501
5
        let worker_ids_to_remove: Vec<WorkerId> = inner
502
5
            .workers
503
5
            .iter()
504
5
            .rev()
505
6
            .map_while(|(worker_id, worker)| {
506
6
                if worker.last_update_timestamp <= now_timestamp - self.worker_timeout_s {
  Branch (506:20): [True: 2, False: 4]
  Branch (506:20): [Folded - Ignored]
507
2
                    Some(*worker_id)
508
                } else {
509
4
                    None
510
                }
511
6
            })
512
5
            .collect();
513
7
        for 
worker_id2
in &worker_ids_to_remove {
514
2
            event!(
515
2
                Level::WARN,
516
                ?worker_id,
517
2
                "Worker timed out, removing from pool"
518
            );
519
2
            result = result.merge(
520
2
                inner
521
2
                    .immediate_evict_worker(
522
2
                        worker_id,
523
2
                        make_err!(
524
2
                            Code::Internal,
525
2
                            "Worker {worker_id} timed out, removing from pool"
526
2
                        ),
527
2
                    )
528
0
                    .await,
529
            );
530
        }
531
532
5
        result
533
10
    }
534
535
2
    async fn set_drain_worker(&self, worker_id: &WorkerId, is_draining: bool) -> Result<(), Error> {
536
2
        let mut inner = self.inner.lock().
await0
;
537
2
        inner.set_drain_worker(worker_id, is_draining).
await0
538
4
    }
539
}
540
541
impl RootMetricsComponent for ApiWorkerScheduler {}