/build/source/nativelink-service/src/bytestream_server.rs
Line | Count | Source |
1 | | // Copyright 2024 The NativeLink Authors. All rights reserved. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use core::convert::Into; |
16 | | use core::fmt::{Debug, Formatter}; |
17 | | use core::pin::Pin; |
18 | | use core::sync::atomic::{AtomicU64, Ordering}; |
19 | | use core::time::Duration; |
20 | | use std::collections::HashMap; |
21 | | use std::collections::hash_map::Entry; |
22 | | use std::sync::Arc; |
23 | | |
24 | | use futures::future::{BoxFuture, pending}; |
25 | | use futures::stream::unfold; |
26 | | use futures::{Future, Stream, TryFutureExt, try_join}; |
27 | | use nativelink_config::cas_server::ByteStreamConfig; |
28 | | use nativelink_error::{Code, Error, ResultExt, make_err, make_input_err}; |
29 | | use nativelink_proto::google::bytestream::byte_stream_server::{ |
30 | | ByteStream, ByteStreamServer as Server, |
31 | | }; |
32 | | use nativelink_proto::google::bytestream::{ |
33 | | QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest, |
34 | | WriteResponse, |
35 | | }; |
36 | | use nativelink_store::grpc_store::GrpcStore; |
37 | | use nativelink_store::store_manager::StoreManager; |
38 | | use nativelink_util::buf_channel::{ |
39 | | DropCloserReadHalf, DropCloserWriteHalf, make_buf_channel_pair, |
40 | | }; |
41 | | use nativelink_util::common::DigestInfo; |
42 | | use nativelink_util::digest_hasher::{ |
43 | | DigestHasherFunc, default_digest_hasher_func, make_ctx_for_hash_func, |
44 | | }; |
45 | | use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper; |
46 | | use nativelink_util::resource_info::ResourceInfo; |
47 | | use nativelink_util::spawn; |
48 | | use nativelink_util::store_trait::{Store, StoreLike, UploadSizeInfo}; |
49 | | use nativelink_util::task::JoinHandleDropGuard; |
50 | | use opentelemetry::context::FutureExt; |
51 | | use parking_lot::Mutex; |
52 | | use tokio::time::sleep; |
53 | | use tonic::{Request, Response, Status, Streaming}; |
54 | | use tracing::{Instrument, Level, debug, error, error_span, info, instrument, trace}; |
55 | | |
56 | | /// If this value changes update the documentation in the config definition. |
57 | | const DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT: Duration = Duration::from_secs(60); |
58 | | |
59 | | /// If this value changes update the documentation in the config definition. |
60 | | const DEFAULT_MAX_BYTES_PER_STREAM: usize = 64 * 1024; |
61 | | |
62 | | /// If this value changes update the documentation in the config definition. |
63 | | const DEFAULT_MAX_DECODING_MESSAGE_SIZE: usize = 4 * 1024 * 1024; |
64 | | |
65 | | type ReadStream = Pin<Box<dyn Stream<Item = Result<ReadResponse, Status>> + Send + 'static>>; |
66 | | type StoreUpdateFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>; |
67 | | |
68 | | struct StreamState { |
69 | | uuid: String, |
70 | | tx: DropCloserWriteHalf, |
71 | | store_update_fut: StoreUpdateFuture, |
72 | | } |
73 | | |
74 | | impl Debug for StreamState { |
75 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { |
76 | 0 | f.debug_struct("StreamState") |
77 | 0 | .field("uuid", &self.uuid) |
78 | 0 | .finish() |
79 | 0 | } |
80 | | } |
81 | | |
82 | | /// If a stream is in this state, it will automatically be put back into an `IdleStream` and |
83 | | /// placed back into the `active_uploads` map as an `IdleStream` after it is dropped. |
84 | | /// To prevent it from being put back into an `IdleStream` you must call `.graceful_finish()`. |
85 | | struct ActiveStreamGuard<'a> { |
86 | | stream_state: Option<StreamState>, |
87 | | bytes_received: Arc<AtomicU64>, |
88 | | bytestream_server: &'a ByteStreamServer, |
89 | | } |
90 | | |
91 | | impl ActiveStreamGuard<'_> { |
92 | | /// Consumes the guard. The stream will be considered "finished", will |
93 | | /// remove it from the `active_uploads`. |
94 | 8 | fn graceful_finish(mut self) { |
95 | 8 | let stream_state = self.stream_state.take().unwrap(); |
96 | 8 | self.bytestream_server |
97 | 8 | .active_uploads |
98 | 8 | .lock() |
99 | 8 | .remove(&stream_state.uuid); |
100 | 8 | } |
101 | | } |
102 | | |
103 | | impl Drop for ActiveStreamGuard<'_> { |
104 | 15 | fn drop(&mut self) { |
105 | 15 | let Some(stream_state7 ) = self.stream_state.take() else { Branch (105:13): [True: 7, False: 8]
Branch (105:13): [Folded - Ignored]
|
106 | 8 | return; // If None it means we don't want it put back into an IdleStream. |
107 | | }; |
108 | 7 | let weak_active_uploads = Arc::downgrade(&self.bytestream_server.active_uploads); |
109 | 7 | let mut active_uploads = self.bytestream_server.active_uploads.lock(); |
110 | 7 | let uuid = stream_state.uuid.clone(); |
111 | 7 | let Some(active_uploads_slot) = active_uploads.get_mut(&uuid) else { Branch (111:13): [True: 7, False: 0]
Branch (111:13): [Folded - Ignored]
|
112 | 0 | error!( |
113 | | err = "Failed to find active upload. This should never happen.", |
114 | | uuid = ?uuid, |
115 | | ); |
116 | 0 | return; |
117 | | }; |
118 | 7 | let sleep_fn = self.bytestream_server.sleep_fn.clone(); |
119 | 7 | active_uploads_slot.1 = Some(IdleStream { |
120 | 7 | stream_state, |
121 | 7 | _timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move {3 |
122 | 3 | (*sleep_fn)().await; |
123 | 0 | if let Some(active_uploads) = weak_active_uploads.upgrade() { Branch (123:24): [True: 0, False: 0]
Branch (123:24): [Folded - Ignored]
|
124 | 0 | let mut active_uploads = active_uploads.lock(); |
125 | 0 | info!(msg = "Removing idle stream", uuid = ?uuid); |
126 | 0 | active_uploads.remove(&uuid); |
127 | 0 | } |
128 | 0 | }), |
129 | | }); |
130 | 15 | } |
131 | | } |
132 | | |
133 | | /// Represents a stream that is in the "idle" state. this means it is not currently being used |
134 | | /// by a client. If it is not used within a certain amount of time it will be removed from the |
135 | | /// `active_uploads` map automatically. |
136 | | #[derive(Debug)] |
137 | | struct IdleStream { |
138 | | stream_state: StreamState, |
139 | | _timeout_streaam_drop_guard: JoinHandleDropGuard<()>, |
140 | | } |
141 | | |
142 | | impl IdleStream { |
143 | 3 | fn into_active_stream( |
144 | 3 | self, |
145 | 3 | bytes_received: Arc<AtomicU64>, |
146 | 3 | bytestream_server: &ByteStreamServer, |
147 | 3 | ) -> ActiveStreamGuard<'_> { |
148 | 3 | ActiveStreamGuard { |
149 | 3 | stream_state: Some(self.stream_state), |
150 | 3 | bytes_received, |
151 | 3 | bytestream_server, |
152 | 3 | } |
153 | 3 | } |
154 | | } |
155 | | |
156 | | type BytesWrittenAndIdleStream = (Arc<AtomicU64>, Option<IdleStream>); |
157 | | type SleepFn = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>; |
158 | | |
159 | | pub struct ByteStreamServer { |
160 | | stores: HashMap<String, Store>, |
161 | | // Max number of bytes to send on each grpc stream chunk. |
162 | | max_bytes_per_stream: usize, |
163 | | max_decoding_message_size: usize, |
164 | | active_uploads: Arc<Mutex<HashMap<String, BytesWrittenAndIdleStream>>>, |
165 | | sleep_fn: SleepFn, |
166 | | } |
167 | | |
168 | | impl Debug for ByteStreamServer { |
169 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { |
170 | 0 | f.debug_struct("ByteStreamServer") |
171 | 0 | .field("stores", &self.stores) |
172 | 0 | .field("max_bytes_per_stream", &self.max_bytes_per_stream) |
173 | 0 | .field("max_decoding_message_size", &self.max_decoding_message_size) |
174 | 0 | .field("active_uploads", &self.active_uploads) |
175 | 0 | .finish_non_exhaustive() |
176 | 0 | } |
177 | | } |
178 | | |
179 | | impl ByteStreamServer { |
180 | 15 | pub fn new(config: &ByteStreamConfig, store_manager: &StoreManager) -> Result<Self, Error> { |
181 | 15 | let persist_stream_on_disconnect_timeout = |
182 | 15 | if config.persist_stream_on_disconnect_timeout == 0 { Branch (182:16): [True: 15, False: 0]
Branch (182:16): [Folded - Ignored]
|
183 | 15 | DEFAULT_PERSIST_STREAM_ON_DISCONNECT_TIMEOUT |
184 | | } else { |
185 | 0 | Duration::from_secs(config.persist_stream_on_disconnect_timeout as u64) |
186 | | }; |
187 | 15 | Self::new_with_sleep_fn( |
188 | 15 | config, |
189 | 15 | store_manager, |
190 | 15 | Arc::new(move || Box::pin3 (sleep3 (persist_stream_on_disconnect_timeout3 ))), |
191 | | ) |
192 | 15 | } |
193 | | |
194 | 15 | pub fn new_with_sleep_fn( |
195 | 15 | config: &ByteStreamConfig, |
196 | 15 | store_manager: &StoreManager, |
197 | 15 | sleep_fn: SleepFn, |
198 | 15 | ) -> Result<Self, Error> { |
199 | 15 | let mut stores = HashMap::with_capacity(config.cas_stores.len()); |
200 | 30 | for (instance_name15 , store_name15 ) in &config.cas_stores { |
201 | 15 | let store = store_manager |
202 | 15 | .get_store(store_name) |
203 | 15 | .ok_or_else(|| make_input_err!("'cas_store': '{}' does not exist", store_name))?0 ; |
204 | 15 | stores.insert(instance_name.to_string(), store); |
205 | | } |
206 | 15 | let max_bytes_per_stream = if config.max_bytes_per_stream == 0 { Branch (206:39): [True: 1, False: 14]
Branch (206:39): [Folded - Ignored]
|
207 | 1 | DEFAULT_MAX_BYTES_PER_STREAM |
208 | | } else { |
209 | 14 | config.max_bytes_per_stream |
210 | | }; |
211 | 15 | let max_decoding_message_size = if config.max_decoding_message_size == 0 { Branch (211:44): [True: 14, False: 1]
Branch (211:44): [Folded - Ignored]
|
212 | 14 | DEFAULT_MAX_DECODING_MESSAGE_SIZE |
213 | | } else { |
214 | 1 | config.max_decoding_message_size |
215 | | }; |
216 | 15 | Ok(Self { |
217 | 15 | stores, |
218 | 15 | max_bytes_per_stream, |
219 | 15 | max_decoding_message_size, |
220 | 15 | active_uploads: Arc::new(Mutex::new(HashMap::new())), |
221 | 15 | sleep_fn, |
222 | 15 | }) |
223 | 15 | } |
224 | | |
225 | 1 | pub fn into_service(self) -> Server<Self> { |
226 | 1 | let max_decoding_message_size = self.max_decoding_message_size; |
227 | 1 | Server::new(self).max_decoding_message_size(max_decoding_message_size) |
228 | 1 | } |
229 | | |
230 | 15 | fn create_or_join_upload_stream( |
231 | 15 | &self, |
232 | 15 | uuid: String, |
233 | 15 | store: Store, |
234 | 15 | digest: DigestInfo, |
235 | 15 | ) -> Result<ActiveStreamGuard<'_>, Error> { |
236 | 15 | let (uuid12 , bytes_received12 ) = match self.active_uploads.lock().entry(uuid) { |
237 | 3 | Entry::Occupied(mut entry) => { |
238 | 3 | let maybe_idle_stream = entry.get_mut(); |
239 | 3 | let Some(idle_stream) = maybe_idle_stream.1.take() else { Branch (239:21): [True: 3, False: 0]
Branch (239:21): [Folded - Ignored]
|
240 | 0 | return Err(make_input_err!("Cannot upload same UUID simultaneously")); |
241 | | }; |
242 | 3 | let bytes_received = maybe_idle_stream.0.clone(); |
243 | 3 | info!(msg = "Joining existing stream", entry = ?entry.key()); |
244 | 3 | return Ok(idle_stream.into_active_stream(bytes_received, self)); |
245 | | } |
246 | 12 | Entry::Vacant(entry) => { |
247 | 12 | let bytes_received = Arc::new(AtomicU64::new(0)); |
248 | 12 | let uuid = entry.key().clone(); |
249 | | // Our stream is "in use" if the key is in the map, but the value is None. |
250 | 12 | entry.insert((bytes_received.clone(), None)); |
251 | 12 | (uuid, bytes_received) |
252 | | } |
253 | | }; |
254 | | |
255 | | // Important: Do not return an error from this point onwards without |
256 | | // removing the entry from the map, otherwise that UUID becomes |
257 | | // unusable. |
258 | | |
259 | 12 | let (tx, rx) = make_buf_channel_pair(); |
260 | 12 | let store_update_fut = Box::pin(async move {8 |
261 | | // We need to wrap `Store::update()` in a another future because we need to capture |
262 | | // `store` to ensure its lifetime follows the future and not the caller. |
263 | 8 | store |
264 | 8 | // Bytestream always uses digest size as the actual byte size. |
265 | 8 | .update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes())) |
266 | 8 | .await |
267 | 8 | }); |
268 | 12 | Ok(ActiveStreamGuard { |
269 | 12 | stream_state: Some(StreamState { |
270 | 12 | uuid, |
271 | 12 | tx, |
272 | 12 | store_update_fut, |
273 | 12 | }), |
274 | 12 | bytes_received, |
275 | 12 | bytestream_server: self, |
276 | 12 | }) |
277 | 15 | } |
278 | | |
279 | 3 | async fn inner_read( |
280 | 3 | &self, |
281 | 3 | store: Store, |
282 | 3 | digest: DigestInfo, |
283 | 3 | read_request: ReadRequest, |
284 | 3 | ) -> Result<impl Stream<Item = Result<ReadResponse, Status>> + Send + use<>, Error> { |
285 | | struct ReaderState { |
286 | | max_bytes_per_stream: usize, |
287 | | rx: DropCloserReadHalf, |
288 | | maybe_get_part_result: Option<Result<(), Error>>, |
289 | | get_part_fut: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>, |
290 | | } |
291 | | |
292 | 3 | let read_limit = u64::try_from(read_request.read_limit) |
293 | 3 | .err_tip(|| "Could not convert read_limit to u64")?0 ; |
294 | | |
295 | 3 | let (tx, rx) = make_buf_channel_pair(); |
296 | | |
297 | 3 | let read_limit = if read_limit != 0 { Branch (297:29): [True: 3, False: 0]
Branch (297:29): [Folded - Ignored]
|
298 | 3 | Some(read_limit) |
299 | | } else { |
300 | 0 | None |
301 | | }; |
302 | | |
303 | | // This allows us to call a destructor when the the object is dropped. |
304 | 3 | let state = Some(ReaderState { |
305 | 3 | rx, |
306 | 3 | max_bytes_per_stream: self.max_bytes_per_stream, |
307 | 3 | maybe_get_part_result: None, |
308 | 3 | get_part_fut: Box::pin(async move { |
309 | 3 | store |
310 | 3 | .get_part( |
311 | 3 | digest, |
312 | 3 | tx, |
313 | 3 | u64::try_from(read_request.read_offset) |
314 | 3 | .err_tip(|| "Could not convert read_offset to u64")?0 , |
315 | 3 | read_limit, |
316 | | ) |
317 | 3 | .await |
318 | 3 | }), |
319 | | }); |
320 | | |
321 | 3 | let read_stream_span = error_span!("read_stream"); |
322 | | |
323 | 9.77k | Ok(Box::pin3 (unfold3 (state3 , move |state| { |
324 | 9.77k | async { |
325 | 9.77k | let mut state = state?0 ; // If None our stream is done. |
326 | 9.77k | let mut response = ReadResponse::default(); |
327 | | { |
328 | 9.77k | let consume_fut = state.rx.consume(Some(state.max_bytes_per_stream)); |
329 | 9.77k | tokio::pin!(consume_fut); |
330 | | loop { |
331 | 9.77k | tokio::select! { |
332 | 9.77k | read_result9.77k = &mut consume_fut => { |
333 | 9.77k | match read_result { |
334 | 9.76k | Ok(bytes) => { |
335 | 9.76k | if bytes.is_empty() { Branch (335:40): [True: 2, False: 9.76k]
Branch (335:40): [Folded - Ignored]
|
336 | | // EOF. |
337 | 2 | return None; |
338 | 9.76k | } |
339 | 9.76k | if bytes.len() > state.max_bytes_per_stream { Branch (339:40): [True: 0, False: 9.76k]
Branch (339:40): [Folded - Ignored]
|
340 | 0 | let err = make_err!(Code::Internal, "Returned store size was larger than read size"); |
341 | 0 | return Some((Err(err.into()), None)); |
342 | 9.76k | } |
343 | 9.76k | response.data = bytes; |
344 | 9.76k | trace!(response = ?response); |
345 | 9.76k | debug!(response.data = format!0 ("<redacted len({})>"0 , response.data0 .len0 ())); |
346 | 9.76k | break; |
347 | | } |
348 | 1 | Err(mut e) => { |
349 | | // We may need to propagate the error from reading the data through first. |
350 | | // For example, the NotFound error will come through `get_part_fut`, and |
351 | | // will not be present in `e`, but we need to ensure we pass NotFound error |
352 | | // code or the client won't know why it failed. |
353 | 1 | let get_part_result = if let Some(result) = state.maybe_get_part_result { Branch (353:66): [True: 1, False: 0]
Branch (353:66): [Folded - Ignored]
|
354 | 1 | result |
355 | | } else { |
356 | | // This should never be `future::pending()` if maybe_get_part_result is |
357 | | // not set. |
358 | 0 | state.get_part_fut.await |
359 | | }; |
360 | 1 | if let Err(err) = get_part_result { Branch (360:44): [True: 1, False: 0]
Branch (360:44): [Folded - Ignored]
|
361 | 1 | e = err.merge(e); |
362 | 1 | }0 |
363 | 1 | if e.code == Code::NotFound { Branch (363:40): [True: 1, False: 0]
Branch (363:40): [Folded - Ignored]
|
364 | 1 | // Trim the error code. Not Found is quite common and we don't want to send a large |
365 | 1 | // error (debug) message for something that is common. We resize to just the last |
366 | 1 | // message as it will be the most relevant. |
367 | 1 | e.messages.truncate(1); |
368 | 1 | }0 |
369 | 1 | error!(response = ?e); |
370 | 1 | return Some((Err(e.into()), None)) |
371 | | } |
372 | | } |
373 | | }, |
374 | 9.77k | result3 = &mut state.get_part_fut => { |
375 | 3 | state.maybe_get_part_result = Some(result); |
376 | 3 | // It is non-deterministic on which future will finish in what order. |
377 | 3 | // It is also possible that the `state.rx.consume()` call above may not be able to |
378 | 3 | // respond even though the publishing future is done. |
379 | 3 | // Because of this we set the writing future to pending so it never finishes. |
380 | 3 | // The `state.rx.consume()` future will eventually finish and return either the |
381 | 3 | // data or an error. |
382 | 3 | // An EOF will terminate the `state.rx.consume()` future, but we are also protected |
383 | 3 | // because we are dropping the writing future, it will drop the `tx` channel |
384 | 3 | // which will eventually propagate an error to the `state.rx.consume()` future if |
385 | 3 | // the EOF was not sent due to some other error. |
386 | 3 | state.get_part_fut = Box::pin(pending()); |
387 | 3 | }, |
388 | | } |
389 | | } |
390 | | } |
391 | 9.76k | Some((Ok(response), Some(state))) |
392 | 9.77k | }.instrument(read_stream_span.clone()) |
393 | 9.77k | }))) |
394 | 3 | } |
395 | | |
396 | | // We instrument tracing here as well as below because `stream` has a hash on it |
397 | | // that is extracted from the first stream message. If we only implemented it below |
398 | | // we would not have the hash available to us. |
399 | | #[instrument( |
400 | | ret(level = Level::DEBUG), |
401 | 15 | level = Level::ERROR, |
402 | | skip(self, store), |
403 | | fields(stream.first_msg = "<redacted>") |
404 | | )] |
405 | | async fn inner_write( |
406 | | &self, |
407 | | store: Store, |
408 | | digest: DigestInfo, |
409 | | stream: WriteRequestStreamWrapper<impl Stream<Item = Result<WriteRequest, Status>> + Unpin>, |
410 | | ) -> Result<Response<WriteResponse>, Error> { |
411 | 15 | async fn process_client_stream( |
412 | 15 | mut stream: WriteRequestStreamWrapper< |
413 | 15 | impl Stream<Item = Result<WriteRequest, Status>> + Unpin, |
414 | 15 | >, |
415 | 15 | tx: &mut DropCloserWriteHalf, |
416 | 15 | outer_bytes_received: &Arc<AtomicU64>, |
417 | 15 | expected_size: u64, |
418 | 15 | ) -> Result<(), Error> { |
419 | | loop { |
420 | 25 | let write_request20 = match stream.next().await { |
421 | | // Code path for when client tries to gracefully close the stream. |
422 | | // If this happens it means there's a problem with the data sent, |
423 | | // because we always close the stream from our end before this point |
424 | | // by counting the number of bytes sent from the client. If they send |
425 | | // less than the amount they said they were going to send and then |
426 | | // close the stream, we know there's a problem. |
427 | | None => { |
428 | 0 | return Err(make_input_err!( |
429 | 0 | "Client closed stream before sending all data" |
430 | 0 | )); |
431 | | } |
432 | | // Code path for client stream error. Probably client disconnect. |
433 | 5 | Some(Err(err)) => return Err(err), |
434 | | // Code path for received chunk of data. |
435 | 20 | Some(Ok(write_request)) => write_request, |
436 | | }; |
437 | | |
438 | 20 | if write_request.write_offset < 0 { Branch (438:20): [True: 1, False: 19]
Branch (438:20): [Folded - Ignored]
|
439 | 1 | return Err(make_input_err!( |
440 | 1 | "Invalid negative write offset in write request: {}", |
441 | 1 | write_request.write_offset |
442 | 1 | )); |
443 | 19 | } |
444 | 19 | let write_offset = write_request.write_offset as u64; |
445 | | |
446 | | // If we get duplicate data because a client didn't know where |
447 | | // it left off from, then we can simply skip it. |
448 | 19 | let data18 = if write_offset < tx.get_bytes_written() { Branch (448:31): [True: 2, False: 17]
Branch (448:31): [Folded - Ignored]
|
449 | 2 | if (write_offset + write_request.data.len() as u64) < tx.get_bytes_written() { Branch (449:24): [True: 0, False: 2]
Branch (449:24): [Folded - Ignored]
|
450 | 0 | if write_request.finish_write { Branch (450:28): [True: 0, False: 0]
Branch (450:28): [Folded - Ignored]
|
451 | 0 | return Err(make_input_err!( |
452 | 0 | "Resumed stream finished at {} bytes when we already received {} bytes.", |
453 | 0 | write_offset + write_request.data.len() as u64, |
454 | 0 | tx.get_bytes_written() |
455 | 0 | )); |
456 | 0 | } |
457 | 0 | continue; |
458 | 2 | } |
459 | 2 | write_request |
460 | 2 | .data |
461 | 2 | .slice((tx.get_bytes_written() - write_offset) as usize..) |
462 | | } else { |
463 | 17 | if write_offset != tx.get_bytes_written() { Branch (463:24): [True: 1, False: 16]
Branch (463:24): [Folded - Ignored]
|
464 | 1 | return Err(make_input_err!( |
465 | 1 | "Received out of order data. Got {}, expected {}", |
466 | 1 | write_offset, |
467 | 1 | tx.get_bytes_written() |
468 | 1 | )); |
469 | 16 | } |
470 | 16 | write_request.data |
471 | | }; |
472 | | |
473 | | // Do not process EOF or weird stuff will happen. |
474 | 18 | if !data.is_empty() { Branch (474:20): [True: 13, False: 5]
Branch (474:20): [Folded - Ignored]
|
475 | | // We also need to process the possible EOF branch, so we can't early return. |
476 | 13 | if let Err(mut err0 ) = tx.send(data).await { Branch (476:28): [True: 0, False: 13]
Branch (476:28): [Folded - Ignored]
|
477 | 0 | err.code = Code::Internal; |
478 | 0 | return Err(err); |
479 | 13 | } |
480 | 13 | outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release); |
481 | 5 | } |
482 | | |
483 | 18 | if expected_size < tx.get_bytes_written() { Branch (483:20): [True: 0, False: 18]
Branch (483:20): [Folded - Ignored]
|
484 | 0 | return Err(make_input_err!("Received more bytes than expected")); |
485 | 18 | } |
486 | 18 | if write_request.finish_write { Branch (486:20): [True: 8, False: 10]
Branch (486:20): [Folded - Ignored]
|
487 | | // Gracefully close our stream. |
488 | 8 | tx.send_eof() |
489 | 8 | .err_tip(|| "Failed to send EOF in ByteStream::write")?0 ; |
490 | 8 | return Ok(()); |
491 | 10 | } |
492 | | // Continue. |
493 | | } |
494 | | // Unreachable. |
495 | 15 | } |
496 | | |
497 | | let uuid = stream |
498 | | .resource_info |
499 | | .uuid |
500 | | .as_ref() |
501 | | .ok_or_else(|| make_input_err!("UUID must be set if writing data"))? |
502 | | .to_string(); |
503 | | let mut active_stream_guard = self.create_or_join_upload_stream(uuid, store, digest)?; |
504 | | let expected_size = stream.resource_info.expected_size as u64; |
505 | | |
506 | | let active_stream = active_stream_guard.stream_state.as_mut().unwrap(); |
507 | | try_join!( |
508 | | process_client_stream( |
509 | | stream, |
510 | | &mut active_stream.tx, |
511 | | &active_stream_guard.bytes_received, |
512 | | expected_size |
513 | | ), |
514 | | (&mut active_stream.store_update_fut) |
515 | 0 | .map_err(|err| { err.append("Error updating inner store") }) |
516 | | )?; |
517 | | |
518 | | // Close our guard and consider the stream no longer active. |
519 | | active_stream_guard.graceful_finish(); |
520 | | |
521 | | Ok(Response::new(WriteResponse { |
522 | | committed_size: expected_size as i64, |
523 | | })) |
524 | | } |
525 | | |
526 | 3 | async fn inner_query_write_status( |
527 | 3 | &self, |
528 | 3 | query_request: &QueryWriteStatusRequest, |
529 | 3 | ) -> Result<Response<QueryWriteStatusResponse>, Error> { |
530 | 3 | let mut resource_info = ResourceInfo::new(&query_request.resource_name, true)?0 ; |
531 | | |
532 | 3 | let store_clone = self |
533 | 3 | .stores |
534 | 3 | .get(resource_info.instance_name.as_ref()) |
535 | 3 | .err_tip(|| {0 |
536 | 0 | format!( |
537 | 0 | "'instance_name' not configured for '{}'", |
538 | 0 | &resource_info.instance_name |
539 | | ) |
540 | 0 | })? |
541 | 3 | .clone(); |
542 | | |
543 | 3 | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?0 ; |
544 | | |
545 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
546 | 3 | if let Some(grpc_store0 ) = store_clone.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (546:16): [True: 0, False: 3]
Branch (546:16): [Folded - Ignored]
|
547 | 0 | return grpc_store |
548 | 0 | .query_write_status(Request::new(query_request.clone())) |
549 | 0 | .await; |
550 | 3 | } |
551 | | |
552 | 3 | let uuid = resource_info |
553 | 3 | .uuid |
554 | 3 | .take() |
555 | 3 | .ok_or_else(|| make_input_err!("UUID must be set if querying write status"))?0 ; |
556 | | |
557 | | { |
558 | 3 | let active_uploads = self.active_uploads.lock(); |
559 | 3 | if let Some((received_bytes1 , _maybe_idle_stream1 )) = active_uploads.get(uuid.as_ref()) { Branch (559:20): [True: 1, False: 2]
Branch (559:20): [Folded - Ignored]
|
560 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
561 | 1 | committed_size: received_bytes.load(Ordering::Acquire) as i64, |
562 | 1 | // If we are in the active_uploads map, but the value is None, |
563 | 1 | // it means the stream is not complete. |
564 | 1 | complete: false, |
565 | 1 | })); |
566 | 2 | } |
567 | | } |
568 | | |
569 | 2 | let has_fut = store_clone.has(digest); |
570 | 2 | let Some(item_size1 ) = has_fut.await.err_tip(|| "Failed to call .has() on store")?0 else { Branch (570:13): [True: 1, False: 1]
Branch (570:13): [Folded - Ignored]
|
571 | | // We lie here and say that the stream needs to start over, even though |
572 | | // it was never started. This can happen when the client disconnects |
573 | | // before sending the first payload, but the client thinks it did send |
574 | | // the payload. |
575 | 1 | return Ok(Response::new(QueryWriteStatusResponse { |
576 | 1 | committed_size: 0, |
577 | 1 | complete: false, |
578 | 1 | })); |
579 | | }; |
580 | 1 | Ok(Response::new(QueryWriteStatusResponse { |
581 | 1 | committed_size: item_size as i64, |
582 | 1 | complete: true, |
583 | 1 | })) |
584 | 3 | } |
585 | | } |
586 | | |
587 | | #[tonic::async_trait] |
588 | | impl ByteStream for ByteStreamServer { |
589 | | type ReadStream = ReadStream; |
590 | | |
591 | | #[instrument( |
592 | | err, |
593 | | level = Level::ERROR, |
594 | | skip_all, |
595 | | fields(request = ?grpc_request.get_ref()) |
596 | | )] |
597 | | async fn read( |
598 | | &self, |
599 | | grpc_request: Request<ReadRequest>, |
600 | 6 | ) -> Result<Response<Self::ReadStream>, Status> { |
601 | 3 | let read_request = grpc_request.into_inner(); |
602 | 3 | let resource_info = ResourceInfo::new(&read_request.resource_name, false)?0 ; |
603 | 3 | let instance_name = resource_info.instance_name.as_ref(); |
604 | 3 | let store = self |
605 | 3 | .stores |
606 | 3 | .get(instance_name) |
607 | 3 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'"0 ))?0 |
608 | 3 | .clone(); |
609 | | |
610 | 3 | let digest = DigestInfo::try_new(resource_info.hash.as_ref(), resource_info.expected_size)?0 ; |
611 | | |
612 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
613 | 3 | if let Some(grpc_store0 ) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (613:16): [True: 0, False: 3]
Branch (613:16): [Folded - Ignored]
|
614 | 0 | let stream = Box::pin(grpc_store.read(Request::new(read_request)).await?); |
615 | 0 | return Ok(Response::new(stream)); |
616 | 3 | } |
617 | | |
618 | 3 | let digest_function = resource_info.digest_function.as_deref().map_or_else( |
619 | 3 | || Ok(default_digest_hasher_func()), |
620 | | DigestHasherFunc::try_from, |
621 | 0 | )?; |
622 | | |
623 | 3 | let resp = self |
624 | 3 | .inner_read(store, digest, read_request) |
625 | 3 | .instrument(error_span!("bytestream_read")) |
626 | 3 | .with_context( |
627 | 3 | make_ctx_for_hash_func(digest_function).err_tip(|| "In BytestreamServer::read")?0 , |
628 | | ) |
629 | 3 | .await |
630 | 3 | .err_tip(|| "In ByteStreamServer::read") |
631 | 3 | .map(|stream| -> Response<Self::ReadStream> { Response::new(Box::pin(stream)) }) |
632 | 3 | .map_err(Into::into); |
633 | | |
634 | 3 | if resp.is_ok() { Branch (634:12): [True: 3, False: 0]
Branch (634:12): [Folded - Ignored]
|
635 | 3 | debug!(return = "Ok(<stream>)"); |
636 | 0 | } |
637 | | |
638 | 3 | resp |
639 | 6 | } |
640 | | |
641 | | #[instrument( |
642 | | err, |
643 | | level = Level::ERROR, |
644 | | skip_all, |
645 | | fields(request = ?grpc_request.get_ref()) |
646 | | )] |
647 | | async fn write( |
648 | | &self, |
649 | | grpc_request: Request<Streaming<WriteRequest>>, |
650 | 32 | ) -> Result<Response<WriteResponse>, Status> { |
651 | 16 | let request = grpc_request.into_inner(); |
652 | 16 | let stream15 = WriteRequestStreamWrapper::from(request) |
653 | 16 | .await |
654 | 16 | .err_tip(|| "Could not unwrap first stream message") |
655 | 16 | .map_err(Into::<Status>::into)?1 ; |
656 | | |
657 | 15 | let instance_name = stream.resource_info.instance_name.as_ref(); |
658 | 15 | let store = self |
659 | 15 | .stores |
660 | 15 | .get(instance_name) |
661 | 15 | .err_tip(|| format!("'instance_name' not configured for '{instance_name}'"0 ))?0 |
662 | 15 | .clone(); |
663 | | |
664 | 15 | let digest = DigestInfo::try_new( |
665 | 15 | &stream.resource_info.hash, |
666 | 15 | stream.resource_info.expected_size, |
667 | | ) |
668 | 15 | .err_tip(|| "Invalid digest input in ByteStream::write")?0 ; |
669 | | |
670 | | // If we are a GrpcStore we shortcut here, as this is a special store. |
671 | 15 | if let Some(grpc_store0 ) = store.downcast_ref::<GrpcStore>(Some(digest.into())) { Branch (671:16): [True: 0, False: 15]
Branch (671:16): [Folded - Ignored]
|
672 | 0 | let resp = grpc_store.write(stream).await.map_err(Into::into); |
673 | 0 | return resp; |
674 | 15 | } |
675 | | |
676 | 15 | let digest_function = stream |
677 | 15 | .resource_info |
678 | 15 | .digest_function |
679 | 15 | .as_deref() |
680 | 15 | .map_or_else( |
681 | 15 | || Ok(default_digest_hasher_func()), |
682 | | DigestHasherFunc::try_from, |
683 | 0 | )?; |
684 | | |
685 | 15 | self.inner_write(store, digest, stream) |
686 | 15 | .instrument(error_span!("bytestream_write")) |
687 | 15 | .with_context( |
688 | 15 | make_ctx_for_hash_func(digest_function).err_tip(|| "In BytestreamServer::write")?0 , |
689 | | ) |
690 | 15 | .await |
691 | 15 | .err_tip(|| "In ByteStreamServer::write") |
692 | 15 | .map_err(Into::into) |
693 | 32 | } |
694 | | |
695 | | #[instrument( |
696 | | err, |
697 | | ret(level = Level::INFO), |
698 | | level = Level::ERROR, |
699 | | skip_all, |
700 | | fields(request = ?grpc_request.get_ref()) |
701 | | )] |
702 | | async fn query_write_status( |
703 | | &self, |
704 | | grpc_request: Request<QueryWriteStatusRequest>, |
705 | 6 | ) -> Result<Response<QueryWriteStatusResponse>, Status> { |
706 | 3 | let request = grpc_request.into_inner(); |
707 | 3 | self.inner_query_write_status(&request) |
708 | 3 | .await |
709 | 3 | .err_tip(|| "Failed on query_write_status() command") |
710 | 3 | .map_err(Into::into) |
711 | 6 | } |
712 | | } |