Coverage Report

Created: 2024-10-22 12:33

/build/source/nativelink-store/src/shard_store.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::hash::{DefaultHasher, Hasher};
16
use std::ops::BitXor;
17
use std::pin::Pin;
18
use std::sync::Arc;
19
20
use async_trait::async_trait;
21
use futures::stream::{FuturesUnordered, TryStreamExt};
22
use nativelink_error::{error_if, Error, ResultExt};
23
use nativelink_metric::MetricsComponent;
24
use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf};
25
use nativelink_util::health_utils::{default_health_status_indicator, HealthStatusIndicator};
26
use nativelink_util::store_trait::{Store, StoreDriver, StoreKey, StoreLike, UploadSizeInfo};
27
28
0
#[derive(MetricsComponent)]
29
struct StoreAndWeight {
30
    #[metric(help = "The weight of the store")]
31
    weight: u32,
32
    #[metric(help = "The underlying store")]
33
    store: Store,
34
}
35
36
0
#[derive(MetricsComponent)]
37
pub struct ShardStore {
38
    // The weights will always be in ascending order a specific store is choosen based on the
39
    // the hash of the key hash that is nearest-binary searched using the u32 as the index.
40
    #[metric(
41
        group = "stores",
42
        help = "The weights and stores that are used to determine which store to use"
43
    )]
44
    weights_and_stores: Vec<StoreAndWeight>,
45
}
46
47
impl ShardStore {
48
15
    pub fn new(
49
15
        config: &nativelink_config::stores::ShardStore,
50
15
        stores: Vec<Store>,
51
15
    ) -> Result<Arc<Self>, Error> {
52
0
        error_if!(
53
15
            config.stores.len() != stores.len(),
  Branch (53:13): [True: 0, False: 15]
  Branch (53:13): [Folded - Ignored]
54
            "Config shards do not match stores length"
55
        );
56
0
        error_if!(
57
15
            config.stores.is_empty(),
  Branch (57:13): [True: 0, False: 15]
  Branch (57:13): [Folded - Ignored]
58
            "ShardStore must have at least one store"
59
        );
60
15
        let total_weight: u64 = config
61
15
            .stores
62
15
            .iter()
63
50
            .map(|shard_config| shard_config.weight.unwrap_or(1) as u64)
64
15
            .sum();
65
15
        let mut weights: Vec<u32> = config
66
15
            .stores
67
15
            .iter()
68
50
            .map(|shard_config| {
69
50
                (u32::MAX as u64 * shard_config.weight.unwrap_or(1) as u64 / total_weight) as u32
70
50
            })
71
50
            .scan(0, |state, weight| {
72
50
                *state += weight;
73
50
                Some(*state)
74
50
            })
75
15
            .collect();
76
15
        // Our last item should always be the max.
77
15
        *weights.last_mut().unwrap() = u32::MAX;
78
15
        Ok(Arc::new(Self {
79
15
            weights_and_stores: weights
80
15
                .into_iter()
81
15
                .zip(stores)
82
50
                .map(|(weight, store)| StoreAndWeight { weight, store })
83
15
                .collect(),
84
15
        }))
85
15
    }
86
87
5.01k
    fn get_store_index(&self, store_key: &StoreKey) -> usize {
88
5.01k
        let key = match store_key {
89
5.01k
            StoreKey::Digest(digest) => {
90
5.01k
                // Quote from std primitive array documentation:
91
5.01k
                //     Array’s try_from(slice) implementations (and the corresponding slice.try_into()
92
5.01k
                //     array implementations) succeed if the input slice length is the same as the result
93
5.01k
                //     array length. They optimize especially well when the optimizer can easily determine
94
5.01k
                //     the slice length, e.g. <[u8; 4]>::try_from(&slice[4..8]).unwrap(). Array implements
95
5.01k
                //     TryFrom returning.
96
5.01k
                let size_bytes = digest.size_bytes().to_le_bytes();
97
5.01k
                0.bitxor(u32::from_le_bytes(
98
5.01k
                    digest.packed_hash()[0..4].try_into().unwrap(),
99
5.01k
                ))
100
5.01k
                .bitxor(u32::from_le_bytes(
101
5.01k
                    digest.packed_hash()[4..8].try_into().unwrap(),
102
5.01k
                ))
103
5.01k
                .bitxor(u32::from_le_bytes(
104
5.01k
                    digest.packed_hash()[8..12].try_into().unwrap(),
105
5.01k
                ))
106
5.01k
                .bitxor(u32::from_le_bytes(
107
5.01k
                    digest.packed_hash()[12..16].try_into().unwrap(),
108
5.01k
                ))
109
5.01k
                .bitxor(u32::from_le_bytes(
110
5.01k
                    digest.packed_hash()[16..20].try_into().unwrap(),
111
5.01k
                ))
112
5.01k
                .bitxor(u32::from_le_bytes(
113
5.01k
                    digest.packed_hash()[20..24].try_into().unwrap(),
114
5.01k
                ))
115
5.01k
                .bitxor(u32::from_le_bytes(
116
5.01k
                    digest.packed_hash()[24..28].try_into().unwrap(),
117
5.01k
                ))
118
5.01k
                .bitxor(u32::from_le_bytes(
119
5.01k
                    digest.packed_hash()[28..32].try_into().unwrap(),
120
5.01k
                ))
121
5.01k
                .bitxor(u32::from_le_bytes(size_bytes[0..4].try_into().unwrap()))
122
5.01k
                .bitxor(u32::from_le_bytes(size_bytes[4..8].try_into().unwrap()))
123
            }
124
0
            StoreKey::Str(s) => {
125
0
                let mut hasher = DefaultHasher::new();
126
0
                hasher.write(s.as_bytes());
127
0
                let key_u64 = hasher.finish();
128
0
                (key_u64 >> 32) as u32 // We only need the top 32 bits.
129
            }
130
        };
131
5.01k
        self.weights_and_stores
132
15.0k
            .binary_search_by_key(&key, |item| item.weight)
133
5.01k
            .unwrap_or_else(|index| index)
134
5.01k
    }
135
136
5.00k
    fn get_store(&self, key: &StoreKey) -> &Store {
137
5.00k
        let index = self.get_store_index(key);
138
5.00k
        &self.weights_and_stores[index].store
139
5.00k
    }
140
}
141
142
#[async_trait]
143
impl StoreDriver for ShardStore {
144
    async fn has_with_results(
145
        self: Pin<&Self>,
146
        keys: &[StoreKey<'_>],
147
        results: &mut [Option<u64>],
148
6
    ) -> Result<(), Error> {
149
6
        if keys.len() == 1 {
  Branch (149:12): [True: 3, False: 3]
  Branch (149:12): [Folded - Ignored]
150
            // Hot path: It is very common to lookup only one key.
151
3
            let store_idx = self.get_store_index(&keys[0]);
152
3
            let store = &self.weights_and_stores[store_idx].store;
153
3
            return store
154
3
                .has_with_results(keys, results)
155
0
                .await
156
3
                .err_tip(|| 
"In ShardStore::has_with_results() for store {store_idx}}"0
);
157
3
        }
158
        type KeyIdxVec = Vec<usize>;
159
        type KeyVec<'a> = Vec<StoreKey<'a>>;
160
3
        let mut keys_for_store: Vec<(KeyIdxVec, KeyVec)> = self
161
3
            .weights_and_stores
162
3
            .iter()
163
6
            .map(|_| (Vec::new(), Vec::new()))
164
3
            .collect();
165
3
        // Bucket each key into the store that it belongs to.
166
3
        keys.iter()
167
3
            .enumerate()
168
6
            .map(|(key_idx, key)| (key, key_idx, self.get_store_index(key)))
169
6
            .for_each(|(key, key_idx, store_idx)| {
170
6
                keys_for_store[store_idx].0.push(key_idx);
171
6
                keys_for_store[store_idx].1.push(key.borrow());
172
6
            });
173
3
174
3
        // Build all our futures for each store.
175
3
        let mut future_stream: FuturesUnordered<_> = keys_for_store
176
3
            .into_iter()
177
3
            .enumerate()
178
6
            .map(|(store_idx, (key_idxs, keys))| async move {
179
6
                let store = &self.weights_and_stores[store_idx].store;
180
6
                let mut inner_results = vec![None; keys.len()];
181
6
                store
182
6
                    .has_with_results(&keys, &mut inner_results)
183
0
                    .await
184
6
                    .err_tip(|| 
"In ShardStore::has_with_results() for store {store_idx}"0
)
?0
;
185
6
                Result::<_, Error>::Ok((key_idxs, inner_results))
186
12
            })
187
3
            .collect();
188
189
        // Wait for all the stores to finish and populate our output results.
190
9
        while let Some((
key_idxs, inner_results6
)) = future_stream.try_next().
await0
?0
{
  Branch (190:19): [True: 6, False: 3]
  Branch (190:19): [Folded - Ignored]
191
6
            for (key_idx, inner_result) in key_idxs.into_iter().zip(inner_results) {
192
6
                results[key_idx] = inner_result;
193
6
            }
194
        }
195
3
        Ok(())
196
12
    }
197
198
    async fn update(
199
        self: Pin<&Self>,
200
        key: StoreKey<'_>,
201
        reader: DropCloserReadHalf,
202
        size_info: UploadSizeInfo,
203
5.00k
    ) -> Result<(), Error> {
204
5.00k
        let store = self.get_store(&key);
205
5.00k
        store
206
5.00k
            .update(key, reader, size_info)
207
115
            .await
208
5.00k
            .err_tip(|| 
"In ShardStore::update()"0
)
209
10.0k
    }
210
211
    async fn get_part(
212
        self: Pin<&Self>,
213
        key: StoreKey<'_>,
214
        writer: &mut DropCloserWriteHalf,
215
        offset: u64,
216
        length: Option<u64>,
217
3
    ) -> Result<(), Error> {
218
3
        let store = self.get_store(&key);
219
3
        store
220
3
            .get_part(key, writer, offset, length)
221
0
            .await
222
3
            .err_tip(|| 
"In ShardStore::get_part()"0
)
223
6
    }
224
225
0
    fn inner_store(&self, key: Option<StoreKey>) -> &'_ dyn StoreDriver {
226
0
        let Some(key) = key else {
  Branch (226:13): [True: 0, False: 0]
  Branch (226:13): [Folded - Ignored]
227
0
            return self;
228
        };
229
0
        let index = self.get_store_index(&key);
230
0
        self.weights_and_stores[index].store.inner_store(Some(key))
231
0
    }
232
233
0
    fn as_any<'a>(&'a self) -> &'a (dyn std::any::Any + Sync + Send + 'static) {
234
0
        self
235
0
    }
236
237
0
    fn as_any_arc(self: Arc<Self>) -> Arc<dyn std::any::Any + Sync + Send + 'static> {
238
0
        self
239
0
    }
240
}
241
242
default_health_status_indicator!(ShardStore);