Coverage Report

Created: 2025-07-30 16:11

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-config/src/serde_utils.rs
Line
Count
Source
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::marker::PhantomData;
16
use std::borrow::Cow;
17
use std::fmt;
18
19
use byte_unit::Byte;
20
use humantime::parse_duration;
21
use serde::de::Visitor;
22
use serde::{Deserialize, Deserializer, de};
23
24
/// Helper for serde macro so you can use shellexpand variables in the json configuration
25
/// files when the number is a numeric type.
26
///
27
/// # Errors
28
///
29
/// Will return `Err` if deserialization fails.
30
12
pub fn convert_numeric_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
31
12
where
32
12
    D: Deserializer<'de>,
33
12
    T: TryFrom<i64>,
34
12
    <T as TryFrom<i64>>::Error: fmt::Display,
35
{
36
    struct NumericVisitor<T: TryFrom<i64>>(PhantomData<T>);
37
38
    impl<T> Visitor<'_> for NumericVisitor<T>
39
    where
40
        T: TryFrom<i64>,
41
        <T as TryFrom<i64>>::Error: fmt::Display,
42
    {
43
        type Value = T;
44
45
0
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
46
0
            formatter.write_str("an integer or a plain number string")
47
0
        }
48
49
12
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
50
12
            T::try_from(v).map_err(de::Error::custom)
51
12
        }
52
53
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
54
0
            let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
55
0
            T::try_from(v_i64).map_err(de::Error::custom)
56
0
        }
57
58
0
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
59
0
            let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
60
0
            let s = expanded.as_ref().trim();
61
0
            let parsed = s.parse::<i64>().map_err(de::Error::custom)?;
62
0
            T::try_from(parsed).map_err(de::Error::custom)
63
0
        }
64
    }
65
66
12
    deserializer.deserialize_any(NumericVisitor::<T>(PhantomData))
67
12
}
68
69
/// Same as `convert_numeric_with_shellexpand`, but supports `Option<T>`.
70
///
71
/// # Errors
72
///
73
/// Will return `Err` if deserialization fails.
74
14
pub fn convert_optional_numeric_with_shellexpand<'de, D, T>(
75
14
    deserializer: D,
76
14
) -> Result<Option<T>, D::Error>
77
14
where
78
14
    D: Deserializer<'de>,
79
14
    T: TryFrom<i64>,
80
14
    <T as TryFrom<i64>>::Error: fmt::Display,
81
{
82
    struct OptionalNumericVisitor<T: TryFrom<i64>>(PhantomData<T>);
83
84
    impl<'de, T> Visitor<'de> for OptionalNumericVisitor<T>
85
    where
86
        T: TryFrom<i64>,
87
        <T as TryFrom<i64>>::Error: fmt::Display,
88
    {
89
        type Value = Option<T>;
90
91
1
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
92
1
            formatter.write_str("an optional integer or a plain number string")
93
1
        }
94
95
5
        fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
96
5
            Ok(None)
97
5
        }
98
99
0
        fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
100
0
            Ok(None)
101
0
        }
102
103
9
        fn visit_some<D2: Deserializer<'de>>(
104
9
            self,
105
9
            deserializer: D2,
106
9
        ) -> Result<Self::Value, D2::Error> {
107
9
            deserializer.deserialize_any(self)
108
9
        }
109
110
3
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
111
3
            T::try_from(v).map(Some).map_err(de::Error::custom)
112
3
        }
113
114
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
115
0
            let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
116
0
            T::try_from(v_i64).map(Some).map_err(de::Error::custom)
117
0
        }
118
119
5
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
120
5
            if v.is_empty() {
  Branch (120:16): [True: 0, False: 0]
  Branch (120:16): [True: 0, False: 5]
  Branch (120:16): [Folded - Ignored]
121
0
                return Err(de::Error::custom("empty string is not a valid number"));
122
5
            }
123
5
            if v.trim().is_empty() {
  Branch (123:16): [True: 0, False: 0]
  Branch (123:16): [True: 0, False: 5]
  Branch (123:16): [Folded - Ignored]
124
0
                return Ok(None);
125
5
            }
126
5
            let 
expanded4
= shellexpand::env(v).map_err(de::Error::custom)
?1
;
127
4
            let s = expanded.as_ref().trim();
128
4
            let 
parsed2
= s.parse::<i64>().map_err(de::Error::custom)
?2
;
129
2
            T::try_from(parsed).map(Some).map_err(de::Error::custom)
130
5
        }
131
    }
132
133
14
    deserializer.deserialize_option(OptionalNumericVisitor::<T>(PhantomData))
134
14
}
135
136
/// Helper for serde macro so you can use shellexpand variables in the json
137
/// configuration files when the input is a string.
138
///
139
/// Handles YAML/JSON values according to the YAML 1.2 specification:
140
/// - Empty string (`""`) remains an empty string
141
/// - `null` becomes `None`
142
/// - Missing field becomes `None`
143
/// - Whitespace is preserved
144
///
145
/// # Errors
146
///
147
/// Will return `Err` if deserialization fails.
148
84
pub fn convert_string_with_shellexpand<'de, D: Deserializer<'de>>(
149
84
    deserializer: D,
150
84
) -> Result<String, D::Error> {
151
84
    let value = String::deserialize(deserializer)
?0
;
152
84
    Ok((*(shellexpand::env(&value).map_err(de::Error::custom)
?0
)).to_string())
153
84
}
154
155
/// Same as `convert_string_with_shellexpand`, but supports `Vec<String>`.
156
///
157
/// # Errors
158
///
159
/// Will return `Err` if deserialization fails.
160
10
pub fn convert_vec_string_with_shellexpand<'de, D: Deserializer<'de>>(
161
10
    deserializer: D,
162
10
) -> Result<Vec<String>, D::Error> {
163
10
    let vec = Vec::<String>::deserialize(deserializer)
?0
;
164
10
    vec.into_iter()
165
10
        .map(|s| {
166
10
            shellexpand::env(&s)
167
10
                .map_err(de::Error::custom)
168
10
                .map(Cow::into_owned)
169
10
        })
170
10
        .collect()
171
10
}
172
173
/// Same as `convert_string_with_shellexpand`, but supports `Option<String>`.
174
///
175
/// # Errors
176
///
177
/// Will return `Err` if deserialization fails.
178
16
pub fn convert_optional_string_with_shellexpand<'de, D: Deserializer<'de>>(
179
16
    deserializer: D,
180
16
) -> Result<Option<String>, D::Error> {
181
16
    let value = Option::<String>::deserialize(deserializer)
?0
;
182
13
    match value {
183
13
        Some(
v2
) if v.is_empty(
)2
=>
Ok(Some(String::new()))2
, // Keep empty string as empty string
  Branch (183:20): [True: 0, False: 6]
  Branch (183:20): [True: 0, False: 0]
  Branch (183:20): [True: 2, False: 5]
  Branch (183:20): [Folded - Ignored]
184
11
        Some(v) => Ok(Some(
185
11
            (*(shellexpand::env(&v).map_err(de::Error::custom)
?0
)).to_string(),
186
        )),
187
3
        None => Ok(None), // Handle both null and field not present
188
    }
189
16
}
190
191
/// # Errors
192
///
193
/// Will return `Err` if deserialization fails.
194
38
pub fn convert_data_size_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
195
38
where
196
38
    D: Deserializer<'de>,
197
38
    T: TryFrom<u128>,
198
38
    <T as TryFrom<u128>>::Error: fmt::Display,
199
{
200
    struct DataSizeVisitor<T: TryFrom<u128>>(PhantomData<T>);
201
202
    impl<T> Visitor<'_> for DataSizeVisitor<T>
203
    where
204
        T: TryFrom<u128>,
205
        <T as TryFrom<u128>>::Error: fmt::Display,
206
    {
207
        type Value = T;
208
209
1
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
210
1
            formatter.write_str("either a number of bytes as an integer, or a string with a data size format (e.g., \"1GB\", \"500MB\", \"1.5TB\")")
211
1
        }
212
213
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
214
0
            T::try_from(u128::from(v)).map_err(de::Error::custom)
215
0
        }
216
217
12
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
218
12
            if v < 0 {
  Branch (218:16): [True: 0, False: 9]
  Branch (218:16): [True: 0, False: 0]
  Branch (218:16): [True: 0, False: 1]
  Branch (218:16): [True: 1, False: 1]
  Branch (218:16): [Folded - Ignored]
219
1
                return Err(de::Error::custom("Negative data size is not allowed"));
220
11
            }
221
11
            let v_u128 = u128::try_from(v).map_err(de::Error::custom)
?0
;
222
11
            T::try_from(v_u128).map_err(de::Error::custom)
223
12
        }
224
225
0
        fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
226
0
            T::try_from(v).map_err(de::Error::custom)
227
0
        }
228
229
0
        fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
230
0
            if v < 0 {
  Branch (230:16): [Folded - Ignored]
  Branch (230:16): [Folded - Ignored]
231
0
                return Err(de::Error::custom("Negative data size is not allowed"));
232
0
            }
233
0
            let v_u128 = u128::try_from(v).map_err(de::Error::custom)?;
234
0
            T::try_from(v_u128).map_err(de::Error::custom)
235
0
        }
236
237
25
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
238
25
            let expanded = shellexpand::env(v).map_err(de::Error::custom)
?0
;
239
25
            let s = expanded.as_ref().trim();
240
25
            let 
byte_size23
= Byte::parse_str(s, true).map_err(de::Error::custom)
?2
;
241
23
            let bytes = byte_size.as_u128();
242
23
            T::try_from(bytes).map_err(de::Error::custom)
243
25
        }
244
    }
245
246
38
    deserializer.deserialize_any(DataSizeVisitor::<T>(PhantomData))
247
38
}
248
249
/// # Errors
250
///
251
/// Will return `Err` if deserialization fails.
252
18
pub fn convert_duration_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
253
18
where
254
18
    D: Deserializer<'de>,
255
18
    T: TryFrom<u64>,
256
18
    <T as TryFrom<u64>>::Error: fmt::Display,
257
{
258
    struct DurationVisitor<T: TryFrom<u64>>(PhantomData<T>);
259
260
    impl<T> Visitor<'_> for DurationVisitor<T>
261
    where
262
        T: TryFrom<u64>,
263
        <T as TryFrom<u64>>::Error: fmt::Display,
264
    {
265
        type Value = T;
266
267
1
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
268
1
            formatter.write_str("either a number of seconds as an integer, or a string with a duration format (e.g., \"1h2m3s\", \"30m\", \"1d\")")
269
1
        }
270
271
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
272
0
            T::try_from(v).map_err(de::Error::custom)
273
0
        }
274
275
7
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
276
7
            if v < 0 {
  Branch (276:16): [True: 0, False: 0]
  Branch (276:16): [True: 0, False: 0]
  Branch (276:16): [True: 0, False: 1]
  Branch (276:16): [True: 1, False: 5]
  Branch (276:16): [Folded - Ignored]
277
1
                return Err(de::Error::custom("Negative duration is not allowed"));
278
6
            }
279
6
            let v_u64 = u64::try_from(v).map_err(de::Error::custom)
?0
;
280
6
            T::try_from(v_u64).map_err(de::Error::custom)
281
7
        }
282
283
0
        fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
284
0
            let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
285
0
            T::try_from(v_u64).map_err(de::Error::custom)
286
0
        }
287
288
0
        fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
289
0
            if v < 0 {
  Branch (289:16): [Folded - Ignored]
  Branch (289:16): [Folded - Ignored]
290
0
                return Err(de::Error::custom("Negative duration is not allowed"));
291
0
            }
292
0
            let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
293
0
            T::try_from(v_u64).map_err(de::Error::custom)
294
0
        }
295
296
10
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
297
10
            let expanded = shellexpand::env(v).map_err(de::Error::custom)
?0
;
298
10
            let expanded = expanded.as_ref().trim();
299
10
            let 
duration8
= parse_duration(expanded).map_err(de::Error::custom)
?2
;
300
8
            let secs = duration.as_secs();
301
8
            T::try_from(secs).map_err(de::Error::custom)
302
10
        }
303
    }
304
305
18
    deserializer.deserialize_any(DurationVisitor::<T>(PhantomData))
306
18
}