Coverage Report

Created: 2026-05-23 21:09

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-scheduler/src/grpc_scheduler.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::future::Future;
16
use core::time::Duration;
17
use std::collections::HashMap;
18
use std::sync::Arc;
19
20
use async_trait::async_trait;
21
use futures::stream::unfold;
22
use futures::{StreamExt, TryFutureExt};
23
use nativelink_config::schedulers::GrpcSpec;
24
use nativelink_error::{Code, Error, ResultExt, error_if, make_err};
25
use nativelink_metric::{MetricsComponent, RootMetricsComponent};
26
use nativelink_proto::build::bazel::remote::execution::v2::capabilities_client::CapabilitiesClient;
27
use nativelink_proto::build::bazel::remote::execution::v2::execution_client::ExecutionClient;
28
use nativelink_proto::build::bazel::remote::execution::v2::{
29
    ExecuteRequest, ExecutionPolicy, GetCapabilitiesRequest, WaitExecutionRequest,
30
};
31
use nativelink_proto::google::longrunning::Operation;
32
use nativelink_util::action_messages::{
33
    ActionInfo, ActionState, ActionUniqueQualifier, DEFAULT_EXECUTION_PRIORITY, OperationId,
34
};
35
use nativelink_util::connection_manager::ConnectionManager;
36
use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider;
37
use nativelink_util::operation_state_manager::{
38
    ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter,
39
};
40
use nativelink_util::origin_event::OriginMetadata;
41
use nativelink_util::retry::{Retrier, RetryResult};
42
use nativelink_util::{background_spawn, tls_utils};
43
use parking_lot::Mutex;
44
use tokio::select;
45
use tokio::sync::watch;
46
use tokio::time::sleep;
47
use tonic::{Request, Streaming};
48
use tracing::{error, info, warn};
49
50
struct GrpcActionStateResult {
51
    client_operation_id: OperationId,
52
    rx: watch::Receiver<Arc<ActionState>>,
53
}
54
55
#[async_trait]
56
impl ActionStateResult for GrpcActionStateResult {
57
0
    async fn as_state(&self) -> Result<(Arc<ActionState>, Option<OriginMetadata>), Error> {
58
        let mut action_state = self.rx.borrow().clone();
59
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
60
        // TODO(palfrey) We currently don't support OriginMetadata in this implementation, but
61
        // we should.
62
        Ok((action_state, None))
63
0
    }
64
65
0
    async fn changed(&mut self) -> Result<(Arc<ActionState>, Option<OriginMetadata>), Error> {
66
0
        self.rx.changed().await.map_err(|e| {
67
0
            Error::from_std_err(Code::Internal, &e)
68
0
                .append("Channel closed in GrpcActionStateResult::changed")
69
0
        })?;
70
        let mut action_state = self.rx.borrow().clone();
71
        Arc::make_mut(&mut action_state).client_operation_id = self.client_operation_id.clone();
72
        // TODO(palfrey) We currently don't support OriginMetadata in this implementation, but
73
        // we should.
74
        Ok((action_state, None))
75
0
    }
76
77
0
    async fn as_action_info(&self) -> Result<(Arc<ActionInfo>, Option<OriginMetadata>), Error> {
78
        // TODO(palfrey) We should probably remove as_action_info()
79
        // or implement it properly.
80
        return Err(make_err!(
81
            Code::Unimplemented,
82
            "as_action_info not implemented for GrpcActionStateResult::as_action_info"
83
        ));
84
0
    }
85
}
86
87
#[derive(Debug, MetricsComponent)]
88
pub struct GrpcScheduler {
89
    #[metric(group = "property_managers")]
90
    supported_props: Mutex<HashMap<String, Vec<String>>>,
91
    retrier: Retrier,
92
    connection_manager: ConnectionManager,
93
}
94
95
impl GrpcScheduler {
96
0
    pub fn new(spec: &GrpcSpec) -> Result<Self, Error> {
97
0
        Self::new_with_jitter(spec, spec.retry.make_jitter_fn())
98
0
    }
99
100
0
    pub fn new_with_jitter(
101
0
        spec: &GrpcSpec,
102
0
        jitter_fn: Arc<dyn Fn(Duration) -> Duration + Send + Sync>,
103
0
    ) -> Result<Self, Error> {
104
0
        let endpoint = tls_utils::endpoint(&spec.endpoint)?;
105
        Ok(Self {
106
0
            supported_props: Mutex::new(HashMap::new()),
107
0
            retrier: Retrier::new(
108
0
                Arc::new(|duration| Box::pin(sleep(duration))),
109
0
                jitter_fn.clone(),
110
0
                spec.retry.clone(),
111
            ),
112
0
            connection_manager: ConnectionManager::new(
113
0
                core::iter::once(endpoint),
114
0
                spec.connections_per_endpoint,
115
0
                spec.max_concurrent_requests,
116
0
                spec.retry.clone(),
117
0
                jitter_fn,
118
            ),
119
        })
120
0
    }
121
122
0
    async fn perform_request<F, Fut, R, I>(&self, input: I, mut request: F) -> Result<R, Error>
123
0
    where
124
0
        F: FnMut(I) -> Fut + Send + Copy,
125
0
        Fut: Future<Output = Result<R, Error>> + Send,
126
0
        R: Send,
127
0
        I: Send + Clone,
128
0
    {
129
0
        self.retrier
130
0
            .retry(unfold(input, move |input| async move {
131
0
                let input_clone = input.clone();
132
                Some((
133
0
                    request(input_clone)
134
0
                        .await
135
0
                        .map_or_else(RetryResult::Retry, RetryResult::Ok),
136
0
                    input,
137
                ))
138
0
            }))
139
0
            .await
140
0
    }
141
142
0
    async fn stream_state(
143
0
        mut result_stream: Streaming<Operation>,
144
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
145
0
        if let Some(initial_response) = result_stream
146
0
            .message()
147
0
            .await
148
0
            .err_tip(|| "Receiving response from upstream scheduler")?
149
        {
150
0
            let client_operation_id = OperationId::from(initial_response.name.as_str());
151
            // Our operation_id is not needed here is just a place holder to recycle existing object.
152
            // The only thing that actually matters is the operation_id.
153
0
            let operation_id = OperationId::default();
154
0
            let action_state =
155
0
                ActionState::try_from_operation(initial_response, operation_id.clone())
156
0
                    .err_tip(|| "In GrpcScheduler::stream_state")?;
157
0
            let (tx, mut rx) = watch::channel(Arc::new(action_state));
158
0
            rx.mark_changed();
159
0
            background_spawn!("grpc_scheduler_stream_state", async move {
160
                loop {
161
0
                    select!(
162
0
                        () = tx.closed() => {
163
0
                            info!(
164
                                "Client disconnected in GrpcScheduler"
165
                            );
166
0
                            return;
167
                        }
168
0
                        response = result_stream.message() => {
169
                            // When the upstream closes the channel, close the
170
                            // downstream too.
171
0
                            let Ok(Some(response)) = response else {
172
0
                                return;
173
                            };
174
0
                            let maybe_action_state = ActionState::try_from_operation(response, operation_id.clone());
175
0
                            match maybe_action_state {
176
0
                                Ok(response) => {
177
0
                                    if let Err(err) = tx.send(Arc::new(response)) {
178
0
                                        info!(
179
                                            ?err,
180
                                            "Client error in GrpcScheduler"
181
                                        );
182
0
                                        return;
183
0
                                    }
184
                                }
185
0
                                Err(err) => {
186
0
                                    error!(
187
                                        ?err,
188
                                        "Error converting response to ActionState in GrpcScheduler"
189
                                    );
190
                                },
191
                            }
192
                        }
193
                    );
194
                }
195
0
            });
196
197
0
            return Ok(Box::new(GrpcActionStateResult {
198
0
                client_operation_id,
199
0
                rx,
200
0
            }));
201
0
        }
202
0
        Err(make_err!(
203
0
            Code::Internal,
204
0
            "Upstream scheduler didn't accept action."
205
0
        ))
206
0
    }
207
208
0
    async fn inner_get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
209
0
        if let Some(supported_props) = self.supported_props.lock().get(instance_name) {
210
0
            return Ok(supported_props.clone());
211
0
        }
212
213
0
        self.perform_request(instance_name, |instance_name| async move {
214
            // Not in the cache, lookup the capabilities with the upstream.
215
0
            let channel = self
216
0
                .connection_manager
217
0
                .connection("get_known_properties".into())
218
0
                .await
219
0
                .err_tip(|| "in get_platform_property_manager()")?;
220
0
            let capabilities_result = CapabilitiesClient::new(channel)
221
0
                .get_capabilities(GetCapabilitiesRequest {
222
0
                    instance_name: instance_name.to_string(),
223
0
                })
224
0
                .await
225
0
                .err_tip(|| "Retrieving upstream GrpcScheduler capabilities");
226
0
            let capabilities = capabilities_result?.into_inner();
227
0
            let supported_props = capabilities
228
0
                .execution_capabilities
229
0
                .err_tip(|| "Unable to get execution properties in GrpcScheduler")?
230
                .supported_node_properties
231
0
                .into_iter()
232
0
                .collect::<Vec<String>>();
233
234
0
            self.supported_props
235
0
                .lock()
236
0
                .insert(instance_name.to_string(), supported_props.clone());
237
0
            Ok(supported_props)
238
0
        })
239
0
        .await
240
0
    }
241
242
0
    async fn inner_add_action(
243
0
        &self,
244
0
        _client_operation_id: OperationId,
245
0
        action_info: Arc<ActionInfo>,
246
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
247
0
        let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY {
248
0
            None
249
        } else {
250
0
            Some(ExecutionPolicy {
251
0
                priority: action_info.priority,
252
0
            })
253
        };
254
0
        let skip_cache_lookup = match action_info.unique_qualifier {
255
0
            ActionUniqueQualifier::Cacheable(_) => false,
256
0
            ActionUniqueQualifier::Uncacheable(_) => true,
257
        };
258
0
        let request = ExecuteRequest {
259
0
            instance_name: action_info.instance_name().clone(),
260
0
            skip_cache_lookup,
261
0
            action_digest: Some(action_info.digest().into()),
262
0
            execution_policy,
263
0
            // TODO: Get me from the original request, not very important as we ignore it.
264
0
            results_cache_policy: None,
265
0
            digest_function: action_info
266
0
                .unique_qualifier
267
0
                .digest_function()
268
0
                .proto_digest_func()
269
0
                .into(),
270
0
        };
271
0
        let result_stream = self
272
0
            .perform_request(request, |request| async move {
273
0
                let channel = self
274
0
                    .connection_manager
275
0
                    .connection(format!("add_action: {:?}", request.action_digest))
276
0
                    .await
277
0
                    .err_tip(|| "in add_action()")?;
278
0
                ExecutionClient::new(channel)
279
0
                    .execute(Request::new(request))
280
0
                    .await
281
0
                    .err_tip(|| "Sending action to upstream scheduler")
282
0
            })
283
0
            .await?
284
0
            .into_inner();
285
0
        Self::stream_state(result_stream).await
286
0
    }
287
288
0
    async fn inner_filter_operations(
289
0
        &self,
290
0
        filter: OperationFilter,
291
0
    ) -> Result<ActionStateResultStream<'_>, Error> {
292
0
        error_if!(
293
0
            filter
294
0
                != OperationFilter {
295
0
                    client_operation_id: filter.client_operation_id.clone(),
296
0
                    ..Default::default()
297
0
                },
298
            "Unsupported filter in GrpcScheduler::filter_operations. Only client_operation_id is supported - {filter:?}"
299
        );
300
0
        let client_operation_id = filter.client_operation_id.ok_or_else(|| {
301
0
            make_err!(Code::InvalidArgument, "`client_operation_id` is the only supported filter in GrpcScheduler::filter_operations")
302
0
        })?;
303
0
        let request = WaitExecutionRequest {
304
0
            name: client_operation_id.to_string(),
305
0
        };
306
0
        let result_stream = self
307
0
            .perform_request(request, |request| async move {
308
0
                let channel = self
309
0
                    .connection_manager
310
0
                    .connection(format!("filter_operations: {}", request.name))
311
0
                    .await
312
0
                    .err_tip(|| "in find_by_client_operation_id()")?;
313
0
                ExecutionClient::new(channel)
314
0
                    .wait_execution(Request::new(request))
315
0
                    .await
316
0
                    .err_tip(|| "While getting wait_execution stream")
317
0
            })
318
0
            .and_then(|result_stream| Self::stream_state(result_stream.into_inner()))
319
0
            .await;
320
0
        match result_stream {
321
0
            Ok(result_stream) => Ok(unfold(
322
0
                Some(result_stream),
323
0
                |maybe_result_stream| async move { maybe_result_stream.map(|v| (v, None)) },
324
            )
325
0
            .boxed()),
326
0
            Err(err) => {
327
0
                warn!(?err, "Error looking up action with upstream scheduler");
328
0
                Ok(futures::stream::empty().boxed())
329
            }
330
        }
331
0
    }
332
}
333
334
#[async_trait]
335
impl ClientStateManager for GrpcScheduler {
336
    async fn add_action(
337
        &self,
338
        client_operation_id: OperationId,
339
        action_info: Arc<ActionInfo>,
340
0
    ) -> Result<Box<dyn ActionStateResult>, Error> {
341
        self.inner_add_action(client_operation_id, action_info)
342
            .await
343
0
    }
344
345
    async fn filter_operations<'a>(
346
        &'a self,
347
        filter: OperationFilter,
348
0
    ) -> Result<ActionStateResultStream<'a>, Error> {
349
        self.inner_filter_operations(filter).await
350
0
    }
351
352
0
    fn as_known_platform_property_provider(&self) -> Option<&dyn KnownPlatformPropertyProvider> {
353
0
        Some(self)
354
0
    }
355
}
356
357
#[async_trait]
358
impl KnownPlatformPropertyProvider for GrpcScheduler {
359
0
    async fn get_known_properties(&self, instance_name: &str) -> Result<Vec<String>, Error> {
360
        self.inner_get_known_properties(instance_name).await
361
0
    }
362
}
363
364
impl RootMetricsComponent for GrpcScheduler {}