Coverage Report

Created: 2026-04-16 01:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-redis-tester/src/dynamic_fake_redis.rs
Line
Count
Source
1
// Copyright 2026 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::fmt;
16
use std::collections::HashMap;
17
use std::collections::hash_map::Entry;
18
use std::sync::{Arc, Mutex};
19
20
use nativelink_util::background_spawn;
21
use redis::Value;
22
use redis_protocol::resp2::decode::decode;
23
use redis_protocol::resp2::types::{OwnedFrame, Resp2Frame};
24
use tokio::net::TcpListener;
25
use tracing::{debug, info, trace};
26
27
use crate::fake_redis::{arg_as_string, fake_redis_internal};
28
29
pub trait SubscriptionManagerNotify {
30
    fn notify_for_test(&self, value: String);
31
}
32
33
#[derive(Clone)]
34
pub struct FakeRedisBackend<S: SubscriptionManagerNotify> {
35
    /// Contains a list of all of the Redis keys -> fields.
36
    pub table: Arc<Mutex<HashMap<String, HashMap<String, Value>>>>,
37
    subscription_manager: Arc<Mutex<Option<Arc<S>>>>,
38
}
39
40
impl<S: SubscriptionManagerNotify + Send + 'static + Sync> Default for FakeRedisBackend<S> {
41
0
    fn default() -> Self {
42
0
        Self::new()
43
0
    }
44
}
45
46
impl<S: SubscriptionManagerNotify> fmt::Debug for FakeRedisBackend<S> {
47
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48
0
        f.debug_struct("FakeRedisBackend").finish()
49
0
    }
50
}
51
52
const FAKE_SCRIPT_SHA: &str = "b22b9926cbce9dd9ba97fa7ba3626f89feea1ed5";
53
54
impl<S: SubscriptionManagerNotify + Send + 'static + Sync> FakeRedisBackend<S> {
55
3
    pub fn new() -> Self {
56
3
        Self {
57
3
            table: Arc::new(Mutex::new(HashMap::new())),
58
3
            subscription_manager: Arc::new(Mutex::new(None)),
59
3
        }
60
3
    }
61
62
2
    pub fn set_subscription_manager(&self, subscription_manager: Arc<S>) {
63
2
        self.subscription_manager
64
2
            .lock()
65
2
            .unwrap()
66
2
            .replace(subscription_manager);
67
2
    }
68
69
3
    async fn dynamic_fake_redis(self, listener: TcpListener) {
70
154
        let 
inner3
= move |buf: &[u8]| -> String {
71
154
            let mut output = String::new();
72
154
            let mut buf_index = 0;
73
            loop {
74
166
                let frame = match decode(&buf[buf_index..]).unwrap() {
75
166
                    Some((frame, amt)) => {
76
166
                        buf_index += amt;
77
166
                        frame
78
                    }
79
                    None => {
80
0
                        panic!("No frame!");
81
                    }
82
                };
83
166
                let (cmd, args) = {
84
166
                    if let OwnedFrame::Array(a) = frame {
85
166
                        if let OwnedFrame::BulkString(s) = a.first().unwrap() {
86
166
                            let args: Vec<_> = a[1..].to_vec();
87
166
                            (str::from_utf8(s).unwrap().to_string(), args)
88
                        } else {
89
0
                            panic!("Array not starting with cmd: {a:?}");
90
                        }
91
                    } else {
92
0
                        panic!("Non array cmd: {frame:?}");
93
                    }
94
                };
95
96
166
                let ret: Value = match cmd.as_str() {
97
166
                    "HELLO" => 
Value::Map(3
vec!3
[(
98
3
                        Value::SimpleString("server".into()),
99
3
                        Value::SimpleString("redis".into()),
100
3
                    )]),
101
163
                    "CLIENT" => {
102
                        // We can safely ignore these, as it's just setting the library name/version
103
6
                        Value::Int(0)
104
                    }
105
157
                    "SCRIPT" => {
106
3
                        assert_eq!(args[0], OwnedFrame::BulkString(b"LOAD".to_vec()));
107
108
3
                        let OwnedFrame::BulkString(ref _script) = args[1] else {
109
0
                            panic!("Script should be a bulkstring: {args:?}");
110
                        };
111
3
                        Value::SimpleString(FAKE_SCRIPT_SHA.to_string())
112
                    }
113
114
154
                    "PSUBSCRIBE" => {
115
                        // This does nothing at the moment, maybe we need to implement it later.
116
6
                        Value::Int(0)
117
                    }
118
119
148
                    "PUBLISH" => {
120
14
                        if let Some(subscription_manager) =
121
15
                            self.subscription_manager.lock().unwrap().as_ref()
122
                        {
123
14
                            subscription_manager.notify_for_test(
124
14
                                str::from_utf8(args[1].as_bytes().expect("Notification not bytes"))
125
14
                                    .expect("Notification not UTF-8")
126
14
                                    .into(),
127
14
                            );
128
14
                            Value::Int(1)
129
                        } else {
130
1
                            Value::Int(0)
131
                        }
132
                    }
133
134
133
                    "FT.AGGREGATE" => {
135
                        // The query is either "*" (match all) or @field:{ value }.
136
20
                        let OwnedFrame::BulkString(ref raw_query) = args[1] else {
137
0
                            panic!("Aggregate query should be a string: {args:?}");
138
                        };
139
20
                        let query = str::from_utf8(raw_query).unwrap();
140
                        // Lazy implementation making assumptions.
141
20
                        assert_eq!(
142
20
                            args[2..6],
143
20
                            vec![
144
20
                                OwnedFrame::BulkString(b"LOAD".to_vec()),
145
20
                                OwnedFrame::BulkString(b"2".to_vec()),
146
20
                                OwnedFrame::BulkString(b"data".to_vec()),
147
20
                                OwnedFrame::BulkString(b"version".to_vec())
148
                            ]
149
                        );
150
20
                        let mut results = vec![Value::Int(0)];
151
152
20
                        if query == "*" {
153
                            // Wildcard query - return all records that have both data and version fields.
154
                            // Some entries (e.g., from HSET) may not have version field.
155
0
                            for fields in self.table.lock().unwrap().values() {
156
0
                                if let (Some(data), Some(version)) =
157
0
                                    (fields.get("data"), fields.get("version"))
158
0
                                {
159
0
                                    results.push(Value::Array(vec![
160
0
                                        Value::BulkString(b"data".to_vec()),
161
0
                                        data.clone(),
162
0
                                        Value::BulkString(b"version".to_vec()),
163
0
                                        version.clone(),
164
0
                                    ]));
165
0
                                }
166
                            }
167
                        } else {
168
                            // Field-specific query: @field:{ value }
169
20
                            assert_eq!(&query[..1], "@");
170
20
                            let mut parts = query[1..].split(':');
171
20
                            let field = parts.next().expect("No field name");
172
20
                            let value = parts.next().expect("No value");
173
20
                            let value = value
174
20
                                .strip_prefix("{ ")
175
20
                                .and_then(|s| s.strip_suffix(" }"))
176
20
                                .unwrap_or(value);
177
56
                            for fields in 
self.table.lock().unwrap().values()20
{
178
56
                                if let Some(
key_value19
) = fields.get(field)
179
19
                                    && *key_value == Value::BulkString(value.as_bytes().to_vec())
180
10
                                {
181
10
                                    results.push(Value::Array(vec![
182
10
                                        Value::BulkString(b"data".to_vec()),
183
10
                                        fields.get("data").expect("No data field").clone(),
184
10
                                        Value::BulkString(b"version".to_vec()),
185
10
                                        fields.get("version").expect("No version field").clone(),
186
10
                                    ]));
187
46
                                }
188
                            }
189
                        }
190
191
20
                        results[0] =
192
20
                            Value::Int(i64::try_from(results.len() - 1).unwrap_or(i64::MAX));
193
20
                        Value::Array(vec![
194
20
                            Value::Array(results),
195
20
                            Value::Int(0), // Means no more items in cursor.
196
20
                        ])
197
                    }
198
199
113
                    "EVALSHA" => {
200
16
                        assert_eq!(
201
16
                            args[0],
202
16
                            OwnedFrame::BulkString(FAKE_SCRIPT_SHA.as_bytes().to_vec())
203
                        );
204
16
                        assert_eq!(args[1], OwnedFrame::BulkString(b"1".to_vec()));
205
16
                        let mut value: HashMap<_, Value> = HashMap::new();
206
16
                        value.insert(
207
16
                            "data".into(),
208
16
                            Value::BulkString(args[4].as_bytes().unwrap().to_vec()),
209
                        );
210
48
                        for pair in 
args[5..]16
.
chunks16
(2) {
211
48
                            value.insert(
212
48
                                str::from_utf8(pair[0].as_bytes().expect("Field name not bytes"))
213
48
                                    .expect("Unable to parse field name as string")
214
48
                                    .into(),
215
48
                                Value::BulkString(pair[1].as_bytes().unwrap().to_vec()),
216
48
                            );
217
48
                        }
218
16
                        let mut ret: Option<Value> = None;
219
16
                        let key: String =
220
16
                            str::from_utf8(args[2].as_bytes().expect("Key not bytes"))
221
16
                                .expect("Key cannot be parsed as string")
222
16
                                .into();
223
16
                        let expected_existing_version: i64 =
224
16
                            str::from_utf8(args[3].as_bytes().unwrap())
225
16
                                .unwrap()
226
16
                                .parse()
227
16
                                .expect("Unable to parse existing version field");
228
16
                        trace!(%key, %expected_existing_version, ?value, "Want to insert with EVALSHA");
229
16
                        let version = match self.table.lock().unwrap().entry(key.clone()) {
230
13
                            Entry::Occupied(mut occupied_entry) => {
231
13
                                let version = occupied_entry
232
13
                                    .get()
233
13
                                    .get("version")
234
13
                                    .expect("No version field");
235
13
                                let Value::BulkString(version_bytes) = version else {
236
0
                                    panic!("Non-bulkstring version: {version:?}");
237
                                };
238
13
                                let version_int: i64 = str::from_utf8(version_bytes)
239
13
                                    .expect("Version field not valid string")
240
13
                                    .parse()
241
13
                                    .expect("Unable to parse version field");
242
13
                                if version_int == expected_existing_version {
243
9
                                    let new_version = version_int + 1;
244
9
                                    debug!(%key, %new_version, "Version update");
245
9
                                    value.insert(
246
9
                                        "version".into(),
247
9
                                        Value::BulkString(
248
9
                                            format!("{new_version}").as_bytes().to_vec(),
249
9
                                        ),
250
                                    );
251
9
                                    occupied_entry.insert(value);
252
9
                                    new_version
253
                                } else {
254
                                    // Version mismatch.
255
4
                                    debug!(%key, %version_int, %expected_existing_version, "Version mismatch");
256
4
                                    ret = Some(Value::Array(vec![
257
4
                                        Value::Int(0),
258
4
                                        Value::Int(version_int),
259
4
                                    ]));
260
4
                                    -1
261
                                }
262
                            }
263
3
                            Entry::Vacant(vacant_entry) => {
264
3
                                if expected_existing_version != 0 {
265
                                    // Version mismatch.
266
0
                                    debug!(%key, %expected_existing_version, "Version mismatch, expected zero");
267
0
                                    ret = Some(Value::Array(vec![Value::Int(0), Value::Int(0)]));
268
0
                                    -1
269
                                } else {
270
3
                                    debug!(%key, "Version insert");
271
3
                                    value
272
3
                                        .insert("version".into(), Value::BulkString(b"1".to_vec()));
273
3
                                    vacant_entry.insert_entry(value);
274
3
                                    1
275
                                }
276
                            }
277
                        };
278
16
                        if let Some(
r4
) = ret {
279
4
                            r
280
                        } else {
281
12
                            Value::Array(vec![Value::Int(1), Value::Int(version)])
282
                        }
283
                    }
284
285
97
                    "HMSET" => {
286
3
                        let mut values = HashMap::new();
287
3
                        assert_eq!(
288
3
                            (args.len() - 1).rem_euclid(2),
289
                            0,
290
                            "Non-even args for hmset: {args:?}"
291
                        );
292
3
                        let chunks = args[1..].chunks_exact(2);
293
3
                        for chunk in chunks {
294
3
                            let [key, value] = chunk else {
295
0
                                panic!("Uneven hmset args");
296
                            };
297
3
                            let key_name: String =
298
3
                                str::from_utf8(key.as_bytes().expect("Key argument is not bytes"))
299
3
                                    .expect("Unable to parse key as string")
300
3
                                    .into();
301
3
                            values.insert(
302
3
                                key_name,
303
3
                                Value::BulkString(value.as_bytes().unwrap().to_vec()),
304
                            );
305
                        }
306
3
                        let key =
307
3
                            str::from_utf8(args[0].as_bytes().expect("Key argument is not bytes"))
308
3
                                .expect("Unable to parse key as string")
309
3
                                .into();
310
3
                        debug!(%key, ?values, "Inserting with HMSET");
311
3
                        self.table.lock().unwrap().insert(key, values);
312
3
                        Value::Okay
313
                    }
314
315
94
                    "HMGET" => {
316
94
                        let key_name =
317
94
                            str::from_utf8(args[0].as_bytes().expect("Key argument is not bytes"))
318
94
                                .expect("Unable to parse key name");
319
320
94
                        if let Some(
fields48
) = self.table.lock().unwrap().get(key_name) {
321
48
                            trace!(%key_name, keys = ?fields.keys(), "Getting keys with HMGET, some keys");
322
48
                            let mut result = vec![];
323
96
                            for key in 
&args[1..]48
{
324
96
                                let field_name = str::from_utf8(
325
96
                                    key.as_bytes().expect("Field argument is not bytes"),
326
                                )
327
96
                                .expect("Unable to parse requested field");
328
96
                                if let Some(
value94
) = fields.get(field_name) {
329
94
                                    result.push(value.clone());
330
94
                                } else {
331
2
                                    debug!(%key_name, %field_name, "Missing field");
332
2
                                    result.push(Value::Nil);
333
                                }
334
                            }
335
48
                            Value::Array(result)
336
                        } else {
337
46
                            trace!(%key_name, "Getting keys with HMGET, empty");
338
46
                            let null_count = i64::try_from(args.len() - 1).unwrap();
339
46
                            Value::Array(vec![Value::Nil, Value::Int(null_count)])
340
                        }
341
                    }
342
0
                    actual => {
343
0
                        panic!("Mock command not implemented! {actual:?}");
344
                    }
345
                };
346
347
166
                arg_as_string(&mut output, ret);
348
166
                if buf_index == buf.len() {
349
154
                    break;
350
12
                }
351
            }
352
154
            output
353
154
        };
354
3
        fake_redis_internal(listener, vec![inner]).await;
355
0
    }
356
357
3
    pub async fn run(self) -> u16 {
358
3
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
359
3
        let port = listener.local_addr().unwrap().port();
360
3
        info!("Using port {port}");
361
362
3
        background_spawn!("listener", async move {
363
3
            self.dynamic_fake_redis(listener).await;
364
0
        });
365
366
3
        port
367
3
    }
368
}