Coverage Report

Created: 2024-10-22 12:33

/build/source/nativelink-util/src/digest_hasher.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::sync::{Arc, OnceLock};
16
17
use blake3::Hasher as Blake3Hasher;
18
use bytes::BytesMut;
19
use futures::Future;
20
use nativelink_config::stores::ConfigDigestHashFunction;
21
use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt};
22
use nativelink_metric::{
23
    MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent,
24
};
25
use nativelink_proto::build::bazel::remote::execution::v2::digest_function::Value as ProtoDigestFunction;
26
use serde::{Deserialize, Serialize};
27
use sha2::{Digest, Sha256};
28
use tokio::io::{AsyncRead, AsyncReadExt};
29
30
use crate::common::DigestInfo;
31
use crate::origin_context::{ActiveOriginContext, OriginContext};
32
use crate::{fs, make_symbol, spawn_blocking};
33
34
// The symbol can be use to retrieve the active hasher function.
35
// from an `OriginContext`.
36
make_symbol!(ACTIVE_HASHER_FUNC, DigestHasherFunc);
37
38
static DEFAULT_DIGEST_HASHER_FUNC: OnceLock<DigestHasherFunc> = OnceLock::new();
39
40
/// Utility function to make a context with a specific hasher function set.
41
36
pub fn make_ctx_for_hash_func<H>(hasher: H) -> Result<Arc<OriginContext>, Error>
42
36
where
43
36
    H: TryInto<DigestHasherFunc>,
44
36
    H::Error: Into<Error>,
45
36
{
46
36
    let digest_hasher_func = hasher
47
36
        .try_into()
48
36
        .err_tip(|| 
"Could not convert into DigestHasherFunc"0
)
?0
;
49
50
36
    let mut new_ctx = ActiveOriginContext::fork().err_tip(|| 
"In BytestreamServer::inner_write"0
)
?0
;
51
36
    new_ctx.set_value(&ACTIVE_HASHER_FUNC, Arc::new(digest_hasher_func));
52
36
    Ok(Arc::new(new_ctx))
53
36
}
54
55
/// Get the default hasher.
56
26
pub fn default_digest_hasher_func() -> DigestHasherFunc {
57
26
    *DEFAULT_DIGEST_HASHER_FUNC.get_or_init(|| 
DigestHasherFunc::Sha2563
)
58
26
}
59
60
/// Sets the default hasher to use if no hasher was requested by the client.
61
0
pub fn set_default_digest_hasher_func(hasher: DigestHasherFunc) -> Result<(), Error> {
62
0
    DEFAULT_DIGEST_HASHER_FUNC
63
0
        .set(hasher)
64
0
        .map_err(|_| make_err!(Code::Internal, "default_digest_hasher_func already set"))
65
0
}
66
67
/// Supported digest hash functions.
68
4
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize, Deserialize)]
69
pub enum DigestHasherFunc {
70
    Sha256,
71
    Blake3,
72
}
73
74
impl MetricsComponent for DigestHasherFunc {
75
0
    fn publish(
76
0
        &self,
77
0
        kind: MetricKind,
78
0
        field_metadata: MetricFieldData,
79
0
    ) -> Result<MetricPublishKnownKindData, nativelink_metric::Error> {
80
0
        format!("{self:?}").publish(kind, field_metadata)
81
0
    }
82
}
83
84
impl DigestHasherFunc {
85
5.08k
    pub fn hasher(&self) -> DigestHasherImpl {
86
5.08k
        self.into()
87
5.08k
    }
88
89
    #[must_use]
90
37
    pub const fn proto_digest_func(&self) -> ProtoDigestFunction {
91
37
        match self {
92
35
            Self::Sha256 => ProtoDigestFunction::Sha256,
93
2
            Self::Blake3 => ProtoDigestFunction::Blake3,
94
        }
95
37
    }
96
}
97
98
impl From<ConfigDigestHashFunction> for DigestHasherFunc {
99
0
    fn from(value: ConfigDigestHashFunction) -> Self {
100
0
        match value {
101
0
            ConfigDigestHashFunction::sha256 => Self::Sha256,
102
0
            ConfigDigestHashFunction::blake3 => Self::Blake3,
103
        }
104
0
    }
105
}
106
107
impl TryFrom<ProtoDigestFunction> for DigestHasherFunc {
108
    type Error = Error;
109
110
0
    fn try_from(value: ProtoDigestFunction) -> Result<Self, Self::Error> {
111
0
        match value {
112
0
            ProtoDigestFunction::Sha256 => Ok(Self::Sha256),
113
0
            ProtoDigestFunction::Blake3 => Ok(Self::Blake3),
114
0
            v => Err(make_input_err!(
115
0
                "Unknown or unsupported digest function for proto conversion {v:?}"
116
0
            )),
117
        }
118
0
    }
119
}
120
121
impl TryFrom<&str> for DigestHasherFunc {
122
    type Error = Error;
123
124
0
    fn try_from(value: &str) -> Result<Self, Self::Error> {
125
0
        match value.to_uppercase().as_str() {
126
0
            "SHA256" => Ok(Self::Sha256),
127
0
            "BLAKE3" => Ok(Self::Blake3),
128
0
            v => Err(make_input_err!(
129
0
                "Unknown or unsupported digest function for string conversion: {v:?}"
130
0
            )),
131
        }
132
0
    }
133
}
134
135
impl std::fmt::Display for DigestHasherFunc {
136
3
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137
3
        match self {
138
3
            DigestHasherFunc::Sha256 => write!(f, "SHA256"),
139
0
            DigestHasherFunc::Blake3 => write!(f, "BLAKE3"),
140
        }
141
3
    }
142
}
143
144
impl TryFrom<i32> for DigestHasherFunc {
145
    type Error = Error;
146
147
28
    fn try_from(value: i32) -> Result<Self, Self::Error> {
148
28
        // Zero means not-set.
149
28
        if value == 0 {
  Branch (149:12): [True: 6, False: 22]
  Branch (149:12): [Folded - Ignored]
150
6
            return Ok(default_digest_hasher_func());
151
22
        }
152
22
        match ProtoDigestFunction::try_from(value) {
153
19
            Ok(ProtoDigestFunction::Sha256) => Ok(Self::Sha256),
154
3
            Ok(ProtoDigestFunction::Blake3) => Ok(Self::Blake3),
155
0
            value => Err(make_input_err!(
156
0
                "Unknown or unsupported digest function for int conversion: {:?}",
157
0
                value.map(|v| v.as_str_name())
158
0
            )),
159
        }
160
28
    }
161
}
162
163
impl From<&DigestHasherFunc> for DigestHasherImpl {
164
5.08k
    fn from(value: &DigestHasherFunc) -> Self {
165
5.08k
        let hash_func_impl = match value {
166
48
            DigestHasherFunc::Sha256 => DigestHasherFuncImpl::Sha256(Sha256::new()),
167
5.03k
            DigestHasherFunc::Blake3 => DigestHasherFuncImpl::Blake3(Box::new(Blake3Hasher::new())),
168
        };
169
5.08k
        Self {
170
5.08k
            hashed_size: 0,
171
5.08k
            hash_func_impl,
172
5.08k
        }
173
5.08k
    }
174
}
175
176
/// Wrapper to compute a hash of arbitrary data.
177
pub trait DigestHasher {
178
    /// Update the hasher with some additional data.
179
    fn update(&mut self, input: &[u8]);
180
181
    /// Finalize the hash function and collect the results into a digest.
182
    fn finalize_digest(&mut self) -> DigestInfo;
183
184
    /// Specialized version of the hashing function that is optimized for
185
    /// handling files. These optimizations take into account things like,
186
    /// the file size and the hasher algorithm to decide how to best process
187
    /// the file and feed it into the hasher.
188
    fn digest_for_file(
189
        self,
190
        file: fs::ResumeableFileSlot,
191
        size_hint: Option<u64>,
192
    ) -> impl Future<Output = Result<(DigestInfo, fs::ResumeableFileSlot), Error>>;
193
194
    /// Utility function to compute a hash from a generic reader.
195
2
    fn compute_from_reader<R: AsyncRead + Unpin + Send>(
196
2
        &mut self,
197
2
        mut reader: R,
198
2
    ) -> impl Future<Output = Result<DigestInfo, Error>> {
199
2
        async move {
200
2
            let mut chunk = BytesMut::with_capacity(fs::DEFAULT_READ_BUFF_SIZE);
201
            loop {
202
4
                reader
203
4
                    .read_buf(&mut chunk)
204
2
                    .await
205
4
                    .err_tip(|| 
"Could not read chunk during compute_from_reader"0
)
?0
;
206
4
                if chunk.is_empty() {
  Branch (206:20): [Folded - Ignored]
  Branch (206:20): [Folded - Ignored]
  Branch (206:20): [True: 0, False: 0]
  Branch (206:20): [True: 2, False: 2]
207
2
                    break; // EOF.
208
2
                }
209
2
                DigestHasher::update(self, &chunk);
210
2
                chunk.clear();
211
            }
212
2
            Ok(DigestHasher::finalize_digest(self))
213
2
        }
214
2
    }
215
}
216
217
pub enum DigestHasherFuncImpl {
218
    Sha256(Sha256),
219
    Blake3(Box<Blake3Hasher>), // Box because Blake3Hasher is 1.3kb in size.
220
}
221
222
/// The individual implementation of the hash function.
223
pub struct DigestHasherImpl {
224
    hashed_size: u64,
225
    hash_func_impl: DigestHasherFuncImpl,
226
}
227
228
impl DigestHasherImpl {
229
    #[inline]
230
2
    async fn hash_file(
231
2
        &mut self,
232
2
        mut file: fs::ResumeableFileSlot,
233
2
    ) -> Result<(DigestInfo, fs::ResumeableFileSlot), Error> {
234
2
        let reader = file.as_reader().
await0
.err_tip(||
"In digest_for_file"0
)
?0
;
235
2
        let digest = self
236
2
            .compute_from_reader(reader)
237
2
            .await
238
2
            .err_tip(|| 
"In digest_for_file"0
)
?0
;
239
2
        Ok((digest, file))
240
2
    }
241
}
242
243
impl DigestHasher for DigestHasherImpl {
244
    #[inline]
245
5.08k
    fn update(&mut self, input: &[u8]) {
246
5.08k
        self.hashed_size += input.len() as u64;
247
5.08k
        match &mut self.hash_func_impl {
248
48
            DigestHasherFuncImpl::Sha256(h) => sha2::digest::Update::update(h, input),
249
5.03k
            DigestHasherFuncImpl::Blake3(h) => {
250
5.03k
                Blake3Hasher::update(h, input);
251
5.03k
            }
252
        }
253
5.08k
    }
254
255
    #[inline]
256
5.08k
    fn finalize_digest(&mut self) -> DigestInfo {
257
5.08k
        let hash = match &mut self.hash_func_impl {
258
48
            DigestHasherFuncImpl::Sha256(h) => h.finalize_reset().into(),
259
5.03k
            DigestHasherFuncImpl::Blake3(h) => h.finalize().into(),
260
        };
261
5.08k
        DigestInfo::new(hash, self.hashed_size)
262
5.08k
    }
263
264
0
    async fn digest_for_file(
265
0
        mut self,
266
0
        mut file: fs::ResumeableFileSlot,
267
0
        size_hint: Option<u64>,
268
2
    ) -> Result<(DigestInfo, fs::ResumeableFileSlot), Error> {
269
2
        let file_position = file
270
2
            .stream_position()
271
2
            .await
272
2
            .err_tip(|| 
"Couldn't get stream position in digest_for_file"0
)
?0
;
273
2
        if file_position != 0 {
  Branch (273:12): [Folded - Ignored]
  Branch (273:12): [Folded - Ignored]
  Branch (273:12): [True: 0, False: 2]
274
0
            return self.hash_file(file).await;
275
2
        }
276
        // If we are a small file, it's faster to just do it the "slow" way.
277
        // Great read: https://github.com/david-slatinek/c-read-vs.-mmap
278
2
        if let Some(size_hint) = size_hint {
  Branch (278:16): [Folded - Ignored]
  Branch (278:16): [Folded - Ignored]
  Branch (278:16): [True: 2, False: 0]
279
2
            if size_hint <= fs::DEFAULT_READ_BUFF_SIZE as u64 {
  Branch (279:16): [Folded - Ignored]
  Branch (279:16): [Folded - Ignored]
  Branch (279:16): [True: 2, False: 0]
280
2
                return self.hash_file(file).await;
281
0
            }
282
0
        }
283
0
        match self.hash_func_impl {
284
0
            DigestHasherFuncImpl::Sha256(_) => self.hash_file(file).await,
285
0
            DigestHasherFuncImpl::Blake3(mut hasher) => {
286
0
                spawn_blocking!("digest_for_file", move || {
287
0
                    hasher.update_mmap(file.get_path()).map_err(|e| {
288
0
                        make_err!(Code::Internal, "Error in blake3's update_mmap: {e:?}")
289
0
                    })?;
290
0
                    Result::<_, Error>::Ok((
291
0
                        DigestInfo::new(hasher.finalize().into(), hasher.count()),
292
0
                        file,
293
0
                    ))
294
0
                })
295
0
                .await
296
0
                .err_tip(|| "Could not spawn blocking task in digest_for_file")?
297
            }
298
        }
299
2
    }
300
}