/build/source/src/bin/cas_speed_check.rs
Line | Count | Source |
1 | | use core::time::Duration; |
2 | | use std::sync::Arc; |
3 | | |
4 | | use clap::Parser; |
5 | | use nativelink_error::{Error, ResultExt}; |
6 | | use nativelink_proto::build::bazel::remote::execution::v2::content_addressable_storage_client::ContentAddressableStorageClient; |
7 | | use nativelink_proto::build::bazel::remote::execution::v2::{ |
8 | | Digest, FindMissingBlobsRequest, digest_function, |
9 | | }; |
10 | | use nativelink_util::spawn; |
11 | | use nativelink_util::telemetry::init_tracing; |
12 | | use nativelink_util::tls_utils::endpoint_from; |
13 | | use rand::{Rng, RngCore}; |
14 | | use sha2::{Digest as _, Sha256}; |
15 | | use tokio::sync::Mutex; |
16 | | use tokio::time::Instant; |
17 | | use tonic::Request; |
18 | | use tonic::transport::ClientTlsConfig; |
19 | | use tracing::info; |
20 | | |
21 | | #[derive(Parser, Debug)] |
22 | | #[command(version, about)] |
23 | | struct Args { |
24 | | #[arg(short, long)] |
25 | | endpoint: String, |
26 | | |
27 | | #[arg(short, long)] |
28 | | nativelink_key: Option<String>, |
29 | | } |
30 | | |
31 | 0 | fn main() -> Result<(), Box<dyn core::error::Error>> { |
32 | 0 | let args = Args::parse(); |
33 | | #[expect( |
34 | | clippy::disallowed_methods, |
35 | | reason = "It's the top-level, so we need the function" |
36 | | )] |
37 | 0 | tokio::runtime::Builder::new_multi_thread() |
38 | 0 | .enable_all() |
39 | 0 | .build() |
40 | 0 | .unwrap() |
41 | 0 | .block_on(async { |
42 | 0 | init_tracing()?; |
43 | 0 | let timings = Arc::new(Mutex::new(Vec::new())); |
44 | 0 | let spawns: Vec<_> = (0..200) |
45 | 0 | .map(|_| { |
46 | 0 | let local_timings = timings.clone(); |
47 | 0 | let local_endpoint = args.endpoint.clone(); |
48 | 0 | let local_api_key = args.nativelink_key.clone(); |
49 | 0 | spawn!("CAS requester", async move { |
50 | 0 | let tls_config = ClientTlsConfig::new().with_enabled_roots(); |
51 | 0 | let endpoint = endpoint_from(&local_endpoint, Some(tls_config))?; |
52 | 0 | let channel = endpoint.connect().await.unwrap(); |
53 | | |
54 | 0 | let mut client = ContentAddressableStorageClient::new(channel); |
55 | | |
56 | 0 | for _ in 0..100 { |
57 | 0 | let raw_data: String = rand::rng() |
58 | 0 | .sample_iter::<char, _>(rand::distr::StandardUniform) |
59 | 0 | .take(300) |
60 | 0 | .collect(); |
61 | 0 | let hashed = Sha256::digest(raw_data.as_bytes()); |
62 | 0 | let rand_hash = hex::encode(hashed); |
63 | 0 | let digest = Digest { |
64 | 0 | hash: rand_hash, |
65 | 0 | size_bytes: i64::from(rand::rng().next_u32()), |
66 | 0 | }; |
67 | | |
68 | 0 | let mut request = Request::new(FindMissingBlobsRequest { |
69 | 0 | instance_name: String::new(), |
70 | 0 | blob_digests: vec![digest.clone()], |
71 | 0 | digest_function: digest_function::Value::Sha256.into(), |
72 | 0 | }); |
73 | 0 | if let Some(ref api_key) = local_api_key { |
74 | 0 | request |
75 | 0 | .metadata_mut() |
76 | 0 | .insert("x-nativelink-api-key", api_key.parse().unwrap()); |
77 | 0 | } |
78 | 0 | let start = Instant::now(); |
79 | 0 | client |
80 | 0 | .find_missing_blobs(request) |
81 | 0 | .await |
82 | 0 | .err_tip(|| "in find_missing_blobs")? |
83 | 0 | .into_inner(); |
84 | 0 | let duration = Instant::now().checked_duration_since(start).unwrap(); |
85 | | |
86 | | // info!("response duration={duration:?} res={:?}", res); |
87 | 0 | local_timings.lock().await.push(duration); |
88 | | } |
89 | 0 | Ok::<(), Error>(()) |
90 | 0 | }) |
91 | 0 | }) |
92 | 0 | .collect(); |
93 | 0 | for thread in spawns { |
94 | 0 | let res = thread.await; |
95 | 0 | res.err_tip(|| "with spawn")??; |
96 | | } |
97 | 0 | let avg = Duration::from_secs_f64({ |
98 | 0 | let locked = timings.lock().await; |
99 | 0 | locked.iter().map(Duration::as_secs_f64).sum::<f64>() / locked.len() as f64 |
100 | | }); |
101 | 0 | info!(?avg, "avg"); |
102 | 0 | Ok::<(), Error>(()) |
103 | 0 | })?; |
104 | 0 | Ok(()) |
105 | 0 | } |