/build/source/nativelink-scheduler/src/grpc_scheduler.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::collections::HashMap; |
16 | | use std::future::Future; |
17 | | use std::sync::Arc; |
18 | | use std::time::Duration; |
19 | | |
20 | | use async_trait::async_trait; |
21 | | use futures::stream::unfold; |
22 | | use futures::{StreamExt, TryFutureExt}; |
23 | | use nativelink_error::{error_if, make_err, Code, Error, ResultExt}; |
24 | | use nativelink_metric::{MetricsComponent, RootMetricsComponent}; |
25 | | use nativelink_proto::build::bazel::remote::execution::v2::capabilities_client::CapabilitiesClient; |
26 | | use nativelink_proto::build::bazel::remote::execution::v2::execution_client::ExecutionClient; |
27 | | use nativelink_proto::build::bazel::remote::execution::v2::{ |
28 | | ExecuteRequest, ExecutionPolicy, GetCapabilitiesRequest, WaitExecutionRequest, |
29 | | }; |
30 | | use nativelink_proto::google::longrunning::Operation; |
31 | | use nativelink_util::action_messages::{ |
32 | | ActionInfo, ActionState, ActionUniqueQualifier, OperationId, DEFAULT_EXECUTION_PRIORITY, |
33 | | }; |
34 | | use nativelink_util::connection_manager::ConnectionManager; |
35 | | use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider; |
36 | | use nativelink_util::operation_state_manager::{ |
37 | | ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, |
38 | | }; |
39 | | use nativelink_util::retry::{Retrier, RetryResult}; |
40 | | use nativelink_util::{background_spawn, tls_utils}; |
41 | | use parking_lot::Mutex; |
42 | | use rand::rngs::OsRng; |
43 | | use rand::Rng; |
44 | | use tokio::select; |
45 | | use tokio::sync::watch; |
46 | | use tokio::time::sleep; |
47 | | use tonic::{Request, Streaming}; |
48 | | use tracing::{event, Level}; |
49 | | |
50 | | struct GrpcActionStateResult { |
51 | | client_operation_id: OperationId, |
52 | | rx: watch::Receiver<Arc<ActionState>>, |
53 | | } |
54 | | |
55 | | #[async_trait] |
56 | | impl ActionStateResult for GrpcActionStateResult { |
57 | 0 | async fn as_state(&self) -> Result<Arc<ActionState>, Error> { |
58 | 0 | let mut action_state = self.rx.borrow().clone(); |
59 | 0 | Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone(); |
60 | 0 | Ok(action_state) |
61 | 0 | } |
62 | | |
63 | 0 | async fn changed(&mut self) -> Result<Arc<ActionState>, Error> { |
64 | 0 | self.rx.changed().await.map_err(|_| { |
65 | 0 | make_err!( |
66 | 0 | Code::Internal, |
67 | 0 | "Channel closed in GrpcActionStateResult::changed" |
68 | 0 | ) |
69 | 0 | })?; |
70 | 0 | let mut action_state = self.rx.borrow().clone(); |
71 | 0 | Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone(); |
72 | 0 | Ok(action_state) |
73 | 0 | } |
74 | | |
75 | 0 | async fn as_action_info(&self) -> Result<Arc<ActionInfo>, Error> { |
76 | | // TODO(allada) We should probably remove as_action_info() |
77 | | // or implement it properly. |
78 | 0 | return Err(make_err!( |
79 | 0 | Code::Unimplemented, |
80 | 0 | "as_action_info not implemented for GrpcActionStateResult::as_action_info" |
81 | 0 | )); |
82 | 0 | } |
83 | | } |
84 | | |
85 | 0 | #[derive(MetricsComponent)] |
86 | | pub struct GrpcScheduler { |
87 | | #[metric(group = "property_managers")] |
88 | | supported_props: Mutex<HashMap<String, Vec<String>>>, |
89 | | retrier: Retrier, |
90 | | connection_manager: ConnectionManager, |
91 | | } |
92 | | |
93 | | impl GrpcScheduler { |
94 | 0 | pub fn new(config: &nativelink_config::schedulers::GrpcScheduler) -> Result<Self, Error> { |
95 | 0 | let jitter_amt = config.retry.jitter; |
96 | 0 | Self::new_with_jitter( |
97 | 0 | config, |
98 | 0 | Box::new(move |delay: Duration| { |
99 | 0 | if jitter_amt == 0. { Branch (99:20): [True: 0, False: 0]
Branch (99:20): [Folded - Ignored]
|
100 | 0 | return delay; |
101 | 0 | } |
102 | 0 | let min = 1. - (jitter_amt / 2.); |
103 | 0 | let max = 1. + (jitter_amt / 2.); |
104 | 0 | delay.mul_f32(OsRng.gen_range(min..max)) |
105 | 0 | }), |
106 | 0 | ) |
107 | 0 | } |
108 | | |
109 | 0 | pub fn new_with_jitter( |
110 | 0 | config: &nativelink_config::schedulers::GrpcScheduler, |
111 | 0 | jitter_fn: Box<dyn Fn(Duration) -> Duration + Send + Sync>, |
112 | 0 | ) -> Result<Self, Error> { |
113 | 0 | let endpoint = tls_utils::endpoint(&config.endpoint)?; |
114 | 0 | let jitter_fn = Arc::new(jitter_fn); |
115 | 0 | Ok(Self { |
116 | 0 | supported_props: Mutex::new(HashMap::new()), |
117 | 0 | retrier: Retrier::new( |
118 | 0 | Arc::new(|duration| Box::pin(sleep(duration))), |
119 | 0 | jitter_fn.clone(), |
120 | 0 | config.retry.clone(), |
121 | 0 | ), |
122 | 0 | connection_manager: ConnectionManager::new( |
123 | 0 | std::iter::once(endpoint), |
124 | 0 | config.connections_per_endpoint, |
125 | 0 | config.max_concurrent_requests, |
126 | 0 | config.retry.clone(), |
127 | 0 | jitter_fn, |
128 | 0 | ), |
129 | 0 | }) |
130 | 0 | } |
131 | | |
132 | 0 | async fn perform_request<F, Fut, R, I>(&self, input: I, mut request: F) -> Result<R, Error> |
133 | 0 | where |
134 | 0 | F: FnMut(I) -> Fut + Send + Copy, |
135 | 0 | Fut: Future<Output = Result<R, Error>> + Send, |
136 | 0 | R: Send, |
137 | 0 | I: Send + Clone, |
138 | 0 | { |
139 | 0 | self.retrier |
140 | 0 | .retry(unfold(input, move |input| async move { |
141 | 0 | let input_clone = input.clone(); |
142 | 0 | Some(( |
143 | 0 | request(input_clone) |
144 | 0 | .await |
145 | 0 | .map_or_else(RetryResult::Retry, RetryResult::Ok), |
146 | 0 | input, |
147 | 0 | )) |
148 | 0 | })) |
149 | 0 | .await |
150 | 0 | } |
151 | | |
152 | 0 | async fn stream_state( |
153 | 0 | mut result_stream: Streaming<Operation>, |
154 | 0 | ) -> Result<Box<dyn ActionStateResult>, Error> { |
155 | 0 | if let Some(initial_response) = result_stream Branch (155:16): [True: 0, False: 0]
Branch (155:16): [Folded - Ignored]
|
156 | 0 | .message() |
157 | 0 | .await |
158 | 0 | .err_tip(|| "Recieving response from upstream scheduler")? |
159 | | { |
160 | 0 | let client_operation_id = OperationId::from(initial_response.name.as_str()); |
161 | 0 | // Our operation_id is not needed here is just a place holder to recycle existing object. |
162 | 0 | // The only thing that actually matters is the operation_id. |
163 | 0 | let operation_id = OperationId::default(); |
164 | 0 | let action_state = |
165 | 0 | ActionState::try_from_operation(initial_response, operation_id.clone()) |
166 | 0 | .err_tip(|| "In GrpcScheduler::stream_state")?; |
167 | 0 | let (tx, mut rx) = watch::channel(Arc::new(action_state)); |
168 | 0 | rx.mark_changed(); |
169 | 0 | background_spawn!("grpc_scheduler_stream_state", async move { |
170 | | loop { |
171 | 0 | select!( |
172 | 0 | _ = tx.closed() => { |
173 | 0 | event!( |
174 | 0 | Level::INFO, |
175 | 0 | "Client disconnected in GrpcScheduler" |
176 | | ); |
177 | 0 | return; |
178 | | } |
179 | 0 | response = result_stream.message() => { |
180 | | // When the upstream closes the channel, close the |
181 | | // downstream too. |
182 | 0 | let Ok(Some(response)) = response else { Branch (182:33): [True: 0, False: 0]
Branch (182:33): [Folded - Ignored]
|
183 | 0 | return; |
184 | | }; |
185 | 0 | let maybe_action_state = ActionState::try_from_operation(response, operation_id.clone()); |
186 | 0 | match maybe_action_state { |
187 | 0 | Ok(response) => { |
188 | 0 | if let Err(err) = tx.send(Arc::new(response)) { Branch (188:44): [True: 0, False: 0]
Branch (188:44): [Folded - Ignored]
|
189 | 0 | event!( |
190 | 0 | Level::INFO, |
191 | | ?err, |
192 | 0 | "Client error in GrpcScheduler" |
193 | | ); |
194 | 0 | return; |
195 | 0 | } |
196 | | } |
197 | 0 | Err(err) => { |
198 | 0 | event!( |
199 | 0 | Level::ERROR, |
200 | | ?err, |
201 | 0 | "Error converting response to ActionState in GrpcScheduler" |
202 | | ); |
203 | | }, |
204 | | } |
205 | | } |
206 | | ); |
207 | | } |
208 | 0 | }); |
209 | | |
210 | 0 | return Ok(Box::new(GrpcActionStateResult { |
211 | 0 | client_operation_id, |
212 | 0 | rx, |
213 | 0 | })); |
214 | 0 | } |
215 | 0 | Err(make_err!( |
216 | 0 | Code::Internal, |
217 | 0 | "Upstream scheduler didn't accept action." |
218 | 0 | )) |
219 | 0 | } |
220 | | |
221 | 0 | async fn inner_get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> { |
222 | 0 | if let Some(supported_props) = self.supported_props.lock().get(instance_name) { Branch (222:16): [True: 0, False: 0]
Branch (222:16): [Folded - Ignored]
|
223 | 0 | return Ok(supported_props.clone()); |
224 | 0 | } |
225 | 0 |
|
226 | 0 | self.perform_request(instance_name, |instance_name| async move { |
227 | | // Not in the cache, lookup the capabilities with the upstream. |
228 | 0 | let channel = self |
229 | 0 | .connection_manager |
230 | 0 | .connection() |
231 | 0 | .await |
232 | 0 | .err_tip(|| "in get_platform_property_manager()")?; |
233 | 0 | let capabilities_result = CapabilitiesClient::new(channel) |
234 | 0 | .get_capabilities(GetCapabilitiesRequest { |
235 | 0 | instance_name: instance_name.to_string(), |
236 | 0 | }) |
237 | 0 | .await |
238 | 0 | .err_tip(|| "Retrieving upstream GrpcScheduler capabilities"); |
239 | 0 | let capabilities = capabilities_result?.into_inner(); |
240 | 0 | let supported_props = capabilities |
241 | 0 | .execution_capabilities |
242 | 0 | .err_tip(|| "Unable to get execution properties in GrpcScheduler")? |
243 | | .supported_node_properties |
244 | 0 | .into_iter() |
245 | 0 | .collect::<Vec<String>>(); |
246 | 0 |
|
247 | 0 | self.supported_props |
248 | 0 | .lock() |
249 | 0 | .insert(instance_name.to_string(), supported_props.clone()); |
250 | 0 | Ok(supported_props) |
251 | 0 | }) |
252 | 0 | .await |
253 | 0 | } |
254 | | |
255 | 0 | async fn inner_add_action( |
256 | 0 | &self, |
257 | 0 | _client_operation_id: OperationId, |
258 | 0 | action_info: Arc<ActionInfo>, |
259 | 0 | ) -> Result<Box<dyn ActionStateResult>, Error> { |
260 | 0 | let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY { Branch (260:35): [True: 0, False: 0]
Branch (260:35): [Folded - Ignored]
|
261 | 0 | None |
262 | | } else { |
263 | 0 | Some(ExecutionPolicy { |
264 | 0 | priority: action_info.priority, |
265 | 0 | }) |
266 | | }; |
267 | 0 | let skip_cache_lookup = match action_info.unique_qualifier { |
268 | 0 | ActionUniqueQualifier::Cachable(_) => false, |
269 | 0 | ActionUniqueQualifier::Uncachable(_) => true, |
270 | | }; |
271 | 0 | let request = ExecuteRequest { |
272 | 0 | instance_name: action_info.instance_name().clone(), |
273 | 0 | skip_cache_lookup, |
274 | 0 | action_digest: Some(action_info.digest().into()), |
275 | 0 | execution_policy, |
276 | 0 | // TODO: Get me from the original request, not very important as we ignore it. |
277 | 0 | results_cache_policy: None, |
278 | 0 | digest_function: action_info |
279 | 0 | .unique_qualifier |
280 | 0 | .digest_function() |
281 | 0 | .proto_digest_func() |
282 | 0 | .into(), |
283 | 0 | }; |
284 | 0 | let result_stream = self |
285 | 0 | .perform_request(request, |request| async move { |
286 | 0 | let channel = self |
287 | 0 | .connection_manager |
288 | 0 | .connection() |
289 | 0 | .await |
290 | 0 | .err_tip(|| "in add_action()")?; |
291 | 0 | ExecutionClient::new(channel) |
292 | 0 | .execute(Request::new(request)) |
293 | 0 | .await |
294 | 0 | .err_tip(|| "Sending action to upstream scheduler") |
295 | 0 | }) |
296 | 0 | .await? |
297 | 0 | .into_inner(); |
298 | 0 | Self::stream_state(result_stream).await |
299 | 0 | } |
300 | | |
301 | 0 | async fn inner_filter_operations( |
302 | 0 | &self, |
303 | 0 | filter: OperationFilter, |
304 | 0 | ) -> Result<ActionStateResultStream, Error> { |
305 | 0 | error_if!(filter != OperationFilter { Branch (305:19): [True: 0, False: 0]
Branch (305:19): [Folded - Ignored]
|
306 | 0 | client_operation_id: filter.client_operation_id.clone(), |
307 | 0 | ..Default::default() |
308 | 0 | }, "Unsupported filter in GrpcScheduler::filter_operations. Only client_operation_id is supported - {filter:?}"); |
309 | 0 | let client_operation_id = filter.client_operation_id.ok_or_else(|| { |
310 | 0 | make_err!(Code::InvalidArgument, "`client_operation_id` is the only supported filter in GrpcScheduler::filter_operations") |
311 | 0 | })?; |
312 | 0 | let request = WaitExecutionRequest { |
313 | 0 | name: client_operation_id.to_string(), |
314 | 0 | }; |
315 | 0 | let result_stream = self |
316 | 0 | .perform_request(request, |request| async move { |
317 | 0 | let channel = self |
318 | 0 | .connection_manager |
319 | 0 | .connection() |
320 | 0 | .await |
321 | 0 | .err_tip(|| "in find_by_client_operation_id()")?; |
322 | 0 | ExecutionClient::new(channel) |
323 | 0 | .wait_execution(Request::new(request)) |
324 | 0 | .await |
325 | 0 | .err_tip(|| "While getting wait_execution stream") |
326 | 0 | }) |
327 | 0 | .and_then(|result_stream| Self::stream_state(result_stream.into_inner())) |
328 | 0 | .await; |
329 | 0 | match result_stream { |
330 | 0 | Ok(result_stream) => Ok(unfold( |
331 | 0 | Some(result_stream), |
332 | 0 | |maybe_result_stream| async move { maybe_result_stream.map(|v| (v, None)) }, |
333 | 0 | ) |
334 | 0 | .boxed()), |
335 | 0 | Err(err) => { |
336 | 0 | event!( |
337 | 0 | Level::WARN, |
338 | | ?err, |
339 | 0 | "Error looking up action with upstream scheduler" |
340 | | ); |
341 | 0 | Ok(futures::stream::empty().boxed()) |
342 | | } |
343 | | } |
344 | 0 | } |
345 | | } |
346 | | |
347 | | #[async_trait] |
348 | | impl ClientStateManager for GrpcScheduler { |
349 | | async fn add_action( |
350 | | &self, |
351 | | client_operation_id: OperationId, |
352 | | action_info: Arc<ActionInfo>, |
353 | 0 | ) -> Result<Box<dyn ActionStateResult>, Error> { |
354 | 0 | self.inner_add_action(client_operation_id, action_info) |
355 | 0 | .await |
356 | 0 | } |
357 | | |
358 | | async fn filter_operations<'a>( |
359 | | &'a self, |
360 | | filter: OperationFilter, |
361 | 0 | ) -> Result<ActionStateResultStream<'a>, Error> { |
362 | 0 | self.inner_filter_operations(filter).await |
363 | 0 | } |
364 | | |
365 | 0 | fn as_known_platform_property_provider(&self) -> Option<&dyn KnownPlatformPropertyProvider> { |
366 | 0 | Some(self) |
367 | 0 | } |
368 | | } |
369 | | |
370 | | #[async_trait] |
371 | | impl KnownPlatformPropertyProvider for GrpcScheduler { |
372 | 0 | async fn get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> { |
373 | 0 | self.inner_get_known_properties(instance_name).await |
374 | 0 | } |
375 | | } |
376 | | |
377 | | impl RootMetricsComponent for GrpcScheduler {} |