/build/source/nativelink-service/src/worker_api_server.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::collections::HashMap; |
16 | | use std::convert::Into; |
17 | | use std::pin::Pin; |
18 | | use std::sync::Arc; |
19 | | use std::time::{Duration, SystemTime, UNIX_EPOCH}; |
20 | | |
21 | | use futures::stream::unfold; |
22 | | use futures::Stream; |
23 | | use nativelink_config::cas_server::WorkerApiConfig; |
24 | | use nativelink_error::{make_err, Code, Error, ResultExt}; |
25 | | use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_server::{ |
26 | | WorkerApi, WorkerApiServer as Server, |
27 | | }; |
28 | | use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ |
29 | | execute_result, ExecuteResult, GoingAwayRequest, KeepAliveRequest, SupportedProperties, UpdateForWorker, |
30 | | }; |
31 | | use nativelink_scheduler::worker::Worker; |
32 | | use nativelink_scheduler::worker_scheduler::WorkerScheduler; |
33 | | use nativelink_util::background_spawn; |
34 | | use nativelink_util::action_messages::{OperationId, WorkerId}; |
35 | | use nativelink_util::operation_state_manager::UpdateOperationType; |
36 | | use nativelink_util::platform_properties::PlatformProperties; |
37 | | use tokio::sync::mpsc; |
38 | | use tokio::time::interval; |
39 | | use tonic::{Request, Response, Status}; |
40 | | use tracing::{event, instrument, Level}; |
41 | | use uuid::Uuid; |
42 | | |
43 | | pub type ConnectWorkerStream = |
44 | | Pin<Box<dyn Stream<Item = Result<UpdateForWorker, Status>> + Send + Sync + 'static>>; |
45 | | |
46 | | pub type NowFn = Box<dyn Fn() -> Result<Duration, Error> + Send + Sync>; |
47 | | |
48 | | pub struct WorkerApiServer { |
49 | | scheduler: Arc<dyn WorkerScheduler>, |
50 | | now_fn: NowFn, |
51 | | } |
52 | | |
53 | | impl WorkerApiServer { |
54 | 0 | pub fn new( |
55 | 0 | config: &WorkerApiConfig, |
56 | 0 | schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>, |
57 | 0 | ) -> Result<Self, Error> { |
58 | 0 | for scheduler in schedulers.values() { |
59 | | // This will protect us from holding a reference to the scheduler forever in the |
60 | | // event our ExecutionServer dies. Our scheduler is a weak ref, so the spawn will |
61 | | // eventually see the Arc went away and return. |
62 | 0 | let weak_scheduler = Arc::downgrade(scheduler); |
63 | 0 | background_spawn!("worker_api_server", async move { |
64 | 0 | let mut ticker = interval(Duration::from_secs(1)); |
65 | | loop { |
66 | 0 | ticker.tick().await; |
67 | 0 | let timestamp = SystemTime::now() |
68 | 0 | .duration_since(UNIX_EPOCH) |
69 | 0 | .expect("Error: system time is now behind unix epoch"); |
70 | 0 | match weak_scheduler.upgrade() { |
71 | 0 | Some(scheduler) => { |
72 | 0 | if let Err(err) = Branch (72:36): [True: 0, False: 0]
Branch (72:36): [Folded - Ignored]
|
73 | 0 | scheduler.remove_timedout_workers(timestamp.as_secs()).await |
74 | | { |
75 | 0 | event!(Level::ERROR, ?err, "Failed to remove_timedout_workers",); |
76 | 0 | } |
77 | | } |
78 | | // If we fail to upgrade, our service is probably destroyed, so return. |
79 | 0 | None => return, |
80 | 0 | } |
81 | 0 | } |
82 | 0 | }); |
83 | | } |
84 | | |
85 | 0 | Self::new_with_now_fn( |
86 | 0 | config, |
87 | 0 | schedulers, |
88 | 0 | Box::new(move || { |
89 | 0 | SystemTime::now() |
90 | 0 | .duration_since(UNIX_EPOCH) |
91 | 0 | .map_err(|_| make_err!(Code::Internal, "System time is now behind unix epoch")) |
92 | 0 | }), |
93 | 0 | ) |
94 | 0 | } |
95 | | |
96 | | /// Same as new(), but you can pass a custom `now_fn`, that returns a Duration since UNIX_EPOCH |
97 | | /// representing the current time. Used mostly in unit tests. |
98 | 6 | pub fn new_with_now_fn( |
99 | 6 | config: &WorkerApiConfig, |
100 | 6 | schedulers: &HashMap<String, Arc<dyn WorkerScheduler>>, |
101 | 6 | now_fn: NowFn, |
102 | 6 | ) -> Result<Self, Error> { |
103 | 6 | let scheduler = schedulers |
104 | 6 | .get(&config.scheduler) |
105 | 6 | .err_tip(|| { |
106 | 0 | format!( |
107 | 0 | "Scheduler needs config for '{}' because it exists in worker_api", |
108 | 0 | config.scheduler |
109 | 0 | ) |
110 | 6 | })?0 |
111 | 6 | .clone(); |
112 | 6 | Ok(Self { scheduler, now_fn }) |
113 | 6 | } |
114 | | |
115 | 0 | pub fn into_service(self) -> Server<WorkerApiServer> { |
116 | 0 | Server::new(self) |
117 | 0 | } |
118 | | |
119 | 6 | async fn inner_connect_worker( |
120 | 6 | &self, |
121 | 6 | supported_properties: SupportedProperties, |
122 | 6 | ) -> Result<Response<ConnectWorkerStream>, Error> { |
123 | 6 | let (tx, rx) = mpsc::unbounded_channel(); |
124 | | |
125 | | // First convert our proto platform properties into one our scheduler understands. |
126 | 6 | let platform_properties = { |
127 | 6 | let mut platform_properties = PlatformProperties::default(); |
128 | 6 | for property0 in supported_properties.properties { |
129 | 0 | let platform_property_value = self |
130 | 0 | .scheduler |
131 | 0 | .get_platform_property_manager() |
132 | 0 | .make_prop_value(&property.name, &property.value) |
133 | 0 | .err_tip(|| "Bad Property during connect_worker()")?; |
134 | 0 | platform_properties |
135 | 0 | .properties |
136 | 0 | .insert(property.name.clone(), platform_property_value); |
137 | | } |
138 | 6 | platform_properties |
139 | | }; |
140 | | |
141 | | // Now register the worker with the scheduler. |
142 | 6 | let worker_id = { |
143 | 6 | let worker_id = Uuid::new_v4(); |
144 | 6 | let worker = Worker::new( |
145 | 6 | WorkerId(worker_id), |
146 | 6 | platform_properties, |
147 | 6 | tx, |
148 | 6 | (self.now_fn)()?0 .as_secs(), |
149 | 6 | ); |
150 | 6 | self.scheduler |
151 | 6 | .add_worker(worker) |
152 | 0 | .await |
153 | 6 | .err_tip(|| "Failed to add worker in inner_connect_worker()"0 )?0 ; |
154 | 6 | worker_id |
155 | 6 | }; |
156 | 6 | |
157 | 6 | Ok(Response::new(Box::pin(unfold( |
158 | 6 | (rx, worker_id), |
159 | 8 | move |state| async move { |
160 | 8 | let (mut rx, worker_id) = state; |
161 | 8 | if let Some(update_for_worker) = rx.recv().await0 { Branch (161:24): [True: 8, False: 0]
Branch (161:24): [Folded - Ignored]
|
162 | 8 | return Some((Ok(update_for_worker), (rx, worker_id))); |
163 | 0 | } |
164 | 0 | event!( |
165 | 0 | Level::WARN, |
166 | | ?worker_id, |
167 | 0 | "UpdateForWorker channel was closed, thus closing connection to worker node", |
168 | | ); |
169 | | |
170 | 0 | None |
171 | 16 | }, |
172 | 6 | )))) |
173 | 6 | } |
174 | | |
175 | 1 | async fn inner_keep_alive( |
176 | 1 | &self, |
177 | 1 | keep_alive_request: KeepAliveRequest, |
178 | 1 | ) -> Result<Response<()>, Error> { |
179 | 1 | let worker_id: WorkerId = keep_alive_request.worker_id.try_into()?0 ; |
180 | 1 | self.scheduler |
181 | 1 | .worker_keep_alive_received(&worker_id, (self.now_fn)()?0 .as_secs()) |
182 | 0 | .await |
183 | 1 | .err_tip(|| "Could not process keep_alive from worker in inner_keep_alive()"0 )?0 ; |
184 | 1 | Ok(Response::new(())) |
185 | 1 | } |
186 | | |
187 | 0 | async fn inner_going_away( |
188 | 0 | &self, |
189 | 0 | going_away_request: GoingAwayRequest, |
190 | 0 | ) -> Result<Response<()>, Error> { |
191 | 0 | let worker_id: WorkerId = going_away_request.worker_id.try_into()?; |
192 | 0 | self.scheduler |
193 | 0 | .remove_worker(&worker_id) |
194 | 0 | .await |
195 | 0 | .err_tip(|| "While calling WorkerApiServer::inner_going_away")?; |
196 | 0 | Ok(Response::new(())) |
197 | 0 | } |
198 | | |
199 | 1 | async fn inner_execution_response( |
200 | 1 | &self, |
201 | 1 | execute_result: ExecuteResult, |
202 | 1 | ) -> Result<Response<()>, Error> { |
203 | 1 | let worker_id: WorkerId = execute_result.worker_id.try_into()?0 ; |
204 | 1 | let operation_id = OperationId::from(execute_result.operation_id); |
205 | 1 | |
206 | 1 | match execute_result |
207 | 1 | .result |
208 | 1 | .err_tip(|| "Expected result to exist in ExecuteResult"0 )?0 |
209 | | { |
210 | 1 | execute_result::Result::ExecuteResponse(finished_result) => { |
211 | 1 | let action_stage = finished_result |
212 | 1 | .try_into() |
213 | 1 | .err_tip(|| "Failed to convert ExecuteResponse into an ActionStage"0 )?0 ; |
214 | 1 | self.scheduler |
215 | 1 | .update_action( |
216 | 1 | &worker_id, |
217 | 1 | &operation_id, |
218 | 1 | UpdateOperationType::UpdateWithActionStage(action_stage), |
219 | 1 | ) |
220 | 1 | .await |
221 | 1 | .err_tip(|| format!("Failed to operation {operation_id:?}")0 )?0 ; |
222 | | } |
223 | 0 | execute_result::Result::InternalError(e) => { |
224 | 0 | self.scheduler |
225 | 0 | .update_action( |
226 | 0 | &worker_id, |
227 | 0 | &operation_id, |
228 | 0 | UpdateOperationType::UpdateWithError(e.into()), |
229 | 0 | ) |
230 | 0 | .await |
231 | 0 | .err_tip(|| format!("Failed to operation {operation_id:?}"))?; |
232 | | } |
233 | | } |
234 | 1 | Ok(Response::new(())) |
235 | 1 | } |
236 | | } |
237 | | |
238 | | #[tonic::async_trait] |
239 | | impl WorkerApi for WorkerApiServer { |
240 | | type ConnectWorkerStream = ConnectWorkerStream; |
241 | | |
242 | | #[allow(clippy::blocks_in_conditions)] |
243 | 6 | #[instrument( |
244 | | err, |
245 | | level = Level::ERROR, |
246 | | skip_all, |
247 | | fields(request = ?grpc_request.get_ref()) |
248 | 12 | )] |
249 | | async fn connect_worker( |
250 | | &self, |
251 | | grpc_request: Request<SupportedProperties>, |
252 | 6 | ) -> Result<Response<Self::ConnectWorkerStream>, Status> { |
253 | 6 | let resp = self |
254 | 6 | .inner_connect_worker(grpc_request.into_inner()) |
255 | 0 | .await |
256 | 6 | .map_err(Into::into); |
257 | 6 | if resp.is_ok() { Branch (257:12): [True: 6, False: 0]
Branch (257:12): [Folded - Ignored]
|
258 | 6 | event!(Level::DEBUG, return = "Ok(<stream>)"); |
259 | 0 | } |
260 | 6 | resp |
261 | 12 | } |
262 | | |
263 | | #[allow(clippy::blocks_in_conditions)] |
264 | 1 | #[instrument( |
265 | | err, |
266 | | ret(level = Level::INFO), |
267 | | level = Level::INFO, |
268 | | skip_all, |
269 | | fields(request = ?grpc_request.get_ref()) |
270 | 2 | )] |
271 | | async fn keep_alive( |
272 | | &self, |
273 | | grpc_request: Request<KeepAliveRequest>, |
274 | 1 | ) -> Result<Response<()>, Status> { |
275 | 1 | self.inner_keep_alive(grpc_request.into_inner()) |
276 | 0 | .await |
277 | 1 | .map_err(Into::into) |
278 | 2 | } |
279 | | |
280 | | #[allow(clippy::blocks_in_conditions)] |
281 | 0 | #[instrument( |
282 | | err, |
283 | | ret(level = Level::INFO), |
284 | | level = Level::ERROR, |
285 | | skip_all, |
286 | | fields(request = ?grpc_request.get_ref()) |
287 | 0 | )] |
288 | | async fn going_away( |
289 | | &self, |
290 | | grpc_request: Request<GoingAwayRequest>, |
291 | 0 | ) -> Result<Response<()>, Status> { |
292 | 0 | self.inner_going_away(grpc_request.into_inner()) |
293 | 0 | .await |
294 | 0 | .map_err(Into::into) |
295 | 0 | } |
296 | | |
297 | | #[allow(clippy::blocks_in_conditions)] |
298 | 1 | #[instrument( |
299 | | err, |
300 | | ret(level = Level::INFO), |
301 | | level = Level::ERROR, |
302 | | skip_all, |
303 | | fields(request = ?grpc_request.get_ref()) |
304 | 2 | )] |
305 | | async fn execution_response( |
306 | | &self, |
307 | | grpc_request: Request<ExecuteResult>, |
308 | 1 | ) -> Result<Response<()>, Status> { |
309 | 1 | self.inner_execution_response(grpc_request.into_inner()) |
310 | 1 | .await |
311 | 1 | .map_err(Into::into) |
312 | 2 | } |
313 | | } |