/build/source/nativelink-service/src/bytestream_server.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Functional Source License, Version 1.1, Apache 2.0 Future License (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // See LICENSE file for details |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::convert::Into; |
16 | | use core::fmt::{Debug, Formatter}; |
17 | | use core::pin::Pin; |
18 | | use core::sync::atomic::{AtomicU64, Ordering}; |
19 | | use core::time::Duration; |
20 | | use std::collections::HashMap; |
21 | | use std::collections::hash_map::Entry; |
22 | | use std::sync::Arc; |
23 | | use std::time::{SystemTime, UNIX_EPOCH}; |
24 | | |
25 | | use futures::future::{BoxFuture, pending}; |
26 | | use futures::stream::unfold; |
27 | | use futures::{Future, Stream, TryFutureExt, try_join}; |
28 | | use nativelink_config::cas_server::{ByteStreamConfig, InstanceName, WithInstanceName}; |
29 | | use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err}; |
30 | | use nativelink_proto::google::bytestream::byte_stream_server::{ |
31 | | ByteStream, ByteStreamServer as Server, |
32 | | }; |
33 | | use nativelink_proto::google::bytestream::{ |
34 | | QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest, |
35 | | WriteResponse, |
36 | | }; |
37 | | use nativelink_store::grpc_store::GrpcStore; |
38 | | use nativelink_store::store_manager::StoreManager; |
39 | | use nativelink_util::buf_channel::{ |
40 | | DropCloserReadHalf, DropCloserWriteHalf, make_buf_channel_pair, |
41 | | }; |
42 | | use nativelink_util::common::DigestInfo; |
43 | | use nativelink_util::digest_hasher::{ |
44 | | DigestHasherFunc, default_digest_hasher_func, make_ctx_for_hash_func, |
45 | | }; |
46 | | use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper; |
47 | | use nativelink_util::resource_info::ResourceInfo; |
48 | | use nativelink_util::spawn; |
49 | | use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo}; |
50 | | use nativelink_util::task::JoinHandleDropGuard; |
51 | | use opentelemetry::context::FutureExt; |
52 | | use parking_lot::Mutex; |
53 | | use tokio::time::sleep; |
54 | | use tonic::{Request, Response, Status, Streaming}; |
55 | | use tracing::{Instrument, Level, debug, error, error_span, info, instrument, trace, warn}; |
56 | | |
57 | | /// If this value changes update the documentation in the config definition. |
58 | | const DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT: Duration = Duration::from_secs(60); |
59 | | |
60 | | /// If this value changes update the documentation in the config definition. |
61 | | const DEFAULT_MAX_BYTES_PER_STREAM: usize = 64 * 1024; |
62 | | |
63 | | type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>); |
64 | | type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>; |
65 | | |
66 | | pub struct InstanceInfo { |
67 | | store: Store, |
68 | | // Max number of bytes to send on each grpc stream chunk. |
69 | | max_bytes_per_stream: usize, |
70 | | active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>, |
71 | | sleep_fn: SleepFn, |
72 | | } |
73 | | |
74 | | impl Debug for InstanceInfo { |
75 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { |
76 | 0 | f.debug_struct("InstanceInfo") |
77 | 0 | .field("store", &self.store) |
78 | 0 | .field("max_bytes_per_stream", &self.max_bytes_per_stream) |
79 | 0 | .field("active_uploads", &self.active_uploads) |
80 | 0 | .finish() |
81 | 0 | } |
82 | | } |
83 | | |
84 | | type ReadStream = Pin<Box<dyn Stream<Item = Result<ReadResponse, Status>> + Send + 'static>>; |
85 | | type StoreUpdateFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>; |
86 | | |
87 | | struct StreamState { |
88 | | uuid: String, |
89 | | tx: DropCloserWriteHalf, |
90 | | store_update_fut: StoreUpdateFuture, |
91 | | } |
92 | | |
93 | | impl Debug for StreamState { |
94 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { |
95 | 0 | f.debug_struct("StreamState") |
96 | 0 | .field("uuid", &self.uuid) |
97 | 0 | .finish() |
98 | 0 | } |
99 | | } |
100 | | |
101 | | /// If a stream is in this state, it will automatically be put back into an `IdleStream` and |
102 | | /// placed back into the `active_uploads` map as an `IdleStream` after it is dropped. |
103 | | /// To prevent it from being put back into an `IdleStream` you must call `.graceful_finish()`. |
104 | | struct ActiveStreamGuard { |
105 | | stream_state: Option<StreamState>, |
106 | | bytes_received: Arc<AtomicU64>, |
107 | | active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>, |
108 | | sleep_fn: SleepFn, |
109 | | } |
110 | | |
111 | | impl ActiveStreamGuard { |
112 | | /// Consumes the guard. The stream will be considered "finished", will |
113 | | /// remove it from the `active_uploads`. |
114 | 8 | fn graceful_finish(mut self) { |
115 | 8 | let stream_state = self.stream_state.take().unwrap(); |
116 | 8 | self.active_uploads.lock().remove(&stream_state.uuid); |
117 | 8 | } |
118 | | } |
119 | | |
120 | | impl Drop for ActiveStreamGuard { |
121 | 15 | fn drop(&mut self) { |
122 | 15 | let Some(stream_state7 ) = self.stream_state.take() else { Branch (122:13): [True: 7, False: 8]
Branch (122:13): [Folded - Ignored]
|
123 | 8 | return; // If None it means we don't want it put back into an IdleStream. |
124 | | }; |
125 | 7 | let weak_active_uploads = Arc::downgrade(&self.active_uploads); |
126 | 7 | let mut active_uploads = self.active_uploads.lock(); |
127 | 7 | let uuid = stream_state.uuid.clone(); |
128 | 7 | let Some(active_uploads_slot) = active_uploads.get_mut(&uuid) else { Branch (128:13): [True: 7, False: 0]
Branch (128:13): [Folded - Ignored]
|
129 | 0 | error!( |
130 | | err = "Failed to find active upload. This should never happen.", |
131 | | uuid = ?uuid, |
132 | | ); |
133 | 0 | return; |
134 | | }; |
135 | 7 | let sleep_fn = self.sleep_fn.clone(); |
136 | 7 | active_uploads_slot.1 = Some(IdleStream { |
137 | 7 | stream_state, |
138 | 7 | _timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move {3 |
139 | 3 | (*sleep_fn)().await; |
140 | 0 | if let Some(active_uploads) = weak_active_uploads.upgrade() { Branch (140:24): [True: 0, False: 0]
Branch (140:24): [Folded - Ignored]
|
141 | 0 | let mut active_uploads = active_uploads.lock(); |
142 | 0 | info!(msg = "Removing idle stream", uuid = ?uuid); |
143 | 0 | active_uploads.remove(&uuid); |
144 | 0 | } |
145 | 0 | }), |
146 | | }); |
147 | 15 | } |
148 | | } |
149 | | |
150 | | /// Represents a stream that is in the "idle" state. this means it is not currently being used |
151 | | /// by a client. If it is not used within a certain amount of time it will be removed from the |
152 | | /// `active_uploads` map automatically. |
153 | | #[derive(Debug)] |
154 | | struct IdleStream { |
155 | | stream_state: StreamState, |
156 | | _timeout_streaam_drop_guard: JoinHandleDropGuard<()>, |
157 | | } |
158 | | |
159 | | impl IdleStream { |
160 | 3 | fn into_active_stream( |
161 | 3 | self, |
162 | 3 | bytes_received: Arc<AtomicU64>, |
163 | 3 | instance_info: &InstanceInfo, |
164 | 3 | ) -> ActiveStreamGuard { |
165 | 3 | ActiveStreamGuard { |
166 | 3 | stream_state: Some(self.stream_state), |
167 | 3 | bytes_received, |
168 | 3 | active_uploads: instance_info.active_uploads.clone(), |
169 | 3 | sleep_fn: instance_info.sleep_fn.clone(), |
170 | 3 | } |
171 | 3 | } |
172 | | } |
173 | | |
174 | | #[derive(Debug)] |
175 | | pub struct ByteStreamServer { |
176 | | instance_infos: HashMap<InstanceName, InstanceInfo>, |
177 | | } |
178 | | |
179 | | impl ByteStreamServer { |
180 | | /// Generate a unique UUID by appending a nanosecond timestamp to avoid collisions. |
181 | 0 | fn generate_unique_uuid(base_uuid: &str) -> String { |
182 | 0 | let timestamp = SystemTime::now() |
183 | 0 | .duration_since(UNIX_EPOCH) |
184 | 0 | .unwrap_or_default() |
185 | 0 | .as_nanos(); |
186 | 0 | format!("{base_uuid}-{timestamp:x}") |
187 | 0 | } |
188 | | |
189 | 15 | pub fn new( |
190 | 15 | configs: &[WithInstanceName<ByteStreamConfig>], |
191 | 15 | store_manager: &StoreManager, |
192 | 15 | ) -> Result<Self, Error> { |
193 | 15 | let mut instance_infos: HashMap<String, InstanceInfo> = HashMap::new(); |
194 | 30 | for config15 in configs { |
195 | 15 | let persist_stream_on_disconnect_timeout = |
196 | 15 | if config.persist_stream_on_disconnect_timeout == 0 { Branch (196:20): [True: 15, False: 0]
Branch (196:20): [Folded - Ignored]
|
197 | 15 | DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT |
198 | | } else { |
199 | 0 | Duration::from_secs(config.persist_stream_on_disconnect_timeout as u64) |
200 | | }; |
201 | 15 | let _old_value = instance_infos.insert( |
202 | 15 | config.instance_name.clone(), |
203 | 15 | Self::new_with_sleep_fn( |
204 | 15 | config, |
205 | 15 | store_manager, |
206 | 15 | Arc::new(move || Box::pin3 (sleep3 (persist_stream_on_disconnect_timeout3 ))), |
207 | 0 | )?, |
208 | | ); |
209 | | } |
210 | 15 | Ok(Self { instance_infos }) |
211 | 15 | } |
212 | | |
213 | 15 | pub fn new_with_sleep_fn( |
214 | 15 | config: &WithInstanceName<ByteStreamConfig>, |
215 | 15 | store_manager: &StoreManager, |
216 | 15 | sleep_fn: SleepFn, |
217 | 15 | ) -> Result<InstanceInfo, Error> { |
218 | 15 | let store = store_manager |
219 | 15 | .get_store(&config.cas_store) |
220 | 15 | .ok_or_else(|| make_input_err!("'cas_store': '{}' does not exist", config.cas_store0 ))?0 ; |
221 | 15 | let max_bytes_per_stream = if config.max_bytes_per_stream == 0 { Branch (221:39): [True: 1, False: 14]
Branch (221:39): [Folded - Ignored]
|
222 | 1 | DEFAULT_MAX_BYTES_PER_STREAM |
223 | | } else { |
224 | 14 | config.max_bytes_per_stream |
225 | | }; |
226 | 15 | Ok(InstanceInfo { |
227 | 15 | store, |
228 | 15 | max_bytes_per_stream, |
229 | 15 | active_uploads: Arc::new(Mutex::new(HashMap::new())), |
230 | 15 | sleep_fn, |
231 | 15 | }) |
232 | 15 | } |
233 | | |
234 | 1 | pub fn into_service(self) -> Server<Self> { |
235 | 1 | Server::new(self) |
236 | 1 | } |
237 | | |
238 | | /// Creates or joins an upload stream for the given UUID. |
239 | | /// |
240 | | /// This function handles three scenarios: |
241 | | /// 1. UUID doesn't exist - creates a new upload stream |
242 | | /// 2. UUID exists but is idle - resumes the existing stream |
243 | | /// 3. UUID exists and is active - generates a unique UUID by appending a nanosecond |
244 | | /// timestamp to avoid collision, then creates a new stream with that UUID |
245 | | /// |
246 | | /// The nanosecond timestamp ensures virtually zero probability of collision since |
247 | | /// two concurrent uploads would need to both collide on the original UUID AND |
248 | | /// generate the unique UUID in the exact same nanosecond. |
249 | 15 | fn create_or_join_upload_stream( |
250 | 15 | &self, |
251 | 15 | uuid: &str, |
252 | 15 | instance: &InstanceInfo, |
253 | 15 | digest: DigestInfo, |
254 | 15 | ) -> ActiveStreamGuard { |
255 | 15 | let (uuid12 , bytes_received12 ) = match instance.active_uploads.lock().entry(uuid.to_string()) { |
256 | 3 | Entry::Occupied(mut entry) => { |
257 | 3 | let maybe_idle_stream = entry.get_mut(); |
258 | 3 | if let Some(idle_stream) = maybe_idle_stream.1.take() { Branch (258:24): [True: 3, False: 0]
Branch (258:24): [Folded - Ignored]
|
259 | | // Case 2: Stream exists but is idle, we can resume it |
260 | 3 | let bytes_received = maybe_idle_stream.0.clone(); |
261 | 3 | info!(msg = "Joining existing stream", entry = ?entry.key()); |
262 | 3 | return idle_stream.into_active_stream(bytes_received, instance); |
263 | 0 | } |
264 | | // Case 3: Stream is active - generate a unique UUID to avoid collision |
265 | | // Using nanosecond timestamp makes collision probability essentially zero |
266 | 0 | let original_uuid = entry.key().clone(); |
267 | 0 | let unique_uuid = Self::generate_unique_uuid(&original_uuid); |
268 | 0 | warn!( |
269 | | msg = "UUID collision detected, generating unique UUID to prevent conflict", |
270 | | original_uuid = ?original_uuid, |
271 | | unique_uuid = ?unique_uuid |
272 | | ); |
273 | | // Entry goes out of scope here, releasing the lock |
274 | | |
275 | 0 | let bytes_received = Arc::new(AtomicU64::new(0)); |
276 | 0 | let mut active_uploads = instance.active_uploads.lock(); |
277 | | // Insert with the unique UUID - this should never collide due to nanosecond precision |
278 | 0 | active_uploads.insert(unique_uuid.clone(), (bytes_received.clone(), None)); |
279 | 0 | (unique_uuid, bytes_received) |
280 | | } |
281 | 12 | Entry::Vacant(entry) => { |
282 | | // Case 1: UUID doesn't exist, create new stream |
283 | 12 | let bytes_received = Arc::new(AtomicU64::new(0)); |
284 | 12 | let uuid = entry.key().clone(); |
285 | | // Our stream is "in use" if the key is in the map, but the value is None. |
286 | 12 | entry.insert((bytes_received.clone(), None)); |
287 | 12 | (uuid, bytes_received) |
288 | | } |
289 | | }; |
290 | | |
291 | | // Important: Do not return an error from this point onwards without |
292 | | // removing the entry from the map, otherwise that UUID becomes |
293 | | // unusable. |
294 | | |
295 | 12 | let (tx, rx) = make_buf_channel_pair(); |
296 | 12 | let store = instance.store.clone(); |
297 | 12 | let store_update_fut = Box::pin(async move {8 |
298 | | // We need to wrap `Store::update()` in a another future because we need to capture |
299 | | // `store` to ensure its lifetime follows the future and not the caller. |
300 | 8 | store |
301 | 8 | // Bytestream always uses digest size as the actual byte size. |
302 | 8 | .update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes())) |
303 | 8 | .await |
304 | 8 | }); |
305 | 12 | ActiveStreamGuard { |
306 | 12 | stream_state: Some(StreamState { |
307 | 12 | uuid, |
308 | 12 | tx, |
309 | 12 | store_update_fut, |
310 | 12 | }), |
311 | 12 | bytes_received, |
312 | 12 | active_uploads: instance.active_uploads.clone(), |
313 | 12 | sleep_fn: instance.sleep_fn.clone(), |
314 | 12 | } |
315 | 15 | } |
316 | | |
317 | 3 | async fn inner_read( |
318 | 3 | &self, |
319 | 3 | instance: &InstanceInfo, |
320 | 3 | digest: DigestInfo, |
321 | 3 | read_request: ReadRequest, |
322 | 3 | ) -> Result<impl Stream<Item = Result<ReadResponse, Status>> + Send + use<>, Error> { |
323 | | struct ReaderState { |
324 | | max_bytes_per_stream: usize, |
325 | | rx: DropCloserReadHalf, |
326 | | maybe_get_part_result: Option<Result<(), Error>>, |
327 | | get_part_fut: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>, |
328 | | } |
329 | | |
330 | 3 | let read_limit = u64::try_from(read_request.read_limit) |
331 | 3 | .err_tip(|| "Could not convert read_limit to u64")?0 ; |
332 | | |
333 | 3 | let (tx, rx) = make_buf_channel_pair(); |
334 | | |
335 | 3 | let read_limit = if read_limit != 0 { Branch (335:29): [True: 3, False: 0]
Branch (335:29): [Folded - Ignored]
|
336 | 3 | Some(read_limit) |
337 | | } else { |
338 | 0 | None |
339 | | }; |
340 | | |
341 | | // This allows us to call a destructor when the the object is dropped. |
342 | 3 | let store = instance.store.clone(); |
343 | 3 | let state = Some(ReaderState { |
344 | 3 | rx, |
345 | 3 | max_bytes_per_stream: instance.max_bytes_per_stream, |
346 | 3 | maybe_get_part_result: None, |
347 | 3 | get_part_fut: Box::pin(async move { |
348 | 3 | store |
349 | 3 | .get_part( |
350 | 3 | digest, |
351 | 3 | tx, |
352 | 3 | u64::try_from(read_request.read_offset) |
353 | 3 | .err_tip(|| "Could not convert read_offset to u64")?0 , |
354 | 3 | read_limit, |
355 | | ) |
356 | 3 | .await |
357 | 3 | }), |
358 | | }); |
359 | | |
360 | 3 | let read_stream_span = error_span!("read_stream"); |
361 | | |
362 | 9.77k | Ok(Box::pin3 (unfold3 (state3 , move |state| { |
363 | 9.77k | async { |
364 | 9.77k | let mut state = state?0 ; // If None our stream is done. |
365 | 9.77k | let mut response = ReadResponse::default(); |
366 | | { |
367 | 9.77k | let consume_fut = state.rx.consume(Some(state.max_bytes_per_stream)); |
368 | 9.77k | tokio::pin!(consume_fut); |
369 | | loop { |
370 | 9.77k | tokio::select! { |
371 | 9.77k | read_result9.77k = &mut consume_fut => { |
372 | 9.77k | match read_result { |
373 | 9.76k | Ok(bytes) => { |
374 | 9.76k | if bytes.is_empty() { Branch (374:40): [True: 2, False: 9.76k]
Branch (374:40): [Folded - Ignored]
|
375 | | // EOF. |
376 | 2 | return None; |
377 | 9.76k | } |
378 | 9.76k | if bytes.len() > state.max_bytes_per_stream { Branch (378:40): [True: 0, False: 9.76k]
Branch (378:40): [Folded - Ignored]
|
379 | 0 | let err = make_err!(Code::Internal, "Returned store size was larger than read size"); |
380 | 0 | return Some((Err(err.into()), None)); |
381 | 9.76k | } |
382 | 9.76k | response.data = bytes; |
383 | 9.76k | trace!(response = ?response); |
384 | 9.76k | debug!(response.data = format!("<redacted len({})>", response.data.len())); |
385 | 9.76k | break; |
386 | | } |
387 | 1 | Err(mut e) => { |
388 | | // We may need to propagate the error from reading the data through first. |
389 | | // For example, the NotFound error will come through `get_part_fut`, and |
390 | | // will not be present in `e`, but we need to ensure we pass NotFound error |
391 | | // code or the client won't know why it failed. |
392 | 1 | let get_part_result = if let Some(result) = state.maybe_get_part_result { Branch (392:66): [True: 1, False: 0]
Branch (392:66): [Folded - Ignored]
|
393 | 1 | result |
394 | | } else { |
395 | | // This should never be `future::pending()` if maybe_get_part_result is |
396 | | // not set. |
397 | 0 | state.get_part_fut.await |
398 | | }; |
399 | 1 | if let Err(err) = get_part_result { Branch (399:44): [True: 1, False: 0]
Branch (399:44): [Folded - Ignored]
|
400 | 1 | e = err.merge(e); |
401 | 1 | }0 |
402 | 1 | if e.code == Code::NotFound { Branch (402:40): [True: 1, False: 0]
Branch (402:40): [Folded - Ignored]
|
403 | 1 | // Trim the error code. Not Found is quite common and we don't want to send a large |
404 | 1 | // error (debug) message for something that is common. We resize to just the last |
405 | 1 | // message as it will be the most relevant. |
406 | 1 | e.messages.truncate(1); |
407 | 1 | }0 |
408 | 1 | error!(response = ?e); |
409 | 1 | return Some((Err(e.into()), None)) |
410 | | } |
411 | | } |
412 | | }, |
413 | 9.77k | result3 = &mut state.get_part_fut => { |
414 | 3 | state.maybe_get_part_result = Some(result); |
415 | 3 | // It is non-deterministic on which future will finish in what order. |
416 | 3 | // It is also possible that the `state.rx.consume()` call above may not be able to |
417 | 3 | // respond even though the publishing future is done. |
418 | 3 | // Because of this we set the writing future to pending so it never finishes. |
419 | 3 | // The `state.rx.consume()` future will eventually finish and return either the |
420 | 3 | // data or an error. |
421 | 3 | // An EOF will terminate the `state.rx.consume()` future, but we are also protected |
422 | 3 | // because we are dropping the writing future, it will drop the `tx` channel |
423 | 3 | // which will eventually propagate an error to the `state.rx.consume()` future if |
424 | 3 | // the EOF was not sent due to some other error. |
425 | 3 | state.get_part_fut = Box::pin(pending()); |
426 | 3 | }, |
427 | | } |
428 | | } |
429 | | } |
430 | 9.76k | Some((Ok(response), Some(state))) |
431 | 9.77k | }.instrument(read_stream_span.clone()) |
432 | 9.77k | }))) |
433 | 3 | } |
434 | | |
435 | | // We instrument tracing here as well as below because `stream` has a hash on it |
436 | | // that is extracted from the first stream message. If we only implemented it below |
437 | | // we would not have the hash available to us. |
438 | | #[instrument( |
439 | | ret(level = Level::DEBUG), |
440 | | level = Level::ERROR, |
441 | | skip(self, instance_info), |
442 | | )] |
443 | 15 | async fn inner_write( |
444 | 15 | &self, |
445 | 15 | instance_info: &InstanceInfo, |
446 | 15 | digest: DigestInfo, |
447 | 15 | stream: WriteRequestStreamWrapper<impl Stream<Item = Result<WriteRequest, Status>> + Unpin>, |
448 | 15 | ) -> Result<Response<WriteResponse>, Error> { |
449 | 15 | async fn process_client_stream( |
450 | 15 | mut stream: WriteRequestStreamWrapper< |
451 | 15 | impl Stream<Item = Result<WriteRequest, Status>> + Unpin, |
452 | 15 | >, |
453 | 15 | tx: &mut DropCloserWriteHalf, |
454 | 15 | outer_bytes_received: &Arc<AtomicU64>, |
455 | 15 | expected_size: u64, |
456 | 15 | ) -> Result<(), Error> { |
457 | | loop { |
458 | 25 | let write_request20 = match stream.next().await { |
459 | | // Code path for when client tries to gracefully close the stream. |
460 | | // If this happens it means there's a problem with the data sent, |
461 | | // because we always close the stream from our end before this point |
462 | | // by counting the number of bytes sent from the client. If they send |
463 | | // less than the amount they said they were going to send and then |
464 | | // close the stream, we know there's a problem. |
465 | | None => { |
466 | 0 | return Err(make_input_err!( |
467 | 0 | "Client closed stream before sending all data" |
468 | 0 | )); |
469 | | } |
470 | | // Code path for client stream error. Probably client disconnect. |
471 | 5 | Some(Err(err)) => return Err(err), |
472 | | // Code path for received chunk of data. |
473 | 20 | Some(Ok(write_request)) => write_request, |
474 | | }; |
475 | | |
476 | 20 | if write_request.write_offset < 0 { Branch (476:20): [True: 1, False: 19]
Branch (476:20): [Folded - Ignored]
|
477 | 1 | return Err(make_input_err!( |
478 | 1 | "Invalid negative write offset in write request: {}", |
479 | 1 | write_request.write_offset |
480 | 1 | )); |
481 | 19 | } |
482 | 19 | let write_offset = write_request.write_offset as u64; |
483 | | |
484 | | // If we get duplicate data because a client didn't know where |
485 | | // it left off from, then we can simply skip it. |
486 | 19 | let data18 = if write_offset < tx.get_bytes_written() { Branch (486:31): [True: 2, False: 17]
Branch (486:31): [Folded - Ignored]
|
487 | 2 | if (write_offset + write_request.data.len() as u64) < tx.get_bytes_written() { Branch (487:24): [True: 0, False: 2]
Branch (487:24): [Folded - Ignored]
|
488 | 0 | if write_request.finish_write { Branch (488:28): [True: 0, False: 0]
Branch (488:28): [Folded - Ignored]
|
489 | 0 | return Err(make_input_err!( |
490 | 0 | "Resumed stream finished at {} bytes when we already received {} bytes.", |
491 | 0 | write_offset + write_request.data.len() as u64, |
492 | 0 | tx.get_bytes_written() |
493 | 0 | )); |
494 | 0 | } |
495 | 0 | continue; |
496 | 2 | } |
497 | 2 | write_request |
498 | 2 | .data |
499 | 2 | .slice((tx.get_bytes_written() - write_offset) as usize..) |
500 | | } else { |
501 | 17 | if write_offset != tx.get_bytes_written() { Branch (501:24): [True: 1, False: 16]
Branch (501:24): [Folded - Ignored]
|
502 | 1 | return Err(make_input_err!( |
503 | 1 | "Received out of order data. Got {}, expected {}", |
504 | 1 | write_offset, |
505 | 1 | tx.get_bytes_written() |
506 | 1 | )); |
507 | 16 | } |
508 | 16 | write_request.data |
509 | | }; |
510 | | |
511 | | // Do not process EOF or weird stuff will happen. |
512 | 18 | if !data.is_empty() { Branch (512:20): [True: 13, False: 5]
Branch (512:20): [Folded - Ignored]
|
513 | | // We also need to process the possible EOF branch, so we can't early return. |
514 | 13 | if let Err(mut err0 ) = tx.send(data).await { Branch (514:28): [True: 0, False: 13]
Branch (514:28): [Folded - Ignored]
|
515 | 0 | err.code = Code::Internal; |
516 | 0 | return Err(err); |
517 | 13 | } |
518 | 13 | outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release); |
519 | 5 | } |
520 | | |
521 | 18 | if expected_size < tx.get_bytes_written() { Branch (521:20): [True: 0, False: 18]
Branch (521:20): [Folded - Ignored]
|
522 | 0 | return Err(make_input_err!("Received more bytes than expected")); |
523 | 18 | } |
524 | 18 | if write_request.finish_write { Branch (524:20): [True: 8, False: 10]
Branch (524:20): [Folded - Ignored]
|
525 | | // Gracefully close our stream. |
526 | 8 | tx.send_eof() |
527 | 8 | .err_tip(|| "Failed to send EOF in ByteStream::write")?0 ; |
528 | 8 | return Ok(()); |
529 | 10 | } |
530 | | // Continue. |
531 | | } |
532 | | // Unreachable. |
533 | 15 | } |
534 | | |
535 | | let uuid = stream |
536 | | .resource_info |
537 | | .uuid |
538 | | .as_ref() |
539 | | .ok_or_else(|| make_input_err!("UUID must be set if writing data"))?; |
540 | | let mut active_stream_guard = |
541 | | self.create_or_join_upload_stream(uuid, instance_info, digest); |
542 | | let expected_size = stream.resource_info.expected_size as u64; |
543 | | |
544 | | let active_stream = active_stream_guard.stream_state.as_mut().unwrap(); |
545 | | try_join!( |
546 | | process_client_stream( |
547 | | stream, |
548 | | &mut active_stream.tx, |
549 | | &active_stream_guard.bytes_received, |
550 | | expected_size |
551 | | ), |
552 | | (&mut active_stream.store_update_fut) |
553 | 0 | .map_err(|err| { err.append("Error updating inner store") }) |
554 | | )?; |
555 | | |
556 | | // Close our guard and consider the stream no longer active. |
557 | | active_stream_guard.graceful_finish(); |
558 | | |
559 | | Ok(Response::new(WriteResponse { |
560 | | committed_size: expected_size as i64, |
561 | | })) |
562 | 15 | } |
563 | | |
564 | 3 | async fn inner_query_write_status( |
565 | 3 | &self, |
566 | 3 | query_request: &QueryWriteStatusRequest, |
567 | 3 | ) -> Result<Response<QueryWriteStatusResponse>, Error> { |
568 | 3 | let mut resource_info = ResourceInfo::new(&query_request.resource_name, true)?0 ; |
569 | | |
570 | 3 | let instance = self |
571 | 3 | .instance_infos |
572 | 3 | .get(resource_info.instance_name.as_ref()) |
573 | 3 | .err_tip(|| {0 |
574 | 0 | format!( |
575 | 0 | "'instance_name' not configured for '{}'", |
576 | 0 | &resource_info.instance_name |
577 | | ) |
578 | 0 | })?; |
579 | 3 | let store_clone = instance.store.clone(); |
580 | | |
581 | 3 | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?0 ; |
582 | | |
583 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
584 | 3 | if let Some(grpc_store0 ) = store_clone.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (584:16): [True: 0, False: 3]
Branch (584:16): [Folded - Ignored]
|
585 | 0 | return grpc_store |
586 | 0 | .query_write_status(Request::new(query_request.clone())) |
587 | 0 | .await; |
588 | 3 | } |
589 | | |
590 | 3 | let uuid = resource_info |
591 | 3 | .uuid |
592 | 3 | .take() |
593 | 3 | .ok_or_else(|| make_input_err!("UUID must be set if querying write status"))?0 ; |
594 | | |
595 | | { |
596 | 3 | let active_uploads = instance.active_uploads.lock(); |
597 | 3 | if let Some((received_bytes1 , _maybe_idle_stream1 )) = active_uploads.get(uuid.as_ref()) { Branch (597:20): [True: 1, False: 2]
Branch (597:20): [Folded - Ignored]
|
598 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
599 | 1 | committed_size: received_bytes.load(Ordering::Acquire) as i64, |
600 | 1 | // If we are in the active_uploads map, but the value is None, |
601 | 1 | // it means the stream is not complete. |
602 | 1 | complete: false, |
603 | 1 | })); |
604 | 2 | } |
605 | | } |
606 | | |
607 | 2 | let has_fut = store_clone.has(digest); |
608 | 2 | let Some(item_size1 ) = has_fut.await.err_tip(|| "Failed to call .has() on store")?0 else { Branch (608:13): [True: 1, False: 1]
Branch (608:13): [Folded - Ignored]
|
609 | | // We lie here and say that the stream needs to start over, even though |
610 | | // it was never started. This can happen when the client disconnects |
611 | | // before sending the first payload, but the client thinks it did send |
612 | | // the payload. |
613 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
614 | 1 | committed_size: 0, |
615 | 1 | complete: false, |
616 | 1 | })); |
617 | | }; |
618 | 1 | Ok(Response::new(QueryWriteStatusResponse { |
619 | 1 | committed_size: item_size as i64, |
620 | 1 | complete: true, |
621 | 1 | })) |
622 | 3 | } |
623 | | } |
624 | | |
625 | | #[tonic::async_trait] |
626 | | impl ByteStream for ByteStreamServer { |
627 | | type ReadStream = ReadStream; |
628 | | |
629 | | #[instrument( |
630 | | err, |
631 | | level = Level::ERROR, |
632 | | skip_all, |
633 | | fields(request = ?grpc_request.get_ref()) |
634 | | )] |
635 | | async fn read( |
636 | | &self, |
637 | | grpc_request: Request<ReadRequest>, |
638 | | ) -> Result<Response<Self::ReadStream>, Status> { |
639 | | let read_request = grpc_request.into_inner(); |
640 | | let resource_info = ResourceInfo::new(&read_request.resource_name, false)?; |
641 | | let instance_name = resource_info.instance_name.as_ref(); |
642 | | let instance = self |
643 | | .instance_infos |
644 | | .get(instance_name) |
645 | 0 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'"))?; |
646 | | let store = instance.store.clone(); |
647 | | |
648 | | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?; |
649 | | |
650 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
651 | | if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { |
652 | | let stream = Box::pin(grpc_store.read(Request::new(read_request)).await?); |
653 | | return Ok(Response::new(stream)); |
654 | | } |
655 | | |
656 | | let digest_function = resource_info.digest_function.as_deref().map_or_else( |
657 | 3 | || Ok(default_digest_hasher_func()), |
658 | | DigestHasherFunc::try_from, |
659 | | )?; |
660 | | |
661 | | let resp = self |
662 | | .inner_read(instance, digest, read_request) |
663 | | .instrument(error_span!("bytestream_read")) |
664 | | .with_context( |
665 | | make_ctx_for_hash_func(digest_function).err_tip(|| "In BytestreamServer::read")?, |
666 | | ) |
667 | | .await |
668 | | .err_tip(|| "In ByteStreamServer::read") |
669 | 3 | .map(|stream| -> Response<Self::ReadStream> { Response::new(Box::pin(stream)) }) |
670 | | .map_err(Into::into); |
671 | | |
672 | | if resp.is_ok() { |
673 | | debug!(return = "Ok(<stream>)"); |
674 | | } |
675 | | |
676 | | resp |
677 | | } |
678 | | |
679 | | #[instrument( |
680 | | err, |
681 | | level = Level::ERROR, |
682 | | skip_all, |
683 | | fields(request = ?grpc_request.get_ref()) |
684 | | )] |
685 | | async fn write( |
686 | | &self, |
687 | | grpc_request: Request<Streaming<WriteRequest>>, |
688 | | ) -> Result<Response<WriteResponse>, Status> { |
689 | | let request = grpc_request.into_inner(); |
690 | | let stream = WriteRequestStreamWrapper::from(request) |
691 | | .await |
692 | | .err_tip(|| "Could not unwrap first stream message") |
693 | | .map_err(Into::<Status>::into)?; |
694 | | |
695 | | let instance_name = stream.resource_info.instance_name.as_ref(); |
696 | | let instance = self |
697 | | .instance_infos |
698 | | .get(instance_name) |
699 | 0 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'"))?; |
700 | | let store = instance.store.clone(); |
701 | | |
702 | | let digest = DigestInfo::try_new( |
703 | | &stream.resource_info.hash, |
704 | | stream.resource_info.expected_size, |
705 | | ) |
706 | | .err_tip(|| "Invalid digest input in ByteStream::write")?; |
707 | | |
708 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
709 | | if let Some(grpc_store) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { |
710 | | let resp = grpc_store.write(stream).await.map_err(Into::into); |
711 | | return resp; |
712 | | } |
713 | | |
714 | | let digest_function = stream |
715 | | .resource_info |
716 | | .digest_function |
717 | | .as_deref() |
718 | | .map_or_else( |
719 | 15 | || Ok(default_digest_hasher_func()), |
720 | | DigestHasherFunc::try_from, |
721 | | )?; |
722 | | |
723 | | self.inner_write(instance, digest, stream) |
724 | | .instrument(error_span!("bytestream_write")) |
725 | | .with_context( |
726 | | make_ctx_for_hash_func(digest_function).err_tip(|| "In BytestreamServer::write")?, |
727 | | ) |
728 | | .await |
729 | | .err_tip(|| "In ByteStreamServer::write") |
730 | | .map_err(Into::into) |
731 | | } |
732 | | |
733 | | #[instrument( |
734 | | err, |
735 | | ret(level = Level::INFO), |
736 | | level = Level::ERROR, |
737 | | skip_all, |
738 | | fields(request = ?grpc_request.get_ref()) |
739 | | )] |
740 | | async fn query_write_status( |
741 | | &self, |
742 | | grpc_request: Request<QueryWriteStatusRequest>, |
743 | | ) -> Result<Response<QueryWriteStatusResponse>, Status> { |
744 | | let request = grpc_request.into_inner(); |
745 | | self.inner_query_write_status(&request) |
746 | | .await |
747 | | .err_tip(|| "Failed on query_write_status() command") |
748 | | .map_err(Into::into) |
749 | | } |
750 | | } |