/build/source/nativelink-store/src/shard_store.rs
Line  | Count  | Source  | 
1  |  | // Copyright 2024 The NativeLink Authors. All rights reserved.  | 
2  |  | //  | 
3  |  | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");  | 
4  |  | // you may not use this file except in compliance with the License.  | 
5  |  | // You may obtain a copy of the License at  | 
6  |  | //  | 
7  |  | //    See LICENSE file for details  | 
8  |  | //  | 
9  |  | // Unless required by applicable law or agreed to in writing, software  | 
10  |  | // distributed under the License is distributed on an "AS IS" BASIS,  | 
11  |  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
12  |  | // See the License for the specific language governing permissions and  | 
13  |  | // limitations under the License.  | 
14  |  |  | 
15  |  | use core::hash::Hasher;  | 
16  |  | use core::ops::BitXor;  | 
17  |  | use core::pin::Pin;  | 
18  |  | use std::hash::DefaultHasher;  | 
19  |  | use std::sync::Arc;  | 
20  |  |  | 
21  |  | use async_trait::async_trait;  | 
22  |  | use futures::stream::{FuturesUnordered, TryStreamExt}; | 
23  |  | use nativelink_config::stores::ShardSpec;  | 
24  |  | use nativelink_error::{Error, ResultExt, error_if}; | 
25  |  | use nativelink_metric::MetricsComponent;  | 
26  |  | use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf}; | 
27  |  | use nativelink_util::health_utils::{HealthStatusIndicator, default_health_status_indicator}; | 
28  |  | use nativelink_util::store_trait::{ | 
29  |  |     RemoveItemCallback, Store, StoreDriver, StoreKey, StoreLike, UploadSizeInfo,  | 
30  |  | };  | 
31  |  |  | 
32  |  | #[derive(Debug, MetricsComponent)]  | 
33  |  | struct StoreAndWeight { | 
34  |  |     #[metric(help = "The weight of the store")]  | 
35  |  |     weight: u32,  | 
36  |  |     #[metric(help = "The underlying store")]  | 
37  |  |     store: Store,  | 
38  |  | }  | 
39  |  |  | 
40  |  | #[derive(Debug, MetricsComponent)]  | 
41  |  | pub struct ShardStore { | 
42  |  |     // The weights will always be in ascending order a specific store is chosen based on the  | 
43  |  |     // the hash of the key hash that is nearest-binary searched using the u32 as the index.  | 
44  |  |     #[metric(  | 
45  |  |         group = "stores",  | 
46  |  |         help = "The weights and stores that are used to determine which store to use"  | 
47  |  |     )]  | 
48  |  |     weights_and_stores: Vec<StoreAndWeight>,  | 
49  |  | }  | 
50  |  |  | 
51  |  | impl ShardStore { | 
52  | 15  |     pub fn new(spec: &ShardSpec, stores: Vec<Store>) -> Result<Arc<Self>, Error> { | 
53  | 0  |         error_if!(  | 
54  | 15  |             spec.stores.len() != stores.len(),   Branch (54:13): [True: 0, False: 15]
   Branch (54:13): [Folded - Ignored]
  | 
55  |  |             "Config shards do not match stores length"  | 
56  |  |         );  | 
57  | 0  |         error_if!(  | 
58  | 15  |             spec.stores.is_empty(),   Branch (58:13): [True: 0, False: 15]
   Branch (58:13): [Folded - Ignored]
  | 
59  |  |             "ShardStore must have at least one store"  | 
60  |  |         );  | 
61  | 15  |         let total_weight: u64 = spec  | 
62  | 15  |             .stores  | 
63  | 15  |             .iter()  | 
64  | 50  |             .map15 (|shard_config| u64::from(shard_config.weight.unwrap_or(1)))  | 
65  | 15  |             .sum();  | 
66  | 15  |         let mut weights: Vec<u32> = spec  | 
67  | 15  |             .stores  | 
68  | 15  |             .iter()  | 
69  | 50  |             .map15 (|shard_config| {  | 
70  | 50  |                 u32::try_from(  | 
71  | 50  |                     u64::from(u32::MAX) * u64::from(shard_config.weight.unwrap_or(1))  | 
72  | 50  |                         / total_weight,  | 
73  |  |                 )  | 
74  | 50  |                 .unwrap_or(u32::MAX)  | 
75  | 50  |             })  | 
76  | 50  |             .scan15 (0, |state, weight| {  | 
77  | 50  |                 *state += weight;  | 
78  | 50  |                 Some(*state)  | 
79  | 50  |             })  | 
80  | 15  |             .collect();  | 
81  |  |         // Our last item should always be the max.  | 
82  | 15  |         *weights.last_mut().unwrap() = u32::MAX;  | 
83  | 15  |         Ok(Arc::new(Self { | 
84  | 15  |             weights_and_stores: weights  | 
85  | 15  |                 .into_iter()  | 
86  | 15  |                 .zip(stores)  | 
87  | 50  |                 .map15 (|(weight, store)| StoreAndWeight { weight, store })  | 
88  | 15  |                 .collect(),  | 
89  |  |         }))  | 
90  | 15  |     }  | 
91  |  |  | 
92  | 5.01k  |     fn get_store_index(&self, store_key: &StoreKey) -> usize { | 
93  | 5.01k  |         let key = match store_key { | 
94  | 5.01k  |             StoreKey::Digest(digest) => { | 
95  |  |                 // Quote from std primitive array documentation:  | 
96  |  |                 //     Array’s try_from(slice) implementations (and the corresponding slice.try_into()  | 
97  |  |                 //     array implementations) succeed if the input slice length is the same as the result  | 
98  |  |                 //     array length. They optimize especially well when the optimizer can easily determine  | 
99  |  |                 //     the slice length, e.g. <[u8; 4]>::try_from(&slice[4..8]).unwrap(). Array implements  | 
100  |  |                 //     TryFrom returning.  | 
101  | 5.01k  |                 let size_bytes = digest.size_bytes().to_le_bytes();  | 
102  | 5.01k  |                 0.bitxor(u32::from_le_bytes(  | 
103  | 5.01k  |                     digest.packed_hash()[0..4].try_into().unwrap(),  | 
104  |  |                 ))  | 
105  | 5.01k  |                 .bitxor(u32::from_le_bytes(  | 
106  | 5.01k  |                     digest.packed_hash()[4..8].try_into().unwrap(),  | 
107  |  |                 ))  | 
108  | 5.01k  |                 .bitxor(u32::from_le_bytes(  | 
109  | 5.01k  |                     digest.packed_hash()[8..12].try_into().unwrap(),  | 
110  |  |                 ))  | 
111  | 5.01k  |                 .bitxor(u32::from_le_bytes(  | 
112  | 5.01k  |                     digest.packed_hash()[12..16].try_into().unwrap(),  | 
113  |  |                 ))  | 
114  | 5.01k  |                 .bitxor(u32::from_le_bytes(  | 
115  | 5.01k  |                     digest.packed_hash()[16..20].try_into().unwrap(),  | 
116  |  |                 ))  | 
117  | 5.01k  |                 .bitxor(u32::from_le_bytes(  | 
118  | 5.01k  |                     digest.packed_hash()[20..24].try_into().unwrap(),  | 
119  |  |                 ))  | 
120  | 5.01k  |                 .bitxor(u32::from_le_bytes(  | 
121  | 5.01k  |                     digest.packed_hash()[24..28].try_into().unwrap(),  | 
122  |  |                 ))  | 
123  | 5.01k  |                 .bitxor(u32::from_le_bytes(  | 
124  | 5.01k  |                     digest.packed_hash()[28..32].try_into().unwrap(),  | 
125  |  |                 ))  | 
126  | 5.01k  |                 .bitxor(u32::from_le_bytes(size_bytes[0..4].try_into().unwrap()))  | 
127  | 5.01k  |                 .bitxor(u32::from_le_bytes(size_bytes[4..8].try_into().unwrap()))  | 
128  |  |             }  | 
129  | 0  |             StoreKey::Str(s) => { | 
130  | 0  |                 let mut hasher = DefaultHasher::new();  | 
131  | 0  |                 hasher.write(s.as_bytes());  | 
132  | 0  |                 let key_u64 = hasher.finish();  | 
133  | 0  |                 (key_u64 >> 32) as u32 // We only need the top 32 bits.  | 
134  |  |             }  | 
135  |  |         };  | 
136  | 5.01k  |         self.weights_and_stores  | 
137  | 5.01k  |             .binary_search_by_key(&key, |item| item.weight)  | 
138  | 5.01k  |             .unwrap_or_else(|index| index)  | 
139  | 5.01k  |     }  | 
140  |  |  | 
141  | 5.00k  |     fn get_store(&self, key: &StoreKey) -> &Store { | 
142  | 5.00k  |         let index = self.get_store_index(key);  | 
143  | 5.00k  |         &self.weights_and_stores[index].store  | 
144  | 5.00k  |     }  | 
145  |  | }  | 
146  |  |  | 
147  |  | #[async_trait]  | 
148  |  | impl StoreDriver for ShardStore { | 
149  |  |     async fn has_with_results(  | 
150  |  |         self: Pin<&Self>,  | 
151  |  |         keys: &[StoreKey<'_>],  | 
152  |  |         results: &mut [Option<u64>],  | 
153  | 6  |     ) -> Result<(), Error> { | 
154  |  |         type KeyIdxVec = Vec<usize>;  | 
155  |  |         type KeyVec<'a> = Vec<StoreKey<'a>>;  | 
156  |  |  | 
157  |  |         if keys.len() == 1 { | 
158  |  |             // Hot path: It is very common to lookup only one key.  | 
159  |  |             let store_idx = self.get_store_index(&keys[0]);  | 
160  |  |             let store = &self.weights_and_stores[store_idx].store;  | 
161  |  |             return store  | 
162  |  |                 .has_with_results(keys, results)  | 
163  |  |                 .await  | 
164  |  |                 .err_tip(|| "In ShardStore::has_with_results() for store {store_idx}}"); | 
165  |  |         }  | 
166  |  |         let mut keys_for_store: Vec<(KeyIdxVec, KeyVec)> = self  | 
167  |  |             .weights_and_stores  | 
168  |  |             .iter()  | 
169  | 6  |             .map(|_| (Vec::new(), Vec::new()))  | 
170  |  |             .collect();  | 
171  |  |         // Bucket each key into the store that it belongs to.  | 
172  |  |         keys.iter()  | 
173  |  |             .enumerate()  | 
174  | 6  |             .map(|(key_idx, key)| (key, key_idx, self.get_store_index(key)))  | 
175  | 6  |             .for_each(|(key, key_idx, store_idx)| { | 
176  | 6  |                 keys_for_store[store_idx].0.push(key_idx);  | 
177  | 6  |                 keys_for_store[store_idx].1.push(key.borrow());  | 
178  | 6  |             });  | 
179  |  |  | 
180  |  |         // Build all our futures for each store.  | 
181  |  |         let mut future_stream: FuturesUnordered<_> = keys_for_store  | 
182  |  |             .into_iter()  | 
183  |  |             .enumerate()  | 
184  | 6  |             .map(|(store_idx, (key_idxs, keys))| async move { | 
185  | 6  |                 let store = &self.weights_and_stores[store_idx].store;  | 
186  | 6  |                 let mut inner_results = vec![None; keys.len()];  | 
187  | 6  |                 store  | 
188  | 6  |                     .has_with_results(&keys, &mut inner_results)  | 
189  | 6  |                     .await  | 
190  | 6  |                     .err_tip(|| "In ShardStore::has_with_results() for store {store_idx}")?0 ; | 
191  | 6  |                 Result::<_, Error>::Ok((key_idxs, inner_results))  | 
192  | 12  |             })  | 
193  |  |             .collect();  | 
194  |  |  | 
195  |  |         // Wait for all the stores to finish and populate our output results.  | 
196  |  |         while let Some((key_idxs, inner_results)) = future_stream.try_next().await? { | 
197  |  |             for (key_idx, inner_result) in key_idxs.into_iter().zip(inner_results) { | 
198  |  |                 results[key_idx] = inner_result;  | 
199  |  |             }  | 
200  |  |         }  | 
201  |  |         Ok(())  | 
202  | 6  |     }  | 
203  |  |  | 
204  |  |     async fn update(  | 
205  |  |         self: Pin<&Self>,  | 
206  |  |         key: StoreKey<'_>,  | 
207  |  |         reader: DropCloserReadHalf,  | 
208  |  |         size_info: UploadSizeInfo,  | 
209  | 5.00k  |     ) -> Result<(), Error> { | 
210  |  |         let store = self.get_store(&key);  | 
211  |  |         store  | 
212  |  |             .update(key, reader, size_info)  | 
213  |  |             .await  | 
214  |  |             .err_tip(|| "In ShardStore::update()")  | 
215  | 5.00k  |     }  | 
216  |  |  | 
217  |  |     async fn get_part(  | 
218  |  |         self: Pin<&Self>,  | 
219  |  |         key: StoreKey<'_>,  | 
220  |  |         writer: &mut DropCloserWriteHalf,  | 
221  |  |         offset: u64,  | 
222  |  |         length: Option<u64>,  | 
223  | 3  |     ) -> Result<(), Error> { | 
224  |  |         let store = self.get_store(&key);  | 
225  |  |         store  | 
226  |  |             .get_part(key, writer, offset, length)  | 
227  |  |             .await  | 
228  |  |             .err_tip(|| "In ShardStore::get_part()")  | 
229  | 3  |     }  | 
230  |  |  | 
231  | 0  |     fn inner_store(&self, key: Option<StoreKey>) -> &'_ dyn StoreDriver { | 
232  | 0  |         let Some(key) = key else {  Branch (232:13): [True: 0, False: 0]
   Branch (232:13): [Folded - Ignored]
  | 
233  | 0  |             return self;  | 
234  |  |         };  | 
235  | 0  |         let index = self.get_store_index(&key);  | 
236  | 0  |         self.weights_and_stores[index].store.inner_store(Some(key))  | 
237  | 0  |     }  | 
238  |  |  | 
239  | 0  |     fn as_any<'a>(&'a self) -> &'a (dyn core::any::Any + Sync + Send + 'static) { | 
240  | 0  |         self  | 
241  | 0  |     }  | 
242  |  |  | 
243  | 0  |     fn as_any_arc(self: Arc<Self>) -> Arc<dyn core::any::Any + Sync + Send + 'static> { | 
244  | 0  |         self  | 
245  | 0  |     }  | 
246  |  |  | 
247  | 0  |     fn register_remove_callback(  | 
248  | 0  |         self: Arc<Self>,  | 
249  | 0  |         callback: Arc<dyn RemoveItemCallback>,  | 
250  | 0  |     ) -> Result<(), Error> { | 
251  | 0  |         for store in &self.weights_and_stores { | 
252  | 0  |             store.store.register_remove_callback(callback.clone())?;  | 
253  |  |         }  | 
254  | 0  |         Ok(())  | 
255  | 0  |     }  | 
256  |  | }  | 
257  |  |  | 
258  |  | default_health_status_indicator!(ShardStore);  |