Coverage Report

Created: 2025-03-08 07:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-service/src/worker_api_server.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::HashMap;
16
use std::convert::Into;
17
use std::pin::Pin;
18
use std::sync::Arc;
19
use std::time::{Duration, SystemTime, UNIX_EPOCH};
20
21
use futures::stream::unfold;
22
use futures::Stream;
23
use nativelink_config::cas_server::WorkerApiConfig;
24
use nativelink_error::{make_err, Code, Error, ResultExt};
25
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_server::{
26
    WorkerApi, WorkerApiServer as Server,
27
};
28
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
29
    execute_result, ConnectWorkerRequest, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForWorker
30
};
31
use nativelink_scheduler::worker::Worker;
32
use nativelink_scheduler::worker_scheduler::WorkerScheduler;
33
use nativelink_util::background_spawn;
34
use nativelink_util::action_messages::{OperationId, WorkerId};
35
use nativelink_util::operation_state_manager::UpdateOperationType;
36
use nativelink_util::platform_properties::PlatformProperties;
37
use rand::RngCore;
38
use tokio::sync::mpsc;
39
use tokio::time::interval;
40
use tonic::{Request, Response, Status};
41
use tracing::{event, instrument, Level};
42
use uuid::Uuid;
43
44
pub type ConnectWorkerStream =
45
    Pin<Box<dyn Stream<Item = Result<UpdateForWorker, Status>> + Send + Sync + 'static>>;
46
47
pub type NowFn = Box<dyn Fn() -> Result<Duration, Error> + Send + Sync>;
48
49
pub struct WorkerApiServer {
50
    scheduler: Arc<dyn WorkerScheduler>,
51
    now_fn: NowFn,
52
    node_id: [u8; 6],
53
}
54
55
impl WorkerApiServer {
56
0
    pub fn new(
57
0
        config: &WorkerApiConfig,
58
0
        schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>,
59
0
    ) -> Result<Self, Error> {
60
0
        let node_id = {
61
0
            let mut out = [0; 6];
62
0
            rand::rng().fill_bytes(&mut out);
63
0
            out
64
        };
65
0
        for scheduler in schedulers.values() {
66
            // This will protect us from holding a reference to the scheduler forever in the
67
            // event our ExecutionServer dies. Our scheduler is a weak ref, so the spawn will
68
            // eventually see the Arc went away and return.
69
0
            let weak_scheduler = Arc::downgrade(scheduler);
70
0
            background_spawn!("worker_api_server", async move {
71
0
                let mut ticker = interval(Duration::from_secs(1));
72
                loop {
73
0
                    ticker.tick().await;
74
0
                    let timestamp = SystemTime::now()
75
0
                        .duration_since(UNIX_EPOCH)
76
0
                        .expect("Error: system time is now behind unix epoch");
77
0
                    match weak_scheduler.upgrade() {
78
0
                        Some(scheduler) => {
79
0
                            if let Err(err) =
  Branch (79:36): [True: 0, False: 0]
  Branch (79:36): [Folded - Ignored]
80
0
                                scheduler.remove_timedout_workers(timestamp.as_secs()).await
81
                            {
82
0
                                event!(Level::ERROR, ?err, "Failed to remove_timedout_workers",);
83
0
                            }
84
                        }
85
                        // If we fail to upgrade, our service is probably destroyed, so return.
86
0
                        None => return,
87
0
                    }
88
0
                }
89
0
            });
90
        }
91
92
0
        Self::new_with_now_fn(
93
0
            config,
94
0
            schedulers,
95
0
            Box::new(move || {
96
0
                SystemTime::now()
97
0
                    .duration_since(UNIX_EPOCH)
98
0
                    .map_err(|_| make_err!(Code::Internal, "System time is now behind unix epoch"))
99
0
            }),
100
0
            node_id,
101
0
        )
102
0
    }
103
104
    /// Same as `new()`, but you can pass a custom `now_fn`, that returns a Duration since `UNIX_EPOCH`
105
    /// representing the current time. Used mostly in  unit tests.
106
6
    pub fn new_with_now_fn(
107
6
        config: &WorkerApiConfig,
108
6
        schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>,
109
6
        now_fn: NowFn,
110
6
        node_id: [u8; 6],
111
6
    ) -> Result<Self, Error> {
112
6
        let scheduler = schedulers
113
6
            .get(&config.scheduler)
114
6
            .err_tip(|| {
115
0
                format!(
116
0
                    "Scheduler needs config for '{}' because it exists in worker_api",
117
0
                    config.scheduler
118
0
                )
119
6
            })
?0
120
6
            .clone();
121
6
        Ok(Self {
122
6
            scheduler,
123
6
            now_fn,
124
6
            node_id,
125
6
        })
126
6
    }
127
128
0
    pub fn into_service(self) -> Server<WorkerApiServer> {
129
0
        Server::new(self)
130
0
    }
131
132
6
    async fn inner_connect_worker(
133
6
        &self,
134
6
        connect_worker_request: ConnectWorkerRequest,
135
6
    ) -> Result<Response<ConnectWorkerStream>, Error> {
136
6
        let (tx, rx) = mpsc::unbounded_channel();
137
138
        // First convert our proto platform properties into one our scheduler understands.
139
6
        let platform_properties = {
140
6
            let mut platform_properties = PlatformProperties::default();
141
6
            for 
property0
in connect_worker_request.properties {
142
0
                let platform_property_value = self
143
0
                    .scheduler
144
0
                    .get_platform_property_manager()
145
0
                    .make_prop_value(&property.name, &property.value)
146
0
                    .err_tip(|| "Bad Property during connect_worker()")?;
147
0
                platform_properties
148
0
                    .properties
149
0
                    .insert(property.name.clone(), platform_property_value);
150
            }
151
6
            platform_properties
152
        };
153
154
        // Now register the worker with the scheduler.
155
6
        let worker_id = {
156
6
            let worker_id = WorkerId(format!(
157
6
                "{}{}",
158
6
                connect_worker_request.worker_id_prefix,
159
6
                Uuid::now_v6(&self.node_id).hyphenated()
160
6
            ));
161
6
            let worker = Worker::new(
162
6
                worker_id.clone(),
163
6
                platform_properties,
164
6
                tx,
165
6
                (self.now_fn)()
?0
.as_secs(),
166
6
            );
167
6
            self.scheduler
168
6
                .add_worker(worker)
169
6
                .await
170
6
                .err_tip(|| 
"Failed to add worker in inner_connect_worker()"0
)
?0
;
171
6
            worker_id
172
6
        };
173
6
174
6
        Ok(Response::new(Box::pin(unfold(
175
6
            (rx, worker_id),
176
8
            move |state| async move {
177
8
                let (mut rx, worker_id) = state;
178
8
                if let Some(update_for_worker) = rx.recv().await {
  Branch (178:24): [True: 8, False: 0]
  Branch (178:24): [Folded - Ignored]
179
8
                    return Some((Ok(update_for_worker), (rx, worker_id)));
180
0
                }
181
0
                event!(
182
0
                    Level::WARN,
183
                    ?worker_id,
184
0
                    "UpdateForWorker channel was closed, thus closing connection to worker node",
185
                );
186
187
0
                None
188
16
            },
189
6
        ))))
190
6
    }
191
192
1
    async fn inner_keep_alive(
193
1
        &self,
194
1
        keep_alive_request: KeepAliveRequest,
195
1
    ) -> Result<Response<()>, Error> {
196
1
        let worker_id: WorkerId = keep_alive_request.worker_id.into();
197
1
        self.scheduler
198
1
            .worker_keep_alive_received(&worker_id, (self.now_fn)()
?0
.as_secs())
199
1
            .await
200
1
            .err_tip(|| 
"Could not process keep_alive from worker in inner_keep_alive()"0
)
?0
;
201
1
        Ok(Response::new(()))
202
1
    }
203
204
0
    async fn inner_going_away(
205
0
        &self,
206
0
        going_away_request: GoingAwayRequest,
207
0
    ) -> Result<Response<()>, Error> {
208
0
        let worker_id: WorkerId = going_away_request.worker_id.into();
209
0
        self.scheduler
210
0
            .remove_worker(&worker_id)
211
0
            .await
212
0
            .err_tip(|| "While calling WorkerApiServer::inner_going_away")?;
213
0
        Ok(Response::new(()))
214
0
    }
215
216
1
    async fn inner_execution_response(
217
1
        &self,
218
1
        execute_result: ExecuteResult,
219
1
    ) -> Result<Response<()>, Error> {
220
1
        let worker_id: WorkerId = execute_result.worker_id.into();
221
1
        let operation_id = OperationId::from(execute_result.operation_id);
222
1
223
1
        match execute_result
224
1
            .result
225
1
            .err_tip(|| 
"Expected result to exist in ExecuteResult"0
)
?0
226
        {
227
1
            execute_result::Result::ExecuteResponse(finished_result) => {
228
1
                let action_stage = finished_result
229
1
                    .try_into()
230
1
                    .err_tip(|| 
"Failed to convert ExecuteResponse into an ActionStage"0
)
?0
;
231
1
                self.scheduler
232
1
                    .update_action(
233
1
                        &worker_id,
234
1
                        &operation_id,
235
1
                        UpdateOperationType::UpdateWithActionStage(action_stage),
236
1
                    )
237
1
                    .await
238
1
                    .err_tip(|| 
format!("Failed to operation {operation_id:?}")0
)
?0
;
239
            }
240
0
            execute_result::Result::InternalError(e) => {
241
0
                self.scheduler
242
0
                    .update_action(
243
0
                        &worker_id,
244
0
                        &operation_id,
245
0
                        UpdateOperationType::UpdateWithError(e.into()),
246
0
                    )
247
0
                    .await
248
0
                    .err_tip(|| format!("Failed to operation {operation_id:?}"))?;
249
            }
250
        }
251
1
        Ok(Response::new(()))
252
1
    }
253
}
254
255
#[tonic::async_trait]
256
impl WorkerApi for WorkerApiServer {
257
    type ConnectWorkerStream = ConnectWorkerStream;
258
259
    #[allow(clippy::blocks_in_conditions)]
260
    #[instrument(
261
        err,
262
        level = Level::ERROR,
263
        skip_all,
264
        fields(request = ?grpc_request.get_ref())
265
    )]
266
    async fn connect_worker(
267
        &self,
268
        grpc_request: Request<ConnectWorkerRequest>,
269
6
    ) -> Result<Response<Self::ConnectWorkerStream>, Status> {
270
6
        let resp = self
271
6
            .inner_connect_worker(grpc_request.into_inner())
272
6
            .await
273
6
            .map_err(Into::into);
274
6
        if resp.is_ok() {
  Branch (274:12): [True: 6, False: 0]
  Branch (274:12): [Folded - Ignored]
275
6
            event!(Level::DEBUG, return = "Ok(<stream>)");
276
0
        }
277
6
        resp
278
12
    }
279
280
    #[allow(clippy::blocks_in_conditions)]
281
    #[instrument(
282
        err,
283
        ret(level = Level::INFO),
284
        level = Level::INFO,
285
        skip_all,
286
        fields(request = ?grpc_request.get_ref())
287
    )]
288
    async fn keep_alive(
289
        &self,
290
        grpc_request: Request<KeepAliveRequest>,
291
1
    ) -> Result<Response<()>, Status> {
292
1
        self.inner_keep_alive(grpc_request.into_inner())
293
1
            .await
294
1
            .map_err(Into::into)
295
2
    }
296
297
    #[allow(clippy::blocks_in_conditions)]
298
    #[instrument(
299
        err,
300
        ret(level = Level::INFO),
301
        level = Level::ERROR,
302
        skip_all,
303
        fields(request = ?grpc_request.get_ref())
304
    )]
305
    async fn going_away(
306
        &self,
307
        grpc_request: Request<GoingAwayRequest>,
308
0
    ) -> Result<Response<()>, Status> {
309
0
        self.inner_going_away(grpc_request.into_inner())
310
0
            .await
311
0
            .map_err(Into::into)
312
0
    }
313
314
    #[allow(clippy::blocks_in_conditions)]
315
    #[instrument(
316
        err,
317
        ret(level = Level::INFO),
318
        level = Level::ERROR,
319
        skip_all,
320
        fields(request = ?grpc_request.get_ref())
321
    )]
322
    async fn execution_response(
323
        &self,
324
        grpc_request: Request<ExecuteResult>,
325
1
    ) -> Result<Response<()>, Status> {
326
1
        self.inner_execution_response(grpc_request.into_inner())
327
1
            .await
328
1
            .map_err(Into::into)
329
2
    }
330
}