quinn
quinn copied to clipboard
Utilize NEW_TOKEN frames
This is currently a draft PR so I can get some feedback on whether the overall design is good. If the overall design is good, I will address some remaining TODO points and polish it up some more before marking it as ready for review.
Goal and motivation:
The server now sends the client NEW_TOKEN frames, and the client now stores and utilizes them.
The main motivation is that this allows 0.5-RTT data to not be subject to anti-amplification limits. This is a scenario likely to occur in HTTP/3 requests, as one example: a client makes a 0-RTT GET request for something like a jpeg, such that the response will be much bigger than the request, and so unless NEW_TOKEN frames are used, the response may begin to be transmitted but then hit the anti-amplification limit and have to pause until the full 1-RTT handshake completes.
For example, here's some experimental data that should be similar in the relevant ways:
- The client sends the server an integer and the server responds with that number of bytes
- They do it in 0-RTT if they can
- For each iteration the client endpoint does it twice and measures its request/response time from the second time it does it (so it will have 0-RTT and NEW_TOKEN material)
- 100ms localhost latency was simulated by running
sudo tc qdisc add dev lo root netem delay 100ms
(and undone withsudo tc qdisc del dev lo root netem
)
For responses in a certain size range, avoiding the anti-amplification limits by using NEW_TOKEN frames made the request/response complete in 1 RTT on this branch versus 2 RTT on main.
Reproducible experimental setup
newtoken.rs
can be placed into quinn/examples/
:
use std::{
sync::Arc,
net::ToSocketAddrs as _,
};
use anyhow::Error;
use quinn::*;
use tracing::*;
use tracing_subscriber::prelude::*;
#[tokio::main]
async fn main() -> Result<(), Error> {
// init logging
let log_fmt = tracing_subscriber::fmt::format()
.compact()
.with_timer(tracing_subscriber::fmt::time::uptime())
.with_line_number(true);
let stdout_log = tracing_subscriber::fmt::layer()
.event_format(log_fmt)
.with_writer(std::io::stderr);
let log_filter = tracing_subscriber::EnvFilter::new(
std::env::var(tracing_subscriber::EnvFilter::DEFAULT_ENV).unwrap_or("info".into())
);
let log_subscriber = tracing_subscriber::Registry::default()
.with(log_filter)
.with(stdout_log);
tracing::subscriber::set_global_default(log_subscriber).expect("unable to install logger");
// get args
let args = std::env::args().collect::<Vec<_>>();
anyhow::ensure!(args.len() == 2, "wrong number of args");
let num_bytes = args[1].parse::<u32>()?;
// generate keys
let rcgen_cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let key = rustls::pki_types::PrivatePkcs8KeyDer::from(rcgen_cert.key_pair.serialize_der());
let cert = rustls::pki_types::CertificateDer::from(rcgen_cert.cert);
let mut roots = rustls::RootCertStore::empty();
roots.add(cert.clone()).unwrap();
let certs = vec![cert];
let mut tasks = tokio::task::JoinSet::new();
// start server
let (send_stop_server, mut recv_stop_server) = tokio::sync::oneshot::channel();
tasks.spawn(log_err(async move {
let mut server_crypto = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key.into())?;
// make sure to configure this:
server_crypto.max_early_data_size = u32::MAX;
let server_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(server_crypto))?;
let server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
let endpoint = Endpoint::server(
server_config,
"127.0.0.1:4433".to_socket_addrs().unwrap().next().unwrap(),
)?;
loop {
let incoming = tokio::select! {
option = endpoint.accept() => match option { Some(incoming) => incoming, None => break },
result = &mut recv_stop_server => if result.is_ok() { break } else { continue },
};
// spawn subtask for connection
tokio::spawn(log_err(async move {
// attempt to accept 0-RTT data
let conn = match incoming.accept()?.into_0rtt() {
Ok((conn, _)) => conn,
Err(connecting) => connecting.await?,
};
loop {
let (mut send, mut recv) = match conn.accept_bi().await {
Ok(stream) => stream,
Err(ConnectionError::ApplicationClosed(_)) => break,
Err(e) => Err(e)?,
};
// spawn subtask for stream
tokio::spawn(log_err(async move {
let requested_len_le_vec = recv.read_to_end(4).await?;
anyhow::ensure!(requested_len_le_vec.len() == 4, "malformed request {:?}", requested_len_le_vec);
let mut requested_len_le = [0; 4];
requested_len_le.copy_from_slice(&requested_len_le_vec);
let requested_len = u32::from_le_bytes(requested_len_le) as usize;
info!(%requested_len, "received request");
const BUF_LEN: usize = 8 << 10;
let mut buf = [0; BUF_LEN];
for i in 0..requested_len {
buf[i % BUF_LEN] = (i % 0xff) as u8;
if i % BUF_LEN == BUF_LEN - 1 {
send.write_all(&buf).await?;
}
}
if requested_len % BUF_LEN != 0 {
send.write_all(&buf[..requested_len % BUF_LEN]).await?;
}
info!("wrote response");
Ok(())
}.instrument(info_span!("server stream"))));
}
Ok(())
}.instrument(info_span!("server conn"))));
}
// shut down server endpoint cleanly
endpoint.wait_idle().await;
Ok(())
}.instrument(info_span!("server"))));
// start client
async fn send_request(conn: &Connection, num_bytes: u32) -> Result<std::time::Duration, Error> {
let (mut send, mut recv) = conn.open_bi().await?;
let start_time = std::time::Instant::now();
debug!("sending request");
send.write_all(&num_bytes.to_le_bytes()).await?;
send.finish()?;
debug!("receiving response");
let response = recv.read_to_end(num_bytes as _).await?;
anyhow::ensure!(response.len() == num_bytes as usize, "response is the wrong number of bytes");
debug!("response received");
let end_time = std::time::Instant::now();
Ok(end_time.duration_since(start_time))
}
tasks.spawn(log_err(async move {
let mut client_crypto = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
// make sure to configure this:
client_crypto.enable_early_data = true;
let mut endpoint = Endpoint::client(
"0.0.0.0:0".to_socket_addrs().unwrap().next().unwrap()
)?;
let client_crypto =
quinn::crypto::rustls::QuicClientConfig::try_from(Arc::new(client_crypto))?;
endpoint.set_default_client_config(ClientConfig::new(Arc::new(client_crypto)));
// twice, so as to allow 0-rtt to work on the second time
for i in 0..2 {
info!(%i, "client iteration");
let connecting = endpoint.connect(
"127.0.0.1:4433".to_socket_addrs().unwrap().next().unwrap(),
"localhost",
)?;
// attempt to transmit 0-RTT data
let duration = match connecting.into_0rtt() {
Ok((conn, zero_rtt_accepted)) => {
debug!("attempting 0-rtt request");
let send_request_0rtt = send_request(&conn, num_bytes);
let mut send_request_0rtt_pinned = std::pin::pin!(send_request_0rtt);
tokio::select! {
result = &mut send_request_0rtt_pinned => result?,
accepted = zero_rtt_accepted => {
if accepted {
debug!("0-rtt accepted");
send_request_0rtt_pinned.await?
} else {
debug!("0-rtt rejected");
send_request(&conn, num_bytes).await?
}
}
}
}
Err(connecting) => {
debug!("not attempting 0-rtt request");
let conn = connecting.await?;
send_request(&conn, num_bytes).await?
}
};
if i == 1 {
println!("{}", duration.as_millis());
}
println!();
}
// tell the server to shut down so this process doesn't idle forever
let _ = send_stop_server.send(());
Ok(())
}.instrument(info_span!("client"))));
while tasks.join_next().await.is_some() {}
Ok(())
}
async fn log_err<F: std::future::IntoFuture<Output=Result<(), Error>>>(task: F) {
if let Err(e) = task.await {
error!("{}", e);
}
}
science.py
crates the data:
import subprocess
import csv
import os
def run_cargo_command(n):
try:
result = subprocess.run(
["cargo", "run", "--example", "newtoken", "--", str(n)],
capture_output=True, text=True, check=True
)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
return None
def initialize_from_file():
try:
with open('0rtt_time.csv', mode='r', newline='') as file:
last_line = list(csv.reader(file))[-1]
return int(last_line[0])
except (FileNotFoundError, IndexError):
return -100 # Start from -100 since 0 is the first increment
def main():
start_n = initialize_from_file() + 100
with open('0rtt_time.csv', mode='a', newline='') as file:
writer = csv.writer(file)
if os.stat('0rtt_time.csv').st_size == 0:
writer.writerow(['n', 'output']) # Write header if file is empty
for n in range(start_n, 20001, 100):
output = run_cargo_command(n)
if output is not None:
writer.writerow([n, output])
file.flush() # Flush after every write operation
print(f"Written: {n}, {output}")
else:
print(f"Failed to get output for n = {n}")
if __name__ == "__main__":
main()
graph_it.py
graphs the data, after you've manually renamed the files:
import matplotlib.pyplot as plt
import csv
def read_data(filename):
response_sizes = []
response_times = []
try:
with open(filename, mode='r') as file:
reader = csv.reader(file)
next(reader) # Skip the header row
for row in reader:
response_sizes.append(int(row[0]))
response_times.append(int(row[1]))
except FileNotFoundError:
print(f"The file {filename} was not found. Please ensure the file exists.")
except Exception as e:
print(f"An error occurred while reading {filename}: {e}")
return response_sizes, response_times
def plot_data(response_sizes1, response_times1, response_sizes2, response_times2):
plt.figure(figsize=(10, 5))
# Plotting points with lines for the feature data
plt.plot(response_sizes1, response_times1, 'o-', color='blue', label='Feature Data', alpha=0.5, markersize=5)
# Plotting points with lines for the main data
plt.plot(response_sizes2, response_times2, 'o-', color='red', label='Main Data', alpha=0.5, markersize=5)
plt.title('Comparison of Feature and Main Data')
plt.xlabel('Response Size')
plt.ylabel('Request/Response Time')
plt.grid(True)
plt.ylim(bottom=0) # Ensuring the y-axis starts at 0
plt.legend()
plt.show()
def main():
response_sizes1, response_times1 = read_data('0rtt_time_feature.csv')
response_sizes2, response_times2 = read_data('0rtt_time_main.csv')
if response_sizes1 and response_times1 and response_sizes2 and response_times2:
plot_data(response_sizes1, response_times1, response_sizes2, response_times2)
if __name__ == "__main__":
main()
Here's a nix-shell for the Python graphing:
{ pkgs ? import <nixpkgs> {} }:
pkgs.mkShell {
buildInputs = [
pkgs.python3
pkgs.python3Packages.matplotlib
];
shellHook = ''
echo "Python with matplotlib is ready to use."
'';
}
Other motivations may include:
- A server may wish for all connections to be validated before it serves them. If it responds to every initial connection attempt with
.retry()
, this means that requests take a minimum of 3 round trips to complete even for 1-RTT data, and makes 0-RTT impossible. If NEW_TOKENs are used, however, 1-RTT requests can once more be done in only 2 round trips, and 0-RTT requests become possible again. - A system may wish to allow 0-RTT data but mitigate or even make impossible retry attacks. If a server only accepts 0-RTT requests when their connection is validated, then replays are only possible to the extent that the server's
TokenReusePreventer
has false negatives, which may range from "sometimes" to "never," in contrast to the current situation of "always."
Code change:
Key points:
- Token generalized to have both a "retry token" variant and a "new token frame token" variant
- Additional byte acts as discriminant. It is not encrypted, but rather considered the "additional data" of the token's AEAD encryption. Other than that, the "retry token" variant remains the same.
- The
NewToken
variant'saead_from_hkdf
key derivation is based on an empty byte slice&[]
rather than theretry_src_cid
. - The
NewToken
variant's encrypted data consists of: randomly generated 128 bits, IP address (not including port), issued timestamp.
- Server sends client 2 NEW_TOKEN frames whenever client's path is validated (reconfigurable through
ServerConfig.new_tokens_sent_upon_validation
) -
ClientConfig.new_token_store: Option<Arc<dyn NewTokenStore>>
object stores NEW_TOKEN tokens received by client, and dispenses them for one-time use when connecting to sameserver_name
again- Default implementation
InMemNewTokenStore
stores 2 newest unused tokens for up to 256 servers with LRU eviction policy of server names, so as to pair well withrustls::client::ClientSessionMemoryCache
- Default implementation
-
ServerConfig.token_reuse_preventer: Option<Arc<Mutex<Box<dyn TokenReusePreventer>>>>
object is responsible for mitigating reuse of NEW_TOKEN tokens-
Default implementation
BloomTokenReusePreventer
:Divides all time into periods of length
new_token_lifetime
starting at unix epoch. Always maintains two "filters" which track used tokens which expires in that period. Turning over filters as time passes prevents infinite accumulation of tracked tokens.Filters start out as FxHashSets. This achieves the desirable property of linear-ish memory usage: if few NEW_TOKEN tokens are actually being used, the server's bloom token reuse preventer uses negligible memory.
Once a hash set filter would exceed a configurable maximum memory consumption, it's converted to a bloom filter. This achieves the property that an upper bound is set on the number of bytes allocated by the reuse preventer. Instead, as more tokens are added to the bloom filter, the false positive rate (tokens not actually reused but considered to be reused and thus ignored anyways) increases.
-
-
ServerConfig.new_token_lifetime
is different fromServerConfig.retry_token_lifetime
and defaults to 2 weeks.
TODO:
- [x] Send when validated rather than upon first connecting
- [x] Send upon path change
- [ ] Update stats
- [ ] Tests
- [x] Reuse prevention
- [x] Simplify it--it's not even used concurrently
- [ ] Make sure encryption is good
- [x] Make not break if receive Retry in response to request with NEW_TOKEN token
- [x] NEW_TOKEN tokens should not encode the port (?)
- [ ] We don't need a top-level Token.encode