Coverage Report

Created: 2026-02-23 10:49

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/build/source/nativelink-redis-tester/src/dynamic_fake_redis.rs
Line
Count
Source
1
// Copyright 2026 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    See LICENSE file for details
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use core::fmt;
16
use std::collections::HashMap;
17
use std::collections::hash_map::Entry;
18
use std::sync::{Arc, Mutex};
19
20
use nativelink_util::background_spawn;
21
use redis::Value;
22
use redis_protocol::resp2::decode::decode;
23
use redis_protocol::resp2::types::{OwnedFrame, Resp2Frame};
24
use tokio::net::TcpListener;
25
use tracing::{debug, info, trace};
26
27
use crate::fake_redis::{arg_as_string, fake_redis_internal};
28
29
pub trait SubscriptionManagerNotify {
30
    fn notify_for_test(&self, value: String);
31
}
32
33
#[derive(Clone)]
34
pub struct FakeRedisBackend<S: SubscriptionManagerNotify> {
35
    /// Contains a list of all of the Redis keys -> fields.
36
    pub table: Arc<Mutex<HashMap<String, HashMap<String, Value>>>>,
37
    subscription_manager: Arc<Mutex<Option<Arc<S>>>>,
38
}
39
40
impl<S: SubscriptionManagerNotify + Send + 'static + Sync> Default for FakeRedisBackend<S> {
41
0
    fn default() -> Self {
42
0
        Self::new()
43
0
    }
44
}
45
46
impl<S: SubscriptionManagerNotify> fmt::Debug for FakeRedisBackend<S> {
47
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48
0
        f.debug_struct("FakeRedisBackend").finish()
49
0
    }
50
}
51
52
const FAKE_SCRIPT_SHA: &str = "b22b9926cbce9dd9ba97fa7ba3626f89feea1ed5";
53
54
impl<S: SubscriptionManagerNotify + Send + 'static + Sync> FakeRedisBackend<S> {
55
3
    pub fn new() -> Self {
56
3
        Self {
57
3
            table: Arc::new(Mutex::new(HashMap::new())),
58
3
            subscription_manager: Arc::new(Mutex::new(None)),
59
3
        }
60
3
    }
61
62
2
    pub fn set_subscription_manager(&self, subscription_manager: Arc<S>) {
63
2
        self.subscription_manager
64
2
            .lock()
65
2
            .unwrap()
66
2
            .replace(subscription_manager);
67
2
    }
68
69
3
    async fn dynamic_fake_redis(self, listener: TcpListener) {
70
154
        let 
inner3
= move |buf: &[u8]| -> String {
71
154
            let mut output = String::new();
72
154
            let mut buf_index = 0;
73
            loop {
74
166
                let frame = match decode(&buf[buf_index..]).unwrap() {
75
166
                    Some((frame, amt)) => {
76
166
                        buf_index += amt;
77
166
                        frame
78
                    }
79
                    None => {
80
0
                        panic!("No frame!");
81
                    }
82
                };
83
166
                let (cmd, args) = {
84
166
                    if let OwnedFrame::Array(a) = frame {
  Branch (84:28): [Folded - Ignored]
  Branch (84:28): [Folded - Ignored]
  Branch (84:28): [True: 166, False: 0]
85
166
                        if let OwnedFrame::BulkString(s) = a.first().unwrap() {
  Branch (85:32): [Folded - Ignored]
  Branch (85:32): [Folded - Ignored]
  Branch (85:32): [True: 166, False: 0]
86
166
                            let args: Vec<_> = a[1..].to_vec();
87
166
                            (str::from_utf8(s).unwrap().to_string(), args)
88
                        } else {
89
0
                            panic!("Array not starting with cmd: {a:?}");
90
                        }
91
                    } else {
92
0
                        panic!("Non array cmd: {frame:?}");
93
                    }
94
                };
95
96
166
                let ret: Value = match cmd.as_str() {
97
166
                    "CLIENT" => {
98
                        // We can safely ignore these, as it's just setting the library name/version
99
12
                        Value::Int(0)
100
                    }
101
154
                    "SCRIPT" => {
102
3
                        assert_eq!(args[0], OwnedFrame::BulkString(b"LOAD".to_vec()));
103
104
3
                        let OwnedFrame::BulkString(ref _script) = args[1] else {
  Branch (104:29): [Folded - Ignored]
  Branch (104:29): [Folded - Ignored]
  Branch (104:29): [True: 3, False: 0]
105
0
                            panic!("Script should be a bulkstring: {args:?}");
106
                        };
107
3
                        Value::SimpleString(FAKE_SCRIPT_SHA.to_string())
108
                    }
109
110
151
                    "PSUBSCRIBE" => {
111
                        // This does nothing at the moment, maybe we need to implement it later.
112
3
                        Value::Int(0)
113
                    }
114
115
148
                    "PUBLISH" => {
116
14
                        if let Some(subscription_manager) =
  Branch (116:32): [Folded - Ignored]
  Branch (116:32): [Folded - Ignored]
  Branch (116:32): [True: 14, False: 1]
117
15
                            self.subscription_manager.lock().unwrap().as_ref()
118
                        {
119
14
                            subscription_manager.notify_for_test(
120
14
                                str::from_utf8(args[1].as_bytes().expect("Notification not bytes"))
121
14
                                    .expect("Notification not UTF-8")
122
14
                                    .into(),
123
14
                            );
124
14
                            Value::Int(1)
125
                        } else {
126
1
                            Value::Int(0)
127
                        }
128
                    }
129
130
133
                    "FT.AGGREGATE" => {
131
                        // The query is either "*" (match all) or @field:{ value }.
132
20
                        let OwnedFrame::BulkString(ref raw_query) = args[1] else {
  Branch (132:29): [Folded - Ignored]
  Branch (132:29): [Folded - Ignored]
  Branch (132:29): [True: 20, False: 0]
133
0
                            panic!("Aggregate query should be a string: {args:?}");
134
                        };
135
20
                        let query = str::from_utf8(raw_query).unwrap();
136
                        // Lazy implementation making assumptions.
137
20
                        assert_eq!(
138
20
                            args[2..6],
139
20
                            vec![
140
20
                                OwnedFrame::BulkString(b"LOAD".to_vec()),
141
20
                                OwnedFrame::BulkString(b"2".to_vec()),
142
20
                                OwnedFrame::BulkString(b"data".to_vec()),
143
20
                                OwnedFrame::BulkString(b"version".to_vec())
144
                            ]
145
                        );
146
20
                        let mut results = vec![Value::Int(0)];
147
148
20
                        if query == "*" {
  Branch (148:28): [Folded - Ignored]
  Branch (148:28): [Folded - Ignored]
  Branch (148:28): [True: 0, False: 20]
149
                            // Wildcard query - return all records that have both data and version fields.
150
                            // Some entries (e.g., from HSET) may not have version field.
151
0
                            for fields in self.table.lock().unwrap().values() {
152
0
                                if let (Some(data), Some(version)) =
  Branch (152:40): [Folded - Ignored]
  Branch (152:40): [Folded - Ignored]
  Branch (152:40): [True: 0, False: 0]
153
0
                                    (fields.get("data"), fields.get("version"))
154
0
                                {
155
0
                                    results.push(Value::Array(vec![
156
0
                                        Value::BulkString(b"data".to_vec()),
157
0
                                        data.clone(),
158
0
                                        Value::BulkString(b"version".to_vec()),
159
0
                                        version.clone(),
160
0
                                    ]));
161
0
                                }
162
                            }
163
                        } else {
164
                            // Field-specific query: @field:{ value }
165
20
                            assert_eq!(&query[..1], "@");
166
20
                            let mut parts = query[1..].split(':');
167
20
                            let field = parts.next().expect("No field name");
168
20
                            let value = parts.next().expect("No value");
169
20
                            let value = value
170
20
                                .strip_prefix("{ ")
171
20
                                .and_then(|s| s.strip_suffix(" }"))
172
20
                                .unwrap_or(value);
173
56
                            for fields in 
self.table.lock().unwrap()20
.
values20
() {
174
56
                                if let Some(
key_value19
) = fields.get(field) {
  Branch (174:40): [Folded - Ignored]
  Branch (174:40): [Folded - Ignored]
  Branch (174:40): [True: 19, False: 37]
175
19
                                    if *key_value == Value::BulkString(value.as_bytes().to_vec()) {
  Branch (175:40): [Folded - Ignored]
  Branch (175:40): [Folded - Ignored]
  Branch (175:40): [True: 10, False: 9]
176
10
                                        results.push(Value::Array(vec![
177
10
                                            Value::BulkString(b"data".to_vec()),
178
10
                                            fields.get("data").expect("No data field").clone(),
179
10
                                            Value::BulkString(b"version".to_vec()),
180
10
                                            fields
181
10
                                                .get("version")
182
10
                                                .expect("No version field")
183
10
                                                .clone(),
184
10
                                        ]));
185
10
                                    
}9
186
37
                                }
187
                            }
188
                        }
189
190
20
                        results[0] =
191
20
                            Value::Int(i64::try_from(results.len() - 1).unwrap_or(i64::MAX));
192
20
                        Value::Array(vec![
193
20
                            Value::Array(results),
194
20
                            Value::Int(0), // Means no more items in cursor.
195
20
                        ])
196
                    }
197
198
113
                    "EVALSHA" => {
199
16
                        assert_eq!(
200
16
                            args[0],
201
16
                            OwnedFrame::BulkString(FAKE_SCRIPT_SHA.as_bytes().to_vec())
202
                        );
203
16
                        assert_eq!(args[1], OwnedFrame::BulkString(b"1".to_vec()));
204
16
                        let mut value: HashMap<_, Value> = HashMap::new();
205
16
                        value.insert(
206
16
                            "data".into(),
207
16
                            Value::BulkString(args[4].as_bytes().unwrap().to_vec()),
208
                        );
209
48
                        for pair in 
args[5..]16
.
chunks16
(2) {
210
48
                            value.insert(
211
48
                                str::from_utf8(pair[0].as_bytes().expect("Field name not bytes"))
212
48
                                    .expect("Unable to parse field name as string")
213
48
                                    .into(),
214
48
                                Value::BulkString(pair[1].as_bytes().unwrap().to_vec()),
215
48
                            );
216
48
                        }
217
16
                        let mut ret: Option<Value> = None;
218
16
                        let key: String =
219
16
                            str::from_utf8(args[2].as_bytes().expect("Key not bytes"))
220
16
                                .expect("Key cannot be parsed as string")
221
16
                                .into();
222
16
                        let expected_existing_version: i64 =
223
16
                            str::from_utf8(args[3].as_bytes().unwrap())
224
16
                                .unwrap()
225
16
                                .parse()
226
16
                                .expect("Unable to parse existing version field");
227
16
                        trace!(%key, %expected_existing_version, ?value, "Want to insert with EVALSHA");
228
16
                        let version = match self.table.lock().unwrap().entry(key.clone()) {
229
13
                            Entry::Occupied(mut occupied_entry) => {
230
13
                                let version = occupied_entry
231
13
                                    .get()
232
13
                                    .get("version")
233
13
                                    .expect("No version field");
234
13
                                let Value::BulkString(version_bytes) = version else {
  Branch (234:37): [Folded - Ignored]
  Branch (234:37): [Folded - Ignored]
  Branch (234:37): [True: 13, False: 0]
235
0
                                    panic!("Non-bulkstring version: {version:?}");
236
                                };
237
13
                                let version_int: i64 = str::from_utf8(version_bytes)
238
13
                                    .expect("Version field not valid string")
239
13
                                    .parse()
240
13
                                    .expect("Unable to parse version field");
241
13
                                if version_int == expected_existing_version {
  Branch (241:36): [Folded - Ignored]
  Branch (241:36): [Folded - Ignored]
  Branch (241:36): [True: 9, False: 4]
242
9
                                    let new_version = version_int + 1;
243
9
                                    debug!(%key, %new_version, "Version update");
244
9
                                    value.insert(
245
9
                                        "version".into(),
246
9
                                        Value::BulkString(
247
9
                                            format!("{new_version}").as_bytes().to_vec(),
248
9
                                        ),
249
                                    );
250
9
                                    occupied_entry.insert(value);
251
9
                                    new_version
252
                                } else {
253
                                    // Version mismatch.
254
4
                                    debug!(%key, %version_int, %expected_existing_version, "Version mismatch");
255
4
                                    ret = Some(Value::Array(vec![
256
4
                                        Value::Int(0),
257
4
                                        Value::Int(version_int),
258
4
                                    ]));
259
4
                                    -1
260
                                }
261
                            }
262
3
                            Entry::Vacant(vacant_entry) => {
263
3
                                if expected_existing_version != 0 {
  Branch (263:36): [Folded - Ignored]
  Branch (263:36): [Folded - Ignored]
  Branch (263:36): [True: 0, False: 3]
264
                                    // Version mismatch.
265
0
                                    debug!(%key, %expected_existing_version, "Version mismatch, expected zero");
266
0
                                    ret = Some(Value::Array(vec![Value::Int(0), Value::Int(0)]));
267
0
                                    -1
268
                                } else {
269
3
                                    debug!(%key, "Version insert");
270
3
                                    value
271
3
                                        .insert("version".into(), Value::BulkString(b"1".to_vec()));
272
3
                                    vacant_entry.insert_entry(value);
273
3
                                    1
274
                                }
275
                            }
276
                        };
277
16
                        if let Some(
r4
) = ret {
  Branch (277:32): [Folded - Ignored]
  Branch (277:32): [Folded - Ignored]
  Branch (277:32): [True: 4, False: 12]
278
4
                            r
279
                        } else {
280
12
                            Value::Array(vec![Value::Int(1), Value::Int(version)])
281
                        }
282
                    }
283
284
97
                    "HMSET" => {
285
3
                        let mut values = HashMap::new();
286
3
                        assert_eq!(
287
3
                            (args.len() - 1).rem_euclid(2),
288
                            0,
289
0
                            "Non-even args for hmset: {args:?}"
290
                        );
291
3
                        let chunks = args[1..].chunks_exact(2);
292
6
                        for 
chunk3
in chunks {
293
3
                            let [key, value] = chunk else {
  Branch (293:33): [Folded - Ignored]
  Branch (293:33): [Folded - Ignored]
  Branch (293:33): [True: 3, False: 0]
294
0
                                panic!("Uneven hmset args");
295
                            };
296
3
                            let key_name: String =
297
3
                                str::from_utf8(key.as_bytes().expect("Key argument is not bytes"))
298
3
                                    .expect("Unable to parse key as string")
299
3
                                    .into();
300
3
                            values.insert(
301
3
                                key_name,
302
3
                                Value::BulkString(value.as_bytes().unwrap().to_vec()),
303
                            );
304
                        }
305
3
                        let key =
306
3
                            str::from_utf8(args[0].as_bytes().expect("Key argument is not bytes"))
307
3
                                .expect("Unable to parse key as string")
308
3
                                .into();
309
3
                        debug!(%key, ?values, "Inserting with HMSET");
310
3
                        self.table.lock().unwrap().insert(key, values);
311
3
                        Value::Okay
312
                    }
313
314
94
                    "HMGET" => {
315
94
                        let key_name =
316
94
                            str::from_utf8(args[0].as_bytes().expect("Key argument is not bytes"))
317
94
                                .expect("Unable to parse key name");
318
319
94
                        if let Some(
fields48
) = self.table.lock().unwrap().get(key_name) {
  Branch (319:32): [Folded - Ignored]
  Branch (319:32): [Folded - Ignored]
  Branch (319:32): [True: 48, False: 46]
320
48
                            trace!(%key_name, keys = ?fields.keys(), "Getting keys with HMGET, some keys");
321
48
                            let mut result = vec![];
322
96
                            for key in &
args[1..]48
{
323
96
                                let field_name = str::from_utf8(
324
96
                                    key.as_bytes().expect("Field argument is not bytes"),
325
                                )
326
96
                                .expect("Unable to parse requested field");
327
96
                                if let Some(
value94
) = fields.get(field_name) {
  Branch (327:40): [Folded - Ignored]
  Branch (327:40): [Folded - Ignored]
  Branch (327:40): [True: 94, False: 2]
328
94
                                    result.push(value.clone());
329
94
                                } else {
330
2
                                    debug!(%key_name, %field_name, "Missing field");
331
2
                                    result.push(Value::Nil);
332
                                }
333
                            }
334
48
                            Value::Array(result)
335
                        } else {
336
46
                            trace!(%key_name, "Getting keys with HMGET, empty");
337
46
                            let null_count = i64::try_from(args.len() - 1).unwrap();
338
46
                            Value::Array(vec![Value::Nil, Value::Int(null_count)])
339
                        }
340
                    }
341
0
                    actual => {
342
0
                        panic!("Mock command not implemented! {actual:?}");
343
                    }
344
                };
345
346
166
                arg_as_string(&mut output, ret);
347
166
                if buf_index == buf.len() {
  Branch (347:20): [Folded - Ignored]
  Branch (347:20): [Folded - Ignored]
  Branch (347:20): [True: 154, False: 12]
348
154
                    break;
349
12
                }
350
            }
351
154
            output
352
154
        };
353
3
        fake_redis_internal(listener, inner).await;
354
0
    }
355
356
3
    pub async fn run(self) -> u16 {
357
3
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
358
3
        let port = listener.local_addr().unwrap().port();
359
3
        info!("Using port {port}");
360
361
3
        background_spawn!("listener", async move {
362
3
            self.dynamic_fake_redis(listener).await;
363
0
        });
364
365
3
        port
366
3
    }
367
}