/build/source/nativelink-worker/src/worker_api_client_wrapper.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::future::Future; |
16 | | |
17 | | use futures::stream::unfold; |
18 | | use nativelink_error::{make_err, Error, ResultExt}; |
19 | | use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_scheduler::Update; |
20 | | use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_client::WorkerApiClient; |
21 | | use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ |
22 | | ConnectWorkerRequest, ExecuteComplete, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForScheduler, UpdateForWorker |
23 | | }; |
24 | | use tokio::sync::mpsc::Sender; |
25 | | use tonic::codec::Streaming; |
26 | | use tonic::transport::Channel; |
27 | | use tonic::{Code, Response, Status}; |
28 | | |
29 | | /// This is used in order to allow unit tests to intercept these calls. This should always match |
30 | | /// the API of `WorkerApiClient` defined in the `worker_api.proto` file. |
31 | | pub trait WorkerApiClientTrait: Clone + Sync + Send + Sized + Unpin { |
32 | | fn connect_worker( |
33 | | &mut self, |
34 | | request: ConnectWorkerRequest, |
35 | | ) -> impl Future<Output = Result<Response<Streaming<UpdateForWorker>>, Status>> + Send; |
36 | | |
37 | | fn keep_alive( |
38 | | &mut self, |
39 | | request: KeepAliveRequest, |
40 | | ) -> impl Future<Output = Result<(), Error>> + Send; |
41 | | |
42 | | fn going_away( |
43 | | &mut self, |
44 | | request: GoingAwayRequest, |
45 | | ) -> impl Future<Output = Result<(), Error>> + Send; |
46 | | |
47 | | fn execution_response( |
48 | | &mut self, |
49 | | request: ExecuteResult, |
50 | | ) -> impl Future<Output = Result<(), Error>> + Send; |
51 | | |
52 | | fn execution_complete( |
53 | | &mut self, |
54 | | request: ExecuteComplete, |
55 | | ) -> impl Future<Output = Result<(), Error>> + Send; |
56 | | } |
57 | | |
58 | | #[derive(Debug, Clone)] |
59 | | pub struct WorkerApiClientWrapper { |
60 | | inner: WorkerApiClient<Channel>, |
61 | | channel: Option<Sender<Update>>, |
62 | | } |
63 | | |
64 | | impl WorkerApiClientWrapper { |
65 | 0 | async fn send_update(&mut self, update: Update) -> Result<(), Error> { |
66 | 0 | let tx = self |
67 | 0 | .channel |
68 | 0 | .as_ref() |
69 | 0 | .err_tip(|| "worker update without connect_worker")?; |
70 | 0 | match tx.send(update).await { |
71 | 0 | Ok(()) => Ok(()), |
72 | 0 | Err(_err) => { |
73 | | // Remove the sender if it's not going anywhere. |
74 | 0 | self.channel.take(); |
75 | 0 | Err(make_err!( |
76 | 0 | Code::Unavailable, |
77 | 0 | "worker update with disconnected channel" |
78 | 0 | )) |
79 | | } |
80 | | } |
81 | 0 | } |
82 | | } |
83 | | |
84 | | impl From<WorkerApiClient<Channel>> for WorkerApiClientWrapper { |
85 | 0 | fn from(other: WorkerApiClient<Channel>) -> Self { |
86 | 0 | Self { |
87 | 0 | inner: other, |
88 | 0 | channel: None, |
89 | 0 | } |
90 | 0 | } |
91 | | } |
92 | | |
93 | | impl WorkerApiClientTrait for WorkerApiClientWrapper { |
94 | 0 | async fn connect_worker( |
95 | 0 | &mut self, |
96 | 0 | request: ConnectWorkerRequest, |
97 | 0 | ) -> Result<Response<Streaming<UpdateForWorker>>, Status> { |
98 | 0 | drop(self.channel.take()); |
99 | 0 | let (tx, rx) = tokio::sync::mpsc::channel(1); |
100 | 0 | if tx Branch (100:12): [Folded - Ignored]
Branch (100:12): [Folded - Ignored]
|
101 | 0 | .send(Update::ConnectWorkerRequest(request)) |
102 | 0 | .await |
103 | 0 | .is_err() |
104 | | { |
105 | 0 | return Err(Status::data_loss("Unable to push to newly created channel")); |
106 | 0 | } |
107 | 0 | self.channel = Some(tx); |
108 | 0 | self.inner |
109 | 0 | .connect_worker(unfold(rx, |mut rx| async move { |
110 | 0 | let update = rx.recv().await?; |
111 | 0 | Some(( |
112 | 0 | UpdateForScheduler { |
113 | 0 | update: Some(update), |
114 | 0 | }, |
115 | 0 | rx, |
116 | 0 | )) |
117 | 0 | })) |
118 | 0 | .await |
119 | 0 | } |
120 | | |
121 | 0 | async fn keep_alive(&mut self, request: KeepAliveRequest) -> Result<(), Error> { |
122 | 0 | self.send_update(Update::KeepAliveRequest(request)).await |
123 | 0 | } |
124 | | |
125 | 0 | async fn going_away(&mut self, request: GoingAwayRequest) -> Result<(), Error> { |
126 | 0 | self.send_update(Update::GoingAwayRequest(request)).await |
127 | 0 | } |
128 | | |
129 | 0 | async fn execution_response(&mut self, request: ExecuteResult) -> Result<(), Error> { |
130 | 0 | self.send_update(Update::ExecuteResult(request)).await |
131 | 0 | } |
132 | | |
133 | 0 | async fn execution_complete(&mut self, request: ExecuteComplete) -> Result<(), Error> { |
134 | 0 | self.send_update(Update::ExecuteComplete(request)).await |
135 | 0 | } |
136 | | } |