/build/source/nativelink-worker/src/persistent_worker/protocol.rs
Line | Count | Source |
1 | | // Copyright 2024 Trace Machina, Inc. All rights reserved. |
2 | | // |
3 | | // Licensed under the Business Source License, Version 1.1 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may requested a copy of the License by emailing contact@nativelink.com. |
6 | | // |
7 | | // Use of this module requires an enterprise license agreement, which can be |
8 | | // attained by emailing contact@nativelink.com or signing up for Nativelink |
9 | | // Cloud at app.nativelink.com. |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | |
17 | | //! Wire-format types for Bazel's persistent worker protocol. |
18 | | //! |
19 | | //! Reference: <https://bazel.build/remote/persistent#work-protocol> |
20 | | //! and <https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto>. |
21 | | //! |
22 | | //! Both JSON (newline-delimited, camelCase field names) and proto (length-delimited |
23 | | //! varint prefix) wire formats are supported. The choice is per-request, driven by |
24 | | //! the action's `requires-worker-protocol` execution requirement. |
25 | | |
26 | | use bytes::{Bytes, BytesMut}; |
27 | | use nativelink_error::{Code, Error, make_err}; |
28 | | use prost::Message as ProstMessage; |
29 | | use serde::{Deserialize, Serialize}; |
30 | | |
31 | | /// One input file declared in a `WorkRequest`. The tool may use the digest to |
32 | | /// verify cached state matches. |
33 | | #[derive(Clone, PartialEq, Eq, ProstMessage, Serialize, Deserialize)] |
34 | | pub struct Input { |
35 | | #[prost(string, tag = "1")] |
36 | | #[serde(default, skip_serializing_if = "String::is_empty")] |
37 | | pub path: String, |
38 | | |
39 | | #[prost(bytes = "vec", tag = "2")] |
40 | | #[serde( |
41 | | default, |
42 | | skip_serializing_if = "Vec::is_empty", |
43 | | with = "serde_bytes_b64" |
44 | | )] |
45 | | pub digest: Vec<u8>, |
46 | | } |
47 | | |
48 | | /// A single unit of work dispatched to a persistent worker subprocess. |
49 | | #[derive(Clone, PartialEq, Eq, ProstMessage, Serialize, Deserialize)] |
50 | | #[serde(rename_all = "camelCase")] |
51 | | pub struct WorkRequest { |
52 | | #[prost(string, repeated, tag = "1")] |
53 | | #[serde(default, skip_serializing_if = "Vec::is_empty")] |
54 | | pub arguments: Vec<String>, |
55 | | |
56 | | #[prost(message, repeated, tag = "2")] |
57 | | #[serde(default, skip_serializing_if = "Vec::is_empty")] |
58 | | pub inputs: Vec<Input>, |
59 | | |
60 | | /// Multiplex worker request id. v1 always sends 0 and rejects responses |
61 | | /// with non-zero ids. |
62 | | #[prost(int32, tag = "3")] |
63 | | #[serde(default)] |
64 | | pub request_id: i32, |
65 | | |
66 | | #[prost(bool, tag = "4")] |
67 | | #[serde(default)] |
68 | | pub cancel: bool, |
69 | | |
70 | | #[prost(int32, tag = "5")] |
71 | | #[serde(default)] |
72 | | pub verbosity: i32, |
73 | | |
74 | | #[prost(string, tag = "6")] |
75 | | #[serde(default, skip_serializing_if = "String::is_empty")] |
76 | | pub sandbox_dir: String, |
77 | | } |
78 | | |
79 | | /// A single response emitted by a persistent worker subprocess. |
80 | | #[derive(Clone, PartialEq, Eq, ProstMessage, Serialize, Deserialize)] |
81 | | #[serde(rename_all = "camelCase")] |
82 | | pub struct WorkResponse { |
83 | | #[prost(int32, tag = "1")] |
84 | | #[serde(default)] |
85 | | pub exit_code: i32, |
86 | | |
87 | | #[prost(string, tag = "2")] |
88 | | #[serde(default, skip_serializing_if = "String::is_empty")] |
89 | | pub output: String, |
90 | | |
91 | | /// Echoed from `WorkRequest.request_id`. v1 expects 0. |
92 | | #[prost(int32, tag = "3")] |
93 | | #[serde(default)] |
94 | | pub request_id: i32, |
95 | | |
96 | | #[prost(bool, tag = "4")] |
97 | | #[serde(default)] |
98 | | pub was_cancelled: bool, |
99 | | } |
100 | | |
101 | | /// Wire format negotiated per request from the action's |
102 | | /// `requires-worker-protocol` execution requirement. |
103 | | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] |
104 | | pub enum WireFormat { |
105 | | /// Length-delimited proto (Bazel default). |
106 | | Proto, |
107 | | /// Newline-delimited JSON, camelCase fields. |
108 | | Json, |
109 | | } |
110 | | |
111 | | impl WireFormat { |
112 | | /// Parse the execution-requirement value. Bazel accepts only "proto" and |
113 | | /// "json"; anything else is an error. |
114 | 5 | pub fn parse(s: &str) -> Result<Self, Error> { |
115 | 5 | match s { |
116 | 5 | "proto" => Ok(Self::Proto)1 , |
117 | 4 | "json" => Ok(Self::Json)3 , |
118 | 1 | other => Err(make_err!( |
119 | 1 | Code::InvalidArgument, |
120 | 1 | "Unsupported requires-worker-protocol value: {other:?}; expected 'proto' or 'json'" |
121 | 1 | )), |
122 | | } |
123 | 5 | } |
124 | | } |
125 | | |
126 | | impl WorkRequest { |
127 | | /// Serialize this request in the given wire format. The returned bytes |
128 | | /// include the framing prefix/suffix appropriate for the format |
129 | | /// (length-delimited varint for proto, trailing newline for JSON). |
130 | 6 | pub fn encode_framed(&self, format: WireFormat) -> Result<Bytes, Error> { |
131 | 6 | match format { |
132 | | WireFormat::Proto => { |
133 | 2 | let mut buf = |
134 | 2 | BytesMut::with_capacity(<Self as ProstMessage>::encoded_len(self) + 10); |
135 | 2 | <Self as ProstMessage>::encode_length_delimited(self, &mut buf) |
136 | 2 | .map_err(|e| make_err!0 (Code::Internal0 , "WorkRequest proto encode: {e}"))?0 ; |
137 | 2 | Ok(buf.freeze()) |
138 | | } |
139 | | WireFormat::Json => { |
140 | 4 | let mut s = serde_json::to_string(self) |
141 | 4 | .map_err(|e| make_err!0 (Code::Internal0 , "WorkRequest JSON encode: {e}"))?0 ; |
142 | 4 | s.push('\n'); |
143 | 4 | Ok(Bytes::from(s)) |
144 | | } |
145 | | } |
146 | 6 | } |
147 | | } |
148 | | |
149 | | impl WorkResponse { |
150 | | /// Decode a single response from a byte buffer. For proto: expects a |
151 | | /// length-delimited varint frame at the buffer start. For JSON: expects a |
152 | | /// single complete JSON object, optionally with a trailing newline. |
153 | 4 | pub fn decode_framed(buf: &[u8], format: WireFormat) -> Result<Self, Error> { |
154 | 4 | match format { |
155 | 0 | WireFormat::Proto => <Self as ProstMessage>::decode_length_delimited(buf) |
156 | 0 | .map_err(|e| make_err!(Code::Internal, "WorkResponse proto decode: {e}")), |
157 | 4 | WireFormat::Json => serde_json::from_slice(buf) |
158 | 4 | .map_err(|e| make_err!0 (Code::Internal0 , "WorkResponse JSON decode: {e}")), |
159 | | } |
160 | 4 | } |
161 | | } |
162 | | |
163 | | /// Serde adapter for `Input.digest`: Bazel's JSON serialization base64-encodes |
164 | | /// the digest bytes (as is standard for proto-JSON byte fields). |
165 | | mod serde_bytes_b64 { |
166 | | use serde::{Deserialize, Deserializer, Serializer}; |
167 | | |
168 | | // Minimal RFC 4648 base64 encoder/decoder so we don't pull in a `base64` |
169 | | // crate dep just for this adapter. v1 expects digests under ~64 bytes. |
170 | | const ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; |
171 | | |
172 | 1 | fn encode(input: &[u8]) -> String { |
173 | 1 | let mut out = String::with_capacity(input.len().div_ceil(3) * 4); |
174 | 2 | for chunk in input1 .chunks1 (3) { |
175 | 2 | let b0 = chunk[0]; |
176 | 2 | let b1 = chunk.get(1).copied().unwrap_or(0); |
177 | 2 | let b2 = chunk.get(2).copied().unwrap_or(0); |
178 | 2 | out.push(ALPHA[(b0 >> 2) as usize] as char); |
179 | 2 | out.push(ALPHA[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char); |
180 | 2 | if chunk.len() > 1 { |
181 | 1 | out.push(ALPHA[(((b1 & 0x0f) << 2) | (b2 >> 6)) as usize] as char); |
182 | 1 | } else { |
183 | 1 | out.push('='); |
184 | 1 | } |
185 | 2 | if chunk.len() > 2 { |
186 | 1 | out.push(ALPHA[(b2 & 0x3f) as usize] as char); |
187 | 1 | } else { |
188 | 1 | out.push('='); |
189 | 1 | } |
190 | | } |
191 | 1 | out |
192 | 1 | } |
193 | | |
194 | 1 | fn decode(s: &str) -> Result<Vec<u8>, &'static str> { |
195 | 1 | let mut lookup = [255u8; 256]; |
196 | 64 | for (i, b) in ALPHA1 .iter1 ().enumerate1 () { |
197 | 64 | lookup[*b as usize] = u8::try_from(i).expect("base64 alphabet index fits in u8"); |
198 | 64 | } |
199 | 1 | let bytes = s.as_bytes(); |
200 | 1 | if !bytes.len().is_multiple_of(4) { |
201 | 0 | return Err("base64 input length not multiple of 4"); |
202 | 1 | } |
203 | 1 | let mut out = Vec::with_capacity(bytes.len() / 4 * 3); |
204 | 2 | for chunk in bytes1 .chunks1 (4) { |
205 | 2 | let v0 = lookup[chunk[0] as usize]; |
206 | 2 | let v1 = lookup[chunk[1] as usize]; |
207 | 2 | let v2 = lookup[chunk[2] as usize]; |
208 | 2 | let v3 = lookup[chunk[3] as usize]; |
209 | 2 | if v0 == 255 || v1 == 255 { |
210 | 0 | return Err("invalid base64 char"); |
211 | 2 | } |
212 | 2 | out.push((v0 << 2) | (v1 >> 4)); |
213 | 2 | if chunk[2] != b'=' { |
214 | 1 | if v2 == 255 { |
215 | 0 | return Err("invalid base64 char"); |
216 | 1 | } |
217 | 1 | out.push((v1 << 4) | (v2 >> 2)); |
218 | 1 | } |
219 | 2 | if chunk[3] != b'=' { |
220 | 1 | if v3 == 255 { |
221 | 0 | return Err("invalid base64 char"); |
222 | 1 | } |
223 | 1 | out.push((v2 << 6) | v3); |
224 | 1 | } |
225 | | } |
226 | 1 | Ok(out) |
227 | 1 | } |
228 | | |
229 | 1 | pub(super) fn serialize<S: Serializer>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error> { |
230 | 1 | s.serialize_str(&encode(bytes)) |
231 | 1 | } |
232 | | |
233 | 1 | pub(super) fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> { |
234 | 1 | let s = String::deserialize(d)?0 ; |
235 | 1 | decode(&s).map_err(serde::de::Error::custom) |
236 | 1 | } |
237 | | } |
238 | | |
239 | | #[cfg(test)] |
240 | | mod tests { |
241 | | use super::*; |
242 | | |
243 | | #[test] |
244 | 1 | fn parse_wire_format() { |
245 | 1 | assert_eq!(WireFormat::parse("proto").unwrap(), WireFormat::Proto); |
246 | 1 | assert_eq!(WireFormat::parse("json").unwrap(), WireFormat::Json); |
247 | 1 | assert!(WireFormat::parse("xml").is_err()); |
248 | 1 | } |
249 | | |
250 | | #[test] |
251 | 1 | fn proto_round_trip_minimal() { |
252 | 1 | let req = WorkRequest { |
253 | 1 | arguments: vec!["-source".into(), "21".into(), "Foo.java".into()], |
254 | 1 | ..WorkRequest::default() |
255 | 1 | }; |
256 | 1 | let bytes = req.encode_framed(WireFormat::Proto).unwrap(); |
257 | | // Encoding is length-delimited; round-trip via prost's decoder. |
258 | 1 | let mut cursor: &[u8] = &bytes; |
259 | 1 | let decoded = <WorkRequest as ProstMessage>::decode_length_delimited(&mut cursor).unwrap(); |
260 | 1 | assert_eq!(decoded, req); |
261 | 1 | } |
262 | | |
263 | | #[test] |
264 | 1 | fn json_round_trip_uses_camel_case() { |
265 | 1 | let resp = WorkResponse { |
266 | 1 | exit_code: 0, |
267 | 1 | output: "compiled Foo.java".into(), |
268 | 1 | request_id: 0, |
269 | 1 | was_cancelled: false, |
270 | 1 | }; |
271 | 1 | let bytes = serde_json::to_vec(&resp).unwrap(); |
272 | 1 | let s = core::str::from_utf8(&bytes).unwrap(); |
273 | | // Bazel's JSON convention is camelCase; assert the wire form. |
274 | 1 | assert!(s.contains(r#""exitCode":0"#), "got: {s}"); |
275 | 1 | assert!(s.contains(r#""requestId":0"#), "got: {s}"); |
276 | 1 | assert!(s.contains(r#""wasCancelled":false"#), "got: {s}"); |
277 | | |
278 | 1 | let parsed: WorkResponse = serde_json::from_slice(&bytes).unwrap(); |
279 | 1 | assert_eq!(parsed, resp); |
280 | 1 | } |
281 | | |
282 | | #[test] |
283 | 1 | fn json_decodes_minimum_response() { |
284 | | // A real worker may omit fields whose values are proto defaults. |
285 | 1 | let bytes = br#"{"exitCode":1,"output":"oh no"}"#; |
286 | 1 | let resp = WorkResponse::decode_framed(bytes, WireFormat::Json).unwrap(); |
287 | 1 | assert_eq!(resp.exit_code, 1); |
288 | 1 | assert_eq!(resp.output, "oh no"); |
289 | 1 | assert_eq!(resp.request_id, 0); |
290 | 1 | assert!(!resp.was_cancelled); |
291 | 1 | } |
292 | | |
293 | | #[test] |
294 | 1 | fn proto_request_includes_inputs_with_digest() { |
295 | 1 | let req = WorkRequest { |
296 | 1 | arguments: vec!["@argfile".into()], |
297 | 1 | inputs: vec![Input { |
298 | 1 | path: "Foo.java".into(), |
299 | 1 | digest: vec![0xde, 0xad, 0xbe, 0xef], |
300 | 1 | }], |
301 | 1 | ..WorkRequest::default() |
302 | 1 | }; |
303 | 1 | let bytes = req.encode_framed(WireFormat::Proto).unwrap(); |
304 | 1 | let mut cursor: &[u8] = &bytes; |
305 | 1 | let decoded = <WorkRequest as ProstMessage>::decode_length_delimited(&mut cursor).unwrap(); |
306 | 1 | assert_eq!(decoded, req); |
307 | 1 | } |
308 | | |
309 | | #[test] |
310 | 1 | fn json_input_digest_uses_base64() { |
311 | 1 | let req = WorkRequest { |
312 | 1 | inputs: vec![Input { |
313 | 1 | path: "Foo.java".into(), |
314 | 1 | digest: vec![0xde, 0xad, 0xbe, 0xef], |
315 | 1 | }], |
316 | 1 | ..WorkRequest::default() |
317 | 1 | }; |
318 | 1 | let s = serde_json::to_string(&req).unwrap(); |
319 | | // 0xdeadbeef base64-encodes to "3q2+7w==". |
320 | 1 | assert!(s.contains(r#""digest":"3q2+7w==""#), "got: {s}"); |
321 | | |
322 | 1 | let parsed: WorkRequest = serde_json::from_str(&s).unwrap(); |
323 | 1 | assert_eq!(parsed, req); |
324 | 1 | } |
325 | | } |