Coverage Report

Created: 2025-10-24 14:08

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-worker/src/worker_api_client_wrapper.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::future::Future;
16
17
use futures::stream::unfold;
18
use nativelink_error::{make_err, Error, ResultExt};
19
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_scheduler::Update;
20
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_client::WorkerApiClient;
21
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
22
    ConnectWorkerRequest, ExecuteComplete, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForScheduler, UpdateForWorker
23
};
24
use tokio::sync::mpsc::Sender;
25
use tonic::codec::Streaming;
26
use tonic::transport::Channel;
27
use tonic::{Code, Response, Status};
28
29
/// This is used in order to allow unit tests to intercept these calls. This should always match
30
/// the API of `WorkerApiClient` defined in the `worker_api.proto` file.
31
pub trait WorkerApiClientTrait: Clone + Sync + Send + Sized + Unpin {
32
    fn connect_worker(
33
        &mut self,
34
        request: ConnectWorkerRequest,
35
    ) -> impl Future<Output = Result<Response<Streaming<UpdateForWorker>>, Status>> + Send;
36
37
    fn keep_alive(
38
        &mut self,
39
        request: KeepAliveRequest,
40
    ) -> impl Future<Output = Result<(), Error>> + Send;
41
42
    fn going_away(
43
        &mut self,
44
        request: GoingAwayRequest,
45
    ) -> impl Future<Output = Result<(), Error>> + Send;
46
47
    fn execution_response(
48
        &mut self,
49
        request: ExecuteResult,
50
    ) -> impl Future<Output = Result<(), Error>> + Send;
51
52
    fn execution_complete(
53
        &mut self,
54
        request: ExecuteComplete,
55
    ) -> impl Future<Output = Result<(), Error>> + Send;
56
}
57
58
#[derive(Debug, Clone)]
59
pub struct WorkerApiClientWrapper {
60
    inner: WorkerApiClient<Channel>,
61
    channel: Option<Sender<Update>>,
62
}
63
64
impl WorkerApiClientWrapper {
65
0
    async fn send_update(&mut self, update: Update) -> Result<(), Error> {
66
0
        let tx = self
67
0
            .channel
68
0
            .as_ref()
69
0
            .err_tip(|| "worker update without connect_worker")?;
70
0
        match tx.send(update).await {
71
0
            Ok(()) => Ok(()),
72
0
            Err(_err) => {
73
                // Remove the sender if it's not going anywhere.
74
0
                self.channel.take();
75
0
                Err(make_err!(
76
0
                    Code::Unavailable,
77
0
                    "worker update with disconnected channel"
78
0
                ))
79
            }
80
        }
81
0
    }
82
}
83
84
impl From<WorkerApiClient<Channel>> for WorkerApiClientWrapper {
85
0
    fn from(other: WorkerApiClient<Channel>) -> Self {
86
0
        Self {
87
0
            inner: other,
88
0
            channel: None,
89
0
        }
90
0
    }
91
}
92
93
impl WorkerApiClientTrait for WorkerApiClientWrapper {
94
0
    async fn connect_worker(
95
0
        &mut self,
96
0
        request: ConnectWorkerRequest,
97
0
    ) -> Result<Response<Streaming<UpdateForWorker>>, Status> {
98
0
        drop(self.channel.take());
99
0
        let (tx, rx) = tokio::sync::mpsc::channel(1);
100
0
        if tx
  Branch (100:12): [Folded - Ignored]
  Branch (100:12): [Folded - Ignored]
101
0
            .send(Update::ConnectWorkerRequest(request))
102
0
            .await
103
0
            .is_err()
104
        {
105
0
            return Err(Status::data_loss("Unable to push to newly created channel"));
106
0
        }
107
0
        self.channel = Some(tx);
108
0
        self.inner
109
0
            .connect_worker(unfold(rx, |mut rx| async move {
110
0
                let update = rx.recv().await?;
111
0
                Some((
112
0
                    UpdateForScheduler {
113
0
                        update: Some(update),
114
0
                    },
115
0
                    rx,
116
0
                ))
117
0
            }))
118
0
            .await
119
0
    }
120
121
0
    async fn keep_alive(&mut self, request: KeepAliveRequest) -> Result<(), Error> {
122
0
        self.send_update(Update::KeepAliveRequest(request)).await
123
0
    }
124
125
0
    async fn going_away(&mut self, request: GoingAwayRequest) -> Result<(), Error> {
126
0
        self.send_update(Update::GoingAwayRequest(request)).await
127
0
    }
128
129
0
    async fn execution_response(&mut self, request: ExecuteResult) -> Result<(), Error> {
130
0
        self.send_update(Update::ExecuteResult(request)).await
131
0
    }
132
133
0
    async fn execution_complete(&mut self, request: ExecuteComplete) -> Result<(), Error> {
134
0
        self.send_update(Update::ExecuteComplete(request)).await
135
0
    }
136
}