/build/source/nativelink-service/src/worker_api_server.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::convert::Into; |
16 | | use core::pin::Pin; |
17 | | use core::time::Duration; |
18 | | use std::collections::HashMap; |
19 | | use std::sync::Arc; |
20 | | use std::time::{SystemTime, UNIX_EPOCH}; |
21 | | |
22 | | use futures::stream::unfold; |
23 | | use futures::{Stream, StreamExt}; |
24 | | use nativelink_config::cas_server::WorkerApiConfig; |
25 | | use nativelink_error::{make_err, Code, Error, ResultExt}; |
26 | | use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_scheduler::Update; |
27 | | use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_server::{ |
28 | | WorkerApi, WorkerApiServer as Server, |
29 | | }; |
30 | | use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ |
31 | | execute_result, ExecuteComplete, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForScheduler, UpdateForWorker |
32 | | }; |
33 | | use nativelink_scheduler::worker::Worker; |
34 | | use nativelink_scheduler::worker_scheduler::WorkerScheduler; |
35 | | use nativelink_util::background_spawn; |
36 | | use nativelink_util::action_messages::{OperationId, WorkerId}; |
37 | | use nativelink_util::operation_state_manager::UpdateOperationType; |
38 | | use nativelink_util::platform_properties::PlatformProperties; |
39 | | use rand::RngCore; |
40 | | use tokio::sync::mpsc; |
41 | | use tokio::time::interval; |
42 | | use tonic::{Response, Status}; |
43 | | use tracing::{debug, error, warn, instrument, Level}; |
44 | | use uuid::Uuid; |
45 | | |
46 | | pub type ConnectWorkerStream = |
47 | | Pin<Box<dyn Stream<Item = Result<UpdateForWorker, Status>> + Send + Sync + 'static>>; |
48 | | |
49 | | pub type NowFn = Box<dyn Fn() -> Result<Duration, Error> + Send + Sync>; |
50 | | |
51 | | pub struct WorkerApiServer { |
52 | | scheduler: Arc<dyn WorkerScheduler>, |
53 | | now_fn: Arc<NowFn>, |
54 | | node_id: [u8; 6], |
55 | | } |
56 | | |
57 | | impl core::fmt::Debug for WorkerApiServer { |
58 | 0 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
59 | 0 | f.debug_struct("WorkerApiServer") |
60 | 0 | .field("node_id", &self.node_id) |
61 | 0 | .finish_non_exhaustive() |
62 | 0 | } |
63 | | } |
64 | | |
65 | | impl WorkerApiServer { |
66 | 0 | pub fn new( |
67 | 0 | config: &WorkerApiConfig, |
68 | 0 | schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>, |
69 | 0 | ) -> Result<Self, Error> { |
70 | 0 | let node_id = { |
71 | 0 | let mut out = [0; 6]; |
72 | 0 | rand::rng().fill_bytes(&mut out); |
73 | 0 | out |
74 | | }; |
75 | 0 | for scheduler in schedulers.values() { |
76 | | // This will protect us from holding a reference to the scheduler forever in the |
77 | | // event our ExecutionServer dies. Our scheduler is a weak ref, so the spawn will |
78 | | // eventually see the Arc went away and return. |
79 | 0 | let weak_scheduler = Arc::downgrade(scheduler); |
80 | 0 | background_spawn!("worker_api_server", async move { |
81 | 0 | let mut ticker = interval(Duration::from_secs(1)); |
82 | | loop { |
83 | 0 | ticker.tick().await; |
84 | 0 | let timestamp = SystemTime::now() |
85 | 0 | .duration_since(UNIX_EPOCH) |
86 | 0 | .expect("Error: system time is now behind unix epoch"); |
87 | 0 | match weak_scheduler.upgrade() { |
88 | 0 | Some(scheduler) => { |
89 | 0 | if let Err(err) = Branch (89:36): [True: 0, False: 0]
Branch (89:36): [Folded - Ignored]
|
90 | 0 | scheduler.remove_timedout_workers(timestamp.as_secs()).await |
91 | | { |
92 | 0 | error!(?err, "Failed to remove_timedout_workers",); |
93 | 0 | } |
94 | | } |
95 | | // If we fail to upgrade, our service is probably destroyed, so return. |
96 | 0 | None => return, |
97 | | } |
98 | | } |
99 | 0 | }); |
100 | | } |
101 | | |
102 | 0 | Self::new_with_now_fn( |
103 | 0 | config, |
104 | 0 | schedulers, |
105 | 0 | Box::new(move || { |
106 | 0 | SystemTime::now() |
107 | 0 | .duration_since(UNIX_EPOCH) |
108 | 0 | .map_err(|_| make_err!(Code::Internal, "System time is now behind unix epoch")) |
109 | 0 | }), |
110 | 0 | node_id, |
111 | | ) |
112 | 0 | } |
113 | | |
114 | | /// Same as `new()`, but you can pass a custom `now_fn`, that returns a Duration since `UNIX_EPOCH` |
115 | | /// representing the current time. Used mostly in unit tests. |
116 | 7 | pub fn new_with_now_fn( |
117 | 7 | config: &WorkerApiConfig, |
118 | 7 | schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>, |
119 | 7 | now_fn: NowFn, |
120 | 7 | node_id: [u8; 6], |
121 | 7 | ) -> Result<Self, Error> { |
122 | 7 | let scheduler = schedulers |
123 | 7 | .get(&config.scheduler) |
124 | 7 | .err_tip(|| {0 |
125 | 0 | format!( |
126 | 0 | "Scheduler needs config for '{}' because it exists in worker_api", |
127 | | config.scheduler |
128 | | ) |
129 | 0 | })? |
130 | 7 | .clone(); |
131 | 7 | Ok(Self { |
132 | 7 | scheduler, |
133 | 7 | now_fn: Arc::new(now_fn), |
134 | 7 | node_id, |
135 | 7 | }) |
136 | 7 | } |
137 | | |
138 | 0 | pub fn into_service(self) -> Server<Self> { |
139 | 0 | Server::new(self) |
140 | 0 | } |
141 | | |
142 | 7 | async fn inner_connect_worker( |
143 | 7 | &self, |
144 | 7 | mut update_stream: impl Stream<Item = Result<UpdateForScheduler, Status>> |
145 | 7 | + Unpin |
146 | 7 | + Send |
147 | 7 | + 'static, |
148 | 7 | ) -> Result<Response<ConnectWorkerStream>, Error> { |
149 | 7 | let first_message = update_stream |
150 | 7 | .next() |
151 | 7 | .await |
152 | 7 | .err_tip(|| "Missing first message for connect_worker")?0 |
153 | 7 | .err_tip(|| "Error reading first message for connect_worker")?0 ; |
154 | 7 | let Some(Update::ConnectWorkerRequest(connect_worker_request)) = first_message.update Branch (154:13): [True: 0, False: 0]
Branch (154:13): [Folded - Ignored]
Branch (154:13): [True: 7, False: 0]
|
155 | | else { |
156 | 0 | return Err(make_err!( |
157 | 0 | Code::Internal, |
158 | 0 | "First message was not a ConnectWorkerRequest" |
159 | 0 | )); |
160 | | }; |
161 | | |
162 | 7 | let (tx, rx) = mpsc::unbounded_channel(); |
163 | | |
164 | | // First convert our proto platform properties into one our scheduler understands. |
165 | 7 | let platform_properties = { |
166 | 7 | let mut platform_properties = PlatformProperties::default(); |
167 | 7 | for property0 in connect_worker_request.properties { |
168 | 0 | let platform_property_value = self |
169 | 0 | .scheduler |
170 | 0 | .get_platform_property_manager() |
171 | 0 | .make_prop_value(&property.name, &property.value) |
172 | 0 | .err_tip(|| "Bad Property during connect_worker()")?; |
173 | 0 | platform_properties |
174 | 0 | .properties |
175 | 0 | .insert(property.name.clone(), platform_property_value); |
176 | | } |
177 | 7 | platform_properties |
178 | | }; |
179 | | |
180 | | // Now register the worker with the scheduler. |
181 | 7 | let worker_id = { |
182 | 7 | let worker_id = WorkerId(format!( |
183 | 7 | "{}{}", |
184 | 7 | connect_worker_request.worker_id_prefix, |
185 | 7 | Uuid::now_v6(&self.node_id).hyphenated() |
186 | 7 | )); |
187 | 7 | let worker = Worker::new( |
188 | 7 | worker_id.clone(), |
189 | 7 | platform_properties, |
190 | 7 | tx, |
191 | 7 | (self.now_fn)()?0 .as_secs(), |
192 | 7 | connect_worker_request.max_inflight_tasks, |
193 | | ); |
194 | 7 | self.scheduler |
195 | 7 | .add_worker(worker) |
196 | 7 | .await |
197 | 7 | .err_tip(|| "Failed to add worker in inner_connect_worker()")?0 ; |
198 | 7 | worker_id |
199 | | }; |
200 | | |
201 | 7 | WorkerConnection::start( |
202 | 7 | self.scheduler.clone(), |
203 | 7 | self.now_fn.clone(), |
204 | 7 | worker_id.clone(), |
205 | 7 | update_stream, |
206 | | ); |
207 | | |
208 | 7 | Ok(Response::new(Box::pin(unfold( |
209 | 7 | (rx, worker_id), |
210 | 9 | move |state| async move { |
211 | 9 | let (mut rx, worker_id) = state; |
212 | 9 | if let Some(update_for_worker) = rx.recv().await { Branch (212:24): [True: 0, False: 0]
Branch (212:24): [Folded - Ignored]
Branch (212:24): [True: 9, False: 0]
|
213 | 9 | return Some((Ok(update_for_worker), (rx, worker_id))); |
214 | 0 | } |
215 | 0 | warn!( |
216 | | ?worker_id, |
217 | 0 | "UpdateForWorker channel was closed, thus closing connection to worker node", |
218 | | ); |
219 | | |
220 | 0 | None |
221 | 18 | }, |
222 | | )))) |
223 | 7 | } |
224 | | |
225 | 7 | pub async fn inner_connect_worker_for_testing( |
226 | 7 | &self, |
227 | 7 | update_stream: impl Stream<Item = Result<UpdateForScheduler, Status>> + Unpin + Send + 'static, |
228 | 7 | ) -> Result<Response<ConnectWorkerStream>, Error> { |
229 | 7 | self.inner_connect_worker(update_stream).await |
230 | 7 | } |
231 | | } |
232 | | |
233 | | #[tonic::async_trait] |
234 | | impl WorkerApi for WorkerApiServer { |
235 | | type ConnectWorkerStream = ConnectWorkerStream; |
236 | | |
237 | | #[instrument( |
238 | | err, |
239 | | level = Level::ERROR, |
240 | | skip_all, |
241 | | fields(request = ?grpc_request.get_ref()) |
242 | | )] |
243 | | async fn connect_worker( |
244 | | &self, |
245 | | grpc_request: tonic::Request<tonic::Streaming<UpdateForScheduler>>, |
246 | | ) -> Result<Response<Self::ConnectWorkerStream>, Status> { |
247 | | let resp = self |
248 | | .inner_connect_worker(grpc_request.into_inner()) |
249 | | .await |
250 | | .map_err(Into::into); |
251 | | if resp.is_ok() { |
252 | | debug!(return = "Ok(<stream>)"); |
253 | | } |
254 | | resp |
255 | | } |
256 | | } |
257 | | |
258 | | struct WorkerConnection { |
259 | | scheduler: Arc<dyn WorkerScheduler>, |
260 | | now_fn: Arc<NowFn>, |
261 | | worker_id: WorkerId, |
262 | | } |
263 | | |
264 | | impl WorkerConnection { |
265 | 7 | fn start( |
266 | 7 | scheduler: Arc<dyn WorkerScheduler>, |
267 | 7 | now_fn: Arc<NowFn>, |
268 | 7 | worker_id: WorkerId, |
269 | 7 | mut connection: impl Stream<Item = Result<UpdateForScheduler, Status>> + Unpin + Send + 'static, |
270 | 7 | ) { |
271 | 7 | let instance = Self { |
272 | 7 | scheduler, |
273 | 7 | now_fn, |
274 | 7 | worker_id, |
275 | 7 | }; |
276 | | |
277 | 7 | background_spawn!("worker_api", async move {2 |
278 | 2 | let mut had_going_away = false; |
279 | 3 | while let Some(maybe_update2 ) = connection.next().await { Branch (279:23): [True: 0, False: 0]
Branch (279:23): [Folded - Ignored]
Branch (279:23): [True: 2, False: 0]
|
280 | 2 | let update = match maybe_update.map(|u| u.update) { |
281 | 2 | Ok(Some(update)) => update, |
282 | | Ok(None) => { |
283 | 0 | tracing::warn!(worker_id=?instance.worker_id, "Empty update"); |
284 | 0 | continue; |
285 | | } |
286 | 0 | Err(err) => { |
287 | 0 | tracing::warn!(worker_id=?instance.worker_id, ?err, "Error from worker"); |
288 | 0 | break; |
289 | | } |
290 | | }; |
291 | 2 | let result1 = match update { |
292 | 0 | Update::ConnectWorkerRequest(_connect_worker_request) => Err(make_err!( |
293 | 0 | Code::Internal, |
294 | 0 | "Got ConnectWorkerRequest after initial message for {}", |
295 | 0 | instance.worker_id |
296 | 0 | )), |
297 | 1 | Update::KeepAliveRequest(keep_alive_request) => { |
298 | 1 | instance.inner_keep_alive(keep_alive_request).await |
299 | | } |
300 | 0 | Update::GoingAwayRequest(going_away_request) => { |
301 | 0 | had_going_away = true; |
302 | 0 | instance.inner_going_away(going_away_request).await |
303 | | } |
304 | 1 | Update::ExecuteResult(execute_result) => { |
305 | 1 | instance.inner_execution_response(execute_result).await |
306 | | } |
307 | 0 | Update::ExecuteComplete(execute_complete) => { |
308 | 0 | instance.execution_complete(execute_complete).await |
309 | | } |
310 | | }; |
311 | 1 | if let Err(err0 ) = result { Branch (311:24): [True: 0, False: 0]
Branch (311:24): [Folded - Ignored]
Branch (311:24): [True: 0, False: 1]
|
312 | 0 | tracing::warn!(worker_id=?instance.worker_id, ?err, "Error processing worker message"); |
313 | 1 | } |
314 | | } |
315 | 0 | tracing::debug!(worker_id=?instance.worker_id, "Update for scheduler dropped"); |
316 | 0 | if !had_going_away { Branch (316:16): [True: 0, False: 0]
Branch (316:16): [Folded - Ignored]
Branch (316:16): [True: 0, False: 0]
|
317 | 0 | drop(instance.scheduler.remove_worker(&instance.worker_id).await); |
318 | 0 | } |
319 | 0 | }); |
320 | 7 | } |
321 | | |
322 | 1 | async fn inner_keep_alive(&self, _keep_alive_request: KeepAliveRequest) -> Result<(), Error> { |
323 | 1 | self.scheduler |
324 | 1 | .worker_keep_alive_received(&self.worker_id, (self.now_fn)()?0 .as_secs()) |
325 | 1 | .await |
326 | 1 | .err_tip(|| "Could not process keep_alive from worker in inner_keep_alive()")?0 ; |
327 | 1 | Ok(()) |
328 | 1 | } |
329 | | |
330 | 0 | async fn inner_going_away(&self, _going_away_request: GoingAwayRequest) -> Result<(), Error> { |
331 | 0 | self.scheduler |
332 | 0 | .remove_worker(&self.worker_id) |
333 | 0 | .await |
334 | 0 | .err_tip(|| "While calling WorkerApiServer::inner_going_away")?; |
335 | 0 | Ok(()) |
336 | 0 | } |
337 | | |
338 | 1 | async fn inner_execution_response(&self, execute_result: ExecuteResult) -> Result<(), Error> { |
339 | 1 | let operation_id = OperationId::from(execute_result.operation_id); |
340 | | |
341 | 1 | match execute_result |
342 | 1 | .result |
343 | 1 | .err_tip(|| "Expected result to exist in ExecuteResult")?0 |
344 | | { |
345 | 1 | execute_result::Result::ExecuteResponse(finished_result) => { |
346 | 1 | let action_stage = finished_result |
347 | 1 | .try_into() |
348 | 1 | .err_tip(|| "Failed to convert ExecuteResponse into an ActionStage")?0 ; |
349 | 1 | self.scheduler |
350 | 1 | .update_action( |
351 | 1 | &self.worker_id, |
352 | 1 | &operation_id, |
353 | 1 | UpdateOperationType::UpdateWithActionStage(action_stage), |
354 | 1 | ) |
355 | 1 | .await |
356 | 0 | .err_tip(|| format!("Failed to operation {operation_id}"))?; |
357 | | } |
358 | 0 | execute_result::Result::InternalError(e) => { |
359 | 0 | self.scheduler |
360 | 0 | .update_action( |
361 | 0 | &self.worker_id, |
362 | 0 | &operation_id, |
363 | 0 | UpdateOperationType::UpdateWithError(e.into()), |
364 | 0 | ) |
365 | 0 | .await |
366 | 0 | .err_tip(|| format!("Failed to operation {operation_id}"))?; |
367 | | } |
368 | | } |
369 | 0 | Ok(()) |
370 | 0 | } |
371 | | |
372 | 0 | async fn execution_complete(&self, execute_complete: ExecuteComplete) -> Result<(), Error> { |
373 | 0 | let operation_id = OperationId::from(execute_complete.operation_id); |
374 | 0 | self.scheduler |
375 | 0 | .update_action( |
376 | 0 | &self.worker_id, |
377 | 0 | &operation_id, |
378 | 0 | UpdateOperationType::ExecutionComplete, |
379 | 0 | ) |
380 | 0 | .await |
381 | 0 | .err_tip(|| format!("Failed to operation {operation_id}"))?; |
382 | 0 | Ok(()) |
383 | 0 | } |
384 | | } |