Coverage Report

Created: 2025-09-16 19:42

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/grpc_scheduler.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::future::Future;
16
use core::time::Duration;
17
use std::collections::HashMap;
18
use std::sync::Arc;
19
20
use async_trait::async_trait;
21
use futures::stream::unfold;
22
use futures::{StreamExt, TryFutureExt};
23
use nativelink_config::schedulers::GrpcSpec;
24
use nativelink_error::{Code, Error, ResultExt, error_if, make_err};
25
use nativelink_metric::{MetricsComponent, RootMetricsComponent};
26
use nativelink_proto::build::bazel::remote::execution::v2::capabilities_client::CapabilitiesClient;
27
use nativelink_proto::build::bazel::remote::execution::v2::execution_client::ExecutionClient;
28
use nativelink_proto::build::bazel::remote::execution::v2::{
29
    ExecuteRequest, ExecutionPolicy, GetCapabilitiesRequest, WaitExecutionRequest,
30
};
31
use nativelink_proto::google::longrunning::Operation;
32
use nativelink_util::action_messages::{
33
    ActionInfo, ActionState, ActionUniqueQualifier, DEFAULT_EXECUTION_PRIORITY, OperationId,
34
};
35
use nativelink_util::connection_manager::ConnectionManager;
36
use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider;
37
use nativelink_util::operation_state_manager::{
38
    ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter,
39
};
40
use nativelink_util::origin_event::OriginMetadata;
41
use nativelink_util::retry::{Retrier, RetryResult};
42
use nativelink_util::{background_spawn, tls_utils};
43
use parking_lot::Mutex;
44
use tokio::select;
45
use tokio::sync::watch;
46
use tokio::time::sleep;
47
use tonic::{Request, Streaming};
48
use tracing::{error, info, warn};
49
50
struct GrpcActionStateResult {
51
    client_operation_id: OperationId,
52
    rx: watch::Receiver<Arc<ActionState>>,
53
}
54
55
#[async_trait]
56
impl ActionStateResult for GrpcActionStateResult {
57
0
    async fn as_state(&self) -> Result<(Arc<ActionState>, Option<OriginMetadata>), Error> {
58
0
        let mut action_state = self.rx.borrow().clone();
59
0
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
60
        // TODO(palfrey) We currently don't support OriginMetadata in this implementation, but
61
        // we should.
62
0
        Ok((action_state, None))
63
0
    }
64
65
0
    async fn changed(&mut self) -> Result<(Arc<ActionState>, Option<OriginMetadata>), Error> {
66
0
        self.rx.changed().await.map_err(|_| {
67
0
            make_err!(
68
0
                Code::Internal,
69
                "Channel closed in GrpcActionStateResult::changed"
70
            )
71
0
        })?;
72
0
        let mut action_state = self.rx.borrow().clone();
73
0
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
74
        // TODO(palfrey) We currently don't support OriginMetadata in this implementation, but
75
        // we should.
76
0
        Ok((action_state, None))
77
0
    }
78
79
0
    async fn as_action_info(&self) -> Result<(Arc<ActionInfo>, Option<OriginMetadata>), Error> {
80
        // TODO(palfrey) We should probably remove as_action_info()
81
        // or implement it properly.
82
0
        return Err(make_err!(
83
0
            Code::Unimplemented,
84
0
            "as_action_info not implemented for GrpcActionStateResult::as_action_info"
85
0
        ));
86
0
    }
87
}
88
89
#[derive(Debug, MetricsComponent)]
90
pub struct GrpcScheduler {
91
    #[metric(group = "property_managers")]
92
    supported_props: Mutex<HashMap<String, Vec<String>>>,
93
    retrier: Retrier,
94
    connection_manager: ConnectionManager,
95
}
96
97
impl GrpcScheduler {
98
0
    pub fn new(spec: &GrpcSpec) -> Result<Self, Error> {
99
0
        Self::new_with_jitter(spec, spec.retry.make_jitter_fn())
100
0
    }
101
102
0
    pub fn new_with_jitter(
103
0
        spec: &GrpcSpec,
104
0
        jitter_fn: Arc<dyn Fn(Duration) -> Duration + Send + Sync>,
105
0
    ) -> Result<Self, Error> {
106
0
        let endpoint = tls_utils::endpoint(&spec.endpoint)?;
107
        Ok(Self {
108
0
            supported_props: Mutex::new(HashMap::new()),
109
0
            retrier: Retrier::new(
110
0
                Arc::new(|duration| Box::pin(sleep(duration))),
111
0
                jitter_fn.clone(),
112
0
                spec.retry.clone(),
113
            ),
114
0
            connection_manager: ConnectionManager::new(
115
0
                core::iter::once(endpoint),
116
0
                spec.connections_per_endpoint,
117
0
                spec.max_concurrent_requests,
118
0
                spec.retry.clone(),
119
0
                jitter_fn,
120
            ),
121
        })
122
0
    }
123
124
0
    async fn perform_request<F, Fut, R, I>(&self, input: I, mut request: F) -> Result<R, Error>
125
0
    where
126
0
        F: FnMut(I) -> Fut + Send + Copy,
127
0
        Fut: Future<Output = Result<R, Error>> + Send,
128
0
        R: Send,
129
0
        I: Send + Clone,
130
0
    {
131
0
        self.retrier
132
0
            .retry(unfold(input, move |input| async move {
133
0
                let input_clone = input.clone();
134
                Some((
135
0
                    request(input_clone)
136
0
                        .await
137
0
                        .map_or_else(RetryResult::Retry, RetryResult::Ok),
138
0
                    input,
139
                ))
140
0
            }))
141
0
            .await
142
0
    }
143
144
0
    async fn stream_state(
145
0
        mut result_stream: Streaming<Operation>,
146
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
147
0
        if let Some(initial_response) = result_stream
  Branch (147:16): [True: 0, False: 0]
  Branch (147:16): [Folded - Ignored]
148
0
            .message()
149
0
            .await
150
0
            .err_tip(|| "Receiving response from upstream scheduler")?
151
        {
152
0
            let client_operation_id = OperationId::from(initial_response.name.as_str());
153
            // Our operation_id is not needed here is just a place holder to recycle existing object.
154
            // The only thing that actually matters is the operation_id.
155
0
            let operation_id = OperationId::default();
156
0
            let action_state =
157
0
                ActionState::try_from_operation(initial_response, operation_id.clone())
158
0
                    .err_tip(|| "In GrpcScheduler::stream_state")?;
159
0
            let (tx, mut rx) = watch::channel(Arc::new(action_state));
160
0
            rx.mark_changed();
161
0
            background_spawn!("grpc_scheduler_stream_state", async move {
162
                loop {
163
0
                    select!(
164
0
                        () = tx.closed() => {
165
0
                            info!(
166
0
                                "Client disconnected in GrpcScheduler"
167
                            );
168
0
                            return;
169
                        }
170
0
                        response = result_stream.message() => {
171
                            // When the upstream closes the channel, close the
172
                            // downstream too.
173
0
                            let Ok(Some(response)) = response else {
  Branch (173:33): [True: 0, False: 0]
  Branch (173:33): [Folded - Ignored]
174
0
                                return;
175
                            };
176
0
                            let maybe_action_state = ActionState::try_from_operation(response, operation_id.clone());
177
0
                            match maybe_action_state {
178
0
                                Ok(response) => {
179
0
                                    if let Err(err) = tx.send(Arc::new(response)) {
  Branch (179:44): [True: 0, False: 0]
  Branch (179:44): [Folded - Ignored]
180
0
                                        info!(
181
                                            ?err,
182
0
                                            "Client error in GrpcScheduler"
183
                                        );
184
0
                                        return;
185
0
                                    }
186
                                }
187
0
                                Err(err) => {
188
0
                                    error!(
189
                                        ?err,
190
0
                                        "Error converting response to ActionState in GrpcScheduler"
191
                                    );
192
                                },
193
                            }
194
                        }
195
                    );
196
                }
197
0
            });
198
199
0
            return Ok(Box::new(GrpcActionStateResult {
200
0
                client_operation_id,
201
0
                rx,
202
0
            }));
203
0
        }
204
0
        Err(make_err!(
205
0
            Code::Internal,
206
0
            "Upstream scheduler didn't accept action."
207
0
        ))
208
0
    }
209
210
0
    async fn inner_get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
211
0
        if let Some(supported_props) = self.supported_props.lock().get(instance_name) {
  Branch (211:16): [True: 0, False: 0]
  Branch (211:16): [Folded - Ignored]
212
0
            return Ok(supported_props.clone());
213
0
        }
214
215
0
        self.perform_request(instance_name, |instance_name| async move {
216
            // Not in the cache, lookup the capabilities with the upstream.
217
0
            let channel = self
218
0
                .connection_manager
219
0
                .connection()
220
0
                .await
221
0
                .err_tip(|| "in get_platform_property_manager()")?;
222
0
            let capabilities_result = CapabilitiesClient::new(channel)
223
0
                .get_capabilities(GetCapabilitiesRequest {
224
0
                    instance_name: instance_name.to_string(),
225
0
                })
226
0
                .await
227
0
                .err_tip(|| "Retrieving upstream GrpcScheduler capabilities");
228
0
            let capabilities = capabilities_result?.into_inner();
229
0
            let supported_props = capabilities
230
0
                .execution_capabilities
231
0
                .err_tip(|| "Unable to get execution properties in GrpcScheduler")?
232
                .supported_node_properties
233
0
                .into_iter()
234
0
                .collect::<Vec<String>>();
235
236
0
            self.supported_props
237
0
                .lock()
238
0
                .insert(instance_name.to_string(), supported_props.clone());
239
0
            Ok(supported_props)
240
0
        })
241
0
        .await
242
0
    }
243
244
0
    async fn inner_add_action(
245
0
        &self,
246
0
        _client_operation_id: OperationId,
247
0
        action_info: Arc<ActionInfo>,
248
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
249
0
        let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY {
  Branch (249:35): [True: 0, False: 0]
  Branch (249:35): [Folded - Ignored]
250
0
            None
251
        } else {
252
0
            Some(ExecutionPolicy {
253
0
                priority: action_info.priority,
254
0
            })
255
        };
256
0
        let skip_cache_lookup = match action_info.unique_qualifier {
257
0
            ActionUniqueQualifier::Cacheable(_) => false,
258
0
            ActionUniqueQualifier::Uncacheable(_) => true,
259
        };
260
0
        let request = ExecuteRequest {
261
0
            instance_name: action_info.instance_name().clone(),
262
0
            skip_cache_lookup,
263
0
            action_digest: Some(action_info.digest().into()),
264
0
            execution_policy,
265
0
            // TODO: Get me from the original request, not very important as we ignore it.
266
0
            results_cache_policy: None,
267
0
            digest_function: action_info
268
0
                .unique_qualifier
269
0
                .digest_function()
270
0
                .proto_digest_func()
271
0
                .into(),
272
0
        };
273
0
        let result_stream = self
274
0
            .perform_request(request, |request| async move {
275
0
                let channel = self
276
0
                    .connection_manager
277
0
                    .connection()
278
0
                    .await
279
0
                    .err_tip(|| "in add_action()")?;
280
0
                ExecutionClient::new(channel)
281
0
                    .execute(Request::new(request))
282
0
                    .await
283
0
                    .err_tip(|| "Sending action to upstream scheduler")
284
0
            })
285
0
            .await?
286
0
            .into_inner();
287
0
        Self::stream_state(result_stream).await
288
0
    }
289
290
0
    async fn inner_filter_operations(
291
0
        &self,
292
0
        filter: OperationFilter,
293
0
    ) -> Result<ActionStateResultStream<'_>, Error> {
294
0
        error_if!(
295
0
            filter
  Branch (295:13): [True: 0, False: 0]
  Branch (295:13): [Folded - Ignored]
296
0
                != OperationFilter {
297
0
                    client_operation_id: filter.client_operation_id.clone(),
298
0
                    ..Default::default()
299
0
                },
300
            "Unsupported filter in GrpcScheduler::filter_operations. Only client_operation_id is supported - {filter:?}"
301
        );
302
0
        let client_operation_id = filter.client_operation_id.ok_or_else(|| {
303
0
            make_err!(Code::InvalidArgument, "`client_operation_id` is the only supported filter in GrpcScheduler::filter_operations")
304
0
        })?;
305
0
        let request = WaitExecutionRequest {
306
0
            name: client_operation_id.to_string(),
307
0
        };
308
0
        let result_stream = self
309
0
            .perform_request(request, |request| async move {
310
0
                let channel = self
311
0
                    .connection_manager
312
0
                    .connection()
313
0
                    .await
314
0
                    .err_tip(|| "in find_by_client_operation_id()")?;
315
0
                ExecutionClient::new(channel)
316
0
                    .wait_execution(Request::new(request))
317
0
                    .await
318
0
                    .err_tip(|| "While getting wait_execution stream")
319
0
            })
320
0
            .and_then(|result_stream| Self::stream_state(result_stream.into_inner()))
321
0
            .await;
322
0
        match result_stream {
323
0
            Ok(result_stream) => Ok(unfold(
324
0
                Some(result_stream),
325
0
                |maybe_result_stream| async move { maybe_result_stream.map(|v| (v, None)) },
326
            )
327
0
            .boxed()),
328
0
            Err(err) => {
329
0
                warn!(?err, "Error looking up action with upstream scheduler");
330
0
                Ok(futures::stream::empty().boxed())
331
            }
332
        }
333
0
    }
334
}
335
336
#[async_trait]
337
impl ClientStateManager for GrpcScheduler {
338
    async fn add_action(
339
        &self,
340
        client_operation_id: OperationId,
341
        action_info: Arc<ActionInfo>,
342
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
343
0
        self.inner_add_action(client_operation_id, action_info)
344
0
            .await
345
0
    }
346
347
    async fn filter_operations<'a>(
348
        &'a self,
349
        filter: OperationFilter,
350
0
    ) -> Result<ActionStateResultStream<'a>, Error> {
351
0
        self.inner_filter_operations(filter).await
352
0
    }
353
354
0
    fn as_known_platform_property_provider(&self) -> Option<&dyn KnownPlatformPropertyProvider> {
355
0
        Some(self)
356
0
    }
357
}
358
359
#[async_trait]
360
impl KnownPlatformPropertyProvider for GrpcScheduler {
361
0
    async fn get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
362
0
        self.inner_get_known_properties(instance_name).await
363
0
    }
364
}
365
366
impl RootMetricsComponent for GrpcScheduler {}