/build/source/nativelink-store/src/redis_utils/ft_aggregate.rs
Line | Count | Source |
1 | | // Copyright 2024-2025 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::fmt::Debug; |
16 | | |
17 | | use futures::Stream; |
18 | | use redis::aio::ConnectionLike; |
19 | | use redis::{Arg, ErrorKind, RedisError, Value}; |
20 | | use tracing::error; |
21 | | |
22 | | use crate::redis_utils::aggregate_types::RedisCursorData; |
23 | | use crate::redis_utils::ft_cursor_read::ft_cursor_read; |
24 | | |
25 | | #[derive(Debug)] |
26 | | pub(crate) struct FtAggregateCursor { |
27 | | pub count: u64, |
28 | | pub max_idle: u64, |
29 | | } |
30 | | |
31 | | #[derive(Debug)] |
32 | | pub(crate) struct FtAggregateOptions { |
33 | | pub load: Vec<String>, |
34 | | pub cursor: FtAggregateCursor, |
35 | | pub sort_by: Vec<String>, |
36 | | } |
37 | | |
38 | | /// Calls `FT.AGGREGATE` in redis. redis-rs does not properly support this command |
39 | | /// so we have to manually handle it. |
40 | 25 | pub(crate) async fn ft_aggregate<C>( |
41 | 25 | mut connection_manager: C, |
42 | 25 | index: String, |
43 | 25 | query: String, |
44 | 25 | options: FtAggregateOptions, |
45 | 25 | ) -> Result<impl Stream<Item = Result<Value, RedisError>> + Send, RedisError> |
46 | 25 | where |
47 | 25 | C: ConnectionLike + Send, |
48 | 25 | { |
49 | | struct State<C: ConnectionLike> { |
50 | | connection_manager: C, |
51 | | index: String, |
52 | | data: RedisCursorData, |
53 | | } |
54 | | |
55 | 25 | let mut cmd = redis::cmd("FT.AGGREGATE"); |
56 | 25 | let mut ft_aggregate_cmd = cmd |
57 | 25 | .arg(&index) |
58 | 25 | .arg(&query) |
59 | 25 | .arg("LOAD") |
60 | 25 | .arg(options.load.len()) |
61 | 25 | .arg(&options.load) |
62 | 25 | .arg("WITHCURSOR") |
63 | 25 | .arg("COUNT") |
64 | 25 | .arg(options.cursor.count) |
65 | 25 | .arg("MAXIDLE") |
66 | 25 | .arg(options.cursor.max_idle) |
67 | 25 | .arg("SORTBY") |
68 | 25 | .arg(options.sort_by.len() * 2); |
69 | 47 | for key22 in &options.sort_by { |
70 | 22 | ft_aggregate_cmd = ft_aggregate_cmd.arg(key).arg("ASC"); |
71 | 22 | } |
72 | 25 | let res = ft_aggregate_cmd |
73 | 25 | .query_async::<Value>(&mut connection_manager) |
74 | 25 | .await; |
75 | 25 | let data23 = match res { |
76 | 23 | Ok(d) => d, |
77 | 2 | Err(e) => { |
78 | 2 | let all_args: Vec<_> = ft_aggregate_cmd |
79 | 2 | .args_iter() |
80 | 32 | .map2 (|a| match a { |
81 | 32 | Arg::Simple(bytes) => match str::from_utf8(bytes) { |
82 | 32 | Ok(s) => s.to_string(), |
83 | 0 | Err(_) => format!("{bytes:?}"), |
84 | | }, |
85 | 0 | other => { |
86 | 0 | format!("{other:?}") |
87 | | } |
88 | 32 | }) |
89 | 2 | .collect(); |
90 | 2 | error!( |
91 | | ?e, |
92 | | index, |
93 | | ?query, |
94 | | ?options, |
95 | | ?all_args, |
96 | 2 | "Error calling ft.aggregate" |
97 | | ); |
98 | 2 | return Err(e); |
99 | | } |
100 | | }; |
101 | | |
102 | 23 | let state = State { |
103 | 23 | connection_manager, |
104 | 23 | index, |
105 | 23 | data: data.try_into()?0 , |
106 | | }; |
107 | | |
108 | 23 | Ok(futures::stream::unfold( |
109 | 23 | Some(state), |
110 | 34 | move |maybe_state| async move { |
111 | 34 | let mut state = maybe_state?0 ; |
112 | | loop { |
113 | 34 | if let Some(map13 ) = state.data.data.pop_front() { Branch (113:24): [Folded - Ignored]
Branch (113:24): [Folded - Ignored]
Branch (113:24): [True: 10, False: 18]
Branch (113:24): [True: 3, False: 3]
|
114 | 13 | return Some((Ok(map), Some(state))); |
115 | 21 | } |
116 | 21 | if state.data.cursor == 0 { Branch (116:20): [Folded - Ignored]
Branch (116:20): [Folded - Ignored]
Branch (116:20): [True: 18, False: 0]
Branch (116:20): [True: 3, False: 0]
|
117 | 21 | return None; |
118 | 0 | } |
119 | 0 | let data_res = ft_cursor_read( |
120 | 0 | &mut state.connection_manager, |
121 | 0 | state.index.clone(), |
122 | 0 | state.data.cursor, |
123 | 0 | ) |
124 | 0 | .await; |
125 | 0 | state.data = match data_res { |
126 | 0 | Ok(data) => data, |
127 | 0 | Err(err) => return Some((Err(err), None)), |
128 | | }; |
129 | | } |
130 | 68 | }, |
131 | | )) |
132 | 25 | } |
133 | | |
134 | 22 | fn resp2_data_parse( |
135 | 22 | output: &mut RedisCursorData, |
136 | 22 | results_array: &[Value], |
137 | 22 | ) -> Result<(), RedisError> { |
138 | 22 | let mut results_iter = results_array.iter(); |
139 | 22 | match results_iter.next() { |
140 | 22 | Some(Value::Int(t)) => { |
141 | 22 | output.total = *t; |
142 | 22 | } |
143 | 0 | Some(other) => { |
144 | 0 | error!(?other, "Non-int for first value in ft.aggregate"); |
145 | 0 | return Err(RedisError::from(( |
146 | 0 | ErrorKind::Parse, |
147 | 0 | "Non int for aggregate total", |
148 | 0 | format!("{other:?}"), |
149 | 0 | ))); |
150 | | } |
151 | | None => { |
152 | 0 | error!("No items in results array for ft.aggregate!"); |
153 | 0 | return Err(RedisError::from(( |
154 | 0 | ErrorKind::Parse, |
155 | 0 | "No items in results array for ft.aggregate", |
156 | 0 | ))); |
157 | | } |
158 | | } |
159 | | |
160 | 34 | for item12 in results_iter { |
161 | 12 | match item { |
162 | 12 | Value::Array(items) if items.len() % 2 == 0 => {} Branch (162:36): [True: 12, False: 0]
Branch (162:36): [Folded - Ignored]
|
163 | 0 | other => { |
164 | 0 | error!( |
165 | | ?other, |
166 | 0 | "Expected an array with an even number of items, didn't get it for aggregate value" |
167 | | ); |
168 | 0 | return Err(RedisError::from(( |
169 | 0 | ErrorKind::Parse, |
170 | 0 | "Expected an array with an even number of items, didn't get it for aggregate value", |
171 | 0 | format!("{other:?}"), |
172 | 0 | ))); |
173 | | } |
174 | | } |
175 | | |
176 | 12 | output.data.push_back(item.clone()); |
177 | | } |
178 | 22 | Ok(()) |
179 | 22 | } |
180 | | |
181 | 1 | fn resp3_data_parse( |
182 | 1 | output: &mut RedisCursorData, |
183 | 1 | results_map: &Vec<(Value, Value)>, |
184 | 1 | ) -> Result<(), RedisError> { |
185 | 6 | for (raw_key5 , value5 ) in results_map { |
186 | 5 | let Value::SimpleString(key) = raw_key else { Branch (186:13): [True: 5, False: 0]
Branch (186:13): [Folded - Ignored]
|
187 | 0 | return Err(RedisError::from(( |
188 | 0 | ErrorKind::Parse, |
189 | 0 | "Expected SimpleString keys", |
190 | 0 | format!("{raw_key:?}"), |
191 | 0 | ))); |
192 | | }; |
193 | 5 | match key.as_str() { |
194 | 5 | "attributes" => { |
195 | 1 | let Value::Array(attributes) = value else { Branch (195:21): [True: 1, False: 0]
Branch (195:21): [Folded - Ignored]
|
196 | 0 | return Err(RedisError::from(( |
197 | 0 | ErrorKind::Parse, |
198 | 0 | "Expected array for attributes", |
199 | 0 | format!("{value:?}"), |
200 | 0 | ))); |
201 | | }; |
202 | 1 | if !attributes.is_empty() { Branch (202:20): [True: 0, False: 1]
Branch (202:20): [Folded - Ignored]
|
203 | 0 | return Err(RedisError::from(( |
204 | 0 | ErrorKind::Parse, |
205 | 0 | "Expected empty attributes", |
206 | 0 | format!("{attributes:?}"), |
207 | 0 | ))); |
208 | 1 | } |
209 | | } |
210 | 4 | "format" => { |
211 | 1 | let Value::SimpleString(format) = value else { Branch (211:21): [True: 1, False: 0]
Branch (211:21): [Folded - Ignored]
|
212 | 0 | return Err(RedisError::from(( |
213 | 0 | ErrorKind::Parse, |
214 | 0 | "Expected SimpleString for format", |
215 | 0 | format!("{value:?}"), |
216 | 0 | ))); |
217 | | }; |
218 | 1 | if format.as_str() != "STRING" { Branch (218:20): [True: 0, False: 1]
Branch (218:20): [Folded - Ignored]
|
219 | 0 | return Err(RedisError::from(( |
220 | 0 | ErrorKind::Parse, |
221 | 0 | "Expected STRING format", |
222 | 0 | format.to_string(), |
223 | 0 | ))); |
224 | 1 | } |
225 | | } |
226 | 3 | "results" => { |
227 | 1 | let Value::Array(values) = value else { Branch (227:21): [True: 1, False: 0]
Branch (227:21): [Folded - Ignored]
|
228 | 0 | return Err(RedisError::from(( |
229 | 0 | ErrorKind::Parse, |
230 | 0 | "Expected Array for results", |
231 | 0 | format!("{value:?}"), |
232 | 0 | ))); |
233 | | }; |
234 | 2 | for raw_value1 in values { |
235 | 1 | let Value::Map(value) = raw_value else { Branch (235:25): [True: 1, False: 0]
Branch (235:25): [Folded - Ignored]
|
236 | 0 | return Err(RedisError::from(( |
237 | 0 | ErrorKind::Parse, |
238 | 0 | "Expected list of maps in result", |
239 | 0 | format!("{raw_value:?}"), |
240 | 0 | ))); |
241 | | }; |
242 | 3 | for (raw_map_key2 , raw_map_value2 ) in value { |
243 | 2 | let Value::SimpleString(map_key) = raw_map_key else { Branch (243:29): [True: 2, False: 0]
Branch (243:29): [Folded - Ignored]
|
244 | 0 | return Err(RedisError::from(( |
245 | 0 | ErrorKind::Parse, |
246 | 0 | "Expected SimpleString keys for result maps", |
247 | 0 | format!("{raw_key:?}"), |
248 | 0 | ))); |
249 | | }; |
250 | 2 | match map_key.as_str() { |
251 | 2 | "extra_attributes" => { |
252 | 1 | let Value::Map(extra_attributes_values) = raw_map_value else { Branch (252:37): [True: 1, False: 0]
Branch (252:37): [Folded - Ignored]
|
253 | 0 | return Err(RedisError::from(( |
254 | 0 | ErrorKind::Parse, |
255 | 0 | "Expected Map for extra_attributes", |
256 | 0 | format!("{raw_map_value:?}"), |
257 | 0 | ))); |
258 | | }; |
259 | 1 | let mut output_array = vec![]; |
260 | 3 | for (e_key2 , e_value2 ) in extra_attributes_values { |
261 | 2 | output_array.push(e_key.clone()); |
262 | 2 | output_array.push(e_value.clone()); |
263 | 2 | } |
264 | 1 | output.data.push_back(Value::Array(output_array)); |
265 | | } |
266 | 1 | "values" => { |
267 | 1 | let Value::Array(values_values) = raw_map_value else { Branch (267:37): [True: 1, False: 0]
Branch (267:37): [Folded - Ignored]
|
268 | 0 | return Err(RedisError::from(( |
269 | 0 | ErrorKind::Parse, |
270 | 0 | "Expected Array for values", |
271 | 0 | format!("{raw_map_value:?}"), |
272 | 0 | ))); |
273 | | }; |
274 | 1 | if !values_values.is_empty() { Branch (274:36): [True: 0, False: 1]
Branch (274:36): [Folded - Ignored]
|
275 | 0 | return Err(RedisError::from(( |
276 | 0 | ErrorKind::Parse, |
277 | 0 | "Expected empty values (all in extra_attributes)", |
278 | 0 | format!("{values_values:?}"), |
279 | 0 | ))); |
280 | 1 | } |
281 | | } |
282 | | _ => { |
283 | 0 | return Err(RedisError::from(( |
284 | 0 | ErrorKind::Parse, |
285 | 0 | "Unknown result map key", |
286 | 0 | format!("{map_key:?}"), |
287 | 0 | ))); |
288 | | } |
289 | | } |
290 | | } |
291 | | } |
292 | | } |
293 | 2 | "total_results" => { |
294 | 1 | let Value::Int(total) = value else { Branch (294:21): [True: 1, False: 0]
Branch (294:21): [Folded - Ignored]
|
295 | 0 | return Err(RedisError::from(( |
296 | 0 | ErrorKind::Parse, |
297 | 0 | "Expected int for total_results", |
298 | 0 | format!("{value:?}"), |
299 | 0 | ))); |
300 | | }; |
301 | 1 | output.total = *total; |
302 | | } |
303 | 1 | "warning" => { |
304 | 1 | let Value::Array(warnings) = value else { Branch (304:21): [True: 1, False: 0]
Branch (304:21): [Folded - Ignored]
|
305 | 0 | return Err(RedisError::from(( |
306 | 0 | ErrorKind::Parse, |
307 | 0 | "Expected Array for warning", |
308 | 0 | format!("{value:?}"), |
309 | 0 | ))); |
310 | | }; |
311 | 1 | if !warnings.is_empty() { Branch (311:20): [True: 0, False: 1]
Branch (311:20): [Folded - Ignored]
|
312 | 0 | return Err(RedisError::from(( |
313 | 0 | ErrorKind::Parse, |
314 | 0 | "Expected empty warnings", |
315 | 0 | format!("{warnings:?}"), |
316 | 0 | ))); |
317 | 1 | } |
318 | | } |
319 | | _ => { |
320 | 0 | return Err(RedisError::from(( |
321 | 0 | ErrorKind::Parse, |
322 | 0 | "Unexpected key in ft.aggregate", |
323 | 0 | format!("{key} => {value:?}"), |
324 | 0 | ))); |
325 | | } |
326 | | } |
327 | | } |
328 | 1 | Ok(()) |
329 | 1 | } |
330 | | |
331 | | impl TryFrom<Value> for RedisCursorData { |
332 | | type Error = RedisError; |
333 | 23 | fn try_from(raw_value: Value) -> Result<Self, RedisError> { |
334 | 23 | let Value::Array(value) = raw_value else { Branch (334:13): [True: 23, False: 0]
Branch (334:13): [Folded - Ignored]
|
335 | 0 | error!( |
336 | | ?raw_value, |
337 | 0 | "Bad data in ft.aggregate, expected array at top-level" |
338 | | ); |
339 | 0 | return Err(RedisError::from((ErrorKind::Parse, "Expected array"))); |
340 | | }; |
341 | 23 | if value.len() < 2 { Branch (341:12): [True: 0, False: 23]
Branch (341:12): [Folded - Ignored]
|
342 | 0 | return Err(RedisError::from(( |
343 | 0 | ErrorKind::Parse, |
344 | 0 | "Expected at least 2 elements", |
345 | 0 | ))); |
346 | 23 | } |
347 | 23 | let mut output = Self::default(); |
348 | 23 | let mut value = value.into_iter(); |
349 | 23 | match value.next().unwrap() { |
350 | 22 | Value::Array(d) => resp2_data_parse(&mut output, &d)?0 , |
351 | 1 | Value::Map(d) => resp3_data_parse(&mut output, &d)?0 , |
352 | 0 | other => { |
353 | 0 | error!( |
354 | | ?other, |
355 | 0 | "Bad data in ft.aggregate, expected array for results" |
356 | | ); |
357 | 0 | return Err(RedisError::from(( |
358 | 0 | ErrorKind::Parse, |
359 | 0 | "Non map item", |
360 | 0 | format!("{other:?}"), |
361 | 0 | ))); |
362 | | } |
363 | | } |
364 | 23 | let Value::Int(cursor) = value.next().unwrap() else { Branch (364:13): [True: 23, False: 0]
Branch (364:13): [Folded - Ignored]
|
365 | 0 | return Err(RedisError::from(( |
366 | 0 | ErrorKind::Parse, |
367 | 0 | "Expected integer as last element", |
368 | 0 | ))); |
369 | | }; |
370 | 23 | output.cursor = cursor as u64; |
371 | 23 | Ok(output) |
372 | 23 | } |
373 | | } |