Coverage Report

Created: 2024-11-20 10:13

/build/source/nativelink-scheduler/src/grpc_scheduler.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::HashMap;
16
use std::future::Future;
17
use std::sync::Arc;
18
use std::time::Duration;
19
20
use async_trait::async_trait;
21
use futures::stream::unfold;
22
use futures::{StreamExt, TryFutureExt};
23
use nativelink_config::schedulers::GrpcSpec;
24
use nativelink_error::{error_if, make_err, Code, Error, ResultExt};
25
use nativelink_metric::{MetricsComponent, RootMetricsComponent};
26
use nativelink_proto::build::bazel::remote::execution::v2::capabilities_client::CapabilitiesClient;
27
use nativelink_proto::build::bazel::remote::execution::v2::execution_client::ExecutionClient;
28
use nativelink_proto::build::bazel::remote::execution::v2::{
29
    ExecuteRequest, ExecutionPolicy, GetCapabilitiesRequest, WaitExecutionRequest,
30
};
31
use nativelink_proto::google::longrunning::Operation;
32
use nativelink_util::action_messages::{
33
    ActionInfo, ActionState, ActionUniqueQualifier, OperationId, DEFAULT_EXECUTION_PRIORITY,
34
};
35
use nativelink_util::connection_manager::ConnectionManager;
36
use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider;
37
use nativelink_util::operation_state_manager::{
38
    ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter,
39
};
40
use nativelink_util::retry::{Retrier, RetryResult};
41
use nativelink_util::{background_spawn, tls_utils};
42
use parking_lot::Mutex;
43
use rand::rngs::OsRng;
44
use rand::Rng;
45
use tokio::select;
46
use tokio::sync::watch;
47
use tokio::time::sleep;
48
use tonic::{Request, Streaming};
49
use tracing::{event, Level};
50
51
struct GrpcActionStateResult {
52
    client_operation_id: OperationId,
53
    rx: watch::Receiver<Arc<ActionState>>,
54
}
55
56
#[async_trait]
57
impl ActionStateResult for GrpcActionStateResult {
58
0
    async fn as_state(&self) -> Result<Arc<ActionState>, Error> {
59
0
        let mut action_state = self.rx.borrow().clone();
60
0
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
61
0
        Ok(action_state)
62
0
    }
63
64
0
    async fn changed(&mut self) -> Result<Arc<ActionState>, Error> {
65
0
        self.rx.changed().await.map_err(|_| {
66
0
            make_err!(
67
0
                Code::Internal,
68
0
                "Channel closed in GrpcActionStateResult::changed"
69
0
            )
70
0
        })?;
71
0
        let mut action_state = self.rx.borrow().clone();
72
0
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
73
0
        Ok(action_state)
74
0
    }
75
76
0
    async fn as_action_info(&self) -> Result<Arc<ActionInfo>, Error> {
77
        // TODO(allada) We should probably remove as_action_info()
78
        // or implement it properly.
79
0
        return Err(make_err!(
80
0
            Code::Unimplemented,
81
0
            "as_action_info not implemented for GrpcActionStateResult::as_action_info"
82
0
        ));
83
0
    }
84
}
85
86
0
#[derive(MetricsComponent)]
87
pub struct GrpcScheduler {
88
    #[metric(group = "property_managers")]
89
    supported_props: Mutex<HashMap<String, Vec<String>>>,
90
    retrier: Retrier,
91
    connection_manager: ConnectionManager,
92
}
93
94
impl GrpcScheduler {
95
0
    pub fn new(spec: &GrpcSpec) -> Result<Self, Error> {
96
0
        let jitter_amt = spec.retry.jitter;
97
0
        Self::new_with_jitter(
98
0
            spec,
99
0
            Box::new(move |delay: Duration| {
100
0
                if jitter_amt == 0. {
  Branch (100:20): [True: 0, False: 0]
  Branch (100:20): [Folded - Ignored]
101
0
                    return delay;
102
0
                }
103
0
                let min = 1. - (jitter_amt / 2.);
104
0
                let max = 1. + (jitter_amt / 2.);
105
0
                delay.mul_f32(OsRng.gen_range(min..max))
106
0
            }),
107
0
        )
108
0
    }
109
110
0
    pub fn new_with_jitter(
111
0
        spec: &GrpcSpec,
112
0
        jitter_fn: Box<dyn Fn(Duration) -> Duration + Send + Sync>,
113
0
    ) -> Result<Self, Error> {
114
0
        let endpoint = tls_utils::endpoint(&spec.endpoint)?;
115
0
        let jitter_fn = Arc::new(jitter_fn);
116
0
        Ok(Self {
117
0
            supported_props: Mutex::new(HashMap::new()),
118
0
            retrier: Retrier::new(
119
0
                Arc::new(|duration| Box::pin(sleep(duration))),
120
0
                jitter_fn.clone(),
121
0
                spec.retry.clone(),
122
0
            ),
123
0
            connection_manager: ConnectionManager::new(
124
0
                std::iter::once(endpoint),
125
0
                spec.connections_per_endpoint,
126
0
                spec.max_concurrent_requests,
127
0
                spec.retry.clone(),
128
0
                jitter_fn,
129
0
            ),
130
0
        })
131
0
    }
132
133
0
    async fn perform_request<F, Fut, R, I>(&self, input: I, mut request: F) -> Result<R, Error>
134
0
    where
135
0
        F: FnMut(I) -> Fut + Send + Copy,
136
0
        Fut: Future<Output = Result<R, Error>> + Send,
137
0
        R: Send,
138
0
        I: Send + Clone,
139
0
    {
140
0
        self.retrier
141
0
            .retry(unfold(input, move |input| async move {
142
0
                let input_clone = input.clone();
143
0
                Some((
144
0
                    request(input_clone)
145
0
                        .await
146
0
                        .map_or_else(RetryResult::Retry, RetryResult::Ok),
147
0
                    input,
148
0
                ))
149
0
            }))
150
0
            .await
151
0
    }
152
153
0
    async fn stream_state(
154
0
        mut result_stream: Streaming<Operation>,
155
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
156
0
        if let Some(initial_response) = result_stream
  Branch (156:16): [True: 0, False: 0]
  Branch (156:16): [Folded - Ignored]
157
0
            .message()
158
0
            .await
159
0
            .err_tip(|| "Recieving response from upstream scheduler")?
160
        {
161
0
            let client_operation_id = OperationId::from(initial_response.name.as_str());
162
0
            // Our operation_id is not needed here is just a place holder to recycle existing object.
163
0
            // The only thing that actually matters is the operation_id.
164
0
            let operation_id = OperationId::default();
165
0
            let action_state =
166
0
                ActionState::try_from_operation(initial_response, operation_id.clone())
167
0
                    .err_tip(|| "In GrpcScheduler::stream_state")?;
168
0
            let (tx, mut rx) = watch::channel(Arc::new(action_state));
169
0
            rx.mark_changed();
170
0
            background_spawn!("grpc_scheduler_stream_state", async move {
171
                loop {
172
0
                    select!(
173
0
                        _ = tx.closed() => {
174
0
                            event!(
175
0
                                Level::INFO,
176
0
                                "Client disconnected in GrpcScheduler"
177
                            );
178
0
                            return;
179
                        }
180
0
                        response = result_stream.message() => {
181
                            // When the upstream closes the channel, close the
182
                            // downstream too.
183
0
                            let Ok(Some(response)) = response else {
  Branch (183:33): [True: 0, False: 0]
  Branch (183:33): [Folded - Ignored]
184
0
                                return;
185
                            };
186
0
                            let maybe_action_state = ActionState::try_from_operation(response, operation_id.clone());
187
0
                            match maybe_action_state {
188
0
                                Ok(response) => {
189
0
                                    if let Err(err) = tx.send(Arc::new(response)) {
  Branch (189:44): [True: 0, False: 0]
  Branch (189:44): [Folded - Ignored]
190
0
                                        event!(
191
0
                                            Level::INFO,
192
                                            ?err,
193
0
                                            "Client error in GrpcScheduler"
194
                                        );
195
0
                                        return;
196
0
                                    }
197
                                }
198
0
                                Err(err) => {
199
0
                                    event!(
200
0
                                        Level::ERROR,
201
                                        ?err,
202
0
                                        "Error converting response to ActionState in GrpcScheduler"
203
                                    );
204
                                },
205
                            }
206
                        }
207
                    );
208
                }
209
0
            });
210
211
0
            return Ok(Box::new(GrpcActionStateResult {
212
0
                client_operation_id,
213
0
                rx,
214
0
            }));
215
0
        }
216
0
        Err(make_err!(
217
0
            Code::Internal,
218
0
            "Upstream scheduler didn't accept action."
219
0
        ))
220
0
    }
221
222
0
    async fn inner_get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
223
0
        if let Some(supported_props) = self.supported_props.lock().get(instance_name) {
  Branch (223:16): [True: 0, False: 0]
  Branch (223:16): [Folded - Ignored]
224
0
            return Ok(supported_props.clone());
225
0
        }
226
0
227
0
        self.perform_request(instance_name, |instance_name| async move {
228
            // Not in the cache, lookup the capabilities with the upstream.
229
0
            let channel = self
230
0
                .connection_manager
231
0
                .connection()
232
0
                .await
233
0
                .err_tip(|| "in get_platform_property_manager()")?;
234
0
            let capabilities_result = CapabilitiesClient::new(channel)
235
0
                .get_capabilities(GetCapabilitiesRequest {
236
0
                    instance_name: instance_name.to_string(),
237
0
                })
238
0
                .await
239
0
                .err_tip(|| "Retrieving upstream GrpcScheduler capabilities");
240
0
            let capabilities = capabilities_result?.into_inner();
241
0
            let supported_props = capabilities
242
0
                .execution_capabilities
243
0
                .err_tip(|| "Unable to get execution properties in GrpcScheduler")?
244
                .supported_node_properties
245
0
                .into_iter()
246
0
                .collect::<Vec<String>>();
247
0
248
0
            self.supported_props
249
0
                .lock()
250
0
                .insert(instance_name.to_string(), supported_props.clone());
251
0
            Ok(supported_props)
252
0
        })
253
0
        .await
254
0
    }
255
256
0
    async fn inner_add_action(
257
0
        &self,
258
0
        _client_operation_id: OperationId,
259
0
        action_info: Arc<ActionInfo>,
260
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
261
0
        let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY {
  Branch (261:35): [True: 0, False: 0]
  Branch (261:35): [Folded - Ignored]
262
0
            None
263
        } else {
264
0
            Some(ExecutionPolicy {
265
0
                priority: action_info.priority,
266
0
            })
267
        };
268
0
        let skip_cache_lookup = match action_info.unique_qualifier {
269
0
            ActionUniqueQualifier::Cachable(_) => false,
270
0
            ActionUniqueQualifier::Uncachable(_) => true,
271
        };
272
0
        let request = ExecuteRequest {
273
0
            instance_name: action_info.instance_name().clone(),
274
0
            skip_cache_lookup,
275
0
            action_digest: Some(action_info.digest().into()),
276
0
            execution_policy,
277
0
            // TODO: Get me from the original request, not very important as we ignore it.
278
0
            results_cache_policy: None,
279
0
            digest_function: action_info
280
0
                .unique_qualifier
281
0
                .digest_function()
282
0
                .proto_digest_func()
283
0
                .into(),
284
0
        };
285
0
        let result_stream = self
286
0
            .perform_request(request, |request| async move {
287
0
                let channel = self
288
0
                    .connection_manager
289
0
                    .connection()
290
0
                    .await
291
0
                    .err_tip(|| "in add_action()")?;
292
0
                ExecutionClient::new(channel)
293
0
                    .execute(Request::new(request))
294
0
                    .await
295
0
                    .err_tip(|| "Sending action to upstream scheduler")
296
0
            })
297
0
            .await?
298
0
            .into_inner();
299
0
        Self::stream_state(result_stream).await
300
0
    }
301
302
0
    async fn inner_filter_operations(
303
0
        &self,
304
0
        filter: OperationFilter,
305
0
    ) -> Result<ActionStateResultStream, Error> {
306
0
        error_if!(filter != OperationFilter {
  Branch (306:19): [True: 0, False: 0]
  Branch (306:19): [Folded - Ignored]
307
0
            client_operation_id: filter.client_operation_id.clone(),
308
0
            ..Default::default()
309
0
        }, "Unsupported filter in GrpcScheduler::filter_operations. Only client_operation_id is supported - {filter:?}");
310
0
        let client_operation_id = filter.client_operation_id.ok_or_else(|| {
311
0
            make_err!(Code::InvalidArgument, "`client_operation_id` is the only supported filter in GrpcScheduler::filter_operations")
312
0
        })?;
313
0
        let request = WaitExecutionRequest {
314
0
            name: client_operation_id.to_string(),
315
0
        };
316
0
        let result_stream = self
317
0
            .perform_request(request, |request| async move {
318
0
                let channel = self
319
0
                    .connection_manager
320
0
                    .connection()
321
0
                    .await
322
0
                    .err_tip(|| "in find_by_client_operation_id()")?;
323
0
                ExecutionClient::new(channel)
324
0
                    .wait_execution(Request::new(request))
325
0
                    .await
326
0
                    .err_tip(|| "While getting wait_execution stream")
327
0
            })
328
0
            .and_then(|result_stream| Self::stream_state(result_stream.into_inner()))
329
0
            .await;
330
0
        match result_stream {
331
0
            Ok(result_stream) => Ok(unfold(
332
0
                Some(result_stream),
333
0
                |maybe_result_stream| async move { maybe_result_stream.map(|v| (v, None)) },
334
0
            )
335
0
            .boxed()),
336
0
            Err(err) => {
337
0
                event!(
338
0
                    Level::WARN,
339
                    ?err,
340
0
                    "Error looking up action with upstream scheduler"
341
                );
342
0
                Ok(futures::stream::empty().boxed())
343
            }
344
        }
345
0
    }
346
}
347
348
#[async_trait]
349
impl ClientStateManager for GrpcScheduler {
350
    async fn add_action(
351
        &self,
352
        client_operation_id: OperationId,
353
        action_info: Arc<ActionInfo>,
354
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
355
0
        self.inner_add_action(client_operation_id, action_info)
356
0
            .await
357
0
    }
358
359
    async fn filter_operations<'a>(
360
        &'a self,
361
        filter: OperationFilter,
362
0
    ) -> Result<ActionStateResultStream<'a>, Error> {
363
0
        self.inner_filter_operations(filter).await
364
0
    }
365
366
0
    fn as_known_platform_property_provider(&self) -> Option<&dyn KnownPlatformPropertyProvider> {
367
0
        Some(self)
368
0
    }
369
}
370
371
#[async_trait]
372
impl KnownPlatformPropertyProvider for GrpcScheduler {
373
0
    async fn get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
374
0
        self.inner_get_known_properties(instance_name).await
375
0
    }
376
}
377
378
impl RootMetricsComponent for GrpcScheduler {}