tokio-rustls icon indicating copy to clipboard operation
tokio-rustls copied to clipboard

Large percentage of cpu time in memset when reading with a larger buffer

Open jhorstmann opened this issue 1 year ago • 9 comments

Reproduced using a slightly modified version of examples/client.rs to allow specifying a path, and reading into a user-provided buffer. The problem was originally noticed via rusoto s3 client, which has a configuration for a read buffer size, which gets mapped to hyper http1_read_buf_exact_size. Since aws s3 recommends fetching large ranges, using a larger buffer seemed like a good idea and was not expected to cause any cpu overhead.

flamegraph

use std::fs::File;
use std::io;
use std::io::BufReader;
use std::net::ToSocketAddrs;
use std::path::PathBuf;
use std::sync::Arc;

use argh::FromArgs;
use tokio::io::{split, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;

/// Tokio Rustls client example
#[derive(FromArgs)]
struct Options {
    /// host
    #[argh(positional)]
    host: String,

    /// path
    #[argh(positional)]
    path: Option<String>,

    /// port
    #[argh(option, short = 'p', default = "443")]
    port: u16,

    /// domain
    #[argh(option, short = 'd')]
    domain: Option<String>,

    /// cafile
    #[argh(option, short = 'c')]
    cafile: Option<PathBuf>,
}

#[tokio::main]
async fn main() -> io::Result<()> {
    let options: Options = argh::from_env();

    let addr = (options.host.as_str(), options.port)
        .to_socket_addrs()?
        .next()
        .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;
    let domain = options.domain.unwrap_or(options.host);
    let path = options.path.as_ref().map(|p| p.as_str()).unwrap_or("/");
    let content = format!(
        "GET {} HTTP/1.0\r\nConnection: close\r\nHost: {}\r\n\r\n",
        path, domain
    );

    let mut root_cert_store = rustls::RootCertStore::empty();
    if let Some(cafile) = &options.cafile {
        let mut pem = BufReader::new(File::open(cafile)?);
        for cert in rustls_pemfile::certs(&mut pem) {
            root_cert_store.add(cert?).unwrap();
        }
    } else {
        root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
    }

    let config = rustls::ClientConfig::builder()
        .with_root_certificates(root_cert_store)
        .with_no_client_auth(); // i guess this was previously the default?
    let connector = TlsConnector::from(Arc::new(config));

    let stream = TcpStream::connect(&addr).await?;

    let domain = pki_types::ServerName::try_from(domain.as_str())
        .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?
        .to_owned();

    let mut stream = connector.connect(domain, stream).await?;
    stream.write_all(content.as_bytes()).await?;

    let (mut reader, _writer) = split(stream);

    let mut buffer = Vec::with_capacity(4 * 1024 * 1024);
    let mut total_len = 0_usize;

    loop {
        match reader.read_buf(&mut buffer).await {
            Ok(len) => {
                total_len += len;
                buffer.clear();
                if len == 0 {
                    break;
                }
            }
            Err(e) => {
                eprintln!("{:?}", e);
                break;
            }
        }
    }

    println!("Size: {}", total_len);

    Ok(())
}

jhorstmann avatar Mar 13 '24 19:03 jhorstmann

I believe this is because the poll_read call is passing in a large uninitialized ReadBuf, causing the stream to zeroing the ReadBuf unnecessarily.

https://github.com/rustls/tokio-rustls/blob/main/src/common/mod.rs#L218

We can do some optimizations on read_buf cfg. https://github.com/rustls/rustls/blob/main/rustls/src/conn.rs#L230

quininer avatar Mar 14 '24 02:03 quininer

This problem can also be fixed if you change Vec::with_capacity(4 * 1024 * 1024) to vec![0; 4 * 1024 * 1024] and do not clear.

quininer avatar Mar 14 '24 02:03 quininer

A possible solution could be the following. I think this is sound, since Reader is a concrete type whose behavior we know, and not a trait where adversarial implementations would be possible. But it seems a bit hacky.

I'm profiling this on an aws m6in.8xlarge instance, downloading from s3. In the flamegraph the memset with default buffer sizes was taking about 5% of time. The throughput is varying too much quantify the benefit.

The call stack is coming from hyper h1 poll_read_from_io, which seems to create a new ReadBuf on every call, so maybe there is also potential improvement on that side.

diff --git a/src/common/mod.rs b/src/common/mod.rs
index fde34c0..a9e3115 100644
--- a/src/common/mod.rs
+++ b/src/common/mod.rs
@@ -248,6 +248,10 @@ where
             }
         }
 
+        // Safety: We trust `read` to only write initialized bytes to the slice and never read from it.
+        unsafe {
+            buf.assume_init(buf.remaining());
+        }
         match self.session.reader().read(buf.initialize_unfilled()) {
             // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
             // connection with a `CloseNotify` message and no more data will be forthcoming.

jhorstmann avatar Apr 02 '24 18:04 jhorstmann

@seanmonstar do you think it's feasible to avoid recreating the ReadBuf on every call?

djc avatar Apr 02 '24 18:04 djc

I suppose theoretically, but realistically at the moment, that ReadBuf is backed by the unfilled part of a BytesMut, which doesn't expose a way to know how much of it is initialized.

seanmonstar avatar Apr 02 '24 20:04 seanmonstar

I think best way for now, apart from stabilize read_buf feature in std, is for rustls to provide a ReadBuf of own (or a trait to avoid unsafe in rustls).

like

pub trait ReadBuf {
    fn append(&mut self, buf: &[u8]);
}

pub fn read_buf(&mut self, buf: &mut dyn ReadBuf) {
    //
}

quininer avatar Apr 05 '24 11:04 quininer

rustls can be asked how much data it would write into an infinite size buffer provided to reader().read() -- self.session.process_new_packets().unwrap().plaintext_bytes_to_read[^1] -- and I think it would be reasonable to expose that quantity as a new method on reader(). I think that would resolve the immediate issue (clearing an 4MB buffer to use the first 4 bytes).

[^1]: unwrap() is for brevity only!

ctz avatar Apr 05 '24 12:04 ctz

I am considering using unbuffered api refactor, which would also solve the problem.

quininer avatar Apr 05 '24 16:04 quininer