/build/source/nativelink-redis-tester/src/dynamic_fake_redis.rs
Line | Count | Source |
1 | | // Copyright 2026 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::fmt; |
16 | | use std::collections::HashMap; |
17 | | use std::collections::hash_map::Entry; |
18 | | use std::sync::{Arc, Mutex}; |
19 | | |
20 | | use nativelink_util::background_spawn; |
21 | | use redis::Value; |
22 | | use redis_protocol::resp2::decode::decode; |
23 | | use redis_protocol::resp2::types::{OwnedFrame, Resp2Frame}; |
24 | | use tokio::net::TcpListener; |
25 | | use tracing::{debug, info, trace}; |
26 | | |
27 | | use crate::fake_redis::{arg_as_string, fake_redis_internal}; |
28 | | |
29 | | pub trait SubscriptionManagerNotify { |
30 | | fn notify_for_test(&self, value: String); |
31 | | } |
32 | | |
33 | | #[derive(Clone)] |
34 | | pub struct FakeRedisBackend<S: SubscriptionManagerNotify> { |
35 | | /// Contains a list of all of the Redis keys -> fields. |
36 | | pub table: Arc<Mutex<HashMap<String, HashMap<String, Value>>>>, |
37 | | subscription_manager: Arc<Mutex<Option<Arc<S>>>>, |
38 | | } |
39 | | |
40 | | impl<S: SubscriptionManagerNotify + Send + 'static + Sync> Default for FakeRedisBackend<S> { |
41 | 0 | fn default() -> Self { |
42 | 0 | Self::new() |
43 | 0 | } |
44 | | } |
45 | | |
46 | | impl<S: SubscriptionManagerNotify> fmt::Debug for FakeRedisBackend<S> { |
47 | 0 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
48 | 0 | f.debug_struct("FakeRedisBackend").finish() |
49 | 0 | } |
50 | | } |
51 | | |
52 | | const FAKE_SCRIPT_SHA: &str = "b22b9926cbce9dd9ba97fa7ba3626f89feea1ed5"; |
53 | | |
54 | | impl<S: SubscriptionManagerNotify + Send + 'static + Sync> FakeRedisBackend<S> { |
55 | 3 | pub fn new() -> Self { |
56 | 3 | Self { |
57 | 3 | table: Arc::new(Mutex::new(HashMap::new())), |
58 | 3 | subscription_manager: Arc::new(Mutex::new(None)), |
59 | 3 | } |
60 | 3 | } |
61 | | |
62 | 2 | pub fn set_subscription_manager(&self, subscription_manager: Arc<S>) { |
63 | 2 | self.subscription_manager |
64 | 2 | .lock() |
65 | 2 | .unwrap() |
66 | 2 | .replace(subscription_manager); |
67 | 2 | } |
68 | | |
69 | 3 | async fn dynamic_fake_redis(self, listener: TcpListener) { |
70 | 154 | let inner3 = move |buf: &[u8]| -> String { |
71 | 154 | let mut output = String::new(); |
72 | 154 | let mut buf_index = 0; |
73 | | loop { |
74 | 166 | let frame = match decode(&buf[buf_index..]).unwrap() { |
75 | 166 | Some((frame, amt)) => { |
76 | 166 | buf_index += amt; |
77 | 166 | frame |
78 | | } |
79 | | None => { |
80 | 0 | panic!("No frame!"); |
81 | | } |
82 | | }; |
83 | 166 | let (cmd, args) = { |
84 | 166 | if let OwnedFrame::Array(a) = frame { Branch (84:28): [Folded - Ignored]
Branch (84:28): [Folded - Ignored]
Branch (84:28): [True: 166, False: 0]
|
85 | 166 | if let OwnedFrame::BulkString(s) = a.first().unwrap() { Branch (85:32): [Folded - Ignored]
Branch (85:32): [Folded - Ignored]
Branch (85:32): [True: 166, False: 0]
|
86 | 166 | let args: Vec<_> = a[1..].to_vec(); |
87 | 166 | (str::from_utf8(s).unwrap().to_string(), args) |
88 | | } else { |
89 | 0 | panic!("Array not starting with cmd: {a:?}"); |
90 | | } |
91 | | } else { |
92 | 0 | panic!("Non array cmd: {frame:?}"); |
93 | | } |
94 | | }; |
95 | | |
96 | 166 | let ret: Value = match cmd.as_str() { |
97 | 166 | "CLIENT" => { |
98 | | // We can safely ignore these, as it's just setting the library name/version |
99 | 12 | Value::Int(0) |
100 | | } |
101 | 154 | "SCRIPT" => { |
102 | 3 | assert_eq!(args[0], OwnedFrame::BulkString(b"LOAD".to_vec())); |
103 | | |
104 | 3 | let OwnedFrame::BulkString(ref _script) = args[1] else { Branch (104:29): [Folded - Ignored]
Branch (104:29): [Folded - Ignored]
Branch (104:29): [True: 3, False: 0]
|
105 | 0 | panic!("Script should be a bulkstring: {args:?}"); |
106 | | }; |
107 | 3 | Value::SimpleString(FAKE_SCRIPT_SHA.to_string()) |
108 | | } |
109 | | |
110 | 151 | "PSUBSCRIBE" => { |
111 | | // This does nothing at the moment, maybe we need to implement it later. |
112 | 3 | Value::Int(0) |
113 | | } |
114 | | |
115 | 148 | "PUBLISH" => { |
116 | 14 | if let Some(subscription_manager) = Branch (116:32): [Folded - Ignored]
Branch (116:32): [Folded - Ignored]
Branch (116:32): [True: 14, False: 1]
|
117 | 15 | self.subscription_manager.lock().unwrap().as_ref() |
118 | | { |
119 | 14 | subscription_manager.notify_for_test( |
120 | 14 | str::from_utf8(args[1].as_bytes().expect("Notification not bytes")) |
121 | 14 | .expect("Notification not UTF-8") |
122 | 14 | .into(), |
123 | 14 | ); |
124 | 14 | Value::Int(1) |
125 | | } else { |
126 | 1 | Value::Int(0) |
127 | | } |
128 | | } |
129 | | |
130 | 133 | "FT.AGGREGATE" => { |
131 | | // The query is either "*" (match all) or @field:{ value }. |
132 | 20 | let OwnedFrame::BulkString(ref raw_query) = args[1] else { Branch (132:29): [Folded - Ignored]
Branch (132:29): [Folded - Ignored]
Branch (132:29): [True: 20, False: 0]
|
133 | 0 | panic!("Aggregate query should be a string: {args:?}"); |
134 | | }; |
135 | 20 | let query = str::from_utf8(raw_query).unwrap(); |
136 | | // Lazy implementation making assumptions. |
137 | 20 | assert_eq!( |
138 | 20 | args[2..6], |
139 | 20 | vec![ |
140 | 20 | OwnedFrame::BulkString(b"LOAD".to_vec()), |
141 | 20 | OwnedFrame::BulkString(b"2".to_vec()), |
142 | 20 | OwnedFrame::BulkString(b"data".to_vec()), |
143 | 20 | OwnedFrame::BulkString(b"version".to_vec()) |
144 | | ] |
145 | | ); |
146 | 20 | let mut results = vec![Value::Int(0)]; |
147 | | |
148 | 20 | if query == "*" { Branch (148:28): [Folded - Ignored]
Branch (148:28): [Folded - Ignored]
Branch (148:28): [True: 0, False: 20]
|
149 | | // Wildcard query - return all records that have both data and version fields. |
150 | | // Some entries (e.g., from HSET) may not have version field. |
151 | 0 | for fields in self.table.lock().unwrap().values() { |
152 | 0 | if let (Some(data), Some(version)) = Branch (152:40): [Folded - Ignored]
Branch (152:40): [Folded - Ignored]
Branch (152:40): [True: 0, False: 0]
|
153 | 0 | (fields.get("data"), fields.get("version")) |
154 | 0 | { |
155 | 0 | results.push(Value::Array(vec![ |
156 | 0 | Value::BulkString(b"data".to_vec()), |
157 | 0 | data.clone(), |
158 | 0 | Value::BulkString(b"version".to_vec()), |
159 | 0 | version.clone(), |
160 | 0 | ])); |
161 | 0 | } |
162 | | } |
163 | | } else { |
164 | | // Field-specific query: @field:{ value } |
165 | 20 | assert_eq!(&query[..1], "@"); |
166 | 20 | let mut parts = query[1..].split(':'); |
167 | 20 | let field = parts.next().expect("No field name"); |
168 | 20 | let value = parts.next().expect("No value"); |
169 | 20 | let value = value |
170 | 20 | .strip_prefix("{ ") |
171 | 20 | .and_then(|s| s.strip_suffix(" }")) |
172 | 20 | .unwrap_or(value); |
173 | 56 | for fields in self.table.lock().unwrap()20 .values20 () { |
174 | 56 | if let Some(key_value19 ) = fields.get(field) { Branch (174:40): [Folded - Ignored]
Branch (174:40): [Folded - Ignored]
Branch (174:40): [True: 19, False: 37]
|
175 | 19 | if *key_value == Value::BulkString(value.as_bytes().to_vec()) { Branch (175:40): [Folded - Ignored]
Branch (175:40): [Folded - Ignored]
Branch (175:40): [True: 10, False: 9]
|
176 | 10 | results.push(Value::Array(vec![ |
177 | 10 | Value::BulkString(b"data".to_vec()), |
178 | 10 | fields.get("data").expect("No data field").clone(), |
179 | 10 | Value::BulkString(b"version".to_vec()), |
180 | 10 | fields |
181 | 10 | .get("version") |
182 | 10 | .expect("No version field") |
183 | 10 | .clone(), |
184 | 10 | ])); |
185 | 10 | }9 |
186 | 37 | } |
187 | | } |
188 | | } |
189 | | |
190 | 20 | results[0] = |
191 | 20 | Value::Int(i64::try_from(results.len() - 1).unwrap_or(i64::MAX)); |
192 | 20 | Value::Array(vec![ |
193 | 20 | Value::Array(results), |
194 | 20 | Value::Int(0), // Means no more items in cursor. |
195 | 20 | ]) |
196 | | } |
197 | | |
198 | 113 | "EVALSHA" => { |
199 | 16 | assert_eq!( |
200 | 16 | args[0], |
201 | 16 | OwnedFrame::BulkString(FAKE_SCRIPT_SHA.as_bytes().to_vec()) |
202 | | ); |
203 | 16 | assert_eq!(args[1], OwnedFrame::BulkString(b"1".to_vec())); |
204 | 16 | let mut value: HashMap<_, Value> = HashMap::new(); |
205 | 16 | value.insert( |
206 | 16 | "data".into(), |
207 | 16 | Value::BulkString(args[4].as_bytes().unwrap().to_vec()), |
208 | | ); |
209 | 48 | for pair in args[5..]16 .chunks16 (2) { |
210 | 48 | value.insert( |
211 | 48 | str::from_utf8(pair[0].as_bytes().expect("Field name not bytes")) |
212 | 48 | .expect("Unable to parse field name as string") |
213 | 48 | .into(), |
214 | 48 | Value::BulkString(pair[1].as_bytes().unwrap().to_vec()), |
215 | 48 | ); |
216 | 48 | } |
217 | 16 | let mut ret: Option<Value> = None; |
218 | 16 | let key: String = |
219 | 16 | str::from_utf8(args[2].as_bytes().expect("Key not bytes")) |
220 | 16 | .expect("Key cannot be parsed as string") |
221 | 16 | .into(); |
222 | 16 | let expected_existing_version: i64 = |
223 | 16 | str::from_utf8(args[3].as_bytes().unwrap()) |
224 | 16 | .unwrap() |
225 | 16 | .parse() |
226 | 16 | .expect("Unable to parse existing version field"); |
227 | 16 | trace!(%key, %expected_existing_version, ?value, "Want to insert with EVALSHA"); |
228 | 16 | let version = match self.table.lock().unwrap().entry(key.clone()) { |
229 | 13 | Entry::Occupied(mut occupied_entry) => { |
230 | 13 | let version = occupied_entry |
231 | 13 | .get() |
232 | 13 | .get("version") |
233 | 13 | .expect("No version field"); |
234 | 13 | let Value::BulkString(version_bytes) = version else { Branch (234:37): [Folded - Ignored]
Branch (234:37): [Folded - Ignored]
Branch (234:37): [True: 13, False: 0]
|
235 | 0 | panic!("Non-bulkstring version: {version:?}"); |
236 | | }; |
237 | 13 | let version_int: i64 = str::from_utf8(version_bytes) |
238 | 13 | .expect("Version field not valid string") |
239 | 13 | .parse() |
240 | 13 | .expect("Unable to parse version field"); |
241 | 13 | if version_int == expected_existing_version { Branch (241:36): [Folded - Ignored]
Branch (241:36): [Folded - Ignored]
Branch (241:36): [True: 9, False: 4]
|
242 | 9 | let new_version = version_int + 1; |
243 | 9 | debug!(%key, %new_version, "Version update"); |
244 | 9 | value.insert( |
245 | 9 | "version".into(), |
246 | 9 | Value::BulkString( |
247 | 9 | format!("{new_version}").as_bytes().to_vec(), |
248 | 9 | ), |
249 | | ); |
250 | 9 | occupied_entry.insert(value); |
251 | 9 | new_version |
252 | | } else { |
253 | | // Version mismatch. |
254 | 4 | debug!(%key, %version_int, %expected_existing_version, "Version mismatch"); |
255 | 4 | ret = Some(Value::Array(vec![ |
256 | 4 | Value::Int(0), |
257 | 4 | Value::Int(version_int), |
258 | 4 | ])); |
259 | 4 | -1 |
260 | | } |
261 | | } |
262 | 3 | Entry::Vacant(vacant_entry) => { |
263 | 3 | if expected_existing_version != 0 { Branch (263:36): [Folded - Ignored]
Branch (263:36): [Folded - Ignored]
Branch (263:36): [True: 0, False: 3]
|
264 | | // Version mismatch. |
265 | 0 | debug!(%key, %expected_existing_version, "Version mismatch, expected zero"); |
266 | 0 | ret = Some(Value::Array(vec![Value::Int(0), Value::Int(0)])); |
267 | 0 | -1 |
268 | | } else { |
269 | 3 | debug!(%key, "Version insert"); |
270 | 3 | value |
271 | 3 | .insert("version".into(), Value::BulkString(b"1".to_vec())); |
272 | 3 | vacant_entry.insert_entry(value); |
273 | 3 | 1 |
274 | | } |
275 | | } |
276 | | }; |
277 | 16 | if let Some(r4 ) = ret { Branch (277:32): [Folded - Ignored]
Branch (277:32): [Folded - Ignored]
Branch (277:32): [True: 4, False: 12]
|
278 | 4 | r |
279 | | } else { |
280 | 12 | Value::Array(vec![Value::Int(1), Value::Int(version)]) |
281 | | } |
282 | | } |
283 | | |
284 | 97 | "HMSET" => { |
285 | 3 | let mut values = HashMap::new(); |
286 | 3 | assert_eq!( |
287 | 3 | (args.len() - 1).rem_euclid(2), |
288 | | 0, |
289 | 0 | "Non-even args for hmset: {args:?}" |
290 | | ); |
291 | 3 | let chunks = args[1..].chunks_exact(2); |
292 | 6 | for chunk3 in chunks { |
293 | 3 | let [key, value] = chunk else { Branch (293:33): [Folded - Ignored]
Branch (293:33): [Folded - Ignored]
Branch (293:33): [True: 3, False: 0]
|
294 | 0 | panic!("Uneven hmset args"); |
295 | | }; |
296 | 3 | let key_name: String = |
297 | 3 | str::from_utf8(key.as_bytes().expect("Key argument is not bytes")) |
298 | 3 | .expect("Unable to parse key as string") |
299 | 3 | .into(); |
300 | 3 | values.insert( |
301 | 3 | key_name, |
302 | 3 | Value::BulkString(value.as_bytes().unwrap().to_vec()), |
303 | | ); |
304 | | } |
305 | 3 | let key = |
306 | 3 | str::from_utf8(args[0].as_bytes().expect("Key argument is not bytes")) |
307 | 3 | .expect("Unable to parse key as string") |
308 | 3 | .into(); |
309 | 3 | debug!(%key, ?values, "Inserting with HMSET"); |
310 | 3 | self.table.lock().unwrap().insert(key, values); |
311 | 3 | Value::Okay |
312 | | } |
313 | | |
314 | 94 | "HMGET" => { |
315 | 94 | let key_name = |
316 | 94 | str::from_utf8(args[0].as_bytes().expect("Key argument is not bytes")) |
317 | 94 | .expect("Unable to parse key name"); |
318 | | |
319 | 94 | if let Some(fields48 ) = self.table.lock().unwrap().get(key_name) { Branch (319:32): [Folded - Ignored]
Branch (319:32): [Folded - Ignored]
Branch (319:32): [True: 48, False: 46]
|
320 | 48 | trace!(%key_name, keys = ?fields.keys(), "Getting keys with HMGET, some keys"); |
321 | 48 | let mut result = vec![]; |
322 | 96 | for key in &args[1..]48 { |
323 | 96 | let field_name = str::from_utf8( |
324 | 96 | key.as_bytes().expect("Field argument is not bytes"), |
325 | | ) |
326 | 96 | .expect("Unable to parse requested field"); |
327 | 96 | if let Some(value94 ) = fields.get(field_name) { Branch (327:40): [Folded - Ignored]
Branch (327:40): [Folded - Ignored]
Branch (327:40): [True: 94, False: 2]
|
328 | 94 | result.push(value.clone()); |
329 | 94 | } else { |
330 | 2 | debug!(%key_name, %field_name, "Missing field"); |
331 | 2 | result.push(Value::Nil); |
332 | | } |
333 | | } |
334 | 48 | Value::Array(result) |
335 | | } else { |
336 | 46 | trace!(%key_name, "Getting keys with HMGET, empty"); |
337 | 46 | let null_count = i64::try_from(args.len() - 1).unwrap(); |
338 | 46 | Value::Array(vec![Value::Nil, Value::Int(null_count)]) |
339 | | } |
340 | | } |
341 | 0 | actual => { |
342 | 0 | panic!("Mock command not implemented! {actual:?}"); |
343 | | } |
344 | | }; |
345 | | |
346 | 166 | arg_as_string(&mut output, ret); |
347 | 166 | if buf_index == buf.len() { Branch (347:20): [Folded - Ignored]
Branch (347:20): [Folded - Ignored]
Branch (347:20): [True: 154, False: 12]
|
348 | 154 | break; |
349 | 12 | } |
350 | | } |
351 | 154 | output |
352 | 154 | }; |
353 | 3 | fake_redis_internal(listener, inner).await; |
354 | 0 | } |
355 | | |
356 | 3 | pub async fn run(self) -> u16 { |
357 | 3 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); |
358 | 3 | let port = listener.local_addr().unwrap().port(); |
359 | 3 | info!("Using port {port}"); |
360 | | |
361 | 3 | background_spawn!("listener", async move { |
362 | 3 | self.dynamic_fake_redis(listener).await; |
363 | 0 | }); |
364 | | |
365 | 3 | port |
366 | 3 | } |
367 | | } |