/build/source/nativelink-util/src/origin_event_middleware.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use std::sync::Arc; |
16 | | |
17 | | use base64::prelude::BASE64_STANDARD_NO_PAD; |
18 | | use base64::Engine; |
19 | | use futures::future::BoxFuture; |
20 | | use futures::task::{Context, Poll}; |
21 | | use hyper::http::{self, StatusCode}; |
22 | | use nativelink_config::cas_server::IdentityHeaderSpec; |
23 | | use nativelink_proto::build::bazel::remote::execution::v2::RequestMetadata; |
24 | | use nativelink_proto::com::github::trace_machina::nativelink::events::OriginEvent; |
25 | | use prost::Message; |
26 | | use tokio::sync::mpsc; |
27 | | use tower::layer::Layer; |
28 | | use tower::Service; |
29 | | use tracing::trace_span; |
30 | | |
31 | | use crate::origin_context::{ActiveOriginContext, ORIGIN_IDENTITY}; |
32 | | use crate::origin_event::{OriginEventCollector, ORIGIN_EVENT_COLLECTOR}; |
33 | | |
34 | | /// Default identity header name. |
35 | | /// Note: If this is changed, the default value in the [`IdentityHeaderSpec`] |
36 | | // TODO(allada) This has a mirror in bep_server.rs. |
37 | | // We should consolidate these. |
38 | | const DEFAULT_IDENTITY_HEADER: &str = "x-identity"; |
39 | | |
40 | | #[derive(Default, Clone)] |
41 | | pub struct OriginRequestMetadata { |
42 | | pub identity: String, |
43 | | pub bazel_metadata: Option<RequestMetadata>, |
44 | | } |
45 | | |
46 | | #[derive(Clone)] |
47 | | pub struct OriginEventMiddlewareLayer { |
48 | | maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>, |
49 | | idenity_header_config: Arc<IdentityHeaderSpec>, |
50 | | } |
51 | | |
52 | | impl OriginEventMiddlewareLayer { |
53 | 0 | pub fn new( |
54 | 0 | maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>, |
55 | 0 | idenity_header_config: IdentityHeaderSpec, |
56 | 0 | ) -> Self { |
57 | 0 | Self { |
58 | 0 | maybe_origin_event_tx, |
59 | 0 | idenity_header_config: Arc::new(idenity_header_config), |
60 | 0 | } |
61 | 0 | } |
62 | | } |
63 | | |
64 | | impl<S> Layer<S> for OriginEventMiddlewareLayer { |
65 | | type Service = OriginEventMiddleware<S>; |
66 | | |
67 | 0 | fn layer(&self, service: S) -> Self::Service { |
68 | 0 | OriginEventMiddleware { |
69 | 0 | inner: service, |
70 | 0 | maybe_origin_event_tx: self.maybe_origin_event_tx.clone(), |
71 | 0 | idenity_header_config: self.idenity_header_config.clone(), |
72 | 0 | } |
73 | 0 | } |
74 | | } |
75 | | |
76 | | #[derive(Clone)] |
77 | | pub struct OriginEventMiddleware<S> { |
78 | | inner: S, |
79 | | maybe_origin_event_tx: Option<mpsc::Sender<OriginEvent>>, |
80 | | idenity_header_config: Arc<IdentityHeaderSpec>, |
81 | | } |
82 | | |
83 | | impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for OriginEventMiddleware<S> |
84 | | where |
85 | | S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static, |
86 | | S::Future: Send + 'static, |
87 | | ReqBody: std::fmt::Debug + Send + 'static, |
88 | | ResBody: From<String> + Send + 'static, |
89 | | { |
90 | | type Response = S::Response; |
91 | | type Error = S::Error; |
92 | | type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>; |
93 | | |
94 | 0 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
95 | 0 | self.inner.poll_ready(cx) |
96 | 0 | } |
97 | | |
98 | 0 | fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future { |
99 | 0 | // We must take the current `inner` and not the clone. |
100 | 0 | // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services |
101 | 0 | let clone = self.inner.clone(); |
102 | 0 | let mut inner = std::mem::replace(&mut self.inner, clone); |
103 | 0 |
|
104 | 0 | let mut context = ActiveOriginContext::fork().unwrap_or_default(); |
105 | 0 | let identity = { |
106 | 0 | let identity_header = self |
107 | 0 | .idenity_header_config |
108 | 0 | .header_name |
109 | 0 | .as_deref() |
110 | 0 | .unwrap_or(DEFAULT_IDENTITY_HEADER); |
111 | 0 | let identity = if !identity_header.is_empty() { Branch (111:31): [Folded - Ignored]
Branch (111:31): [Folded - Ignored]
|
112 | 0 | req.headers() |
113 | 0 | .get(identity_header) |
114 | 0 | .and_then(|header| header.to_str().ok().map(str::to_string)) |
115 | 0 | .unwrap_or_default() |
116 | | } else { |
117 | 0 | String::new() |
118 | | }; |
119 | | |
120 | 0 | if identity.is_empty() && self.idenity_header_config.required { Branch (120:16): [Folded - Ignored]
Branch (120:39): [Folded - Ignored]
Branch (120:16): [Folded - Ignored]
Branch (120:39): [Folded - Ignored]
|
121 | 0 | return Box::pin(async move { |
122 | 0 | Ok(http::Response::builder() |
123 | 0 | .status(StatusCode::UNAUTHORIZED) |
124 | 0 | .body("'identity_header' header is required".to_string().into()) |
125 | 0 | .unwrap()) |
126 | 0 | }); |
127 | 0 | } |
128 | 0 | context.set_value(&ORIGIN_IDENTITY, Arc::new(identity.clone())); |
129 | 0 | identity |
130 | | }; |
131 | 0 | if let Some(origin_event_tx) = &self.maybe_origin_event_tx { Branch (131:16): [Folded - Ignored]
Branch (131:16): [Folded - Ignored]
|
132 | 0 | let bazel_metadata = req |
133 | 0 | .headers() |
134 | 0 | .get("build.bazel.remote.execution.v2.requestmetadata-bin") |
135 | 0 | .and_then(|header| BASE64_STANDARD_NO_PAD.decode(header.as_bytes()).ok()) |
136 | 0 | .and_then(|data| RequestMetadata::decode(data.as_slice()).ok()); |
137 | 0 | context.set_value( |
138 | 0 | &ORIGIN_EVENT_COLLECTOR, |
139 | 0 | Arc::new(OriginEventCollector::new( |
140 | 0 | origin_event_tx.clone(), |
141 | 0 | identity, |
142 | 0 | bazel_metadata, |
143 | 0 | )), |
144 | 0 | ); |
145 | 0 | } |
146 | | |
147 | 0 | Box::pin(async move { |
148 | 0 | Arc::new(context) |
149 | 0 | .wrap_async(trace_span!("OriginEventMiddleware"), inner.call(req)) |
150 | 0 | .await |
151 | 0 | }) |
152 | 0 | } |
153 | | } |