Coverage Report

Created: 2025-04-19 16:54

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-service/src/worker_api_server.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::HashMap;
16
use std::convert::Into;
17
use std::pin::Pin;
18
use std::sync::Arc;
19
use std::time::{Duration, SystemTime, UNIX_EPOCH};
20
21
use futures::stream::unfold;
22
use futures::Stream;
23
use nativelink_config::cas_server::WorkerApiConfig;
24
use nativelink_error::{make_err, Code, Error, ResultExt};
25
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_server::{
26
    WorkerApi, WorkerApiServer as Server,
27
};
28
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
29
    execute_result, ConnectWorkerRequest, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForWorker
30
};
31
use nativelink_scheduler::worker::Worker;
32
use nativelink_scheduler::worker_scheduler::WorkerScheduler;
33
use nativelink_util::background_spawn;
34
use nativelink_util::action_messages::{OperationId, WorkerId};
35
use nativelink_util::operation_state_manager::UpdateOperationType;
36
use nativelink_util::platform_properties::PlatformProperties;
37
use rand::RngCore;
38
use tokio::sync::mpsc;
39
use tokio::time::interval;
40
use tonic::{Request, Response, Status};
41
use tracing::{event, instrument, Level};
42
use uuid::Uuid;
43
44
pub type ConnectWorkerStream =
45
    Pin<Box<dyn Stream<Item = Result<UpdateForWorker, Status>> + Send + Sync + 'static>>;
46
47
pub type NowFn = Box<dyn Fn() -> Result<Duration, Error> + Send + Sync>;
48
49
pub struct WorkerApiServer {
50
    scheduler: Arc<dyn WorkerScheduler>,
51
    now_fn: NowFn,
52
    node_id: [u8; 6],
53
}
54
55
impl std::fmt::Debug for WorkerApiServer {
56
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57
0
        f.debug_struct("WorkerApiServer")
58
0
            .field("node_id", &self.node_id)
59
0
            .finish_non_exhaustive()
60
0
    }
61
}
62
63
impl WorkerApiServer {
64
0
    pub fn new(
65
0
        config: &WorkerApiConfig,
66
0
        schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>,
67
0
    ) -> Result<Self, Error> {
68
0
        let node_id = {
69
0
            let mut out = [0; 6];
70
0
            rand::rng().fill_bytes(&mut out);
71
0
            out
72
        };
73
0
        for scheduler in schedulers.values() {
74
            // This will protect us from holding a reference to the scheduler forever in the
75
            // event our ExecutionServer dies. Our scheduler is a weak ref, so the spawn will
76
            // eventually see the Arc went away and return.
77
0
            let weak_scheduler = Arc::downgrade(scheduler);
78
0
            background_spawn!("worker_api_server", async move {
79
0
                let mut ticker = interval(Duration::from_secs(1));
80
                loop {
81
0
                    ticker.tick().await;
82
0
                    let timestamp = SystemTime::now()
83
0
                        .duration_since(UNIX_EPOCH)
84
0
                        .expect("Error: system time is now behind unix epoch");
85
0
                    match weak_scheduler.upgrade() {
86
0
                        Some(scheduler) => {
87
0
                            if let Err(err) =
  Branch (87:36): [True: 0, False: 0]
  Branch (87:36): [Folded - Ignored]
88
0
                                scheduler.remove_timedout_workers(timestamp.as_secs()).await
89
                            {
90
0
                                event!(Level::ERROR, ?err, "Failed to remove_timedout_workers",);
91
0
                            }
92
                        }
93
                        // If we fail to upgrade, our service is probably destroyed, so return.
94
0
                        None => return,
95
0
                    }
96
0
                }
97
0
            });
98
        }
99
100
0
        Self::new_with_now_fn(
101
0
            config,
102
0
            schedulers,
103
0
            Box::new(move || {
104
0
                SystemTime::now()
105
0
                    .duration_since(UNIX_EPOCH)
106
0
                    .map_err(|_| make_err!(Code::Internal, "System time is now behind unix epoch"))
107
0
            }),
108
0
            node_id,
109
0
        )
110
0
    }
111
112
    /// Same as `new()`, but you can pass a custom `now_fn`, that returns a Duration since `UNIX_EPOCH`
113
    /// representing the current time. Used mostly in  unit tests.
114
6
    pub fn new_with_now_fn(
115
6
        config: &WorkerApiConfig,
116
6
        schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>,
117
6
        now_fn: NowFn,
118
6
        node_id: [u8; 6],
119
6
    ) -> Result<Self, Error> {
120
6
        let scheduler = schedulers
121
6
            .get(&config.scheduler)
122
6
            .err_tip(|| {
123
0
                format!(
124
0
                    "Scheduler needs config for '{}' because it exists in worker_api",
125
0
                    config.scheduler
126
0
                )
127
0
            })?
128
6
            .clone();
129
6
        Ok(Self {
130
6
            scheduler,
131
6
            now_fn,
132
6
            node_id,
133
6
        })
134
6
    }
135
136
0
    pub fn into_service(self) -> Server<WorkerApiServer> {
137
0
        Server::new(self)
138
0
    }
139
140
6
    async fn inner_connect_worker(
141
6
        &self,
142
6
        connect_worker_request: ConnectWorkerRequest,
143
6
    ) -> Result<Response<ConnectWorkerStream>, Error> {
144
6
        let (tx, rx) = mpsc::unbounded_channel();
145
146
        // First convert our proto platform properties into one our scheduler understands.
147
6
        let platform_properties = {
148
6
            let mut platform_properties = PlatformProperties::default();
149
6
            for 
property0
in connect_worker_request.properties {
150
0
                let platform_property_value = self
151
0
                    .scheduler
152
0
                    .get_platform_property_manager()
153
0
                    .make_prop_value(&property.name, &property.value)
154
0
                    .err_tip(|| "Bad Property during connect_worker()")?;
155
0
                platform_properties
156
0
                    .properties
157
0
                    .insert(property.name.clone(), platform_property_value);
158
            }
159
6
            platform_properties
160
        };
161
162
        // Now register the worker with the scheduler.
163
6
        let worker_id = {
164
6
            let worker_id = WorkerId(format!(
165
6
                "{}{}",
166
6
                connect_worker_request.worker_id_prefix,
167
6
                Uuid::now_v6(&self.node_id).hyphenated()
168
6
            ));
169
6
            let worker = Worker::new(
170
6
                worker_id.clone(),
171
6
                platform_properties,
172
6
                tx,
173
6
                (self.now_fn)()
?0
.as_secs(),
174
6
            );
175
6
            self.scheduler
176
6
                .add_worker(worker)
177
6
                .await
178
6
                .err_tip(|| 
"Failed to add worker in inner_connect_worker()"0
)
?0
;
179
6
            worker_id
180
6
        };
181
6
182
6
        Ok(Response::new(Box::pin(unfold(
183
6
            (rx, worker_id),
184
8
            move |state| async move {
185
8
                let (mut rx, worker_id) = state;
186
8
                if let Some(update_for_worker) = rx.recv().await {
  Branch (186:24): [True: 8, False: 0]
  Branch (186:24): [Folded - Ignored]
187
8
                    return Some((Ok(update_for_worker), (rx, worker_id)));
188
0
                }
189
0
                event!(
190
0
                    Level::WARN,
191
                    ?worker_id,
192
0
                    "UpdateForWorker channel was closed, thus closing connection to worker node",
193
                );
194
195
0
                None
196
16
            },
197
        ))))
198
6
    }
199
200
1
    async fn inner_keep_alive(
201
1
        &self,
202
1
        keep_alive_request: KeepAliveRequest,
203
1
    ) -> Result<Response<()>, Error> {
204
1
        let worker_id: WorkerId = keep_alive_request.worker_id.into();
205
1
        self.scheduler
206
1
            .worker_keep_alive_received(&worker_id, (self.now_fn)()
?0
.as_secs())
207
1
            .await
208
1
            .err_tip(|| 
"Could not process keep_alive from worker in inner_keep_alive()"0
)
?0
;
209
1
        Ok(Response::new(()))
210
1
    }
211
212
0
    async fn inner_going_away(
213
0
        &self,
214
0
        going_away_request: GoingAwayRequest,
215
0
    ) -> Result<Response<()>, Error> {
216
0
        let worker_id: WorkerId = going_away_request.worker_id.into();
217
0
        self.scheduler
218
0
            .remove_worker(&worker_id)
219
0
            .await
220
0
            .err_tip(|| "While calling WorkerApiServer::inner_going_away")?;
221
0
        Ok(Response::new(()))
222
0
    }
223
224
1
    async fn inner_execution_response(
225
1
        &self,
226
1
        execute_result: ExecuteResult,
227
1
    ) -> Result<Response<()>, Error> {
228
1
        let worker_id: WorkerId = execute_result.worker_id.into();
229
1
        let operation_id = OperationId::from(execute_result.operation_id);
230
1
231
1
        match execute_result
232
1
            .result
233
1
            .err_tip(|| 
"Expected result to exist in ExecuteResult"0
)
?0
234
        {
235
1
            execute_result::Result::ExecuteResponse(finished_result) => {
236
1
                let action_stage = finished_result
237
1
                    .try_into()
238
1
                    .err_tip(|| 
"Failed to convert ExecuteResponse into an ActionStage"0
)
?0
;
239
1
                self.scheduler
240
1
                    .update_action(
241
1
                        &worker_id,
242
1
                        &operation_id,
243
1
                        UpdateOperationType::UpdateWithActionStage(action_stage),
244
1
                    )
245
1
                    .await
246
1
                    .err_tip(|| 
format!("Failed to operation {operation_id:?}")0
)
?0
;
247
            }
248
0
            execute_result::Result::InternalError(e) => {
249
0
                self.scheduler
250
0
                    .update_action(
251
0
                        &worker_id,
252
0
                        &operation_id,
253
0
                        UpdateOperationType::UpdateWithError(e.into()),
254
0
                    )
255
0
                    .await
256
0
                    .err_tip(|| format!("Failed to operation {operation_id:?}"))?;
257
            }
258
        }
259
1
        Ok(Response::new(()))
260
1
    }
261
}
262
263
#[tonic::async_trait]
264
impl WorkerApi for WorkerApiServer {
265
    type ConnectWorkerStream = ConnectWorkerStream;
266
267
    #[instrument(
268
        err,
269
        level = Level::ERROR,
270
        skip_all,
271
        fields(request = ?grpc_request.get_ref())
272
    )]
273
    async fn connect_worker(
274
        &self,
275
        grpc_request: Request<ConnectWorkerRequest>,
276
12
    ) -> Result<Response<Self::ConnectWorkerStream>, Status> {
277
6
        let resp = self
278
6
            .inner_connect_worker(grpc_request.into_inner())
279
6
            .await
280
6
            .map_err(Into::into);
281
6
        if resp.is_ok() {
  Branch (281:12): [True: 6, False: 0]
  Branch (281:12): [Folded - Ignored]
282
6
            event!(Level::DEBUG, return = "Ok(<stream>)");
283
0
        }
284
6
        resp
285
12
    }
286
287
    #[instrument(
288
        err,
289
        ret(level = Level::INFO),
290
        level = Level::INFO,
291
        skip_all,
292
        fields(request = ?grpc_request.get_ref())
293
    )]
294
    async fn keep_alive(
295
        &self,
296
        grpc_request: Request<KeepAliveRequest>,
297
2
    ) -> Result<Response<()>, Status> {
298
1
        self.inner_keep_alive(grpc_request.into_inner())
299
1
            .await
300
1
            .map_err(Into::into)
301
2
    }
302
303
    #[instrument(
304
        err,
305
        ret(level = Level::INFO),
306
        level = Level::ERROR,
307
        skip_all,
308
        fields(request = ?grpc_request.get_ref())
309
    )]
310
    async fn going_away(
311
        &self,
312
        grpc_request: Request<GoingAwayRequest>,
313
0
    ) -> Result<Response<()>, Status> {
314
0
        self.inner_going_away(grpc_request.into_inner())
315
0
            .await
316
0
            .map_err(Into::into)
317
0
    }
318
319
    #[instrument(
320
        err,
321
        ret(level = Level::INFO),
322
        level = Level::ERROR,
323
        skip_all,
324
        fields(request = ?grpc_request.get_ref())
325
    )]
326
    async fn execution_response(
327
        &self,
328
        grpc_request: Request<ExecuteResult>,
329
2
    ) -> Result<Response<()>, Status> {
330
1
        self.inner_execution_response(grpc_request.into_inner())
331
1
            .await
332
1
            .map_err(Into::into)
333
2
    }
334
}