Coverage Report

Created: 2026-04-14 11:55

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-service/src/worker_api_server.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::convert::Into;
16
use core::pin::Pin;
17
use core::time::Duration;
18
use std::collections::HashMap;
19
use std::sync::Arc;
20
use std::time::{SystemTime, UNIX_EPOCH};
21
22
use futures::stream::unfold;
23
use futures::{Stream, StreamExt};
24
use nativelink_config::cas_server::WorkerApiConfig;
25
use nativelink_error::{make_err, Code, Error, ResultExt};
26
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_scheduler::Update;
27
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_server::{
28
    WorkerApi, WorkerApiServer as Server,
29
};
30
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
31
    execute_result, ExecuteComplete, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForScheduler, UpdateForWorker
32
};
33
use nativelink_scheduler::worker::Worker;
34
use nativelink_scheduler::worker_scheduler::WorkerScheduler;
35
use nativelink_util::background_spawn;
36
use nativelink_util::action_messages::{OperationId, WorkerId};
37
use nativelink_util::operation_state_manager::UpdateOperationType;
38
use nativelink_util::platform_properties::PlatformProperties;
39
use rand::RngCore;
40
use tokio::sync::mpsc;
41
use tokio::time::interval;
42
use tonic::{Response, Status};
43
use tracing::{debug, error, warn, instrument, Level};
44
use uuid::Uuid;
45
46
pub type ConnectWorkerStream =
47
    Pin<Box<dyn Stream<Item = Result<UpdateForWorker, Status>> + Send + Sync + 'static>>;
48
49
pub type NowFn = Box<dyn Fn() -> Result<Duration, Error> + Send + Sync>;
50
51
pub struct WorkerApiServer {
52
    scheduler: Arc<dyn WorkerScheduler>,
53
    now_fn: Arc<NowFn>,
54
    node_id: [u8; 6],
55
}
56
57
impl core::fmt::Debug for WorkerApiServer {
58
0
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59
0
        f.debug_struct("WorkerApiServer")
60
0
            .field("node_id", &self.node_id)
61
0
            .finish_non_exhaustive()
62
0
    }
63
}
64
65
impl WorkerApiServer {
66
0
    pub fn new(
67
0
        config: &WorkerApiConfig,
68
0
        schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>,
69
0
    ) -> Result<Self, Error> {
70
0
        let node_id = {
71
0
            let mut out = [0; 6];
72
0
            rand::rng().fill_bytes(&mut out);
73
0
            out
74
        };
75
0
        for scheduler in schedulers.values() {
76
            // This will protect us from holding a reference to the scheduler forever in the
77
            // event our ExecutionServer dies. Our scheduler is a weak ref, so the spawn will
78
            // eventually see the Arc went away and return.
79
0
            let weak_scheduler = Arc::downgrade(scheduler);
80
0
            background_spawn!("worker_api_server", async move {
81
0
                let mut ticker = interval(Duration::from_secs(1));
82
                loop {
83
0
                    ticker.tick().await;
84
0
                    let timestamp = SystemTime::now()
85
0
                        .duration_since(UNIX_EPOCH)
86
0
                        .expect("Error: system time is now behind unix epoch");
87
0
                    match weak_scheduler.upgrade() {
88
0
                        Some(scheduler) => {
89
0
                            if let Err(err) =
90
0
                                scheduler.remove_timedout_workers(timestamp.as_secs()).await
91
                            {
92
0
                                error!(?err, "Failed to remove_timedout_workers",);
93
0
                            }
94
                        }
95
                        // If we fail to upgrade, our service is probably destroyed, so return.
96
0
                        None => return,
97
                    }
98
                }
99
0
            });
100
        }
101
102
0
        Self::new_with_now_fn(
103
0
            config,
104
0
            schedulers,
105
0
            Box::new(move || {
106
0
                SystemTime::now().duration_since(UNIX_EPOCH).map_err(|err| {
107
0
                    Error::from_std_err(Code::Internal, &err)
108
0
                        .append("System time is now behind unix epoch")
109
0
                })
110
0
            }),
111
0
            node_id,
112
        )
113
0
    }
114
115
    /// Same as `new()`, but you can pass a custom `now_fn`, that returns a Duration since `UNIX_EPOCH`
116
    /// representing the current time. Used mostly in  unit tests.
117
7
    pub fn new_with_now_fn(
118
7
        config: &WorkerApiConfig,
119
7
        schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>,
120
7
        now_fn: NowFn,
121
7
        node_id: [u8; 6],
122
7
    ) -> Result<Self, Error> {
123
7
        let scheduler = schedulers
124
7
            .get(&config.scheduler)
125
7
            .err_tip(|| 
{0
126
0
                format!(
127
                    "Scheduler needs config for '{}' because it exists in worker_api",
128
                    config.scheduler
129
                )
130
0
            })?
131
7
            .clone();
132
7
        Ok(Self {
133
7
            scheduler,
134
7
            now_fn: Arc::new(now_fn),
135
7
            node_id,
136
7
        })
137
7
    }
138
139
0
    pub fn into_service(self) -> Server<Self> {
140
0
        Server::new(self)
141
0
    }
142
143
7
    async fn inner_connect_worker(
144
7
        &self,
145
7
        mut update_stream: impl Stream<Item = Result<UpdateForScheduler, Status>>
146
7
        + Unpin
147
7
        + Send
148
7
        + 'static,
149
7
    ) -> Result<Response<ConnectWorkerStream>, Error> {
150
7
        let first_message = update_stream
151
7
            .next()
152
7
            .await
153
7
            .err_tip(|| "Missing first message for connect_worker")
?0
154
7
            .err_tip(|| "Error reading first message for connect_worker")
?0
;
155
7
        let Some(Update::ConnectWorkerRequest(connect_worker_request)) = first_message.update
156
        else {
157
0
            return Err(make_err!(
158
0
                Code::Internal,
159
0
                "First message was not a ConnectWorkerRequest"
160
0
            ));
161
        };
162
163
7
        let (tx, rx) = mpsc::unbounded_channel();
164
165
        // First convert our proto platform properties into one our scheduler understands.
166
7
        let platform_properties = {
167
7
            let mut platform_properties = PlatformProperties::default();
168
7
            for 
property0
in connect_worker_request.properties {
169
0
                let platform_property_value = self
170
0
                    .scheduler
171
0
                    .get_platform_property_manager()
172
0
                    .make_prop_value(&property.name, &property.value)
173
0
                    .err_tip(|| "Bad Property during connect_worker()")?;
174
0
                platform_properties
175
0
                    .properties
176
0
                    .insert(property.name.clone(), platform_property_value);
177
            }
178
7
            platform_properties
179
        };
180
181
        // Now register the worker with the scheduler.
182
7
        let worker_id = {
183
7
            let worker_id = WorkerId(format!(
184
7
                "{}{}",
185
7
                connect_worker_request.worker_id_prefix,
186
7
                Uuid::now_v6(&self.node_id).hyphenated()
187
7
            ));
188
7
            let worker = Worker::new(
189
7
                worker_id.clone(),
190
7
                platform_properties,
191
7
                tx,
192
7
                (self.now_fn)()
?0
.as_secs(),
193
7
                connect_worker_request.max_inflight_tasks,
194
            );
195
7
            self.scheduler
196
7
                .add_worker(worker)
197
7
                .await
198
7
                .err_tip(|| "Failed to add worker in inner_connect_worker()")
?0
;
199
7
            worker_id
200
        };
201
202
7
        WorkerConnection::start(
203
7
            self.scheduler.clone(),
204
7
            self.now_fn.clone(),
205
7
            worker_id.clone(),
206
7
            update_stream,
207
        );
208
209
7
        Ok(Response::new(Box::pin(unfold(
210
7
            (rx, worker_id),
211
9
            move |state| async move {
212
9
                let (mut rx, worker_id) = state;
213
9
                if let Some(update_for_worker) = rx.recv().await {
214
9
                    return Some((Ok(update_for_worker), (rx, worker_id)));
215
0
                }
216
0
                warn!(
217
                    ?worker_id,
218
                    "UpdateForWorker channel was closed, thus closing connection to worker node",
219
                );
220
221
0
                None
222
18
            },
223
        ))))
224
7
    }
225
226
7
    pub async fn inner_connect_worker_for_testing(
227
7
        &self,
228
7
        update_stream: impl Stream<Item = Result<UpdateForScheduler, Status>> + Unpin + Send + 'static,
229
7
    ) -> Result<Response<ConnectWorkerStream>, Error> {
230
7
        self.inner_connect_worker(update_stream).await
231
7
    }
232
}
233
234
#[tonic::async_trait]
235
impl WorkerApi for WorkerApiServer {
236
    type ConnectWorkerStream = ConnectWorkerStream;
237
238
    #[instrument(
239
        err,
240
        level = Level::ERROR,
241
        skip_all,
242
        fields(request = ?grpc_request.get_ref())
243
    )]
244
    async fn connect_worker(
245
        &self,
246
        grpc_request: tonic::Request<tonic::Streaming<UpdateForScheduler>>,
247
    ) -> Result<Response<Self::ConnectWorkerStream>, Status> {
248
        let resp = self
249
            .inner_connect_worker(grpc_request.into_inner())
250
            .await
251
            .map_err(Into::into);
252
        if resp.is_ok() {
253
            debug!(return = "Ok(<stream>)");
254
        }
255
        resp
256
    }
257
}
258
259
struct WorkerConnection {
260
    scheduler: Arc<dyn WorkerScheduler>,
261
    now_fn: Arc<NowFn>,
262
    worker_id: WorkerId,
263
}
264
265
impl WorkerConnection {
266
7
    fn start(
267
7
        scheduler: Arc<dyn WorkerScheduler>,
268
7
        now_fn: Arc<NowFn>,
269
7
        worker_id: WorkerId,
270
7
        mut connection: impl Stream<Item = Result<UpdateForScheduler, Status>> + Unpin + Send + 'static,
271
7
    ) {
272
7
        let instance = Self {
273
7
            scheduler,
274
7
            now_fn,
275
7
            worker_id,
276
7
        };
277
278
7
        background_spawn!("worker_api", async move 
{2
279
2
            let mut had_going_away = false;
280
3
            while let Some(
maybe_update2
) = connection.next().await {
281
2
                let update = match maybe_update.map(|u| u.update) {
282
2
                    Ok(Some(update)) => update,
283
                    Ok(None) => {
284
0
                        tracing::warn!(worker_id=?instance.worker_id, "Empty update");
285
0
                        continue;
286
                    }
287
0
                    Err(err) => {
288
0
                        tracing::warn!(worker_id=?instance.worker_id, ?err, "Error from worker");
289
0
                        break;
290
                    }
291
                };
292
2
                let 
result1
= match update {
293
0
                    Update::ConnectWorkerRequest(_connect_worker_request) => Err(make_err!(
294
0
                        Code::Internal,
295
0
                        "Got ConnectWorkerRequest after initial message for {}",
296
0
                        instance.worker_id
297
0
                    )),
298
1
                    Update::KeepAliveRequest(keep_alive_request) => {
299
1
                        instance.inner_keep_alive(keep_alive_request).await
300
                    }
301
0
                    Update::GoingAwayRequest(going_away_request) => {
302
0
                        had_going_away = true;
303
0
                        instance.inner_going_away(going_away_request).await
304
                    }
305
1
                    Update::ExecuteResult(execute_result) => {
306
1
                        instance.inner_execution_response(execute_result).await
307
                    }
308
0
                    Update::ExecuteComplete(execute_complete) => {
309
0
                        instance.execution_complete(execute_complete).await
310
                    }
311
                };
312
1
                if let Err(
err0
) = result {
313
0
                    tracing::warn!(worker_id=?instance.worker_id, ?err, "Error processing worker message");
314
1
                }
315
            }
316
0
            tracing::debug!(worker_id=?instance.worker_id, "Update for scheduler dropped");
317
0
            if !had_going_away {
318
0
                drop(instance.scheduler.remove_worker(&instance.worker_id).await);
319
0
            }
320
0
        });
321
7
    }
322
323
1
    async fn inner_keep_alive(&self, _keep_alive_request: KeepAliveRequest) -> Result<(), Error> {
324
1
        self.scheduler
325
1
            .worker_keep_alive_received(&self.worker_id, (self.now_fn)()
?0
.as_secs())
326
1
            .await
327
1
            .err_tip(|| "Could not process keep_alive from worker in inner_keep_alive()")
?0
;
328
1
        Ok(())
329
1
    }
330
331
0
    async fn inner_going_away(&self, _going_away_request: GoingAwayRequest) -> Result<(), Error> {
332
0
        self.scheduler
333
0
            .remove_worker(&self.worker_id)
334
0
            .await
335
0
            .err_tip(|| "While calling WorkerApiServer::inner_going_away")?;
336
0
        Ok(())
337
0
    }
338
339
1
    async fn inner_execution_response(&self, execute_result: ExecuteResult) -> Result<(), Error> {
340
1
        let operation_id = OperationId::from(execute_result.operation_id);
341
342
1
        match execute_result
343
1
            .result
344
1
            .err_tip(|| "Expected result to exist in ExecuteResult")
?0
345
        {
346
1
            execute_result::Result::ExecuteResponse(finished_result) => {
347
1
                let action_stage = finished_result
348
1
                    .try_into()
349
1
                    .err_tip(|| "Failed to convert ExecuteResponse into an ActionStage")
?0
;
350
1
                self.scheduler
351
1
                    .update_action(
352
1
                        &self.worker_id,
353
1
                        &operation_id,
354
1
                        UpdateOperationType::UpdateWithActionStage(action_stage),
355
1
                    )
356
1
                    .await
357
0
                    .err_tip(|| format!("Failed to operation {operation_id}"))?;
358
            }
359
0
            execute_result::Result::InternalError(e) => {
360
0
                self.scheduler
361
0
                    .update_action(
362
0
                        &self.worker_id,
363
0
                        &operation_id,
364
0
                        UpdateOperationType::UpdateWithError(e.into()),
365
0
                    )
366
0
                    .await
367
0
                    .err_tip(|| format!("Failed to operation {operation_id}"))?;
368
            }
369
        }
370
0
        Ok(())
371
0
    }
372
373
0
    async fn execution_complete(&self, execute_complete: ExecuteComplete) -> Result<(), Error> {
374
0
        let operation_id = OperationId::from(execute_complete.operation_id);
375
0
        self.scheduler
376
0
            .update_action(
377
0
                &self.worker_id,
378
0
                &operation_id,
379
0
                UpdateOperationType::ExecutionComplete,
380
0
            )
381
0
            .await
382
0
            .err_tip(|| format!("Failed to operation {operation_id}"))?;
383
0
        Ok(())
384
0
    }
385
}