/build/source/nativelink-redis-tester/src/dynamic_fake_redis.rs
Line | Count | Source |
1 | | // Copyright 2026 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::fmt; |
16 | | use std::collections::HashMap; |
17 | | use std::collections::hash_map::Entry; |
18 | | use std::sync::{Arc, Mutex}; |
19 | | |
20 | | use nativelink_util::background_spawn; |
21 | | use redis::Value; |
22 | | use redis_protocol::resp2::decode::decode; |
23 | | use redis_protocol::resp2::types::{OwnedFrame, Resp2Frame}; |
24 | | use tokio::net::TcpListener; |
25 | | use tracing::{debug, info, trace}; |
26 | | |
27 | | use crate::fake_redis::{arg_as_string, fake_redis_internal}; |
28 | | |
29 | | pub trait SubscriptionManagerNotify { |
30 | | fn notify_for_test(&self, value: String); |
31 | | } |
32 | | |
33 | | #[derive(Clone)] |
34 | | pub struct FakeRedisBackend<S: SubscriptionManagerNotify> { |
35 | | /// Contains a list of all of the Redis keys -> fields. |
36 | | pub table: Arc<Mutex<HashMap<String, HashMap<String, Value>>>>, |
37 | | subscription_manager: Arc<Mutex<Option<Arc<S>>>>, |
38 | | } |
39 | | |
40 | | impl<S: SubscriptionManagerNotify + Send + 'static + Sync> Default for FakeRedisBackend<S> { |
41 | 0 | fn default() -> Self { |
42 | 0 | Self::new() |
43 | 0 | } |
44 | | } |
45 | | |
46 | | impl<S: SubscriptionManagerNotify> fmt::Debug for FakeRedisBackend<S> { |
47 | 0 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
48 | 0 | f.debug_struct("FakeRedisBackend").finish() |
49 | 0 | } |
50 | | } |
51 | | |
52 | | const FAKE_SCRIPT_SHA: &str = "b22b9926cbce9dd9ba97fa7ba3626f89feea1ed5"; |
53 | | |
54 | | impl<S: SubscriptionManagerNotify + Send + 'static + Sync> FakeRedisBackend<S> { |
55 | 3 | pub fn new() -> Self { |
56 | 3 | Self { |
57 | 3 | table: Arc::new(Mutex::new(HashMap::new())), |
58 | 3 | subscription_manager: Arc::new(Mutex::new(None)), |
59 | 3 | } |
60 | 3 | } |
61 | | |
62 | 2 | pub fn set_subscription_manager(&self, subscription_manager: Arc<S>) { |
63 | 2 | self.subscription_manager |
64 | 2 | .lock() |
65 | 2 | .unwrap() |
66 | 2 | .replace(subscription_manager); |
67 | 2 | } |
68 | | |
69 | 3 | async fn dynamic_fake_redis(self, listener: TcpListener) { |
70 | 154 | let inner3 = move |buf: &[u8]| -> String { |
71 | 154 | let mut output = String::new(); |
72 | 154 | let mut buf_index = 0; |
73 | | loop { |
74 | 166 | let frame = match decode(&buf[buf_index..]).unwrap() { |
75 | 166 | Some((frame, amt)) => { |
76 | 166 | buf_index += amt; |
77 | 166 | frame |
78 | | } |
79 | | None => { |
80 | 0 | panic!("No frame!"); |
81 | | } |
82 | | }; |
83 | 166 | let (cmd, args) = { |
84 | 166 | if let OwnedFrame::Array(a) = frame { |
85 | 166 | if let OwnedFrame::BulkString(s) = a.first().unwrap() { |
86 | 166 | let args: Vec<_> = a[1..].to_vec(); |
87 | 166 | (str::from_utf8(s).unwrap().to_string(), args) |
88 | | } else { |
89 | 0 | panic!("Array not starting with cmd: {a:?}"); |
90 | | } |
91 | | } else { |
92 | 0 | panic!("Non array cmd: {frame:?}"); |
93 | | } |
94 | | }; |
95 | | |
96 | 166 | let ret: Value = match cmd.as_str() { |
97 | 166 | "HELLO" => Value::Map(3 vec!3 [( |
98 | 3 | Value::SimpleString("server".into()), |
99 | 3 | Value::SimpleString("redis".into()), |
100 | 3 | )]), |
101 | 163 | "CLIENT" => { |
102 | | // We can safely ignore these, as it's just setting the library name/version |
103 | 6 | Value::Int(0) |
104 | | } |
105 | 157 | "SCRIPT" => { |
106 | 3 | assert_eq!(args[0], OwnedFrame::BulkString(b"LOAD".to_vec())); |
107 | | |
108 | 3 | let OwnedFrame::BulkString(ref _script) = args[1] else { |
109 | 0 | panic!("Script should be a bulkstring: {args:?}"); |
110 | | }; |
111 | 3 | Value::SimpleString(FAKE_SCRIPT_SHA.to_string()) |
112 | | } |
113 | | |
114 | 154 | "PSUBSCRIBE" => { |
115 | | // This does nothing at the moment, maybe we need to implement it later. |
116 | 6 | Value::Int(0) |
117 | | } |
118 | | |
119 | 148 | "PUBLISH" => { |
120 | 14 | if let Some(subscription_manager) = |
121 | 15 | self.subscription_manager.lock().unwrap().as_ref() |
122 | | { |
123 | 14 | subscription_manager.notify_for_test( |
124 | 14 | str::from_utf8(args[1].as_bytes().expect("Notification not bytes")) |
125 | 14 | .expect("Notification not UTF-8") |
126 | 14 | .into(), |
127 | 14 | ); |
128 | 14 | Value::Int(1) |
129 | | } else { |
130 | 1 | Value::Int(0) |
131 | | } |
132 | | } |
133 | | |
134 | 133 | "FT.AGGREGATE" => { |
135 | | // The query is either "*" (match all) or @field:{ value }. |
136 | 20 | let OwnedFrame::BulkString(ref raw_query) = args[1] else { |
137 | 0 | panic!("Aggregate query should be a string: {args:?}"); |
138 | | }; |
139 | 20 | let query = str::from_utf8(raw_query).unwrap(); |
140 | | // Lazy implementation making assumptions. |
141 | 20 | assert_eq!( |
142 | 20 | args[2..6], |
143 | 20 | vec![ |
144 | 20 | OwnedFrame::BulkString(b"LOAD".to_vec()), |
145 | 20 | OwnedFrame::BulkString(b"2".to_vec()), |
146 | 20 | OwnedFrame::BulkString(b"data".to_vec()), |
147 | 20 | OwnedFrame::BulkString(b"version".to_vec()) |
148 | | ] |
149 | | ); |
150 | 20 | let mut results = vec![Value::Int(0)]; |
151 | | |
152 | 20 | if query == "*" { |
153 | | // Wildcard query - return all records that have both data and version fields. |
154 | | // Some entries (e.g., from HSET) may not have version field. |
155 | 0 | for fields in self.table.lock().unwrap().values() { |
156 | 0 | if let (Some(data), Some(version)) = |
157 | 0 | (fields.get("data"), fields.get("version")) |
158 | 0 | { |
159 | 0 | results.push(Value::Array(vec![ |
160 | 0 | Value::BulkString(b"data".to_vec()), |
161 | 0 | data.clone(), |
162 | 0 | Value::BulkString(b"version".to_vec()), |
163 | 0 | version.clone(), |
164 | 0 | ])); |
165 | 0 | } |
166 | | } |
167 | | } else { |
168 | | // Field-specific query: @field:{ value } |
169 | 20 | assert_eq!(&query[..1], "@"); |
170 | 20 | let mut parts = query[1..].split(':'); |
171 | 20 | let field = parts.next().expect("No field name"); |
172 | 20 | let value = parts.next().expect("No value"); |
173 | 20 | let value = value |
174 | 20 | .strip_prefix("{ ") |
175 | 20 | .and_then(|s| s.strip_suffix(" }")) |
176 | 20 | .unwrap_or(value); |
177 | 56 | for fields in self.table.lock().unwrap().values()20 { |
178 | 56 | if let Some(key_value19 ) = fields.get(field) |
179 | 19 | && *key_value == Value::BulkString(value.as_bytes().to_vec()) |
180 | 10 | { |
181 | 10 | results.push(Value::Array(vec![ |
182 | 10 | Value::BulkString(b"data".to_vec()), |
183 | 10 | fields.get("data").expect("No data field").clone(), |
184 | 10 | Value::BulkString(b"version".to_vec()), |
185 | 10 | fields.get("version").expect("No version field").clone(), |
186 | 10 | ])); |
187 | 46 | } |
188 | | } |
189 | | } |
190 | | |
191 | 20 | results[0] = |
192 | 20 | Value::Int(i64::try_from(results.len() - 1).unwrap_or(i64::MAX)); |
193 | 20 | Value::Array(vec![ |
194 | 20 | Value::Array(results), |
195 | 20 | Value::Int(0), // Means no more items in cursor. |
196 | 20 | ]) |
197 | | } |
198 | | |
199 | 113 | "EVALSHA" => { |
200 | 16 | assert_eq!( |
201 | 16 | args[0], |
202 | 16 | OwnedFrame::BulkString(FAKE_SCRIPT_SHA.as_bytes().to_vec()) |
203 | | ); |
204 | 16 | assert_eq!(args[1], OwnedFrame::BulkString(b"1".to_vec())); |
205 | 16 | let mut value: HashMap<_, Value> = HashMap::new(); |
206 | 16 | value.insert( |
207 | 16 | "data".into(), |
208 | 16 | Value::BulkString(args[4].as_bytes().unwrap().to_vec()), |
209 | | ); |
210 | 48 | for pair in args[5..]16 .chunks16 (2) { |
211 | 48 | value.insert( |
212 | 48 | str::from_utf8(pair[0].as_bytes().expect("Field name not bytes")) |
213 | 48 | .expect("Unable to parse field name as string") |
214 | 48 | .into(), |
215 | 48 | Value::BulkString(pair[1].as_bytes().unwrap().to_vec()), |
216 | 48 | ); |
217 | 48 | } |
218 | 16 | let mut ret: Option<Value> = None; |
219 | 16 | let key: String = |
220 | 16 | str::from_utf8(args[2].as_bytes().expect("Key not bytes")) |
221 | 16 | .expect("Key cannot be parsed as string") |
222 | 16 | .into(); |
223 | 16 | let expected_existing_version: i64 = |
224 | 16 | str::from_utf8(args[3].as_bytes().unwrap()) |
225 | 16 | .unwrap() |
226 | 16 | .parse() |
227 | 16 | .expect("Unable to parse existing version field"); |
228 | 16 | trace!(%key, %expected_existing_version, ?value, "Want to insert with EVALSHA"); |
229 | 16 | let version = match self.table.lock().unwrap().entry(key.clone()) { |
230 | 13 | Entry::Occupied(mut occupied_entry) => { |
231 | 13 | let version = occupied_entry |
232 | 13 | .get() |
233 | 13 | .get("version") |
234 | 13 | .expect("No version field"); |
235 | 13 | let Value::BulkString(version_bytes) = version else { |
236 | 0 | panic!("Non-bulkstring version: {version:?}"); |
237 | | }; |
238 | 13 | let version_int: i64 = str::from_utf8(version_bytes) |
239 | 13 | .expect("Version field not valid string") |
240 | 13 | .parse() |
241 | 13 | .expect("Unable to parse version field"); |
242 | 13 | if version_int == expected_existing_version { |
243 | 9 | let new_version = version_int + 1; |
244 | 9 | debug!(%key, %new_version, "Version update"); |
245 | 9 | value.insert( |
246 | 9 | "version".into(), |
247 | 9 | Value::BulkString( |
248 | 9 | format!("{new_version}").as_bytes().to_vec(), |
249 | 9 | ), |
250 | | ); |
251 | 9 | occupied_entry.insert(value); |
252 | 9 | new_version |
253 | | } else { |
254 | | // Version mismatch. |
255 | 4 | debug!(%key, %version_int, %expected_existing_version, "Version mismatch"); |
256 | 4 | ret = Some(Value::Array(vec![ |
257 | 4 | Value::Int(0), |
258 | 4 | Value::Int(version_int), |
259 | 4 | ])); |
260 | 4 | -1 |
261 | | } |
262 | | } |
263 | 3 | Entry::Vacant(vacant_entry) => { |
264 | 3 | if expected_existing_version != 0 { |
265 | | // Version mismatch. |
266 | 0 | debug!(%key, %expected_existing_version, "Version mismatch, expected zero"); |
267 | 0 | ret = Some(Value::Array(vec![Value::Int(0), Value::Int(0)])); |
268 | 0 | -1 |
269 | | } else { |
270 | 3 | debug!(%key, "Version insert"); |
271 | 3 | value |
272 | 3 | .insert("version".into(), Value::BulkString(b"1".to_vec())); |
273 | 3 | vacant_entry.insert_entry(value); |
274 | 3 | 1 |
275 | | } |
276 | | } |
277 | | }; |
278 | 16 | if let Some(r4 ) = ret { |
279 | 4 | r |
280 | | } else { |
281 | 12 | Value::Array(vec![Value::Int(1), Value::Int(version)]) |
282 | | } |
283 | | } |
284 | | |
285 | 97 | "HMSET" => { |
286 | 3 | let mut values = HashMap::new(); |
287 | 3 | assert_eq!( |
288 | 3 | (args.len() - 1).rem_euclid(2), |
289 | | 0, |
290 | | "Non-even args for hmset: {args:?}" |
291 | | ); |
292 | 3 | let chunks = args[1..].chunks_exact(2); |
293 | 3 | for chunk in chunks { |
294 | 3 | let [key, value] = chunk else { |
295 | 0 | panic!("Uneven hmset args"); |
296 | | }; |
297 | 3 | let key_name: String = |
298 | 3 | str::from_utf8(key.as_bytes().expect("Key argument is not bytes")) |
299 | 3 | .expect("Unable to parse key as string") |
300 | 3 | .into(); |
301 | 3 | values.insert( |
302 | 3 | key_name, |
303 | 3 | Value::BulkString(value.as_bytes().unwrap().to_vec()), |
304 | | ); |
305 | | } |
306 | 3 | let key = |
307 | 3 | str::from_utf8(args[0].as_bytes().expect("Key argument is not bytes")) |
308 | 3 | .expect("Unable to parse key as string") |
309 | 3 | .into(); |
310 | 3 | debug!(%key, ?values, "Inserting with HMSET"); |
311 | 3 | self.table.lock().unwrap().insert(key, values); |
312 | 3 | Value::Okay |
313 | | } |
314 | | |
315 | 94 | "HMGET" => { |
316 | 94 | let key_name = |
317 | 94 | str::from_utf8(args[0].as_bytes().expect("Key argument is not bytes")) |
318 | 94 | .expect("Unable to parse key name"); |
319 | | |
320 | 94 | if let Some(fields48 ) = self.table.lock().unwrap().get(key_name) { |
321 | 48 | trace!(%key_name, keys = ?fields.keys(), "Getting keys with HMGET, some keys"); |
322 | 48 | let mut result = vec![]; |
323 | 96 | for key in &args[1..]48 { |
324 | 96 | let field_name = str::from_utf8( |
325 | 96 | key.as_bytes().expect("Field argument is not bytes"), |
326 | | ) |
327 | 96 | .expect("Unable to parse requested field"); |
328 | 96 | if let Some(value94 ) = fields.get(field_name) { |
329 | 94 | result.push(value.clone()); |
330 | 94 | } else { |
331 | 2 | debug!(%key_name, %field_name, "Missing field"); |
332 | 2 | result.push(Value::Nil); |
333 | | } |
334 | | } |
335 | 48 | Value::Array(result) |
336 | | } else { |
337 | 46 | trace!(%key_name, "Getting keys with HMGET, empty"); |
338 | 46 | let null_count = i64::try_from(args.len() - 1).unwrap(); |
339 | 46 | Value::Array(vec![Value::Nil, Value::Int(null_count)]) |
340 | | } |
341 | | } |
342 | 0 | actual => { |
343 | 0 | panic!("Mock command not implemented! {actual:?}"); |
344 | | } |
345 | | }; |
346 | | |
347 | 166 | arg_as_string(&mut output, ret); |
348 | 166 | if buf_index == buf.len() { |
349 | 154 | break; |
350 | 12 | } |
351 | | } |
352 | 154 | output |
353 | 154 | }; |
354 | 3 | fake_redis_internal(listener, vec![inner]).await; |
355 | 0 | } |
356 | | |
357 | 3 | pub async fn run(self) -> u16 { |
358 | 3 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); |
359 | 3 | let port = listener.local_addr().unwrap().port(); |
360 | 3 | info!("Using port {port}"); |
361 | | |
362 | 3 | background_spawn!("listener", async move { |
363 | 3 | self.dynamic_fake_redis(listener).await; |
364 | 0 | }); |
365 | | |
366 | 3 | port |
367 | 3 | } |
368 | | } |