Coverage Report

Created: 2025-03-08 07:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/grpc_scheduler.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::collections::HashMap;
16
use std::future::Future;
17
use std::sync::Arc;
18
use std::time::Duration;
19
20
use async_trait::async_trait;
21
use futures::stream::unfold;
22
use futures::{StreamExt, TryFutureExt};
23
use nativelink_config::schedulers::GrpcSpec;
24
use nativelink_error::{error_if, make_err, Code, Error, ResultExt};
25
use nativelink_metric::{MetricsComponent, RootMetricsComponent};
26
use nativelink_proto::build::bazel::remote::execution::v2::capabilities_client::CapabilitiesClient;
27
use nativelink_proto::build::bazel::remote::execution::v2::execution_client::ExecutionClient;
28
use nativelink_proto::build::bazel::remote::execution::v2::{
29
    ExecuteRequest, ExecutionPolicy, GetCapabilitiesRequest, WaitExecutionRequest,
30
};
31
use nativelink_proto::google::longrunning::Operation;
32
use nativelink_util::action_messages::{
33
    ActionInfo, ActionState, ActionUniqueQualifier, OperationId, DEFAULT_EXECUTION_PRIORITY,
34
};
35
use nativelink_util::connection_manager::ConnectionManager;
36
use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider;
37
use nativelink_util::operation_state_manager::{
38
    ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter,
39
};
40
use nativelink_util::origin_event::OriginMetadata;
41
use nativelink_util::retry::{Retrier, RetryResult};
42
use nativelink_util::{background_spawn, tls_utils};
43
use parking_lot::Mutex;
44
use rand::Rng;
45
use tokio::select;
46
use tokio::sync::watch;
47
use tokio::time::sleep;
48
use tonic::{Request, Streaming};
49
use tracing::{event, Level};
50
51
struct GrpcActionStateResult {
52
    client_operation_id: OperationId,
53
    rx: watch::Receiver<Arc<ActionState>>,
54
}
55
56
#[async_trait]
57
impl ActionStateResult for GrpcActionStateResult {
58
0
    async fn as_state(&self) -> Result<(Arc<ActionState>, Option<OriginMetadata>), Error> {
59
0
        let mut action_state = self.rx.borrow().clone();
60
0
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
61
0
        // TODO(allada) We currently don't support OriginMetadata in this implementation, but
62
0
        // we should.
63
0
        Ok((action_state, None))
64
0
    }
65
66
0
    async fn changed(&mut self) -> Result<(Arc<ActionState>, Option<OriginMetadata>), Error> {
67
0
        self.rx.changed().await.map_err(|_| {
68
0
            make_err!(
69
0
                Code::Internal,
70
0
                "Channel closed in GrpcActionStateResult::changed"
71
0
            )
72
0
        })?;
73
0
        let mut action_state = self.rx.borrow().clone();
74
0
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
75
0
        // TODO(allada) We currently don't support OriginMetadata in this implementation, but
76
0
        // we should.
77
0
        Ok((action_state, None))
78
0
    }
79
80
0
    async fn as_action_info(&self) -> Result<(Arc<ActionInfo>, Option<OriginMetadata>), Error> {
81
        // TODO(allada) We should probably remove as_action_info()
82
        // or implement it properly.
83
0
        return Err(make_err!(
84
0
            Code::Unimplemented,
85
0
            "as_action_info not implemented for GrpcActionStateResult::as_action_info"
86
0
        ));
87
0
    }
88
}
89
90
#[derive(MetricsComponent)]
91
pub struct GrpcScheduler {
92
    #[metric(group = "property_managers")]
93
    supported_props: Mutex<HashMap<String, Vec<String>>>,
94
    retrier: Retrier,
95
    connection_manager: ConnectionManager,
96
}
97
98
impl GrpcScheduler {
99
0
    pub fn new(spec: &GrpcSpec) -> Result<Self, Error> {
100
0
        let jitter_amt = spec.retry.jitter;
101
0
        Self::new_with_jitter(
102
0
            spec,
103
0
            Box::new(move |delay: Duration| {
104
0
                if jitter_amt == 0. {
  Branch (104:20): [True: 0, False: 0]
  Branch (104:20): [Folded - Ignored]
105
0
                    return delay;
106
0
                }
107
0
                let min = 1. - (jitter_amt / 2.);
108
0
                let max = 1. + (jitter_amt / 2.);
109
0
                delay.mul_f32(rand::rng().random_range(min..max))
110
0
            }),
111
0
        )
112
0
    }
113
114
0
    pub fn new_with_jitter(
115
0
        spec: &GrpcSpec,
116
0
        jitter_fn: Box<dyn Fn(Duration) -> Duration + Send + Sync>,
117
0
    ) -> Result<Self, Error> {
118
0
        let endpoint = tls_utils::endpoint(&spec.endpoint)?;
119
0
        let jitter_fn = Arc::new(jitter_fn);
120
0
        Ok(Self {
121
0
            supported_props: Mutex::new(HashMap::new()),
122
0
            retrier: Retrier::new(
123
0
                Arc::new(|duration| Box::pin(sleep(duration))),
124
0
                jitter_fn.clone(),
125
0
                spec.retry.clone(),
126
0
            ),
127
0
            connection_manager: ConnectionManager::new(
128
0
                std::iter::once(endpoint),
129
0
                spec.connections_per_endpoint,
130
0
                spec.max_concurrent_requests,
131
0
                spec.retry.clone(),
132
0
                jitter_fn,
133
0
            ),
134
0
        })
135
0
    }
136
137
0
    async fn perform_request<F, Fut, R, I>(&self, input: I, mut request: F) -> Result<R, Error>
138
0
    where
139
0
        F: FnMut(I) -> Fut + Send + Copy,
140
0
        Fut: Future<Output = Result<R, Error>> + Send,
141
0
        R: Send,
142
0
        I: Send + Clone,
143
0
    {
144
0
        self.retrier
145
0
            .retry(unfold(input, move |input| async move {
146
0
                let input_clone = input.clone();
147
0
                Some((
148
0
                    request(input_clone)
149
0
                        .await
150
0
                        .map_or_else(RetryResult::Retry, RetryResult::Ok),
151
0
                    input,
152
0
                ))
153
0
            }))
154
0
            .await
155
0
    }
156
157
0
    async fn stream_state(
158
0
        mut result_stream: Streaming<Operation>,
159
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
160
0
        if let Some(initial_response) = result_stream
  Branch (160:16): [True: 0, False: 0]
  Branch (160:16): [Folded - Ignored]
161
0
            .message()
162
0
            .await
163
0
            .err_tip(|| "Recieving response from upstream scheduler")?
164
        {
165
0
            let client_operation_id = OperationId::from(initial_response.name.as_str());
166
0
            // Our operation_id is not needed here is just a place holder to recycle existing object.
167
0
            // The only thing that actually matters is the operation_id.
168
0
            let operation_id = OperationId::default();
169
0
            let action_state =
170
0
                ActionState::try_from_operation(initial_response, operation_id.clone())
171
0
                    .err_tip(|| "In GrpcScheduler::stream_state")?;
172
0
            let (tx, mut rx) = watch::channel(Arc::new(action_state));
173
0
            rx.mark_changed();
174
0
            background_spawn!("grpc_scheduler_stream_state", async move {
175
                loop {
176
0
                    select!(
177
0
                        () = tx.closed() => {
178
0
                            event!(
179
0
                                Level::INFO,
180
0
                                "Client disconnected in GrpcScheduler"
181
                            );
182
0
                            return;
183
                        }
184
0
                        response = result_stream.message() => {
185
                            // When the upstream closes the channel, close the
186
                            // downstream too.
187
0
                            let Ok(Some(response)) = response else {
  Branch (187:33): [True: 0, False: 0]
  Branch (187:33): [Folded - Ignored]
188
0
                                return;
189
                            };
190
0
                            let maybe_action_state = ActionState::try_from_operation(response, operation_id.clone());
191
0
                            match maybe_action_state {
192
0
                                Ok(response) => {
193
0
                                    if let Err(err) = tx.send(Arc::new(response)) {
  Branch (193:44): [True: 0, False: 0]
  Branch (193:44): [Folded - Ignored]
194
0
                                        event!(
195
0
                                            Level::INFO,
196
                                            ?err,
197
0
                                            "Client error in GrpcScheduler"
198
                                        );
199
0
                                        return;
200
0
                                    }
201
                                }
202
0
                                Err(err) => {
203
0
                                    event!(
204
0
                                        Level::ERROR,
205
                                        ?err,
206
0
                                        "Error converting response to ActionState in GrpcScheduler"
207
                                    );
208
                                },
209
                            }
210
                        }
211
                    );
212
                }
213
0
            });
214
215
0
            return Ok(Box::new(GrpcActionStateResult {
216
0
                client_operation_id,
217
0
                rx,
218
0
            }));
219
0
        }
220
0
        Err(make_err!(
221
0
            Code::Internal,
222
0
            "Upstream scheduler didn't accept action."
223
0
        ))
224
0
    }
225
226
0
    async fn inner_get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
227
0
        if let Some(supported_props) = self.supported_props.lock().get(instance_name) {
  Branch (227:16): [True: 0, False: 0]
  Branch (227:16): [Folded - Ignored]
228
0
            return Ok(supported_props.clone());
229
0
        }
230
0
231
0
        self.perform_request(instance_name, |instance_name| async move {
232
            // Not in the cache, lookup the capabilities with the upstream.
233
0
            let channel = self
234
0
                .connection_manager
235
0
                .connection()
236
0
                .await
237
0
                .err_tip(|| "in get_platform_property_manager()")?;
238
0
            let capabilities_result = CapabilitiesClient::new(channel)
239
0
                .get_capabilities(GetCapabilitiesRequest {
240
0
                    instance_name: instance_name.to_string(),
241
0
                })
242
0
                .await
243
0
                .err_tip(|| "Retrieving upstream GrpcScheduler capabilities");
244
0
            let capabilities = capabilities_result?.into_inner();
245
0
            let supported_props = capabilities
246
0
                .execution_capabilities
247
0
                .err_tip(|| "Unable to get execution properties in GrpcScheduler")?
248
                .supported_node_properties
249
0
                .into_iter()
250
0
                .collect::<Vec<String>>();
251
0
252
0
            self.supported_props
253
0
                .lock()
254
0
                .insert(instance_name.to_string(), supported_props.clone());
255
0
            Ok(supported_props)
256
0
        })
257
0
        .await
258
0
    }
259
260
0
    async fn inner_add_action(
261
0
        &self,
262
0
        _client_operation_id: OperationId,
263
0
        action_info: Arc<ActionInfo>,
264
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
265
0
        let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY {
  Branch (265:35): [True: 0, False: 0]
  Branch (265:35): [Folded - Ignored]
266
0
            None
267
        } else {
268
0
            Some(ExecutionPolicy {
269
0
                priority: action_info.priority,
270
0
            })
271
        };
272
0
        let skip_cache_lookup = match action_info.unique_qualifier {
273
0
            ActionUniqueQualifier::Cachable(_) => false,
274
0
            ActionUniqueQualifier::Uncachable(_) => true,
275
        };
276
0
        let request = ExecuteRequest {
277
0
            instance_name: action_info.instance_name().clone(),
278
0
            skip_cache_lookup,
279
0
            action_digest: Some(action_info.digest().into()),
280
0
            execution_policy,
281
0
            // TODO: Get me from the original request, not very important as we ignore it.
282
0
            results_cache_policy: None,
283
0
            digest_function: action_info
284
0
                .unique_qualifier
285
0
                .digest_function()
286
0
                .proto_digest_func()
287
0
                .into(),
288
0
        };
289
0
        let result_stream = self
290
0
            .perform_request(request, |request| async move {
291
0
                let channel = self
292
0
                    .connection_manager
293
0
                    .connection()
294
0
                    .await
295
0
                    .err_tip(|| "in add_action()")?;
296
0
                ExecutionClient::new(channel)
297
0
                    .execute(Request::new(request))
298
0
                    .await
299
0
                    .err_tip(|| "Sending action to upstream scheduler")
300
0
            })
301
0
            .await?
302
0
            .into_inner();
303
0
        Self::stream_state(result_stream).await
304
0
    }
305
306
0
    async fn inner_filter_operations(
307
0
        &self,
308
0
        filter: OperationFilter,
309
0
    ) -> Result<ActionStateResultStream, Error> {
310
0
        error_if!(filter != OperationFilter {
  Branch (310:19): [True: 0, False: 0]
  Branch (310:19): [Folded - Ignored]
311
0
            client_operation_id: filter.client_operation_id.clone(),
312
0
            ..Default::default()
313
0
        }, "Unsupported filter in GrpcScheduler::filter_operations. Only client_operation_id is supported - {filter:?}");
314
0
        let client_operation_id = filter.client_operation_id.ok_or_else(|| {
315
0
            make_err!(Code::InvalidArgument, "`client_operation_id` is the only supported filter in GrpcScheduler::filter_operations")
316
0
        })?;
317
0
        let request = WaitExecutionRequest {
318
0
            name: client_operation_id.to_string(),
319
0
        };
320
0
        let result_stream = self
321
0
            .perform_request(request, |request| async move {
322
0
                let channel = self
323
0
                    .connection_manager
324
0
                    .connection()
325
0
                    .await
326
0
                    .err_tip(|| "in find_by_client_operation_id()")?;
327
0
                ExecutionClient::new(channel)
328
0
                    .wait_execution(Request::new(request))
329
0
                    .await
330
0
                    .err_tip(|| "While getting wait_execution stream")
331
0
            })
332
0
            .and_then(|result_stream| Self::stream_state(result_stream.into_inner()))
333
0
            .await;
334
0
        match result_stream {
335
0
            Ok(result_stream) => Ok(unfold(
336
0
                Some(result_stream),
337
0
                |maybe_result_stream| async move { maybe_result_stream.map(|v| (v, None)) },
338
0
            )
339
0
            .boxed()),
340
0
            Err(err) => {
341
0
                event!(
342
0
                    Level::WARN,
343
                    ?err,
344
0
                    "Error looking up action with upstream scheduler"
345
                );
346
0
                Ok(futures::stream::empty().boxed())
347
            }
348
        }
349
0
    }
350
}
351
352
#[async_trait]
353
impl ClientStateManager for GrpcScheduler {
354
    async fn add_action(
355
        &self,
356
        client_operation_id: OperationId,
357
        action_info: Arc<ActionInfo>,
358
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
359
0
        self.inner_add_action(client_operation_id, action_info)
360
0
            .await
361
0
    }
362
363
    async fn filter_operations<'a>(
364
        &'a self,
365
        filter: OperationFilter,
366
0
    ) -> Result<ActionStateResultStream<'a>, Error> {
367
0
        self.inner_filter_operations(filter).await
368
0
    }
369
370
0
    fn as_known_platform_property_provider(&self) -> Option<&dyn KnownPlatformPropertyProvider> {
371
0
        Some(self)
372
0
    }
373
}
374
375
#[async_trait]
376
impl KnownPlatformPropertyProvider for GrpcScheduler {
377
0
    async fn get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
378
0
        self.inner_get_known_properties(instance_name).await
379
0
    }
380
}
381
382
impl RootMetricsComponent for GrpcScheduler {}