Coverage Report

Created: 2024-10-22 12:33

/build/source/nativelink-scheduler/src/grpc_scheduler.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::HashMap;
16
use std::future::Future;
17
use std::sync::Arc;
18
use std::time::Duration;
19
20
use async_trait::async_trait;
21
use futures::stream::unfold;
22
use futures::{StreamExt, TryFutureExt};
23
use nativelink_error::{error_if, make_err, Code, Error, ResultExt};
24
use nativelink_metric::{MetricsComponent, RootMetricsComponent};
25
use nativelink_proto::build::bazel::remote::execution::v2::capabilities_client::CapabilitiesClient;
26
use nativelink_proto::build::bazel::remote::execution::v2::execution_client::ExecutionClient;
27
use nativelink_proto::build::bazel::remote::execution::v2::{
28
    ExecuteRequest, ExecutionPolicy, GetCapabilitiesRequest, WaitExecutionRequest,
29
};
30
use nativelink_proto::google::longrunning::Operation;
31
use nativelink_util::action_messages::{
32
    ActionInfo, ActionState, ActionUniqueQualifier, OperationId, DEFAULT_EXECUTION_PRIORITY,
33
};
34
use nativelink_util::connection_manager::ConnectionManager;
35
use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider;
36
use nativelink_util::operation_state_manager::{
37
    ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter,
38
};
39
use nativelink_util::retry::{Retrier, RetryResult};
40
use nativelink_util::{background_spawn, tls_utils};
41
use parking_lot::Mutex;
42
use rand::rngs::OsRng;
43
use rand::Rng;
44
use tokio::select;
45
use tokio::sync::watch;
46
use tokio::time::sleep;
47
use tonic::{Request, Streaming};
48
use tracing::{event, Level};
49
50
struct GrpcActionStateResult {
51
    client_operation_id: OperationId,
52
    rx: watch::Receiver<Arc<ActionState>>,
53
}
54
55
#[async_trait]
56
impl ActionStateResult for GrpcActionStateResult {
57
0
    async fn as_state(&self) -> Result<Arc<ActionState>, Error> {
58
0
        let mut action_state = self.rx.borrow().clone();
59
0
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
60
0
        Ok(action_state)
61
0
    }
62
63
0
    async fn changed(&mut self) -> Result<Arc<ActionState>, Error> {
64
0
        self.rx.changed().await.map_err(|_| {
65
0
            make_err!(
66
0
                Code::Internal,
67
0
                "Channel closed in GrpcActionStateResult::changed"
68
0
            )
69
0
        })?;
70
0
        let mut action_state = self.rx.borrow().clone();
71
0
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
72
0
        Ok(action_state)
73
0
    }
74
75
0
    async fn as_action_info(&self) -> Result<Arc<ActionInfo>, Error> {
76
        // TODO(allada) We should probably remove as_action_info()
77
        // or implement it properly.
78
0
        return Err(make_err!(
79
0
            Code::Unimplemented,
80
0
            "as_action_info not implemented for GrpcActionStateResult::as_action_info"
81
0
        ));
82
0
    }
83
}
84
85
0
#[derive(MetricsComponent)]
86
pub struct GrpcScheduler {
87
    #[metric(group = "property_managers")]
88
    supported_props: Mutex<HashMap<String, Vec<String>>>,
89
    retrier: Retrier,
90
    connection_manager: ConnectionManager,
91
}
92
93
impl GrpcScheduler {
94
0
    pub fn new(config: &nativelink_config::schedulers::GrpcScheduler) -> Result<Self, Error> {
95
0
        let jitter_amt = config.retry.jitter;
96
0
        Self::new_with_jitter(
97
0
            config,
98
0
            Box::new(move |delay: Duration| {
99
0
                if jitter_amt == 0. {
  Branch (99:20): [True: 0, False: 0]
  Branch (99:20): [Folded - Ignored]
100
0
                    return delay;
101
0
                }
102
0
                let min = 1. - (jitter_amt / 2.);
103
0
                let max = 1. + (jitter_amt / 2.);
104
0
                delay.mul_f32(OsRng.gen_range(min..max))
105
0
            }),
106
0
        )
107
0
    }
108
109
0
    pub fn new_with_jitter(
110
0
        config: &nativelink_config::schedulers::GrpcScheduler,
111
0
        jitter_fn: Box<dyn Fn(Duration) -> Duration + Send + Sync>,
112
0
    ) -> Result<Self, Error> {
113
0
        let endpoint = tls_utils::endpoint(&config.endpoint)?;
114
0
        let jitter_fn = Arc::new(jitter_fn);
115
0
        Ok(Self {
116
0
            supported_props: Mutex::new(HashMap::new()),
117
0
            retrier: Retrier::new(
118
0
                Arc::new(|duration| Box::pin(sleep(duration))),
119
0
                jitter_fn.clone(),
120
0
                config.retry.clone(),
121
0
            ),
122
0
            connection_manager: ConnectionManager::new(
123
0
                std::iter::once(endpoint),
124
0
                config.connections_per_endpoint,
125
0
                config.max_concurrent_requests,
126
0
                config.retry.clone(),
127
0
                jitter_fn,
128
0
            ),
129
0
        })
130
0
    }
131
132
0
    async fn perform_request<F, Fut, R, I>(&self, input: I, mut request: F) -> Result<R, Error>
133
0
    where
134
0
        F: FnMut(I) -> Fut + Send + Copy,
135
0
        Fut: Future<Output = Result<R, Error>> + Send,
136
0
        R: Send,
137
0
        I: Send + Clone,
138
0
    {
139
0
        self.retrier
140
0
            .retry(unfold(input, move |input| async move {
141
0
                let input_clone = input.clone();
142
0
                Some((
143
0
                    request(input_clone)
144
0
                        .await
145
0
                        .map_or_else(RetryResult::Retry, RetryResult::Ok),
146
0
                    input,
147
0
                ))
148
0
            }))
149
0
            .await
150
0
    }
151
152
0
    async fn stream_state(
153
0
        mut result_stream: Streaming<Operation>,
154
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
155
0
        if let Some(initial_response) = result_stream
  Branch (155:16): [True: 0, False: 0]
  Branch (155:16): [Folded - Ignored]
156
0
            .message()
157
0
            .await
158
0
            .err_tip(|| "Recieving response from upstream scheduler")?
159
        {
160
0
            let client_operation_id = OperationId::from(initial_response.name.as_str());
161
0
            // Our operation_id is not needed here is just a place holder to recycle existing object.
162
0
            // The only thing that actually matters is the operation_id.
163
0
            let operation_id = OperationId::default();
164
0
            let action_state =
165
0
                ActionState::try_from_operation(initial_response, operation_id.clone())
166
0
                    .err_tip(|| "In GrpcScheduler::stream_state")?;
167
0
            let (tx, mut rx) = watch::channel(Arc::new(action_state));
168
0
            rx.mark_changed();
169
0
            background_spawn!("grpc_scheduler_stream_state", async move {
170
                loop {
171
0
                    select!(
172
0
                        _ = tx.closed() => {
173
0
                            event!(
174
0
                                Level::INFO,
175
0
                                "Client disconnected in GrpcScheduler"
176
                            );
177
0
                            return;
178
                        }
179
0
                        response = result_stream.message() => {
180
                            // When the upstream closes the channel, close the
181
                            // downstream too.
182
0
                            let Ok(Some(response)) = response else {
  Branch (182:33): [True: 0, False: 0]
  Branch (182:33): [Folded - Ignored]
183
0
                                return;
184
                            };
185
0
                            let maybe_action_state = ActionState::try_from_operation(response, operation_id.clone());
186
0
                            match maybe_action_state {
187
0
                                Ok(response) => {
188
0
                                    if let Err(err) = tx.send(Arc::new(response)) {
  Branch (188:44): [True: 0, False: 0]
  Branch (188:44): [Folded - Ignored]
189
0
                                        event!(
190
0
                                            Level::INFO,
191
                                            ?err,
192
0
                                            "Client error in GrpcScheduler"
193
                                        );
194
0
                                        return;
195
0
                                    }
196
                                }
197
0
                                Err(err) => {
198
0
                                    event!(
199
0
                                        Level::ERROR,
200
                                        ?err,
201
0
                                        "Error converting response to ActionState in GrpcScheduler"
202
                                    );
203
                                },
204
                            }
205
                        }
206
                    );
207
                }
208
0
            });
209
210
0
            return Ok(Box::new(GrpcActionStateResult {
211
0
                client_operation_id,
212
0
                rx,
213
0
            }));
214
0
        }
215
0
        Err(make_err!(
216
0
            Code::Internal,
217
0
            "Upstream scheduler didn't accept action."
218
0
        ))
219
0
    }
220
221
0
    async fn inner_get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
222
0
        if let Some(supported_props) = self.supported_props.lock().get(instance_name) {
  Branch (222:16): [True: 0, False: 0]
  Branch (222:16): [Folded - Ignored]
223
0
            return Ok(supported_props.clone());
224
0
        }
225
0
226
0
        self.perform_request(instance_name, |instance_name| async move {
227
            // Not in the cache, lookup the capabilities with the upstream.
228
0
            let channel = self
229
0
                .connection_manager
230
0
                .connection()
231
0
                .await
232
0
                .err_tip(|| "in get_platform_property_manager()")?;
233
0
            let capabilities_result = CapabilitiesClient::new(channel)
234
0
                .get_capabilities(GetCapabilitiesRequest {
235
0
                    instance_name: instance_name.to_string(),
236
0
                })
237
0
                .await
238
0
                .err_tip(|| "Retrieving upstream GrpcScheduler capabilities");
239
0
            let capabilities = capabilities_result?.into_inner();
240
0
            let supported_props = capabilities
241
0
                .execution_capabilities
242
0
                .err_tip(|| "Unable to get execution properties in GrpcScheduler")?
243
                .supported_node_properties
244
0
                .into_iter()
245
0
                .collect::<Vec<String>>();
246
0
247
0
            self.supported_props
248
0
                .lock()
249
0
                .insert(instance_name.to_string(), supported_props.clone());
250
0
            Ok(supported_props)
251
0
        })
252
0
        .await
253
0
    }
254
255
0
    async fn inner_add_action(
256
0
        &self,
257
0
        _client_operation_id: OperationId,
258
0
        action_info: Arc<ActionInfo>,
259
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
260
0
        let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY {
  Branch (260:35): [True: 0, False: 0]
  Branch (260:35): [Folded - Ignored]
261
0
            None
262
        } else {
263
0
            Some(ExecutionPolicy {
264
0
                priority: action_info.priority,
265
0
            })
266
        };
267
0
        let skip_cache_lookup = match action_info.unique_qualifier {
268
0
            ActionUniqueQualifier::Cachable(_) => false,
269
0
            ActionUniqueQualifier::Uncachable(_) => true,
270
        };
271
0
        let request = ExecuteRequest {
272
0
            instance_name: action_info.instance_name().clone(),
273
0
            skip_cache_lookup,
274
0
            action_digest: Some(action_info.digest().into()),
275
0
            execution_policy,
276
0
            // TODO: Get me from the original request, not very important as we ignore it.
277
0
            results_cache_policy: None,
278
0
            digest_function: action_info
279
0
                .unique_qualifier
280
0
                .digest_function()
281
0
                .proto_digest_func()
282
0
                .into(),
283
0
        };
284
0
        let result_stream = self
285
0
            .perform_request(request, |request| async move {
286
0
                let channel = self
287
0
                    .connection_manager
288
0
                    .connection()
289
0
                    .await
290
0
                    .err_tip(|| "in add_action()")?;
291
0
                ExecutionClient::new(channel)
292
0
                    .execute(Request::new(request))
293
0
                    .await
294
0
                    .err_tip(|| "Sending action to upstream scheduler")
295
0
            })
296
0
            .await?
297
0
            .into_inner();
298
0
        Self::stream_state(result_stream).await
299
0
    }
300
301
0
    async fn inner_filter_operations(
302
0
        &self,
303
0
        filter: OperationFilter,
304
0
    ) -> Result<ActionStateResultStream, Error> {
305
0
        error_if!(filter != OperationFilter {
  Branch (305:19): [True: 0, False: 0]
  Branch (305:19): [Folded - Ignored]
306
0
            client_operation_id: filter.client_operation_id.clone(),
307
0
            ..Default::default()
308
0
        }, "Unsupported filter in GrpcScheduler::filter_operations. Only client_operation_id is supported - {filter:?}");
309
0
        let client_operation_id = filter.client_operation_id.ok_or_else(|| {
310
0
            make_err!(Code::InvalidArgument, "`client_operation_id` is the only supported filter in GrpcScheduler::filter_operations")
311
0
        })?;
312
0
        let request = WaitExecutionRequest {
313
0
            name: client_operation_id.to_string(),
314
0
        };
315
0
        let result_stream = self
316
0
            .perform_request(request, |request| async move {
317
0
                let channel = self
318
0
                    .connection_manager
319
0
                    .connection()
320
0
                    .await
321
0
                    .err_tip(|| "in find_by_client_operation_id()")?;
322
0
                ExecutionClient::new(channel)
323
0
                    .wait_execution(Request::new(request))
324
0
                    .await
325
0
                    .err_tip(|| "While getting wait_execution stream")
326
0
            })
327
0
            .and_then(|result_stream| Self::stream_state(result_stream.into_inner()))
328
0
            .await;
329
0
        match result_stream {
330
0
            Ok(result_stream) => Ok(unfold(
331
0
                Some(result_stream),
332
0
                |maybe_result_stream| async move { maybe_result_stream.map(|v| (v, None)) },
333
0
            )
334
0
            .boxed()),
335
0
            Err(err) => {
336
0
                event!(
337
0
                    Level::WARN,
338
                    ?err,
339
0
                    "Error looking up action with upstream scheduler"
340
                );
341
0
                Ok(futures::stream::empty().boxed())
342
            }
343
        }
344
0
    }
345
}
346
347
#[async_trait]
348
impl ClientStateManager for GrpcScheduler {
349
    async fn add_action(
350
        &self,
351
        client_operation_id: OperationId,
352
        action_info: Arc<ActionInfo>,
353
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
354
0
        self.inner_add_action(client_operation_id, action_info)
355
0
            .await
356
0
    }
357
358
    async fn filter_operations<'a>(
359
        &'a self,
360
        filter: OperationFilter,
361
0
    ) -> Result<ActionStateResultStream<'a>, Error> {
362
0
        self.inner_filter_operations(filter).await
363
0
    }
364
365
0
    fn as_known_platform_property_provider(&self) -> Option<&dyn KnownPlatformPropertyProvider> {
366
0
        Some(self)
367
0
    }
368
}
369
370
#[async_trait]
371
impl KnownPlatformPropertyProvider for GrpcScheduler {
372
0
    async fn get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
373
0
        self.inner_get_known_properties(instance_name).await
374
0
    }
375
}
376
377
impl RootMetricsComponent for GrpcScheduler {}