Coverage Report

Created: 2024-10-22 12:33

/build/source/nativelink-util/src/write_counter.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2024 The NativeLink Authors. All rights reserved.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//    http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::pin::Pin;
16
use std::task::{Context, Poll};
17
18
use pin_project_lite::pin_project;
19
use tokio::io::AsyncWrite;
20
21
pin_project! {
22
    /// Utility struct that counts the number of bytes sent and passes everything else through to
23
    /// the underlying writer.
24
    pub struct WriteCounter<T: AsyncWrite> {
25
        #[pin]
26
        inner: T,
27
        bytes_written: u64,
28
        failed: bool,
29
    }
30
}
31
32
impl<T: AsyncWrite> WriteCounter<T> {
33
0
    pub fn new(inner: T) -> Self {
34
0
        WriteCounter {
35
0
            inner,
36
0
            bytes_written: 0,
37
0
            failed: false,
38
0
        }
39
0
    }
40
41
0
    pub fn inner_ref(&self) -> &T {
42
0
        &self.inner
43
0
    }
44
45
0
    pub fn inner_mut(&mut self) -> &mut T {
46
0
        &mut self.inner
47
0
    }
48
49
    /// Returns the number of bytes written.
50
0
    pub fn get_bytes_written(&self) -> u64 {
51
0
        self.bytes_written
52
0
    }
53
54
0
    pub fn did_fail(&self) -> bool {
55
0
        self.failed
56
0
    }
57
}
58
59
impl<T: AsyncWrite> AsyncWrite for WriteCounter<T> {
60
0
    fn poll_write(
61
0
        self: Pin<&mut Self>,
62
0
        cx: &mut Context<'_>,
63
0
        buf: &[u8],
64
0
    ) -> Poll<Result<usize, std::io::Error>> {
65
0
        let me = self.project();
66
0
        let result = me.inner.poll_write(cx, buf);
67
0
        match &result {
68
0
            Poll::Ready(Result::Ok(sz)) => *me.bytes_written += *sz as u64,
69
0
            Poll::Ready(Result::Err(_)) => *me.failed = true,
70
0
            _ => {}
71
        }
72
0
        result
73
0
    }
74
75
0
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
76
0
        let me = self.project();
77
0
        me.inner.poll_flush(cx)
78
0
    }
79
80
0
    fn poll_shutdown(
81
0
        self: Pin<&mut Self>,
82
0
        cx: &mut Context<'_>,
83
0
    ) -> Poll<Result<(), std::io::Error>> {
84
0
        let me = self.project();
85
0
        me.inner.poll_shutdown(cx)
86
0
    }
87
}