Coverage Report

Created: 2026-05-23 21:09

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-worker/src/persistent_worker/protocol.rs
Line
Count
Source
1
// Copyright 2024 Trace Machina, Inc. All rights reserved.
2
//
3
// Licensed under the Business Source License, Version 1.1 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may requested a copy of the License by emailing contact@nativelink.com.
6
//
7
// Use of this module requires an enterprise license agreement, which can be
8
// attained by emailing contact@nativelink.com or signing up for Nativelink
9
// Cloud at app.nativelink.com.
10
//
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
17
//! Wire-format types for Bazel's persistent worker protocol.
18
//!
19
//! Reference: <https://bazel.build/remote/persistent#work-protocol>
20
//! and <https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto>.
21
//!
22
//! Both JSON (newline-delimited, camelCase field names) and proto (length-delimited
23
//! varint prefix) wire formats are supported. The choice is per-request, driven by
24
//! the action's `requires-worker-protocol` execution requirement.
25
26
use bytes::{Bytes, BytesMut};
27
use nativelink_error::{Code, Error, make_err};
28
use prost::Message as ProstMessage;
29
use serde::{Deserialize, Serialize};
30
31
/// One input file declared in a `WorkRequest`. The tool may use the digest to
32
/// verify cached state matches.
33
#[derive(Clone, PartialEq, Eq, ProstMessage, Serialize, Deserialize)]
34
pub struct Input {
35
    #[prost(string, tag = "1")]
36
    #[serde(default, skip_serializing_if = "String::is_empty")]
37
    pub path: String,
38
39
    #[prost(bytes = "vec", tag = "2")]
40
    #[serde(
41
        default,
42
        skip_serializing_if = "Vec::is_empty",
43
        with = "serde_bytes_b64"
44
    )]
45
    pub digest: Vec<u8>,
46
}
47
48
/// A single unit of work dispatched to a persistent worker subprocess.
49
#[derive(Clone, PartialEq, Eq, ProstMessage, Serialize, Deserialize)]
50
#[serde(rename_all = "camelCase")]
51
pub struct WorkRequest {
52
    #[prost(string, repeated, tag = "1")]
53
    #[serde(default, skip_serializing_if = "Vec::is_empty")]
54
    pub arguments: Vec<String>,
55
56
    #[prost(message, repeated, tag = "2")]
57
    #[serde(default, skip_serializing_if = "Vec::is_empty")]
58
    pub inputs: Vec<Input>,
59
60
    /// Multiplex worker request id. v1 always sends 0 and rejects responses
61
    /// with non-zero ids.
62
    #[prost(int32, tag = "3")]
63
    #[serde(default)]
64
    pub request_id: i32,
65
66
    #[prost(bool, tag = "4")]
67
    #[serde(default)]
68
    pub cancel: bool,
69
70
    #[prost(int32, tag = "5")]
71
    #[serde(default)]
72
    pub verbosity: i32,
73
74
    #[prost(string, tag = "6")]
75
    #[serde(default, skip_serializing_if = "String::is_empty")]
76
    pub sandbox_dir: String,
77
}
78
79
/// A single response emitted by a persistent worker subprocess.
80
#[derive(Clone, PartialEq, Eq, ProstMessage, Serialize, Deserialize)]
81
#[serde(rename_all = "camelCase")]
82
pub struct WorkResponse {
83
    #[prost(int32, tag = "1")]
84
    #[serde(default)]
85
    pub exit_code: i32,
86
87
    #[prost(string, tag = "2")]
88
    #[serde(default, skip_serializing_if = "String::is_empty")]
89
    pub output: String,
90
91
    /// Echoed from `WorkRequest.request_id`. v1 expects 0.
92
    #[prost(int32, tag = "3")]
93
    #[serde(default)]
94
    pub request_id: i32,
95
96
    #[prost(bool, tag = "4")]
97
    #[serde(default)]
98
    pub was_cancelled: bool,
99
}
100
101
/// Wire format negotiated per request from the action's
102
/// `requires-worker-protocol` execution requirement.
103
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
104
pub enum WireFormat {
105
    /// Length-delimited proto (Bazel default).
106
    Proto,
107
    /// Newline-delimited JSON, camelCase fields.
108
    Json,
109
}
110
111
impl WireFormat {
112
    /// Parse the execution-requirement value. Bazel accepts only "proto" and
113
    /// "json"; anything else is an error.
114
5
    pub fn parse(s: &str) -> Result<Self, Error> {
115
5
        match s {
116
5
            "proto" => 
Ok(Self::Proto)1
,
117
4
            "json" => 
Ok(Self::Json)3
,
118
1
            other => Err(make_err!(
119
1
                Code::InvalidArgument,
120
1
                "Unsupported requires-worker-protocol value: {other:?}; expected 'proto' or 'json'"
121
1
            )),
122
        }
123
5
    }
124
}
125
126
impl WorkRequest {
127
    /// Serialize this request in the given wire format. The returned bytes
128
    /// include the framing prefix/suffix appropriate for the format
129
    /// (length-delimited varint for proto, trailing newline for JSON).
130
6
    pub fn encode_framed(&self, format: WireFormat) -> Result<Bytes, Error> {
131
6
        match format {
132
            WireFormat::Proto => {
133
2
                let mut buf =
134
2
                    BytesMut::with_capacity(<Self as ProstMessage>::encoded_len(self) + 10);
135
2
                <Self as ProstMessage>::encode_length_delimited(self, &mut buf)
136
2
                    .map_err(|e| 
make_err!0
(
Code::Internal0
, "WorkRequest proto encode: {e}"))
?0
;
137
2
                Ok(buf.freeze())
138
            }
139
            WireFormat::Json => {
140
4
                let mut s = serde_json::to_string(self)
141
4
                    .map_err(|e| 
make_err!0
(
Code::Internal0
, "WorkRequest JSON encode: {e}"))
?0
;
142
4
                s.push('\n');
143
4
                Ok(Bytes::from(s))
144
            }
145
        }
146
6
    }
147
}
148
149
impl WorkResponse {
150
    /// Decode a single response from a byte buffer. For proto: expects a
151
    /// length-delimited varint frame at the buffer start. For JSON: expects a
152
    /// single complete JSON object, optionally with a trailing newline.
153
4
    pub fn decode_framed(buf: &[u8], format: WireFormat) -> Result<Self, Error> {
154
4
        match format {
155
0
            WireFormat::Proto => <Self as ProstMessage>::decode_length_delimited(buf)
156
0
                .map_err(|e| make_err!(Code::Internal, "WorkResponse proto decode: {e}")),
157
4
            WireFormat::Json => serde_json::from_slice(buf)
158
4
                .map_err(|e| 
make_err!0
(
Code::Internal0
, "WorkResponse JSON decode: {e}")),
159
        }
160
4
    }
161
}
162
163
/// Serde adapter for `Input.digest`: Bazel's JSON serialization base64-encodes
164
/// the digest bytes (as is standard for proto-JSON byte fields).
165
mod serde_bytes_b64 {
166
    use serde::{Deserialize, Deserializer, Serializer};
167
168
    // Minimal RFC 4648 base64 encoder/decoder so we don't pull in a `base64`
169
    // crate dep just for this adapter. v1 expects digests under ~64 bytes.
170
    const ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
171
172
1
    fn encode(input: &[u8]) -> String {
173
1
        let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
174
2
        for chunk in 
input1
.
chunks1
(3) {
175
2
            let b0 = chunk[0];
176
2
            let b1 = chunk.get(1).copied().unwrap_or(0);
177
2
            let b2 = chunk.get(2).copied().unwrap_or(0);
178
2
            out.push(ALPHA[(b0 >> 2) as usize] as char);
179
2
            out.push(ALPHA[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
180
2
            if chunk.len() > 1 {
181
1
                out.push(ALPHA[(((b1 & 0x0f) << 2) | (b2 >> 6)) as usize] as char);
182
1
            } else {
183
1
                out.push('=');
184
1
            }
185
2
            if chunk.len() > 2 {
186
1
                out.push(ALPHA[(b2 & 0x3f) as usize] as char);
187
1
            } else {
188
1
                out.push('=');
189
1
            }
190
        }
191
1
        out
192
1
    }
193
194
1
    fn decode(s: &str) -> Result<Vec<u8>, &'static str> {
195
1
        let mut lookup = [255u8; 256];
196
64
        for (i, b) in 
ALPHA1
.
iter1
().
enumerate1
() {
197
64
            lookup[*b as usize] = u8::try_from(i).expect("base64 alphabet index fits in u8");
198
64
        }
199
1
        let bytes = s.as_bytes();
200
1
        if !bytes.len().is_multiple_of(4) {
201
0
            return Err("base64 input length not multiple of 4");
202
1
        }
203
1
        let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
204
2
        for chunk in 
bytes1
.
chunks1
(4) {
205
2
            let v0 = lookup[chunk[0] as usize];
206
2
            let v1 = lookup[chunk[1] as usize];
207
2
            let v2 = lookup[chunk[2] as usize];
208
2
            let v3 = lookup[chunk[3] as usize];
209
2
            if v0 == 255 || v1 == 255 {
210
0
                return Err("invalid base64 char");
211
2
            }
212
2
            out.push((v0 << 2) | (v1 >> 4));
213
2
            if chunk[2] != b'=' {
214
1
                if v2 == 255 {
215
0
                    return Err("invalid base64 char");
216
1
                }
217
1
                out.push((v1 << 4) | (v2 >> 2));
218
1
            }
219
2
            if chunk[3] != b'=' {
220
1
                if v3 == 255 {
221
0
                    return Err("invalid base64 char");
222
1
                }
223
1
                out.push((v2 << 6) | v3);
224
1
            }
225
        }
226
1
        Ok(out)
227
1
    }
228
229
1
    pub(super) fn serialize<S: Serializer>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error> {
230
1
        s.serialize_str(&encode(bytes))
231
1
    }
232
233
1
    pub(super) fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
234
1
        let s = String::deserialize(d)
?0
;
235
1
        decode(&s).map_err(serde::de::Error::custom)
236
1
    }
237
}
238
239
#[cfg(test)]
240
mod tests {
241
    use super::*;
242
243
    #[test]
244
1
    fn parse_wire_format() {
245
1
        assert_eq!(WireFormat::parse("proto").unwrap(), WireFormat::Proto);
246
1
        assert_eq!(WireFormat::parse("json").unwrap(), WireFormat::Json);
247
1
        assert!(WireFormat::parse("xml").is_err());
248
1
    }
249
250
    #[test]
251
1
    fn proto_round_trip_minimal() {
252
1
        let req = WorkRequest {
253
1
            arguments: vec!["-source".into(), "21".into(), "Foo.java".into()],
254
1
            ..WorkRequest::default()
255
1
        };
256
1
        let bytes = req.encode_framed(WireFormat::Proto).unwrap();
257
        // Encoding is length-delimited; round-trip via prost's decoder.
258
1
        let mut cursor: &[u8] = &bytes;
259
1
        let decoded = <WorkRequest as ProstMessage>::decode_length_delimited(&mut cursor).unwrap();
260
1
        assert_eq!(decoded, req);
261
1
    }
262
263
    #[test]
264
1
    fn json_round_trip_uses_camel_case() {
265
1
        let resp = WorkResponse {
266
1
            exit_code: 0,
267
1
            output: "compiled Foo.java".into(),
268
1
            request_id: 0,
269
1
            was_cancelled: false,
270
1
        };
271
1
        let bytes = serde_json::to_vec(&resp).unwrap();
272
1
        let s = core::str::from_utf8(&bytes).unwrap();
273
        // Bazel's JSON convention is camelCase; assert the wire form.
274
1
        assert!(s.contains(r#""exitCode":0"#), "got: {s}");
275
1
        assert!(s.contains(r#""requestId":0"#), "got: {s}");
276
1
        assert!(s.contains(r#""wasCancelled":false"#), "got: {s}");
277
278
1
        let parsed: WorkResponse = serde_json::from_slice(&bytes).unwrap();
279
1
        assert_eq!(parsed, resp);
280
1
    }
281
282
    #[test]
283
1
    fn json_decodes_minimum_response() {
284
        // A real worker may omit fields whose values are proto defaults.
285
1
        let bytes = br#"{"exitCode":1,"output":"oh no"}"#;
286
1
        let resp = WorkResponse::decode_framed(bytes, WireFormat::Json).unwrap();
287
1
        assert_eq!(resp.exit_code, 1);
288
1
        assert_eq!(resp.output, "oh no");
289
1
        assert_eq!(resp.request_id, 0);
290
1
        assert!(!resp.was_cancelled);
291
1
    }
292
293
    #[test]
294
1
    fn proto_request_includes_inputs_with_digest() {
295
1
        let req = WorkRequest {
296
1
            arguments: vec!["@argfile".into()],
297
1
            inputs: vec![Input {
298
1
                path: "Foo.java".into(),
299
1
                digest: vec![0xde, 0xad, 0xbe, 0xef],
300
1
            }],
301
1
            ..WorkRequest::default()
302
1
        };
303
1
        let bytes = req.encode_framed(WireFormat::Proto).unwrap();
304
1
        let mut cursor: &[u8] = &bytes;
305
1
        let decoded = <WorkRequest as ProstMessage>::decode_length_delimited(&mut cursor).unwrap();
306
1
        assert_eq!(decoded, req);
307
1
    }
308
309
    #[test]
310
1
    fn json_input_digest_uses_base64() {
311
1
        let req = WorkRequest {
312
1
            inputs: vec![Input {
313
1
                path: "Foo.java".into(),
314
1
                digest: vec![0xde, 0xad, 0xbe, 0xef],
315
1
            }],
316
1
            ..WorkRequest::default()
317
1
        };
318
1
        let s = serde_json::to_string(&req).unwrap();
319
        // 0xdeadbeef base64-encodes to "3q2+7w==".
320
1
        assert!(s.contains(r#""digest":"3q2+7w==""#), "got: {s}");
321
322
1
        let parsed: WorkRequest = serde_json::from_str(&s).unwrap();
323
1
        assert_eq!(parsed, req);
324
1
    }
325
}