Coverage Report

Created: 2025-10-30 00:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-config/src/serde_utils.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::marker::PhantomData;
16
use std::borrow::Cow;
17
use std::fmt;
18
19
use byte_unit::Byte;
20
use humantime::parse_duration;
21
use serde::de::Visitor;
22
use serde::{Deserialize, Deserializer, de};
23
24
/// Helper for serde macro so you can use shellexpand variables in the json configuration
25
/// files when the number is a numeric type.
26
///
27
/// # Errors
28
///
29
/// Will return `Err` if deserialization fails.
30
23
pub fn convert_numeric_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
31
23
where
32
23
    D: Deserializer<'de>,
33
23
    T: TryFrom<i64>,
34
23
    <T as TryFrom<i64>>::Error: fmt::Display,
35
{
36
    struct NumericVisitor<T: TryFrom<i64>>(PhantomData<T>);
37
38
    impl<T> Visitor<'_> for NumericVisitor<T>
39
    where
40
        T: TryFrom<i64>,
41
        <T as TryFrom<i64>>::Error: fmt::Display,
42
    {
43
        type Value = T;
44
45
0
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
46
0
            formatter.write_str("an integer or a plain number string")
47
0
        }
48
49
23
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
50
23
            T::try_from(v).map_err(de::Error::custom)
51
23
        }
52
53
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
54
0
            let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
55
0
            T::try_from(v_i64).map_err(de::Error::custom)
56
0
        }
57
58
0
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
59
0
            let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
60
0
            let s = expanded.as_ref().trim();
61
0
            let parsed = s.parse::<i64>().map_err(de::Error::custom)?;
62
0
            T::try_from(parsed).map_err(de::Error::custom)
63
0
        }
64
    }
65
66
23
    deserializer.deserialize_any(NumericVisitor::<T>(PhantomData))
67
23
}
68
69
/// Same as `convert_numeric_with_shellexpand`, but supports `Option<T>`.
70
///
71
/// # Errors
72
///
73
/// Will return `Err` if deserialization fails.
74
14
pub fn convert_optional_numeric_with_shellexpand<'de, D, T>(
75
14
    deserializer: D,
76
14
) -> Result<Option<T>, D::Error>
77
14
where
78
14
    D: Deserializer<'de>,
79
14
    T: TryFrom<i64>,
80
14
    <T as TryFrom<i64>>::Error: fmt::Display,
81
{
82
    struct OptionalNumericVisitor<T: TryFrom<i64>>(PhantomData<T>);
83
84
    impl<'de, T> Visitor<'de> for OptionalNumericVisitor<T>
85
    where
86
        T: TryFrom<i64>,
87
        <T as TryFrom<i64>>::Error: fmt::Display,
88
    {
89
        type Value = Option<T>;
90
91
1
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
92
1
            formatter.write_str("an optional integer or a plain number string")
93
1
        }
94
95
5
        fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
96
5
            Ok(None)
97
5
        }
98
99
0
        fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
100
0
            Ok(None)
101
0
        }
102
103
9
        fn visit_some<D2: Deserializer<'de>>(
104
9
            self,
105
9
            deserializer: D2,
106
9
        ) -> Result<Self::Value, D2::Error> {
107
9
            deserializer.deserialize_any(self)
108
9
        }
109
110
3
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
111
3
            T::try_from(v).map(Some).map_err(de::Error::custom)
112
3
        }
113
114
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
115
0
            let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
116
0
            T::try_from(v_i64).map(Some).map_err(de::Error::custom)
117
0
        }
118
119
5
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
120
5
            if v.is_empty() {
  Branch (120:16): [True: 0, False: 0]
  Branch (120:16): [True: 0, False: 5]
  Branch (120:16): [Folded - Ignored]
121
0
                return Err(de::Error::custom("empty string is not a valid number"));
122
5
            }
123
5
            if v.trim().is_empty() {
  Branch (123:16): [True: 0, False: 0]
  Branch (123:16): [True: 0, False: 5]
  Branch (123:16): [Folded - Ignored]
124
0
                return Ok(None);
125
5
            }
126
5
            let 
expanded4
= shellexpand::env(v).map_err(de::Error::custom)
?1
;
127
4
            let s = expanded.as_ref().trim();
128
4
            let 
parsed2
= s.parse::<i64>().map_err(de::Error::custom)
?2
;
129
2
            T::try_from(parsed).map(Some).map_err(de::Error::custom)
130
5
        }
131
    }
132
133
14
    deserializer.deserialize_option(OptionalNumericVisitor::<T>(PhantomData))
134
14
}
135
136
/// Helper for serde macro so you can use shellexpand variables in the json
137
/// configuration files when the input is a string.
138
///
139
/// Handles YAML/JSON values according to the YAML 1.2 specification:
140
/// - Empty string (`""`) remains an empty string
141
/// - `null` becomes `None`
142
/// - Missing field becomes `None`
143
/// - Whitespace is preserved
144
///
145
/// # Errors
146
///
147
/// Will return `Err` if deserialization fails.
148
146
pub fn convert_string_with_shellexpand<'de, D: Deserializer<'de>>(
149
146
    deserializer: D,
150
146
) -> Result<String, D::Error> {
151
146
    let value = String::deserialize(deserializer)
?0
;
152
146
    Ok((*(shellexpand::env(&value).map_err(de::Error::custom)
?0
)).to_string())
153
146
}
154
155
/// Same as `convert_string_with_shellexpand`, but supports `Vec<String>`.
156
///
157
/// # Errors
158
///
159
/// Will return `Err` if deserialization fails.
160
11
pub fn convert_vec_string_with_shellexpand<'de, D: Deserializer<'de>>(
161
11
    deserializer: D,
162
11
) -> Result<Vec<String>, D::Error> {
163
11
    let vec = Vec::<String>::deserialize(deserializer)
?0
;
164
11
    vec.into_iter()
165
11
        .map(|s| {
166
11
            shellexpand::env(&s)
167
11
                .map_err(de::Error::custom)
168
11
                .map(Cow::into_owned)
169
11
        })
170
11
        .collect()
171
11
}
172
173
/// Same as `convert_string_with_shellexpand`, but supports `Option<String>`.
174
///
175
/// # Errors
176
///
177
/// Will return `Err` if deserialization fails.
178
17
pub fn convert_optional_string_with_shellexpand<'de, D: Deserializer<'de>>(
179
17
    deserializer: D,
180
17
) -> Result<Option<String>, D::Error> {
181
17
    let value = Option::<String>::deserialize(deserializer)
?0
;
182
14
    match value {
183
14
        Some(
v2
) if v.is_empty(
)2
=>
Ok(Some(String::new()))2
, // Keep empty string as empty string
  Branch (183:20): [True: 0, False: 7]
  Branch (183:20): [True: 0, False: 0]
  Branch (183:20): [True: 2, False: 5]
  Branch (183:20): [Folded - Ignored]
184
12
        Some(v) => Ok(Some(
185
12
            (*(shellexpand::env(&v).map_err(de::Error::custom)
?0
)).to_string(),
186
        )),
187
3
        None => Ok(None), // Handle both null and field not present
188
    }
189
17
}
190
191
/// # Errors
192
///
193
/// Will return `Err` if deserialization fails.
194
54
pub fn convert_data_size_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
195
54
where
196
54
    D: Deserializer<'de>,
197
54
    T: TryFrom<u128>,
198
54
    <T as TryFrom<u128>>::Error: fmt::Display,
199
{
200
    struct DataSizeVisitor<T: TryFrom<u128>>(PhantomData<T>);
201
202
    impl<T> Visitor<'_> for DataSizeVisitor<T>
203
    where
204
        T: TryFrom<u128>,
205
        <T as TryFrom<u128>>::Error: fmt::Display,
206
    {
207
        type Value = T;
208
209
1
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
210
1
            formatter.write_str("either a number of bytes as an integer, or a string with a data size format (e.g., \"1GB\", \"500MB\", \"1.5TB\")")
211
1
        }
212
213
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
214
0
            T::try_from(u128::from(v)).map_err(de::Error::custom)
215
0
        }
216
217
13
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
218
13
            if v < 0 {
  Branch (218:16): [True: 0, False: 10]
  Branch (218:16): [True: 0, False: 0]
  Branch (218:16): [True: 0, False: 1]
  Branch (218:16): [True: 1, False: 1]
  Branch (218:16): [True: 0, False: 0]
219
1
                return Err(de::Error::custom("Negative data size is not allowed"));
220
12
            }
221
12
            let v_u128 = u128::try_from(v).map_err(de::Error::custom)
?0
;
222
12
            T::try_from(v_u128).map_err(de::Error::custom)
223
13
        }
224
225
0
        fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
226
0
            T::try_from(v).map_err(de::Error::custom)
227
0
        }
228
229
0
        fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
230
0
            if v < 0 {
  Branch (230:16): [Folded - Ignored]
  Branch (230:16): [Folded - Ignored]
231
0
                return Err(de::Error::custom("Negative data size is not allowed"));
232
0
            }
233
0
            let v_u128 = u128::try_from(v).map_err(de::Error::custom)?;
234
0
            T::try_from(v_u128).map_err(de::Error::custom)
235
0
        }
236
237
40
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
238
40
            let expanded = shellexpand::env(v).map_err(de::Error::custom)
?0
;
239
40
            let s = expanded.as_ref().trim();
240
40
            if v.is_empty() {
  Branch (240:16): [True: 0, False: 19]
  Branch (240:16): [True: 0, False: 0]
  Branch (240:16): [True: 0, False: 2]
  Branch (240:16): [True: 1, False: 18]
  Branch (240:16): [True: 0, False: 0]
241
1
                return Err(de::Error::custom("Missing value in a size field"));
242
39
            }
243
39
            let 
byte_size37
= Byte::parse_str(s, true).map_err(de::Error::custom)
?2
;
244
37
            let bytes = byte_size.as_u128();
245
37
            T::try_from(bytes).map_err(de::Error::custom)
246
40
        }
247
    }
248
249
54
    deserializer.deserialize_any(DataSizeVisitor::<T>(PhantomData))
250
54
}
251
252
/// # Errors
253
///
254
/// Will return `Err` if deserialization fails.
255
19
pub fn convert_duration_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
256
19
where
257
19
    D: Deserializer<'de>,
258
19
    T: TryFrom<u64>,
259
19
    <T as TryFrom<u64>>::Error: fmt::Display,
260
{
261
    struct DurationVisitor<T: TryFrom<u64>>(PhantomData<T>);
262
263
    impl<T> Visitor<'_> for DurationVisitor<T>
264
    where
265
        T: TryFrom<u64>,
266
        <T as TryFrom<u64>>::Error: fmt::Display,
267
    {
268
        type Value = T;
269
270
1
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
271
1
            formatter.write_str("either a number of seconds as an integer, or a string with a duration format (e.g., \"1h2m3s\", \"30m\", \"1d\")")
272
1
        }
273
274
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
275
0
            T::try_from(v).map_err(de::Error::custom)
276
0
        }
277
278
8
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
279
8
            if v < 0 {
  Branch (279:16): [True: 0, False: 0]
  Branch (279:16): [True: 0, False: 1]
  Branch (279:16): [True: 0, False: 1]
  Branch (279:16): [True: 1, False: 5]
  Branch (279:16): [True: 0, False: 0]
280
1
                return Err(de::Error::custom("Negative duration is not allowed"));
281
7
            }
282
7
            let v_u64 = u64::try_from(v).map_err(de::Error::custom)
?0
;
283
7
            T::try_from(v_u64).map_err(de::Error::custom)
284
8
        }
285
286
0
        fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
287
0
            let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
288
0
            T::try_from(v_u64).map_err(de::Error::custom)
289
0
        }
290
291
0
        fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
292
0
            if v < 0 {
  Branch (292:16): [Folded - Ignored]
  Branch (292:16): [Folded - Ignored]
293
0
                return Err(de::Error::custom("Negative duration is not allowed"));
294
0
            }
295
0
            let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
296
0
            T::try_from(v_u64).map_err(de::Error::custom)
297
0
        }
298
299
10
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
300
10
            let expanded = shellexpand::env(v).map_err(de::Error::custom)
?0
;
301
10
            let expanded = expanded.as_ref().trim();
302
10
            let 
duration8
= parse_duration(expanded).map_err(de::Error::custom)
?2
;
303
8
            let secs = duration.as_secs();
304
8
            T::try_from(secs).map_err(de::Error::custom)
305
10
        }
306
    }
307
308
19
    deserializer.deserialize_any(DurationVisitor::<T>(PhantomData))
309
19
}