Coverage Report

Created: 2024-11-20 10:13

/build/source/nativelink-config/src/serde_utils.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::borrow::Cow;
16
use std::fmt;
17
use std::marker::PhantomData;
18
19
use byte_unit::Byte;
20
use humantime::parse_duration;
21
use serde::de::Visitor;
22
use serde::{de, Deserialize, Deserializer};
23
24
/// Helper for serde macro so you can use shellexpand variables in the json configuration
25
/// files when the number is a numeric type.
26
0
pub fn convert_numeric_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
27
0
where
28
0
    D: Deserializer<'de>,
29
0
    T: TryFrom<i64>,
30
0
    <T as TryFrom<i64>>::Error: fmt::Display,
31
0
{
32
    struct NumericVisitor<T: TryFrom<i64>>(PhantomData<T>);
33
34
    impl<'de, T> Visitor<'de> for NumericVisitor<T>
35
    where
36
        T: TryFrom<i64>,
37
        <T as TryFrom<i64>>::Error: fmt::Display,
38
    {
39
        type Value = T;
40
41
0
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
42
0
            formatter.write_str("an integer or a plain number string")
43
0
        }
44
45
0
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
46
0
            T::try_from(v).map_err(de::Error::custom)
47
0
        }
48
49
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
50
0
            let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
51
0
            T::try_from(v_i64).map_err(de::Error::custom)
52
0
        }
53
54
0
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
55
0
            let expanded = shellexpand::env(v).map_err(de::Error::custom)?;
56
0
            let s = expanded.as_ref().trim();
57
0
            let parsed = s.parse::<i64>().map_err(de::Error::custom)?;
58
0
            T::try_from(parsed).map_err(de::Error::custom)
59
0
        }
60
    }
61
62
0
    deserializer.deserialize_any(NumericVisitor::<T>(PhantomData))
63
0
}
64
65
/// Same as `convert_numeric_with_shellexpand`, but supports `Option<T>`.
66
13
pub fn convert_optional_numeric_with_shellexpand<'de, D, T>(
67
13
    deserializer: D,
68
13
) -> Result<Option<T>, D::Error>
69
13
where
70
13
    D: Deserializer<'de>,
71
13
    T: TryFrom<i64>,
72
13
    <T as TryFrom<i64>>::Error: fmt::Display,
73
13
{
74
    struct OptionalNumericVisitor<T: TryFrom<i64>>(PhantomData<T>);
75
76
    impl<'de, T> Visitor<'de> for OptionalNumericVisitor<T>
77
    where
78
        T: TryFrom<i64>,
79
        <T as TryFrom<i64>>::Error: fmt::Display,
80
    {
81
        type Value = Option<T>;
82
83
1
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
84
1
            formatter.write_str("an optional integer or a plain number string")
85
1
        }
86
87
5
        fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
88
5
            Ok(None)
89
5
        }
90
91
0
        fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
92
0
            Ok(None)
93
0
        }
94
95
8
        fn visit_some<D2: Deserializer<'de>>(
96
8
            self,
97
8
            deserializer: D2,
98
8
        ) -> Result<Self::Value, D2::Error> {
99
8
            deserializer.deserialize_any(self)
100
8
        }
101
102
2
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
103
2
            T::try_from(v).map(Some).map_err(de::Error::custom)
104
2
        }
105
106
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
107
0
            let v_i64 = i64::try_from(v).map_err(de::Error::custom)?;
108
0
            T::try_from(v_i64).map(Some).map_err(de::Error::custom)
109
0
        }
110
111
5
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
112
5
            if v.is_empty() {
  Branch (112:16): [Folded - Ignored]
  Branch (112:16): [True: 0, False: 5]
  Branch (112:16): [Folded - Ignored]
113
0
                return Err(de::Error::custom("empty string is not a valid number"));
114
5
            }
115
5
            if v.trim().is_empty() {
  Branch (115:16): [Folded - Ignored]
  Branch (115:16): [True: 0, False: 5]
  Branch (115:16): [Folded - Ignored]
116
0
                return Ok(None);
117
5
            }
118
5
            let 
expanded4
= shellexpand::env(v).map_err(de::Error::custom)
?1
;
119
4
            let s = expanded.as_ref().trim();
120
4
            let 
parsed2
= s.parse::<i64>().map_err(de::Error::custom)
?2
;
121
2
            T::try_from(parsed).map(Some).map_err(de::Error::custom)
122
5
        }
123
    }
124
125
13
    deserializer.deserialize_option(OptionalNumericVisitor::<T>(PhantomData))
126
13
}
127
128
/// Helper for serde macro so you can use shellexpand variables in the json
129
/// configuration files when the input is a string.
130
///
131
/// Handles YAML/JSON values according to the YAML 1.2 specification:
132
/// - Empty string (`""`) remains an empty string
133
/// - `null` becomes `None`
134
/// - Missing field becomes `None`
135
/// - Whitespace is preserved
136
0
pub fn convert_string_with_shellexpand<'de, D: Deserializer<'de>>(
137
0
    deserializer: D,
138
0
) -> Result<String, D::Error> {
139
0
    let value = String::deserialize(deserializer)?;
140
0
    Ok((*(shellexpand::env(&value).map_err(de::Error::custom)?)).to_string())
141
0
}
142
143
/// Same as `convert_string_with_shellexpand`, but supports `Vec<String>`.
144
0
pub fn convert_vec_string_with_shellexpand<'de, D: Deserializer<'de>>(
145
0
    deserializer: D,
146
0
) -> Result<Vec<String>, D::Error> {
147
0
    let vec = Vec::<String>::deserialize(deserializer)?;
148
0
    vec.into_iter()
149
0
        .map(|s| {
150
0
            shellexpand::env(&s)
151
0
                .map_err(de::Error::custom)
152
0
                .map(Cow::into_owned)
153
0
        })
154
0
        .collect()
155
0
}
156
157
/// Same as `convert_string_with_shellexpand`, but supports `Option<String>`.
158
10
pub fn convert_optional_string_with_shellexpand<'de, D: Deserializer<'de>>(
159
10
    deserializer: D,
160
10
) -> Result<Option<String>, D::Error> {
161
10
    let value = Option::<String>::deserialize(deserializer)
?0
;
162
7
    match value {
163
7
        Some(
v2
) if v.is_empty(
) => Ok(Some(String::new()))2
, // Keep empty string as empty string
  Branch (163:20): [Folded - Ignored]
  Branch (163:20): [True: 2, False: 5]
  Branch (163:20): [Folded - Ignored]
164
5
        Some(v) => Ok(Some(
165
5
            (*(shellexpand::env(&v).map_err(de::Error::custom)
?0
)).to_string(),
166
        )),
167
3
        None => Ok(None), // Handle both null and field not present
168
    }
169
10
}
170
171
21
pub fn convert_data_size_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
172
21
where
173
21
    D: Deserializer<'de>,
174
21
    T: TryFrom<u128>,
175
21
    <T as TryFrom<u128>>::Error: fmt::Display,
176
21
{
177
    struct DataSizeVisitor<T: TryFrom<u128>>(PhantomData<T>);
178
179
    impl<'de, T> Visitor<'de> for DataSizeVisitor<T>
180
    where
181
        T: TryFrom<u128>,
182
        <T as TryFrom<u128>>::Error: fmt::Display,
183
    {
184
        type Value = T;
185
186
1
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
187
1
            formatter.write_str("either a number of bytes as an integer, or a string with a data size format (e.g., \"1GB\", \"500MB\", \"1.5TB\")")
188
1
        }
189
190
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
191
0
            T::try_from(u128::from(v)).map_err(de::Error::custom)
192
0
        }
193
194
2
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
195
2
            if v < 0 {
  Branch (195:16): [Folded - Ignored]
  Branch (195:16): [True: 1, False: 1]
  Branch (195:16): [Folded - Ignored]
196
1
                return Err(de::Error::custom("Negative data size is not allowed"));
197
1
            }
198
1
            T::try_from(v as u128).map_err(de::Error::custom)
199
2
        }
200
201
0
        fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
202
0
            T::try_from(v).map_err(de::Error::custom)
203
0
        }
204
205
0
        fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
206
0
            if v < 0 {
  Branch (206:16): [Folded - Ignored]
  Branch (206:16): [Folded - Ignored]
207
0
                return Err(de::Error::custom("Negative data size is not allowed"));
208
0
            }
209
0
            T::try_from(v as u128).map_err(de::Error::custom)
210
0
        }
211
212
18
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
213
18
            let expanded = shellexpand::env(v).map_err(de::Error::custom)
?0
;
214
18
            let s = expanded.as_ref().trim();
215
18
            let 
byte_size16
= Byte::parse_str(s, true).map_err(de::Error::custom)
?2
;
216
16
            let bytes = byte_size.as_u128();
217
16
            T::try_from(bytes).map_err(de::Error::custom)
218
18
        }
219
    }
220
221
21
    deserializer.deserialize_any(DataSizeVisitor::<T>(PhantomData))
222
21
}
223
224
17
pub fn convert_duration_with_shellexpand<'de, D, T>(deserializer: D) -> Result<T, D::Error>
225
17
where
226
17
    D: Deserializer<'de>,
227
17
    T: TryFrom<u64>,
228
17
    <T as TryFrom<u64>>::Error: fmt::Display,
229
17
{
230
    struct DurationVisitor<T: TryFrom<u64>>(PhantomData<T>);
231
232
    impl<'de, T> Visitor<'de> for DurationVisitor<T>
233
    where
234
        T: TryFrom<u64>,
235
        <T as TryFrom<u64>>::Error: fmt::Display,
236
    {
237
        type Value = T;
238
239
1
        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
240
1
            formatter.write_str("either a number of seconds as an integer, or a string with a duration format (e.g., \"1h2m3s\", \"30m\", \"1d\")")
241
1
        }
242
243
0
        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
244
0
            T::try_from(v).map_err(de::Error::custom)
245
0
        }
246
247
6
        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
248
6
            if v < 0 {
  Branch (248:16): [Folded - Ignored]
  Branch (248:16): [True: 1, False: 5]
  Branch (248:16): [Folded - Ignored]
249
1
                return Err(de::Error::custom("Negative duration is not allowed"));
250
5
            }
251
5
            T::try_from(v as u64).map_err(de::Error::custom)
252
6
        }
253
254
0
        fn visit_u128<E: de::Error>(self, v: u128) -> Result<Self::Value, E> {
255
0
            let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
256
0
            T::try_from(v_u64).map_err(de::Error::custom)
257
0
        }
258
259
0
        fn visit_i128<E: de::Error>(self, v: i128) -> Result<Self::Value, E> {
260
0
            if v < 0 {
  Branch (260:16): [Folded - Ignored]
  Branch (260:16): [Folded - Ignored]
261
0
                return Err(de::Error::custom("Negative duration is not allowed"));
262
0
            }
263
0
            let v_u64 = u64::try_from(v).map_err(de::Error::custom)?;
264
0
            T::try_from(v_u64).map_err(de::Error::custom)
265
0
        }
266
267
10
        fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
268
10
            let expanded = shellexpand::env(v).map_err(de::Error::custom)
?0
;
269
10
            let expanded = expanded.as_ref().trim();
270
10
            let 
duration8
= parse_duration(expanded).map_err(de::Error::custom)
?2
;
271
8
            let secs = duration.as_secs();
272
8
            T::try_from(secs).map_err(de::Error::custom)
273
10
        }
274
    }
275
276
17
    deserializer.deserialize_any(DurationVisitor::<T>(PhantomData))
277
17
}