/build/source/nativelink-store/src/shard_store.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::hash::Hasher; |
16 | | use core::ops::BitXor; |
17 | | use core::pin::Pin; |
18 | | use std::hash::DefaultHasher; |
19 | | use std::sync::Arc; |
20 | | |
21 | | use async_trait::async_trait; |
22 | | use futures::stream::{FuturesUnordered, TryStreamExt}; |
23 | | use nativelink_config::stores::ShardSpec; |
24 | | use nativelink_error::{Error, ResultExt, error_if}; |
25 | | use nativelink_metric::MetricsComponent; |
26 | | use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf}; |
27 | | use nativelink_util::health_utils::{HealthStatusIndicator, default_health_status_indicator}; |
28 | | use nativelink_util::store_trait::{ |
29 | | RemoveItemCallback, Store, StoreDriver, StoreKey, StoreLike, UploadSizeInfo, |
30 | | }; |
31 | | |
32 | | #[derive(Debug, MetricsComponent)] |
33 | | struct StoreAndWeight { |
34 | | #[metric(help = "The weight of the store")] |
35 | | weight: u32, |
36 | | #[metric(help = "The underlying store")] |
37 | | store: Store, |
38 | | } |
39 | | |
40 | | #[derive(Debug, MetricsComponent)] |
41 | | pub struct ShardStore { |
42 | | // The weights will always be in ascending order a specific store is chosen based on the |
43 | | // the hash of the key hash that is nearest-binary searched using the u32 as the index. |
44 | | #[metric( |
45 | | group = "stores", |
46 | | help = "The weights and stores that are used to determine which store to use" |
47 | | )] |
48 | | weights_and_stores: Vec<StoreAndWeight>, |
49 | | } |
50 | | |
51 | | impl ShardStore { |
52 | 15 | pub fn new(spec: &ShardSpec, stores: Vec<Store>) -> Result<Arc<Self>, Error> { |
53 | 0 | error_if!( |
54 | 15 | spec.stores.len() != stores.len(), Branch (54:13): [True: 0, False: 15]
Branch (54:13): [Folded - Ignored]
|
55 | | "Config shards do not match stores length" |
56 | | ); |
57 | 0 | error_if!( |
58 | 15 | spec.stores.is_empty(), Branch (58:13): [True: 0, False: 15]
Branch (58:13): [Folded - Ignored]
|
59 | | "ShardStore must have at least one store" |
60 | | ); |
61 | 15 | let total_weight: u64 = spec |
62 | 15 | .stores |
63 | 15 | .iter() |
64 | 50 | .map15 (|shard_config| u64::from(shard_config.weight.unwrap_or(1))) |
65 | 15 | .sum(); |
66 | 15 | let mut weights: Vec<u32> = spec |
67 | 15 | .stores |
68 | 15 | .iter() |
69 | 50 | .map15 (|shard_config| { |
70 | 50 | u32::try_from( |
71 | 50 | u64::from(u32::MAX) * u64::from(shard_config.weight.unwrap_or(1)) |
72 | 50 | / total_weight, |
73 | | ) |
74 | 50 | .unwrap_or(u32::MAX) |
75 | 50 | }) |
76 | 50 | .scan15 (0, |state, weight| { |
77 | 50 | *state += weight; |
78 | 50 | Some(*state) |
79 | 50 | }) |
80 | 15 | .collect(); |
81 | | // Our last item should always be the max. |
82 | 15 | *weights.last_mut().unwrap() = u32::MAX; |
83 | 15 | Ok(Arc::new(Self { |
84 | 15 | weights_and_stores: weights |
85 | 15 | .into_iter() |
86 | 15 | .zip(stores) |
87 | 50 | .map15 (|(weight, store)| StoreAndWeight { weight, store }) |
88 | 15 | .collect(), |
89 | | })) |
90 | 15 | } |
91 | | |
92 | 5.01k | fn get_store_index(&self, store_key: &StoreKey) -> usize { |
93 | 5.01k | let key = match store_key { |
94 | 5.01k | StoreKey::Digest(digest) => { |
95 | | // Quote from std primitive array documentation: |
96 | | // Array’s try_from(slice) implementations (and the corresponding slice.try_into() |
97 | | // array implementations) succeed if the input slice length is the same as the result |
98 | | // array length. They optimize especially well when the optimizer can easily determine |
99 | | // the slice length, e.g. <[u8; 4]>::try_from(&slice[4..8]).unwrap(). Array implements |
100 | | // TryFrom returning. |
101 | 5.01k | let size_bytes = digest.size_bytes().to_le_bytes(); |
102 | 5.01k | 0.bitxor(u32::from_le_bytes( |
103 | 5.01k | digest.packed_hash()[0..4].try_into().unwrap(), |
104 | | )) |
105 | 5.01k | .bitxor(u32::from_le_bytes( |
106 | 5.01k | digest.packed_hash()[4..8].try_into().unwrap(), |
107 | | )) |
108 | 5.01k | .bitxor(u32::from_le_bytes( |
109 | 5.01k | digest.packed_hash()[8..12].try_into().unwrap(), |
110 | | )) |
111 | 5.01k | .bitxor(u32::from_le_bytes( |
112 | 5.01k | digest.packed_hash()[12..16].try_into().unwrap(), |
113 | | )) |
114 | 5.01k | .bitxor(u32::from_le_bytes( |
115 | 5.01k | digest.packed_hash()[16..20].try_into().unwrap(), |
116 | | )) |
117 | 5.01k | .bitxor(u32::from_le_bytes( |
118 | 5.01k | digest.packed_hash()[20..24].try_into().unwrap(), |
119 | | )) |
120 | 5.01k | .bitxor(u32::from_le_bytes( |
121 | 5.01k | digest.packed_hash()[24..28].try_into().unwrap(), |
122 | | )) |
123 | 5.01k | .bitxor(u32::from_le_bytes( |
124 | 5.01k | digest.packed_hash()[28..32].try_into().unwrap(), |
125 | | )) |
126 | 5.01k | .bitxor(u32::from_le_bytes(size_bytes[0..4].try_into().unwrap())) |
127 | 5.01k | .bitxor(u32::from_le_bytes(size_bytes[4..8].try_into().unwrap())) |
128 | | } |
129 | 0 | StoreKey::Str(s) => { |
130 | 0 | let mut hasher = DefaultHasher::new(); |
131 | 0 | hasher.write(s.as_bytes()); |
132 | 0 | let key_u64 = hasher.finish(); |
133 | 0 | (key_u64 >> 32) as u32 // We only need the top 32 bits. |
134 | | } |
135 | | }; |
136 | 5.01k | self.weights_and_stores |
137 | 5.01k | .binary_search_by_key(&key, |item| item.weight) |
138 | 5.01k | .unwrap_or_else(|index| index) |
139 | 5.01k | } |
140 | | |
141 | 5.00k | fn get_store(&self, key: &StoreKey) -> &Store { |
142 | 5.00k | let index = self.get_store_index(key); |
143 | 5.00k | &self.weights_and_stores[index].store |
144 | 5.00k | } |
145 | | } |
146 | | |
147 | | #[async_trait] |
148 | | impl StoreDriver for ShardStore { |
149 | | async fn has_with_results( |
150 | | self: Pin<&Self>, |
151 | | keys: &[StoreKey<'_>], |
152 | | results: &mut [Option<u64>], |
153 | 6 | ) -> Result<(), Error> { |
154 | | type KeyIdxVec = Vec<usize>; |
155 | | type KeyVec<'a> = Vec<StoreKey<'a>>; |
156 | | |
157 | | if keys.len() == 1 { |
158 | | // Hot path: It is very common to lookup only one key. |
159 | | let store_idx = self.get_store_index(&keys[0]); |
160 | | let store = &self.weights_and_stores[store_idx].store; |
161 | | return store |
162 | | .has_with_results(keys, results) |
163 | | .await |
164 | | .err_tip(|| "In ShardStore::has_with_results() for store {store_idx}}"); |
165 | | } |
166 | | let mut keys_for_store: Vec<(KeyIdxVec, KeyVec)> = self |
167 | | .weights_and_stores |
168 | | .iter() |
169 | 6 | .map(|_| (Vec::new(), Vec::new())) |
170 | | .collect(); |
171 | | // Bucket each key into the store that it belongs to. |
172 | | keys.iter() |
173 | | .enumerate() |
174 | 6 | .map(|(key_idx, key)| (key, key_idx, self.get_store_index(key))) |
175 | 6 | .for_each(|(key, key_idx, store_idx)| { |
176 | 6 | keys_for_store[store_idx].0.push(key_idx); |
177 | 6 | keys_for_store[store_idx].1.push(key.borrow()); |
178 | 6 | }); |
179 | | |
180 | | // Build all our futures for each store. |
181 | | let mut future_stream: FuturesUnordered<_> = keys_for_store |
182 | | .into_iter() |
183 | | .enumerate() |
184 | 6 | .map(|(store_idx, (key_idxs, keys))| async move { |
185 | 6 | let store = &self.weights_and_stores[store_idx].store; |
186 | 6 | let mut inner_results = vec![None; keys.len()]; |
187 | 6 | store |
188 | 6 | .has_with_results(&keys, &mut inner_results) |
189 | 6 | .await |
190 | 6 | .err_tip(|| "In ShardStore::has_with_results() for store {store_idx}")?0 ; |
191 | 6 | Result::<_, Error>::Ok((key_idxs, inner_results)) |
192 | 12 | }) |
193 | | .collect(); |
194 | | |
195 | | // Wait for all the stores to finish and populate our output results. |
196 | | while let Some((key_idxs, inner_results)) = future_stream.try_next().await? { |
197 | | for (key_idx, inner_result) in key_idxs.into_iter().zip(inner_results) { |
198 | | results[key_idx] = inner_result; |
199 | | } |
200 | | } |
201 | | Ok(()) |
202 | 6 | } |
203 | | |
204 | | async fn update( |
205 | | self: Pin<&Self>, |
206 | | key: StoreKey<'_>, |
207 | | reader: DropCloserReadHalf, |
208 | | size_info: UploadSizeInfo, |
209 | 5.00k | ) -> Result<(), Error> { |
210 | | let store = self.get_store(&key); |
211 | | store |
212 | | .update(key, reader, size_info) |
213 | | .await |
214 | | .err_tip(|| "In ShardStore::update()") |
215 | 5.00k | } |
216 | | |
217 | | async fn get_part( |
218 | | self: Pin<&Self>, |
219 | | key: StoreKey<'_>, |
220 | | writer: &mut DropCloserWriteHalf, |
221 | | offset: u64, |
222 | | length: Option<u64>, |
223 | 3 | ) -> Result<(), Error> { |
224 | | let store = self.get_store(&key); |
225 | | store |
226 | | .get_part(key, writer, offset, length) |
227 | | .await |
228 | | .err_tip(|| "In ShardStore::get_part()") |
229 | 3 | } |
230 | | |
231 | 0 | fn inner_store(&self, key: Option<StoreKey>) -> &'_ dyn StoreDriver { |
232 | 0 | let Some(key) = key else { Branch (232:13): [True: 0, False: 0]
Branch (232:13): [Folded - Ignored]
|
233 | 0 | return self; |
234 | | }; |
235 | 0 | let index = self.get_store_index(&key); |
236 | 0 | self.weights_and_stores[index].store.inner_store(Some(key)) |
237 | 0 | } |
238 | | |
239 | 0 | fn as_any<'a>(&'a self) -> &'a (dyn core::any::Any + Sync + Send + 'static) { |
240 | 0 | self |
241 | 0 | } |
242 | | |
243 | 0 | fn as_any_arc(self: Arc<Self>) -> Arc<dyn core::any::Any + Sync + Send + 'static> { |
244 | 0 | self |
245 | 0 | } |
246 | | |
247 | 0 | fn register_remove_callback( |
248 | 0 | self: Arc<Self>, |
249 | 0 | callback: Arc<dyn RemoveItemCallback>, |
250 | 0 | ) -> Result<(), Error> { |
251 | 0 | for store in &self.weights_and_stores { |
252 | 0 | store.store.register_remove_callback(callback.clone())?; |
253 | | } |
254 | 0 | Ok(()) |
255 | 0 | } |
256 | | } |
257 | | |
258 | | default_health_status_indicator!(ShardStore); |