Coverage Report

Created: 2024-11-20 10:13

/build/source/nativelink-util/src/common.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::cmp::Ordering;
16
use std::collections::HashMap;
17
use std::fmt;
18
use std::hash::Hash;
19
use std::io::{Cursor, Write};
20
use std::ops::{Deref, DerefMut};
21
22
use bytes::{Buf, BufMut, Bytes, BytesMut};
23
use nativelink_error::{make_input_err, Error, ResultExt};
24
use nativelink_metric::{
25
    MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent,
26
};
27
use nativelink_proto::build::bazel::remote::execution::v2::Digest;
28
use prost::Message;
29
use serde::de::Visitor;
30
use serde::ser::Error as _;
31
use serde::{Deserialize, Deserializer, Serialize, Serializer};
32
use tracing::{event, Level};
33
34
pub use crate::fs;
35
36
#[derive(Default, Clone, Copy, Eq, PartialEq, Hash)]
37
#[repr(C)]
38
pub struct DigestInfo {
39
    /// Raw hash in packed form.
40
    packed_hash: PackedHash,
41
42
    /// Possibly the size of the digest in bytes.
43
    size_bytes: u64,
44
}
45
46
impl MetricsComponent for DigestInfo {
47
0
    fn publish(
48
0
        &self,
49
0
        _kind: MetricKind,
50
0
        field_metadata: MetricFieldData,
51
0
    ) -> Result<MetricPublishKnownKindData, nativelink_metric::Error> {
52
0
        format!("{self}").publish(MetricKind::String, field_metadata)
53
0
    }
54
}
55
56
impl DigestInfo {
57
5.35k
    pub const fn new(packed_hash: [u8; 32], size_bytes: u64) -> Self {
58
5.35k
        DigestInfo {
59
5.35k
            size_bytes,
60
5.35k
            packed_hash: PackedHash(packed_hash),
61
5.35k
        }
62
5.35k
    }
63
64
5.26k
    pub fn try_new<T>(hash: &str, size_bytes: T) -> Result<Self, Error>
65
5.26k
    where
66
5.26k
        T: TryInto<u64> + std::fmt::Display + Copy,
67
5.26k
    {
68
5.25k
        let packed_hash =
69
5.26k
            PackedHash::from_hex(hash).err_tip(|| 
format!("Invalid sha256 hash: {hash}")9
)
?9
;
70
5.25k
        let size_bytes = size_bytes
71
5.25k
            .try_into()
72
5.25k
            .map_err(|_| 
make_input_err!("Could not convert {} into u64", size_bytes)0
)
?0
;
73
        // The proto `Digest` takes an i64, so to keep compatibility
74
        // we only allow sizes that can fit into an i64.
75
5.25k
        if size_bytes > i64::MAX as u64 {
  Branch (75:12): [True: 0, False: 3]
  Branch (75:12): [True: 0, False: 8]
  Branch (75:12): [True: 0, False: 20]
  Branch (75:12): [True: 0, False: 4]
  Branch (75:12): [True: 0, False: 0]
  Branch (75:12): [True: 0, False: 4.60k]
  Branch (75:12): [Folded - Ignored]
  Branch (75:12): [True: 0, False: 1]
  Branch (75:12): [True: 0, False: 8]
  Branch (75:12): [True: 0, False: 1]
  Branch (75:12): [True: 0, False: 6]
  Branch (75:12): [True: 2, False: 5]
  Branch (75:12): [True: 0, False: 6]
  Branch (75:12): [True: 0, False: 9]
  Branch (75:12): [True: 0, False: 58]
  Branch (75:12): [True: 0, False: 4]
  Branch (75:12): [True: 0, False: 1]
  Branch (75:12): [True: 0, False: 5]
  Branch (75:12): [True: 0, False: 21]
  Branch (75:12): [True: 0, False: 4]
  Branch (75:12): [True: 0, False: 0]
  Branch (75:12): [True: 0, False: 405]
  Branch (75:12): [True: 0, False: 1]
  Branch (75:12): [True: 0, False: 2]
  Branch (75:12): [Folded - Ignored]
  Branch (75:12): [True: 0, False: 6]
  Branch (75:12): [True: 0, False: 6]
  Branch (75:12): [True: 0, False: 17]
  Branch (75:12): [True: 0, False: 3]
  Branch (75:12): [True: 0, False: 3]
  Branch (75:12): [True: 0, False: 4]
  Branch (75:12): [True: 0, False: 3]
  Branch (75:12): [True: 0, False: 13]
  Branch (75:12): [True: 0, False: 12]
  Branch (75:12): [True: 0, False: 10]
76
2
            return Err(make_input_err!(
77
2
                "Size bytes is too large: {} - max: {}",
78
2
                size_bytes,
79
2
                i64::MAX
80
2
            ));
81
5.25k
        }
82
5.25k
        Ok(DigestInfo {
83
5.25k
            packed_hash,
84
5.25k
            size_bytes,
85
5.25k
        })
86
5.26k
    }
87
88
0
    pub const fn zero_digest() -> DigestInfo {
89
0
        DigestInfo {
90
0
            size_bytes: 0,
91
0
            packed_hash: PackedHash::new(),
92
0
        }
93
0
    }
94
95
40.2k
    pub const fn packed_hash(&self) -> &PackedHash {
96
40.2k
        &self.packed_hash
97
40.2k
    }
98
99
90
    pub fn set_packed_hash(&mut self, packed_hash: [u8; 32]) {
100
90
        self.packed_hash = PackedHash(packed_hash);
101
90
    }
102
103
15.9k
    pub const fn size_bytes(&self) -> u64 {
104
15.9k
        self.size_bytes
105
15.9k
    }
106
107
    /// Returns a struct that can turn the `DigestInfo` into a string.
108
697
    const fn stringifier(&self) -> DigestStackStringifier<'_> {
109
697
        DigestStackStringifier::new(self)
110
697
    }
111
}
112
113
/// Counts the number of digits a number needs if it were to be
114
/// converted to a string.
115
0
const fn count_digits(mut num: u64) -> usize {
116
0
    let mut count = 0;
117
0
    while num != 0 {
  Branch (117:11): [Folded - Ignored]
  Branch (117:11): [Folded - Ignored]
118
0
        count += 1;
119
0
        num /= 10;
120
0
    }
121
0
    count
122
0
}
123
124
/// An optimized version of a function that can convert a `DigestInfo`
125
/// into a str on the stack.
126
struct DigestStackStringifier<'a> {
127
    digest: &'a DigestInfo,
128
    /// Buffer that can hold the string representation of the `DigestInfo`.
129
    /// - Hex is '2 * sizeof(PackedHash)'.
130
    /// - Digits can be at most `count_digits(u64::MAX)`.
131
    /// - We also have a hyphen separator.
132
    buf: [u8; std::mem::size_of::<PackedHash>() * 2 + count_digits(u64::MAX) + 1],
133
}
134
135
impl<'a> DigestStackStringifier<'a> {
136
697
    const fn new(digest: &'a DigestInfo) -> Self {
137
697
        DigestStackStringifier {
138
697
            digest,
139
697
            buf: [b'-'; std::mem::size_of::<PackedHash>() * 2 + count_digits(u64::MAX) + 1],
140
697
        }
141
697
    }
142
143
697
    fn as_str(&mut self) -> Result<&str, Error> {
144
        // Populate the buffer and return the amount of bytes written
145
        // to the buffer.
146
697
        let len = {
147
697
            let mut cursor = Cursor::new(&mut self.buf[..]);
148
697
            let hex = self.digest.packed_hash.to_hex().map_err(|e| {
149
0
                make_input_err!(
150
0
                    "Could not convert PackedHash to hex - {e:?} - {:?}",
151
0
                    self.digest
152
0
                )
153
697
            })
?0
;
154
697
            cursor
155
697
                .write_all(&hex)
156
697
                .err_tip(|| 
format!("Could not write hex to buffer - {hex:?} - {hex:?}",)0
)
?0
;
157
            // Note: We already have a hyphen at this point because we
158
            // initialized the buffer with hyphens.
159
697
            cursor.advance(1);
160
697
            cursor
161
697
                .write_fmt(format_args!("{}", self.digest.size_bytes()))
162
697
                .err_tip(|| 
format!("Could not write size_bytes to buffer - {hex:?}",)0
)
?0
;
163
697
            cursor.position() as usize
164
697
        };
165
697
        // Convert the buffer into utf8 string.
166
697
        std::str::from_utf8(&self.buf[..len]).map_err(|e| {
167
0
            make_input_err!(
168
0
                "Could not convert [u8] to string - {} - {:?} - {:?}",
169
0
                self.digest,
170
0
                self.buf,
171
0
                e,
172
0
            )
173
697
        })
174
697
    }
175
}
176
177
/// Custom serializer for `DigestInfo` because the default Serializer
178
/// would try to encode the data as a byte array, but we use {hex}-{size}.
179
impl Serialize for DigestInfo {
180
221
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
181
221
    where
182
221
        S: Serializer,
183
221
    {
184
221
        let mut stringifier = self.stringifier();
185
221
        serializer.serialize_str(
186
221
            stringifier
187
221
                .as_str()
188
221
                .err_tip(|| 
"During serialization of DigestInfo"0
)
189
221
                .map_err(S::Error::custom)
?0
,
190
        )
191
221
    }
192
}
193
194
/// Custom deserializer for `DigestInfo` because the default Deserializer
195
/// would try to decode the data as a byte array, but we use {hex}-{size}.
196
impl<'de> Deserialize<'de> for DigestInfo {
197
4.61k
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
198
4.61k
    where
199
4.61k
        D: Deserializer<'de>,
200
4.61k
    {
201
        struct DigestInfoVisitor;
202
        impl<'a> Visitor<'a> for DigestInfoVisitor {
203
            type Value = DigestInfo;
204
205
0
            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
206
0
                formatter.write_str("a string representing a DigestInfo")
207
0
            }
208
209
4.61k
            fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
210
4.61k
            where
211
4.61k
                E: serde::de::Error,
212
4.61k
            {
213
4.61k
                let Some((hash, size)) = s.split_once('-') else {
  Branch (213:21): [True: 8, False: 0]
  Branch (213:21): [True: 4.60k, False: 0]
  Branch (213:21): [Folded - Ignored]
  Branch (213:21): [True: 3, False: 0]
  Branch (213:21): [Folded - Ignored]
214
0
                    return Err(E::custom(
215
0
                        "Invalid DigestInfo format, expected '-' separator",
216
0
                    ));
217
                };
218
4.61k
                let size_bytes = size
219
4.61k
                    .parse::<u64>()
220
4.61k
                    .map_err(|e| 
E::custom(format!("Could not parse size_bytes: {e:?}"))0
)
?0
;
221
4.61k
                DigestInfo::try_new(hash, size_bytes)
222
4.61k
                    .map_err(|e| 
E::custom(format!("Could not create DigestInfo: {e:?}"))1
)
223
4.61k
            }
224
        }
225
4.61k
        deserializer.deserialize_str(DigestInfoVisitor)
226
4.61k
    }
227
}
228
229
impl fmt::Display for DigestInfo {
230
445
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231
445
        let mut stringifier = self.stringifier();
232
445
        f.write_str(
233
445
            stringifier
234
445
                .as_str()
235
445
                .err_tip(|| 
"During serialization of DigestInfo"0
)
236
445
                .map_err(|e| {
237
0
                    event!(
238
0
                        Level::ERROR,
239
0
                        "Could not convert DigestInfo to string - {e:?}"
240
                    );
241
0
                    fmt::Error
242
445
                })
?0
,
243
        )
244
445
    }
245
}
246
247
impl fmt::Debug for DigestInfo {
248
31
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249
31
        let mut stringifier = self.stringifier();
250
31
        match stringifier.as_str() {
251
31
            Ok(s) => f.debug_tuple("DigestInfo").field(&s).finish(),
252
0
            Err(e) => {
253
0
                event!(
254
0
                    Level::ERROR,
255
0
                    "Could not convert DigestInfo to string - {e:?}"
256
                );
257
0
                Err(fmt::Error)
258
            }
259
        }
260
31
    }
261
}
262
263
impl Ord for DigestInfo {
264
0
    fn cmp(&self, other: &Self) -> Ordering {
265
0
        self.packed_hash
266
0
            .cmp(&other.packed_hash)
267
0
            .then_with(|| self.size_bytes.cmp(&other.size_bytes))
268
0
    }
269
}
270
271
impl PartialOrd for DigestInfo {
272
0
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
273
0
        Some(self.cmp(other))
274
0
    }
275
}
276
277
impl TryFrom<Digest> for DigestInfo {
278
    type Error = Error;
279
280
155
    fn try_from(digest: Digest) -> Result<Self, Self::Error> {
281
155
        let 
packed_hash154
= PackedHash::from_hex(&digest.hash)
282
155
            .err_tip(|| 
format!("Invalid sha256 hash: {}", digest.hash)1
)
?1
;
283
154
        let size_bytes = digest
284
154
            .size_bytes
285
154
            .try_into()
286
154
            .map_err(|_| 
make_input_err!("Could not convert {} into u64", digest.size_bytes)0
)
?0
;
287
154
        Ok(DigestInfo {
288
154
            packed_hash,
289
154
            size_bytes,
290
154
        })
291
155
    }
292
}
293
294
impl TryFrom<&Digest> for DigestInfo {
295
    type Error = Error;
296
297
0
    fn try_from(digest: &Digest) -> Result<Self, Self::Error> {
298
0
        let packed_hash = PackedHash::from_hex(&digest.hash)
299
0
            .err_tip(|| format!("Invalid sha256 hash: {}", digest.hash))?;
300
0
        let size_bytes = digest
301
0
            .size_bytes
302
0
            .try_into()
303
0
            .map_err(|_| make_input_err!("Could not convert {} into u64", digest.size_bytes))?;
304
0
        Ok(DigestInfo {
305
0
            packed_hash,
306
0
            size_bytes,
307
0
        })
308
0
    }
309
}
310
311
impl From<DigestInfo> for Digest {
312
190
    fn from(val: DigestInfo) -> Self {
313
190
        Digest {
314
190
            hash: val.packed_hash.to_string(),
315
190
            size_bytes: val.size_bytes.try_into().unwrap_or_else(|e| {
316
0
                event!(
317
0
                    Level::ERROR,
318
0
                    "Could not convert {} into u64 - {e:?}",
319
                    val.size_bytes
320
                );
321
                // This is a placeholder value that can help a user identify
322
                // that the conversion failed.
323
0
                -255
324
190
            }),
325
190
        }
326
190
    }
327
}
328
329
impl From<&DigestInfo> for Digest {
330
7
    fn from(val: &DigestInfo) -> Self {
331
7
        Digest {
332
7
            hash: val.packed_hash.to_string(),
333
7
            size_bytes: val.size_bytes.try_into().unwrap_or_else(|e| {
334
0
                event!(
335
0
                    Level::ERROR,
336
0
                    "Could not convert {} into u64 - {e:?}",
337
                    val.size_bytes
338
                );
339
                // This is a placeholder value that can help a user identify
340
                // that the conversion failed.
341
0
                -255
342
7
            }),
343
7
        }
344
7
    }
345
}
346
347
0
#[derive(Serialize, Deserialize, Default, Clone, Copy, Eq, PartialEq, Hash, PartialOrd, Ord)]
348
pub struct PackedHash([u8; 32]);
349
350
const SIZE_OF_PACKED_HASH: usize = 32;
351
impl PackedHash {
352
19
    const fn new() -> Self {
353
19
        PackedHash([0; SIZE_OF_PACKED_HASH])
354
19
    }
355
356
5.42k
    fn from_hex(hash: &str) -> Result<Self, Error> {
357
5.42k
        let mut packed_hash = [0u8; 32];
358
5.42k
        hex::decode_to_slice(hash, &mut packed_hash)
359
5.42k
            .map_err(|e| 
make_input_err!("Invalid sha256 hash: {hash} - {e:?}")10
)
?10
;
360
5.41k
        Ok(PackedHash(packed_hash))
361
5.42k
    }
362
363
    /// Converts the packed hash into a hex string.
364
    #[inline]
365
910
    fn to_hex(self) -> Result<[u8; SIZE_OF_PACKED_HASH * 2], fmt::Error> {
366
910
        let mut hash = [0u8; SIZE_OF_PACKED_HASH * 2];
367
910
        hex::encode_to_slice(self.0, &mut hash).map_err(|e| {
368
0
            event!(
369
0
                Level::ERROR,
370
0
                "Could not convert PackedHash to hex - {e:?} - {:?}",
371
                self.0
372
            );
373
0
            fmt::Error
374
910
        })
?0
;
375
910
        Ok(hash)
376
910
    }
377
}
378
379
impl fmt::Display for PackedHash {
380
213
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381
213
        let hash = self.to_hex()
?0
;
382
213
        match std::str::from_utf8(&hash) {
383
213
            Ok(hash) => f.write_str(hash)
?0
,
384
0
            Err(_) => f.write_str(&format!("Could not convert hash to utf8 {:?}", self.0))?,
385
        }
386
213
        Ok(())
387
213
    }
388
}
389
390
impl Deref for PackedHash {
391
    type Target = [u8; 32];
392
393
40.2k
    fn deref(&self) -> &Self::Target {
394
40.2k
        &self.0
395
40.2k
    }
396
}
397
398
impl DerefMut for PackedHash {
399
90
    fn deref_mut(&mut self) -> &mut Self::Target {
400
90
        &mut self.0
401
90
    }
402
}
403
404
// Simple utility trait that makes it easier to apply `.try_map` to Vec.
405
// This will convert one vector into another vector with a different type.
406
pub trait VecExt<T> {
407
    fn try_map<F, U>(self, f: F) -> Result<Vec<U>, Error>
408
    where
409
        Self: Sized,
410
        F: (std::ops::Fn(T) -> Result<U, Error>) + Sized;
411
}
412
413
impl<T> VecExt<T> for Vec<T> {
414
12
    fn try_map<F, U>(self, f: F) -> Result<Vec<U>, Error>
415
12
    where
416
12
        Self: Sized,
417
12
        F: (std::ops::Fn(T) -> Result<U, Error>) + Sized,
418
12
    {
419
12
        let mut output = Vec::with_capacity(self.len());
420
18
        for 
item6
in self {
421
6
            output.push((f)(item)
?0
);
422
        }
423
12
        Ok(output)
424
12
    }
425
}
426
427
// Simple utility trait that makes it easier to apply `.try_map` to HashMap.
428
// This will convert one HashMap into another keeping the key the same, but
429
// different value type.
430
pub trait HashMapExt<K: std::cmp::Eq + std::hash::Hash, T> {
431
    fn try_map<F, U>(self, f: F) -> Result<HashMap<K, U>, Error>
432
    where
433
        Self: Sized,
434
        F: (std::ops::Fn(T) -> Result<U, Error>) + Sized;
435
}
436
437
impl<K: std::cmp::Eq + std::hash::Hash, T> HashMapExt<K, T> for HashMap<K, T> {
438
3
    fn try_map<F, U>(self, f: F) -> Result<HashMap<K, U>, Error>
439
3
    where
440
3
        Self: Sized,
441
3
        F: (std::ops::Fn(T) -> Result<U, Error>) + Sized,
442
3
    {
443
3
        let mut output = HashMap::with_capacity(self.len());
444
5
        for (
k, v2
) in self {
445
2
            output.insert(k, (f)(v)
?0
);
446
        }
447
3
        Ok(output)
448
3
    }
449
}
450
451
// Utility to encode our proto into GRPC stream format.
452
36
pub fn encode_stream_proto<T: Message>(proto: &T) -> Result<Bytes, Box<dyn std::error::Error>> {
453
    // See below comment on spec.
454
    use std::mem::size_of;
455
    const PREFIX_BYTES: usize = size_of::<u8>() + size_of::<u32>();
456
457
36
    let mut buf = BytesMut::new();
458
459
216
    for _ in 0..PREFIX_BYTES {
460
180
        // Advance our buffer first.
461
180
        // We will backfill it once we know the size of the message.
462
180
        buf.put_u8(0);
463
180
    }
464
36
    proto.encode(&mut buf)
?0
;
465
36
    let len = buf.len() - PREFIX_BYTES;
466
36
    {
467
36
        let mut buf = &mut buf[0..PREFIX_BYTES];
468
36
        // See: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#:~:text=Compressed-Flag
469
36
        // for more details on spec.
470
36
        // Compressed-Flag -> 0 / 1 # encoded as 1 byte unsigned integer.
471
36
        buf.put_u8(0);
472
36
        // Message-Length -> {length of Message} # encoded as 4 byte unsigned integer (big endian).
473
36
        buf.put_u32(len as u32);
474
36
        // Message -> *{binary octet}.
475
36
    }
476
36
477
36
    Ok(buf.freeze())
478
36
}