/build/source/nativelink-config/src/backcompat.rs
Line | Count | Source |
1 | | use std::collections::HashMap; |
2 | | |
3 | | use serde::{Deserialize, Deserializer, Serialize}; |
4 | | use tracing::warn; |
5 | | |
6 | | use crate::cas_server::{ByteStreamConfig, OldByteStreamConfig, WithInstanceName}; |
7 | | |
8 | | #[derive(Debug, Deserialize)] |
9 | | #[serde(untagged)] |
10 | | enum WithInstanceNameBackCompat<T> { |
11 | | Map(HashMap<String, T>), |
12 | | Vec(Vec<WithInstanceName<T>>), |
13 | | } |
14 | | |
15 | 13 | fn deprecated(old_map: &String, new_map: &String) { |
16 | 13 | warn!( |
17 | 2 | r"WARNING: Using deprecated map format for services. Please migrate to the new array format: |
18 | 2 | // Old: |
19 | 2 | {} |
20 | 2 | // New: |
21 | 2 | {} |
22 | 2 | ", |
23 | | old_map, new_map |
24 | | ); |
25 | 13 | } |
26 | | |
27 | | /// Use `#[serde(default, deserialize_with = "backcompat::opt_vec_with_instance_name")]` for backwards |
28 | | /// compatibility with map-based access. A deprecation warning will be written to stderr if the |
29 | | /// old format is used. |
30 | 33 | pub(crate) fn opt_vec_with_instance_name<'de, D, T>( |
31 | 33 | deserializer: D, |
32 | 33 | ) -> Result<Option<Vec<WithInstanceName<T>>>, D::Error> |
33 | 33 | where |
34 | 33 | D: Deserializer<'de>, |
35 | 33 | T: Deserialize<'de> + Serialize, |
36 | | { |
37 | 33 | let Some(back_compat) = Option::deserialize(deserializer)?0 else { Branch (37:9): [True: 0, False: 0]
Branch (37:9): [True: 0, False: 0]
Branch (37:9): [True: 8, False: 0]
Branch (37:9): [True: 8, False: 0]
Branch (37:9): [True: 7, False: 0]
Branch (37:9): [True: 8, False: 0]
Branch (37:9): [True: 2, False: 0]
|
38 | 0 | return Ok(None); |
39 | | }; |
40 | | |
41 | 33 | match back_compat { |
42 | 5 | WithInstanceNameBackCompat::Map(map) => { |
43 | | // TODO(palfrey): ideally this would be serde_json5::to_string_pretty but that doesn't exist |
44 | | // JSON is close enough to be workable for now |
45 | 5 | let serde_map = serde_json::to_string_pretty(&map).expect("valid map"); |
46 | 5 | let vec: Vec<WithInstanceName<T>> = map |
47 | 5 | .into_iter() |
48 | 5 | .map(|(instance_name, config)| WithInstanceName { |
49 | 6 | instance_name, |
50 | 6 | config, |
51 | 6 | }) |
52 | 5 | .collect(); |
53 | 5 | deprecated( |
54 | 5 | &serde_map, |
55 | | // TODO(palfrey): ideally this would be serde_json5::to_string_pretty but that doesn't exist |
56 | | // JSON is close enough to be workable for now |
57 | 5 | &serde_json::to_string_pretty(&vec).expect("valid new map"), |
58 | | ); |
59 | 5 | Ok(Some(vec)) |
60 | | } |
61 | 28 | WithInstanceNameBackCompat::Vec(vec) => Ok(Some(vec)), |
62 | | } |
63 | 33 | } |
64 | | |
65 | | #[derive(Debug, Deserialize)] |
66 | | #[serde(untagged)] |
67 | | enum ByteStreamKind { |
68 | | New(Vec<WithInstanceName<ByteStreamConfig>>), |
69 | | Old(OldByteStreamConfig), |
70 | | } |
71 | | |
72 | | /// Use `#[serde(default, deserialize_with = "backcompat::opt_bytestream")]` for backwards |
73 | | /// compatibility with older bytestream config . A deprecation warning will be written to stderr if the |
74 | | /// old format is used. |
75 | 10 | pub(crate) fn opt_bytestream<'de, D>( |
76 | 10 | deserializer: D, |
77 | 10 | ) -> Result<Option<Vec<WithInstanceName<ByteStreamConfig>>>, D::Error> |
78 | 10 | where |
79 | 10 | D: Deserializer<'de>, |
80 | | { |
81 | 10 | let Some(back_compat) = Option::deserialize(deserializer)?0 else { Branch (81:9): [True: 8, False: 0]
Branch (81:9): [True: 2, False: 0]
|
82 | 0 | return Ok(None); |
83 | | }; |
84 | | |
85 | 10 | match back_compat { |
86 | 8 | ByteStreamKind::Old(old_config) => { |
87 | 8 | if old_config.max_decoding_message_size != 0 { Branch (87:16): [True: 0, False: 7]
Branch (87:16): [True: 0, False: 1]
|
88 | 0 | warn!( |
89 | 0 | "WARNING: max_decoding_message_size is ignored on Bytestream now. Please set on the HTTP listener instead" |
90 | | ); |
91 | 8 | } |
92 | | // TODO(palfrey): ideally this would be serde_json5::to_string_pretty but that doesn't exist |
93 | | // JSON is close enough to be workable for now |
94 | 8 | let serde_map = serde_json::to_string_pretty(&old_config).expect("valid map"); |
95 | 8 | let names = old_config.cas_stores; |
96 | 8 | let vec: Vec<WithInstanceName<_>> = names |
97 | 8 | .iter() |
98 | 8 | .map(|(instance_name, cas_store)| WithInstanceName { |
99 | 8 | instance_name: instance_name.clone(), |
100 | 8 | config: ByteStreamConfig { |
101 | 8 | cas_store: cas_store.clone(), |
102 | 8 | max_bytes_per_stream: old_config.max_bytes_per_stream, |
103 | 8 | persist_stream_on_disconnect_timeout: old_config |
104 | 8 | .persist_stream_on_disconnect_timeout, |
105 | 8 | }, |
106 | 8 | }) |
107 | 8 | .collect(); |
108 | 8 | deprecated( |
109 | 8 | &serde_map, |
110 | | // TODO(palfrey): ideally this would be serde_json5::to_string_pretty but that doesn't exist |
111 | | // JSON is close enough to be workable for now |
112 | 8 | &serde_json::to_string_pretty(&vec).expect("valid new map"), |
113 | | ); |
114 | 8 | Ok(Some(vec)) |
115 | | } |
116 | 2 | ByteStreamKind::New(vec) => Ok(Some(vec)), |
117 | | } |
118 | 10 | } |
119 | | |
120 | | #[cfg(test)] |
121 | | mod tests { |
122 | | use serde_json::json; |
123 | | use tracing_test::traced_test; |
124 | | |
125 | | use super::*; |
126 | | |
127 | | #[derive(Debug, Deserialize, Serialize, PartialEq)] |
128 | | struct PartialConfig { |
129 | | store: String, |
130 | | } |
131 | | |
132 | | #[derive(Debug, Deserialize, Serialize, PartialEq)] |
133 | | struct FullConfig { |
134 | | #[serde(default, deserialize_with = "opt_vec_with_instance_name")] |
135 | | cas: Option<Vec<WithInstanceName<PartialConfig>>>, |
136 | | } |
137 | | |
138 | | #[test] |
139 | | #[traced_test] |
140 | 1 | fn test_configs_deserialization() { |
141 | 1 | let old_format = json!({ |
142 | 1 | "cas": { |
143 | 1 | "foo": { "store": "foo_store" }, |
144 | 1 | "bar": { "store": "bar_store" } |
145 | | } |
146 | | }); |
147 | | |
148 | 1 | let new_format = json!({ |
149 | 1 | "cas": [ |
150 | | { |
151 | 1 | "instance_name": "foo", |
152 | 1 | "store": "foo_store" |
153 | | }, |
154 | | { |
155 | 1 | "instance_name": "bar", |
156 | 1 | "store": "bar_store" |
157 | | } |
158 | | ] |
159 | | }); |
160 | | |
161 | 1 | let mut old_format: FullConfig = serde_json::from_value(old_format).unwrap(); |
162 | 1 | let mut new_format: FullConfig = serde_json::from_value(new_format).unwrap(); |
163 | | |
164 | | // Ensure deterministic ordering. |
165 | 1 | if let Some(vec) = old_format.cas.as_mut() { Branch (165:16): [True: 1, False: 0]
|
166 | 1 | vec.sort_by(|a, b| a.instance_name.cmp(&b.instance_name)); |
167 | 0 | } |
168 | 1 | if let Some(vec) = new_format.cas.as_mut() { Branch (168:16): [True: 1, False: 0]
|
169 | 1 | vec.sort_by(|a, b| a.instance_name.cmp(&b.instance_name)); |
170 | 0 | } |
171 | | |
172 | 1 | assert_eq!(old_format, new_format); |
173 | | |
174 | 1 | logs_assert(|lines: &[&str]| { |
175 | 1 | if lines.len() != 1 { Branch (175:16): [True: 0, False: 1]
|
176 | 0 | return Err(format!("Expected 1 log line, got: {lines:?}")); |
177 | 1 | } |
178 | 1 | let line = lines[0]; |
179 | | // TODO(palfrey): we should be checking the whole thing, but tracing-test is broken with multi-line items |
180 | | // See https://github.com/dbrgn/tracing-test/issues/48 |
181 | 1 | assert!(line.ends_with("WARNING: Using deprecated map format for services. Please migrate to the new array format:")); |
182 | 1 | Ok(()) |
183 | 1 | }); |
184 | 1 | } |
185 | | |
186 | | #[test] |
187 | 1 | fn test_deserialize_none() { |
188 | 1 | let json = json!({}); |
189 | | |
190 | 1 | let value: FullConfig = serde_json::from_value(json).unwrap(); |
191 | 1 | assert_eq!(value.cas, None); |
192 | 1 | } |
193 | | |
194 | | #[derive(Debug, Deserialize, Serialize, PartialEq)] |
195 | | struct FullBytestreamConfig { |
196 | | #[serde(default, deserialize_with = "opt_bytestream")] |
197 | | pub bytestream: Option<Vec<WithInstanceName<ByteStreamConfig>>>, |
198 | | } |
199 | | |
200 | | #[test] |
201 | | #[traced_test] |
202 | 1 | fn test_bytestream_old_config() { |
203 | 1 | let old_format = json!({ |
204 | 1 | "bytestream": { |
205 | 1 | "cas_stores": { |
206 | 1 | "": "WORKER_FAST_SLOW_STORE" |
207 | | }} |
208 | | }); |
209 | | |
210 | 1 | let new_format = json!({ |
211 | 1 | "bytestream": [ |
212 | | { |
213 | 1 | "cas_store": "WORKER_FAST_SLOW_STORE", |
214 | | }, |
215 | | ], |
216 | | }); |
217 | | |
218 | 1 | let old_format: FullBytestreamConfig = serde_json::from_value(old_format).unwrap(); |
219 | 1 | let new_format: FullBytestreamConfig = serde_json::from_value(new_format).unwrap(); |
220 | | |
221 | 1 | assert_eq!(old_format, new_format); |
222 | | |
223 | 1 | logs_assert(|lines: &[&str]| { |
224 | 1 | if lines.len() != 1 { Branch (224:16): [True: 0, False: 1]
|
225 | 0 | return Err(format!("Expected 1 log line, got: {lines:?}")); |
226 | 1 | } |
227 | 1 | let line = lines[0]; |
228 | | // TODO(palfrey): we should be checking the whole thing, but tracing-test is broken with multi-line items |
229 | | // See https://github.com/dbrgn/tracing-test/issues/48 |
230 | 1 | assert!(line.ends_with("WARNING: Using deprecated map format for services. Please migrate to the new array format:")); |
231 | 1 | Ok(()) |
232 | 1 | }); |
233 | 1 | } |
234 | | } |