Coverage Report

Created: 2024-11-20 10:13

/build/source/nativelink-store/src/shard_store.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::hash::{DefaultHasher, Hasher};
16
use std::ops::BitXor;
17
use std::pin::Pin;
18
use std::sync::Arc;
19
20
use async_trait::async_trait;
21
use futures::stream::{FuturesUnordered, TryStreamExt};
22
use nativelink_config::stores::ShardSpec;
23
use nativelink_error::{error_if, Error, ResultExt};
24
use nativelink_metric::MetricsComponent;
25
use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf};
26
use nativelink_util::health_utils::{default_health_status_indicator, HealthStatusIndicator};
27
use nativelink_util::store_trait::{Store, StoreDriver, StoreKey, StoreLike, UploadSizeInfo};
28
29
0
#[derive(MetricsComponent)]
30
struct StoreAndWeight {
31
    #[metric(help = "The weight of the store")]
32
    weight: u32,
33
    #[metric(help = "The underlying store")]
34
    store: Store,
35
}
36
37
0
#[derive(MetricsComponent)]
38
pub struct ShardStore {
39
    // The weights will always be in ascending order a specific store is choosen based on the
40
    // the hash of the key hash that is nearest-binary searched using the u32 as the index.
41
    #[metric(
42
        group = "stores",
43
        help = "The weights and stores that are used to determine which store to use"
44
    )]
45
    weights_and_stores: Vec<StoreAndWeight>,
46
}
47
48
impl ShardStore {
49
15
    pub fn new(spec: &ShardSpec, stores: Vec<Store>) -> Result<Arc<Self>, Error> {
50
0
        error_if!(
51
15
            spec.stores.len() != stores.len(),
  Branch (51:13): [True: 0, False: 15]
  Branch (51:13): [Folded - Ignored]
52
            "Config shards do not match stores length"
53
        );
54
0
        error_if!(
55
15
            spec.stores.is_empty(),
  Branch (55:13): [True: 0, False: 15]
  Branch (55:13): [Folded - Ignored]
56
            "ShardStore must have at least one store"
57
        );
58
15
        let total_weight: u64 = spec
59
15
            .stores
60
15
            .iter()
61
50
            .map(|shard_config| u64::from(shard_config.weight.unwrap_or(1)))
62
15
            .sum();
63
15
        let mut weights: Vec<u32> = spec
64
15
            .stores
65
15
            .iter()
66
50
            .map(|shard_config| {
67
50
                (u64::from(u32::MAX) * u64::from(shard_config.weight.unwrap_or(1)) / total_weight)
68
50
                    as u32
69
50
            })
70
50
            .scan(0, |state, weight| {
71
50
                *state += weight;
72
50
                Some(*state)
73
50
            })
74
15
            .collect();
75
15
        // Our last item should always be the max.
76
15
        *weights.last_mut().unwrap() = u32::MAX;
77
15
        Ok(Arc::new(Self {
78
15
            weights_and_stores: weights
79
15
                .into_iter()
80
15
                .zip(stores)
81
50
                .map(|(weight, store)| StoreAndWeight { weight, store })
82
15
                .collect(),
83
15
        }))
84
15
    }
85
86
5.01k
    fn get_store_index(&self, store_key: &StoreKey) -> usize {
87
5.01k
        let key = match store_key {
88
5.01k
            StoreKey::Digest(digest) => {
89
5.01k
                // Quote from std primitive array documentation:
90
5.01k
                //     Array’s try_from(slice) implementations (and the corresponding slice.try_into()
91
5.01k
                //     array implementations) succeed if the input slice length is the same as the result
92
5.01k
                //     array length. They optimize especially well when the optimizer can easily determine
93
5.01k
                //     the slice length, e.g. <[u8; 4]>::try_from(&slice[4..8]).unwrap(). Array implements
94
5.01k
                //     TryFrom returning.
95
5.01k
                let size_bytes = digest.size_bytes().to_le_bytes();
96
5.01k
                0.bitxor(u32::from_le_bytes(
97
5.01k
                    digest.packed_hash()[0..4].try_into().unwrap(),
98
5.01k
                ))
99
5.01k
                .bitxor(u32::from_le_bytes(
100
5.01k
                    digest.packed_hash()[4..8].try_into().unwrap(),
101
5.01k
                ))
102
5.01k
                .bitxor(u32::from_le_bytes(
103
5.01k
                    digest.packed_hash()[8..12].try_into().unwrap(),
104
5.01k
                ))
105
5.01k
                .bitxor(u32::from_le_bytes(
106
5.01k
                    digest.packed_hash()[12..16].try_into().unwrap(),
107
5.01k
                ))
108
5.01k
                .bitxor(u32::from_le_bytes(
109
5.01k
                    digest.packed_hash()[16..20].try_into().unwrap(),
110
5.01k
                ))
111
5.01k
                .bitxor(u32::from_le_bytes(
112
5.01k
                    digest.packed_hash()[20..24].try_into().unwrap(),
113
5.01k
                ))
114
5.01k
                .bitxor(u32::from_le_bytes(
115
5.01k
                    digest.packed_hash()[24..28].try_into().unwrap(),
116
5.01k
                ))
117
5.01k
                .bitxor(u32::from_le_bytes(
118
5.01k
                    digest.packed_hash()[28..32].try_into().unwrap(),
119
5.01k
                ))
120
5.01k
                .bitxor(u32::from_le_bytes(size_bytes[0..4].try_into().unwrap()))
121
5.01k
                .bitxor(u32::from_le_bytes(size_bytes[4..8].try_into().unwrap()))
122
            }
123
0
            StoreKey::Str(s) => {
124
0
                let mut hasher = DefaultHasher::new();
125
0
                hasher.write(s.as_bytes());
126
0
                let key_u64 = hasher.finish();
127
0
                (key_u64 >> 32) as u32 // We only need the top 32 bits.
128
            }
129
        };
130
5.01k
        self.weights_and_stores
131
15.0k
            .binary_search_by_key(&key, |item| item.weight)
132
5.01k
            .unwrap_or_else(|index| index)
133
5.01k
    }
134
135
5.00k
    fn get_store(&self, key: &StoreKey) -> &Store {
136
5.00k
        let index = self.get_store_index(key);
137
5.00k
        &self.weights_and_stores[index].store
138
5.00k
    }
139
}
140
141
#[async_trait]
142
impl StoreDriver for ShardStore {
143
    async fn has_with_results(
144
        self: Pin<&Self>,
145
        keys: &[StoreKey<'_>],
146
        results: &mut [Option<u64>],
147
6
    ) -> Result<(), Error> {
148
6
        if keys.len() == 1 {
  Branch (148:12): [True: 3, False: 3]
  Branch (148:12): [Folded - Ignored]
149
            // Hot path: It is very common to lookup only one key.
150
3
            let store_idx = self.get_store_index(&keys[0]);
151
3
            let store = &self.weights_and_stores[store_idx].store;
152
3
            return store
153
3
                .has_with_results(keys, results)
154
0
                .await
155
3
                .err_tip(|| 
"In ShardStore::has_with_results() for store {store_idx}}"0
);
156
3
        }
157
        type KeyIdxVec = Vec<usize>;
158
        type KeyVec<'a> = Vec<StoreKey<'a>>;
159
3
        let mut keys_for_store: Vec<(KeyIdxVec, KeyVec)> = self
160
3
            .weights_and_stores
161
3
            .iter()
162
6
            .map(|_| (Vec::new(), Vec::new()))
163
3
            .collect();
164
3
        // Bucket each key into the store that it belongs to.
165
3
        keys.iter()
166
3
            .enumerate()
167
6
            .map(|(key_idx, key)| (key, key_idx, self.get_store_index(key)))
168
6
            .for_each(|(key, key_idx, store_idx)| {
169
6
                keys_for_store[store_idx].0.push(key_idx);
170
6
                keys_for_store[store_idx].1.push(key.borrow());
171
6
            });
172
3
173
3
        // Build all our futures for each store.
174
3
        let mut future_stream: FuturesUnordered<_> = keys_for_store
175
3
            .into_iter()
176
3
            .enumerate()
177
6
            .map(|(store_idx, (key_idxs, keys))| async move {
178
6
                let store = &self.weights_and_stores[store_idx].store;
179
6
                let mut inner_results = vec![None; keys.len()];
180
6
                store
181
6
                    .has_with_results(&keys, &mut inner_results)
182
0
                    .await
183
6
                    .err_tip(|| 
"In ShardStore::has_with_results() for store {store_idx}"0
)
?0
;
184
6
                Result::<_, Error>::Ok((key_idxs, inner_results))
185
12
            })
186
3
            .collect();
187
188
        // Wait for all the stores to finish and populate our output results.
189
9
        while let Some((
key_idxs, inner_results6
)) = future_stream.try_next().
await0
?0
{
  Branch (189:19): [True: 6, False: 3]
  Branch (189:19): [Folded - Ignored]
190
6
            for (key_idx, inner_result) in key_idxs.into_iter().zip(inner_results) {
191
6
                results[key_idx] = inner_result;
192
6
            }
193
        }
194
3
        Ok(())
195
12
    }
196
197
    async fn update(
198
        self: Pin<&Self>,
199
        key: StoreKey<'_>,
200
        reader: DropCloserReadHalf,
201
        size_info: UploadSizeInfo,
202
5.00k
    ) -> Result<(), Error> {
203
5.00k
        let store = self.get_store(&key);
204
5.00k
        store
205
5.00k
            .update(key, reader, size_info)
206
115
            .await
207
5.00k
            .err_tip(|| 
"In ShardStore::update()"0
)
208
10.0k
    }
209
210
    async fn get_part(
211
        self: Pin<&Self>,
212
        key: StoreKey<'_>,
213
        writer: &mut DropCloserWriteHalf,
214
        offset: u64,
215
        length: Option<u64>,
216
3
    ) -> Result<(), Error> {
217
3
        let store = self.get_store(&key);
218
3
        store
219
3
            .get_part(key, writer, offset, length)
220
0
            .await
221
3
            .err_tip(|| 
"In ShardStore::get_part()"0
)
222
6
    }
223
224
0
    fn inner_store(&self, key: Option<StoreKey>) -> &'_ dyn StoreDriver {
225
0
        let Some(key) = key else {
  Branch (225:13): [True: 0, False: 0]
  Branch (225:13): [Folded - Ignored]
226
0
            return self;
227
        };
228
0
        let index = self.get_store_index(&key);
229
0
        self.weights_and_stores[index].store.inner_store(Some(key))
230
0
    }
231
232
0
    fn as_any<'a>(&'a self) -> &'a (dyn std::any::Any + Sync + Send + 'static) {
233
0
        self
234
0
    }
235
236
0
    fn as_any_arc(self: Arc<Self>) -> Arc<dyn std::any::Any + Sync + Send + 'static> {
237
0
        self
238
0
    }
239
}
240
241
default_health_status_indicator!(ShardStore);