From 673b330741fd3bdeecc63cd904bc9d43ebd79ac9 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 4 Apr 2025 13:11:36 +0200 Subject: [PATCH 1/5] download: merge integration test files --- download/tests/all.rs | 365 ++++++++++++++++++++++ download/tests/download-curl-resume.rs | 79 ----- download/tests/download-reqwest-resume.rs | 79 ----- download/tests/read-proxy-env.rs | 94 ------ download/tests/support/mod.rs | 119 ------- 5 files changed, 365 insertions(+), 371 deletions(-) create mode 100644 download/tests/all.rs delete mode 100644 download/tests/download-curl-resume.rs delete mode 100644 download/tests/download-reqwest-resume.rs delete mode 100644 download/tests/read-proxy-env.rs delete mode 100644 download/tests/support/mod.rs diff --git a/download/tests/all.rs b/download/tests/all.rs new file mode 100644 index 0000000000..aea33ac229 --- /dev/null +++ b/download/tests/all.rs @@ -0,0 +1,365 @@ +use std::convert::Infallible; +use std::fs; +use std::io; +use std::net::SocketAddr; +use std::path::Path; +use std::sync::mpsc::{Sender, channel}; +use std::thread; + +use http_body_util::Full; +use hyper::Request; +use hyper::body::Bytes; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use tempfile::TempDir; + +#[cfg(feature = "curl-backend")] +mod curl { + use std::sync::Mutex; + use std::sync::atomic::{AtomicBool, Ordering}; + + use url::Url; + + use super::{serve_file, tmp_dir, write_file}; + use download::*; + + #[tokio::test] + async fn partially_downloaded_file_gets_resumed_from_byte_offset() { + let tmpdir = tmp_dir(); + let from_path = tmpdir.path().join("download-source"); + write_file(&from_path, "xxx45"); + + let target_path = tmpdir.path().join("downloaded"); + write_file(&target_path, "123"); + + let from_url = Url::from_file_path(&from_path).unwrap(); + Backend::Curl + .download_to_path(&from_url, &target_path, true, None) + .await + .expect("Test download failed"); + + assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); + } + + #[tokio::test] + async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { + let tmpdir = tmp_dir(); + let target_path = tmpdir.path().join("downloaded"); + write_file(&target_path, "123"); + + let addr = serve_file(b"xxx45".to_vec()); + + let from_url = format!("http://{addr}").parse().unwrap(); + + let callback_partial = AtomicBool::new(false); + let callback_len = Mutex::new(None); + let received_in_callback = Mutex::new(Vec::new()); + + Backend::Curl + .download_to_path( + &from_url, + &target_path, + true, + Some(&|msg| { + match msg { + Event::ResumingPartialDownload => { + assert!(!callback_partial.load(Ordering::SeqCst)); + callback_partial.store(true, Ordering::SeqCst); + } + Event::DownloadContentLengthReceived(len) => { + let mut flag = callback_len.lock().unwrap(); + assert!(flag.is_none()); + *flag = Some(len); + } + Event::DownloadDataReceived(data) => { + for b in data.iter() { + received_in_callback.lock().unwrap().push(*b); + } + } + } + + Ok(()) + }), + ) + .await + .expect("Test download failed"); + + assert!(callback_partial.into_inner()); + assert_eq!(*callback_len.lock().unwrap(), Some(5)); + let observed_bytes = received_in_callback.into_inner().unwrap(); + assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']); + assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); + } +} + +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] +mod reqwest { + use std::env::{remove_var, set_var}; + use std::error::Error; + use std::net::TcpListener; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + use std::sync::{LazyLock, Mutex}; + use std::thread; + use std::time::Duration; + + use env_proxy::for_url; + use reqwest::{Client, Proxy}; + use url::Url; + + use super::{serve_file, tmp_dir, write_file}; + use download::{Backend, Event, TlsBackend}; + + static SERIALISE_TESTS: LazyLock> = + LazyLock::new(|| tokio::sync::Mutex::new(())); + + unsafe fn scrub_env() { + unsafe { + remove_var("http_proxy"); + remove_var("https_proxy"); + remove_var("HTTPS_PROXY"); + remove_var("ftp_proxy"); + remove_var("FTP_PROXY"); + remove_var("all_proxy"); + remove_var("ALL_PROXY"); + remove_var("no_proxy"); + remove_var("NO_PROXY"); + } + } + + // Tests for correctly retrieving the proxy (host, port) tuple from $https_proxy + #[tokio::test] + async fn read_basic_proxy_params() { + let _guard = SERIALISE_TESTS.lock().await; + // SAFETY: We are setting environment variables when `SERIALISE_TESTS` is locked, + // and those environment variables in question are not relevant elsewhere in the test suite. + unsafe { + scrub_env(); + set_var("https_proxy", "http://proxy.example.com:8080"); + } + let u = Url::parse("https://www.example.org").ok().unwrap(); + assert_eq!( + for_url(&u).host_port(), + Some(("proxy.example.com".to_string(), 8080)) + ); + } + + // Tests to verify if socks feature is available and being used + #[tokio::test] + async fn socks_proxy_request() { + static CALL_COUNT: AtomicUsize = AtomicUsize::new(0); + let _guard = SERIALISE_TESTS.lock().await; + + // SAFETY: We are setting environment variables when `SERIALISE_TESTS` is locked, + // and those environment variables in question are not relevant elsewhere in the test suite. + unsafe { + scrub_env(); + set_var("all_proxy", "socks5://127.0.0.1:1080"); + } + + thread::spawn(move || { + let listener = TcpListener::bind("127.0.0.1:1080").unwrap(); + let incoming = listener.incoming(); + for _ in incoming { + CALL_COUNT.fetch_add(1, Ordering::SeqCst); + } + }); + + let env_proxy = |url: &Url| for_url(url).to_url(); + let url = Url::parse("http://192.168.0.1/").unwrap(); + + let client = Client::builder() + // HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying + // `hyper` library that causes the `reqwest` client to hang in some cases. + // See for more details. + .pool_max_idle_per_host(0) + .proxy(Proxy::custom(env_proxy)) + .timeout(Duration::from_secs(1)) + .build() + .unwrap(); + let res = client.get(url.as_str()).send().await; + + if let Err(e) = res { + let s = e.source().unwrap(); + assert!( + s.to_string().contains("client error (Connect)"), + "Expected socks connect error, got: {s}", + ); + assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 1); + } else { + panic!("Socks proxy was ignored") + } + } + + #[tokio::test] + async fn resume_partial_from_file_url() { + let tmpdir = tmp_dir(); + let from_path = tmpdir.path().join("download-source"); + write_file(&from_path, "xxx45"); + + let target_path = tmpdir.path().join("downloaded"); + write_file(&target_path, "123"); + + let from_url = Url::from_file_path(&from_path).unwrap(); + Backend::Reqwest(TlsBackend::NativeTls) + .download_to_path(&from_url, &target_path, true, None) + .await + .expect("Test download failed"); + + assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); + } + + #[tokio::test] + async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { + let tmpdir = tmp_dir(); + let target_path = tmpdir.path().join("downloaded"); + write_file(&target_path, "123"); + + let addr = serve_file(b"xxx45".to_vec()); + + let from_url = format!("http://{addr}").parse().unwrap(); + + let callback_partial = AtomicBool::new(false); + let callback_len = Mutex::new(None); + let received_in_callback = Mutex::new(Vec::new()); + + Backend::Reqwest(TlsBackend::NativeTls) + .download_to_path( + &from_url, + &target_path, + true, + Some(&|msg| { + match msg { + Event::ResumingPartialDownload => { + assert!(!callback_partial.load(Ordering::SeqCst)); + callback_partial.store(true, Ordering::SeqCst); + } + Event::DownloadContentLengthReceived(len) => { + let mut flag = callback_len.lock().unwrap(); + assert!(flag.is_none()); + *flag = Some(len); + } + Event::DownloadDataReceived(data) => { + for b in data.iter() { + received_in_callback.lock().unwrap().push(*b); + } + } + } + + Ok(()) + }), + ) + .await + .expect("Test download failed"); + + assert!(callback_partial.into_inner()); + assert_eq!(*callback_len.lock().unwrap(), Some(5)); + let observed_bytes = received_in_callback.into_inner().unwrap(); + assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']); + assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); + } +} + +pub fn tmp_dir() -> TempDir { + tempfile::Builder::new() + .prefix("rustup-download-test-") + .tempdir() + .expect("creating tempdir for test") +} + +pub fn write_file(path: &Path, contents: &str) { + let mut file = fs::OpenOptions::new() + .write(true) + .truncate(true) + .create(true) + .open(path) + .expect("writing test data"); + + io::Write::write_all(&mut file, contents.as_bytes()).expect("writing test data"); + + file.sync_data().expect("writing test data"); +} + +// A dead simple hyper server implementation. +// For more info, see: +// https://hyper.rs/guides/1/server/hello-world/ +async fn run_server(addr_tx: Sender, addr: SocketAddr, contents: Vec) { + let svc = service_fn(move |req: Request| { + let contents = contents.clone(); + async move { + let res = serve_contents(req, contents); + Ok::<_, Infallible>(res) + } + }); + + let listener = tokio::net::TcpListener::bind(&addr) + .await + .expect("can not bind"); + + let addr = listener.local_addr().unwrap(); + addr_tx.send(addr).unwrap(); + + loop { + let (stream, _) = listener + .accept() + .await + .expect("could not accept connection"); + let io = hyper_util::rt::TokioIo::new(stream); + + let svc = svc.clone(); + tokio::spawn(async move { + if let Err(err) = http1::Builder::new().serve_connection(io, svc).await { + eprintln!("failed to serve connection: {:?}", err); + } + }); + } +} + +pub fn serve_file(contents: Vec) -> SocketAddr { + let addr = ([127, 0, 0, 1], 0).into(); + let (addr_tx, addr_rx) = channel(); + + thread::spawn(move || { + let server = run_server(addr_tx, addr, contents); + let rt = tokio::runtime::Runtime::new().expect("could not creating Runtime"); + rt.block_on(server); + }); + + let addr = addr_rx.recv(); + addr.unwrap() +} + +fn serve_contents( + req: hyper::Request, + contents: Vec, +) -> hyper::Response> { + let mut range_header = None; + let (status, body) = if let Some(range) = req.headers().get(hyper::header::RANGE) { + // extract range "bytes={start}-" + let range = range.to_str().expect("unexpected Range header"); + assert!(range.starts_with("bytes=")); + let range = range.trim_start_matches("bytes="); + assert!(range.ends_with('-')); + let range = range.trim_end_matches('-'); + assert_eq!(range.split('-').count(), 1); + let start: u64 = range.parse().expect("unexpected Range header"); + + range_header = Some(format!("bytes {}-{len}/{len}", start, len = contents.len())); + ( + hyper::StatusCode::PARTIAL_CONTENT, + contents[start as usize..].to_vec(), + ) + } else { + (hyper::StatusCode::OK, contents) + }; + + let mut res = hyper::Response::builder() + .status(status) + .header(hyper::header::CONTENT_LENGTH, body.len()) + .body(Full::new(Bytes::from(body))) + .unwrap(); + if let Some(range) = range_header { + res.headers_mut() + .insert(hyper::header::CONTENT_RANGE, range.parse().unwrap()); + } + res +} diff --git a/download/tests/download-curl-resume.rs b/download/tests/download-curl-resume.rs deleted file mode 100644 index e715f763d7..0000000000 --- a/download/tests/download-curl-resume.rs +++ /dev/null @@ -1,79 +0,0 @@ -#![cfg(feature = "curl-backend")] - -use std::sync::Mutex; -use std::sync::atomic::{AtomicBool, Ordering}; - -use url::Url; - -use download::*; - -mod support; -use crate::support::{serve_file, tmp_dir, write_file}; - -#[tokio::test] -async fn partially_downloaded_file_gets_resumed_from_byte_offset() { - let tmpdir = tmp_dir(); - let from_path = tmpdir.path().join("download-source"); - write_file(&from_path, "xxx45"); - - let target_path = tmpdir.path().join("downloaded"); - write_file(&target_path, "123"); - - let from_url = Url::from_file_path(&from_path).unwrap(); - Backend::Curl - .download_to_path(&from_url, &target_path, true, None) - .await - .expect("Test download failed"); - - assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); -} - -#[tokio::test] -async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { - let tmpdir = tmp_dir(); - let target_path = tmpdir.path().join("downloaded"); - write_file(&target_path, "123"); - - let addr = serve_file(b"xxx45".to_vec()); - - let from_url = format!("http://{addr}").parse().unwrap(); - - let callback_partial = AtomicBool::new(false); - let callback_len = Mutex::new(None); - let received_in_callback = Mutex::new(Vec::new()); - - Backend::Curl - .download_to_path( - &from_url, - &target_path, - true, - Some(&|msg| { - match msg { - Event::ResumingPartialDownload => { - assert!(!callback_partial.load(Ordering::SeqCst)); - callback_partial.store(true, Ordering::SeqCst); - } - Event::DownloadContentLengthReceived(len) => { - let mut flag = callback_len.lock().unwrap(); - assert!(flag.is_none()); - *flag = Some(len); - } - Event::DownloadDataReceived(data) => { - for b in data.iter() { - received_in_callback.lock().unwrap().push(*b); - } - } - } - - Ok(()) - }), - ) - .await - .expect("Test download failed"); - - assert!(callback_partial.into_inner()); - assert_eq!(*callback_len.lock().unwrap(), Some(5)); - let observed_bytes = received_in_callback.into_inner().unwrap(); - assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']); - assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); -} diff --git a/download/tests/download-reqwest-resume.rs b/download/tests/download-reqwest-resume.rs deleted file mode 100644 index 881e8bfbff..0000000000 --- a/download/tests/download-reqwest-resume.rs +++ /dev/null @@ -1,79 +0,0 @@ -#![cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - -use std::sync::Mutex; -use std::sync::atomic::{AtomicBool, Ordering}; - -use url::Url; - -use download::*; - -mod support; -use crate::support::{serve_file, tmp_dir, write_file}; - -#[tokio::test] -async fn resume_partial_from_file_url() { - let tmpdir = tmp_dir(); - let from_path = tmpdir.path().join("download-source"); - write_file(&from_path, "xxx45"); - - let target_path = tmpdir.path().join("downloaded"); - write_file(&target_path, "123"); - - let from_url = Url::from_file_path(&from_path).unwrap(); - Backend::Reqwest(TlsBackend::NativeTls) - .download_to_path(&from_url, &target_path, true, None) - .await - .expect("Test download failed"); - - assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); -} - -#[tokio::test] -async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { - let tmpdir = tmp_dir(); - let target_path = tmpdir.path().join("downloaded"); - write_file(&target_path, "123"); - - let addr = serve_file(b"xxx45".to_vec()); - - let from_url = format!("http://{addr}").parse().unwrap(); - - let callback_partial = AtomicBool::new(false); - let callback_len = Mutex::new(None); - let received_in_callback = Mutex::new(Vec::new()); - - Backend::Reqwest(TlsBackend::NativeTls) - .download_to_path( - &from_url, - &target_path, - true, - Some(&|msg| { - match msg { - Event::ResumingPartialDownload => { - assert!(!callback_partial.load(Ordering::SeqCst)); - callback_partial.store(true, Ordering::SeqCst); - } - Event::DownloadContentLengthReceived(len) => { - let mut flag = callback_len.lock().unwrap(); - assert!(flag.is_none()); - *flag = Some(len); - } - Event::DownloadDataReceived(data) => { - for b in data.iter() { - received_in_callback.lock().unwrap().push(*b); - } - } - } - - Ok(()) - }), - ) - .await - .expect("Test download failed"); - - assert!(callback_partial.into_inner()); - assert_eq!(*callback_len.lock().unwrap(), Some(5)); - let observed_bytes = received_in_callback.into_inner().unwrap(); - assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']); - assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); -} diff --git a/download/tests/read-proxy-env.rs b/download/tests/read-proxy-env.rs deleted file mode 100644 index 1ec8998b05..0000000000 --- a/download/tests/read-proxy-env.rs +++ /dev/null @@ -1,94 +0,0 @@ -#![cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - -use std::env::{remove_var, set_var}; -use std::error::Error; -use std::net::TcpListener; -use std::sync::LazyLock; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::thread; -use std::time::Duration; - -use env_proxy::for_url; -use reqwest::{Client, Proxy}; -use tokio::sync::Mutex; -use url::Url; - -static SERIALISE_TESTS: LazyLock> = LazyLock::new(|| Mutex::new(())); - -unsafe fn scrub_env() { - unsafe { - remove_var("http_proxy"); - remove_var("https_proxy"); - remove_var("HTTPS_PROXY"); - remove_var("ftp_proxy"); - remove_var("FTP_PROXY"); - remove_var("all_proxy"); - remove_var("ALL_PROXY"); - remove_var("no_proxy"); - remove_var("NO_PROXY"); - } -} - -// Tests for correctly retrieving the proxy (host, port) tuple from $https_proxy -#[tokio::test] -async fn read_basic_proxy_params() { - let _guard = SERIALISE_TESTS.lock().await; - // SAFETY: We are setting environment variables when `SERIALISE_TESTS` is locked, - // and those environment variables in question are not relevant elsewhere in the test suite. - unsafe { - scrub_env(); - set_var("https_proxy", "http://proxy.example.com:8080"); - } - let u = Url::parse("https://www.example.org").ok().unwrap(); - assert_eq!( - for_url(&u).host_port(), - Some(("proxy.example.com".to_string(), 8080)) - ); -} - -// Tests to verify if socks feature is available and being used -#[tokio::test] -async fn socks_proxy_request() { - static CALL_COUNT: AtomicUsize = AtomicUsize::new(0); - let _guard = SERIALISE_TESTS.lock().await; - - // SAFETY: We are setting environment variables when `SERIALISE_TESTS` is locked, - // and those environment variables in question are not relevant elsewhere in the test suite. - unsafe { - scrub_env(); - set_var("all_proxy", "socks5://127.0.0.1:1080"); - } - - thread::spawn(move || { - let listener = TcpListener::bind("127.0.0.1:1080").unwrap(); - let incoming = listener.incoming(); - for _ in incoming { - CALL_COUNT.fetch_add(1, Ordering::SeqCst); - } - }); - - let env_proxy = |url: &Url| for_url(url).to_url(); - let url = Url::parse("http://192.168.0.1/").unwrap(); - - let client = Client::builder() - // HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying - // `hyper` library that causes the `reqwest` client to hang in some cases. - // See for more details. - .pool_max_idle_per_host(0) - .proxy(Proxy::custom(env_proxy)) - .timeout(Duration::from_secs(1)) - .build() - .unwrap(); - let res = client.get(url.as_str()).send().await; - - if let Err(e) = res { - let s = e.source().unwrap(); - assert!( - s.to_string().contains("client error (Connect)"), - "Expected socks connect error, got: {s}", - ); - assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 1); - } else { - panic!("Socks proxy was ignored") - } -} diff --git a/download/tests/support/mod.rs b/download/tests/support/mod.rs deleted file mode 100644 index 169377cf45..0000000000 --- a/download/tests/support/mod.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::convert::Infallible; -use std::fs; -use std::io; -use std::net::SocketAddr; -use std::path::Path; -use std::sync::mpsc::{Sender, channel}; -use std::thread; - -use http_body_util::Full; -use hyper::Request; -use hyper::body::Bytes; -use hyper::server::conn::http1; -use hyper::service::service_fn; -use tempfile::TempDir; - -pub fn tmp_dir() -> TempDir { - tempfile::Builder::new() - .prefix("rustup-download-test-") - .tempdir() - .expect("creating tempdir for test") -} - -pub fn write_file(path: &Path, contents: &str) { - let mut file = fs::OpenOptions::new() - .write(true) - .truncate(true) - .create(true) - .open(path) - .expect("writing test data"); - - io::Write::write_all(&mut file, contents.as_bytes()).expect("writing test data"); - - file.sync_data().expect("writing test data"); -} - -// A dead simple hyper server implementation. -// For more info, see: -// https://hyper.rs/guides/1/server/hello-world/ -async fn run_server(addr_tx: Sender, addr: SocketAddr, contents: Vec) { - let svc = service_fn(move |req: Request| { - let contents = contents.clone(); - async move { - let res = serve_contents(req, contents); - Ok::<_, Infallible>(res) - } - }); - - let listener = tokio::net::TcpListener::bind(&addr) - .await - .expect("can not bind"); - - let addr = listener.local_addr().unwrap(); - addr_tx.send(addr).unwrap(); - - loop { - let (stream, _) = listener - .accept() - .await - .expect("could not accept connection"); - let io = hyper_util::rt::TokioIo::new(stream); - - let svc = svc.clone(); - tokio::spawn(async move { - if let Err(err) = http1::Builder::new().serve_connection(io, svc).await { - eprintln!("failed to serve connection: {:?}", err); - } - }); - } -} - -pub fn serve_file(contents: Vec) -> SocketAddr { - let addr = ([127, 0, 0, 1], 0).into(); - let (addr_tx, addr_rx) = channel(); - - thread::spawn(move || { - let server = run_server(addr_tx, addr, contents); - let rt = tokio::runtime::Runtime::new().expect("could not creating Runtime"); - rt.block_on(server); - }); - - let addr = addr_rx.recv(); - addr.unwrap() -} - -fn serve_contents( - req: hyper::Request, - contents: Vec, -) -> hyper::Response> { - let mut range_header = None; - let (status, body) = if let Some(range) = req.headers().get(hyper::header::RANGE) { - // extract range "bytes={start}-" - let range = range.to_str().expect("unexpected Range header"); - assert!(range.starts_with("bytes=")); - let range = range.trim_start_matches("bytes="); - assert!(range.ends_with('-')); - let range = range.trim_end_matches('-'); - assert_eq!(range.split('-').count(), 1); - let start: u64 = range.parse().expect("unexpected Range header"); - - range_header = Some(format!("bytes {}-{len}/{len}", start, len = contents.len())); - ( - hyper::StatusCode::PARTIAL_CONTENT, - contents[start as usize..].to_vec(), - ) - } else { - (hyper::StatusCode::OK, contents) - }; - - let mut res = hyper::Response::builder() - .status(status) - .header(hyper::header::CONTENT_LENGTH, body.len()) - .body(Full::new(Bytes::from(body))) - .unwrap(); - if let Some(range) = range_header { - res.headers_mut() - .insert(hyper::header::CONTENT_RANGE, range.parse().unwrap()); - } - res -} From 437391659e4fda3b34faa168381715a1d9a95d1e Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 4 Apr 2025 13:16:23 +0200 Subject: [PATCH 2/5] Fold download crate back into rustup --- Cargo.lock | 29 ++++-------- Cargo.toml | 24 +++++++--- ci/run.bash | 32 ++++++-------- download/Cargo.toml | 44 ------------------- download/src/lib.rs => src/download/mod.rs | 14 ++++-- .../tests/all.rs => src/download/tests.rs | 4 +- src/lib.rs | 1 + src/utils/mod.rs | 7 ++- 8 files changed, 55 insertions(+), 100 deletions(-) delete mode 100644 download/Cargo.toml rename download/src/lib.rs => src/download/mod.rs (98%) rename download/tests/all.rs => src/download/tests.rs (99%) diff --git a/Cargo.lock b/Cargo.lock index 9df0b54f73..e8d0c744fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -574,26 +574,6 @@ dependencies = [ "syn", ] -[[package]] -name = "download" -version = "1.28.1" -dependencies = [ - "anyhow", - "curl", - "env_proxy", - "http-body-util", - "hyper", - "hyper-util", - "reqwest", - "rustls", - "rustls-platform-verifier", - "tempfile", - "thiserror 2.0.12", - "tokio", - "tokio-stream", - "url", -] - [[package]] name = "dunce" version = "1.0.5" @@ -2257,13 +2237,17 @@ dependencies = [ "chrono", "clap", "clap_complete", - "download", + "curl", "effective-limits", "enum-map", + "env_proxy", "flate2", "fs_at", "git-testament", "home", + "http-body-util", + "hyper", + "hyper-util", "itertools 0.14.0", "libc", "opener", @@ -2277,8 +2261,11 @@ dependencies = [ "rand 0.9.0", "regex", "remove_dir_all", + "reqwest", "retry", "rs_tracing", + "rustls", + "rustls-platform-verifier", "same-file", "scopeguard", "semver", diff --git a/Cargo.toml b/Cargo.toml index f6ce1fef9c..ca5bf4395a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,13 +11,19 @@ repository = "https://github.com/rust-lang/rustup" build = "build.rs" [features] -curl-backend = ["download/curl-backend"] +curl-backend = ["dep:curl"] default = ["curl-backend", "reqwest-native-tls", "reqwest-rustls-tls"] vendored-openssl = ['openssl/vendored'] -reqwest-native-tls = ["download/reqwest-native-tls"] -reqwest-rustls-tls = ["download/reqwest-rustls-tls"] +reqwest-native-tls = ["reqwest/native-tls", "dep:reqwest", "dep:env_proxy"] +reqwest-rustls-tls = [ + "reqwest/rustls-tls-manual-roots-no-provider", + "dep:env_proxy", + "dep:reqwest", + "dep:rustls", + "dep:rustls-platform-verifier", +] # Include in the default set to disable self-update and uninstall. no-self-update = [] @@ -40,9 +46,10 @@ cfg-if = "1.0" chrono = { version = "0.4", default-features = false, features = ["std"] } clap = { version = "4", features = ["derive", "wrap_help"] } clap_complete = "4" -download = { path = "download", default-features = false } +curl = { version = "0.4.44", optional = true } effective-limits = "0.5.5" enum-map = "2.5.0" +env_proxy = { version = "0.4.1", optional = true } flate2 = "1" fs_at.workspace = true git-testament = "0.2" @@ -60,8 +67,11 @@ pulldown-cmark = { version = "0.13", default-features = false } rand = "0.9" regex = "1" remove_dir_all = { version = "1.0.0", features = ["parallel"] } +reqwest = { version = "0.12", default-features = false, features = ["blocking", "gzip", "socks", "stream"], optional = true } retry = { version = "2", default-features = false, features = ["random"] } rs_tracing = { version = "1.1", features = ["rs_tracing"] } +rustls = { version = "0.23", optional = true, default-features = false, features = ["logging", "aws_lc_rs", "tls12"] } +rustls-platform-verifier = { version = "0.5", optional = true } same-file = "1" semver = "1.0" serde = { version = "1.0", features = ["derive"] } @@ -115,6 +125,9 @@ version = "0.59" [dev-dependencies] enum-map = "2.5.0" +http-body-util = "0.1.0" +hyper = { version = "1.0", default-features = false, features = ["server", "http1"] } +hyper-util = { version = "0.1.1", features = ["tokio"] } platforms.workspace = true proptest.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } @@ -126,9 +139,6 @@ platforms.workspace = true [lints] workspace = true -[workspace] -members = ["download"] - [workspace.package] version = "1.28.1" edition = "2024" diff --git a/ci/run.bash b/ci/run.bash index 73f683bbd3..8f6d780946 100644 --- a/ci/run.bash +++ b/ci/run.bash @@ -50,10 +50,18 @@ target_cargo() { target_cargo build -download_pkg_test() { - features=('--no-default-features' '--features' 'curl-backend,reqwest-native-tls') +# Machines have 7GB of RAM, and our target/ contents is large enough that +# thrashing will occur if we build-run-build-run rather than +# build-build-build-run-run-run. Since this is used solely for non-release +# artifacts, we try to keep features consistent across the builds, whether for +# docs/test/runs etc. +build_test() { + cmd="$1" + shift + + features=('--features' 'curl-backend,reqwest-native-tls') case "$TARGET" in - # these platforms aren't supported by ring: + # these platforms aren't supported by aws-lc-rs: powerpc* ) ;; mips* ) ;; riscv* ) ;; @@ -62,23 +70,11 @@ download_pkg_test() { * ) features+=('--features' 'reqwest-rustls-tls') ;; esac - cargo "$1" --locked --profile "$BUILD_PROFILE" --target "$TARGET" "${features[@]}" -p download -} - -# Machines have 7GB of RAM, and our target/ contents is large enough that -# thrashing will occur if we build-run-build-run rather than -# build-build-build-run-run-run. Since this is used solely for non-release -# artifacts, we try to keep features consistent across the builds, whether for -# docs/test/runs etc. -build_test() { - cmd="$1" - shift - download_pkg_test "${cmd}" if [ "build" = "${cmd}" ]; then - target_cargo "${cmd}" --workspace --all-targets --features test + target_cargo "${cmd}" --workspace --all-targets "${features[@]}" --features test else - target_cargo "${cmd}" --workspace --features test --tests - target_cargo "${cmd}" --doc --workspace --features test + target_cargo "${cmd}" --workspace "${features[@]}" --features test --tests + target_cargo "${cmd}" --doc --workspace "${features[@]}" --features test fi } diff --git a/download/Cargo.toml b/download/Cargo.toml deleted file mode 100644 index ed99ed2fdb..0000000000 --- a/download/Cargo.toml +++ /dev/null @@ -1,44 +0,0 @@ -[package] -name = "download" -version.workspace = true -edition.workspace = true -license.workspace = true - -[features] -default = ["reqwest-rustls-tls", "reqwest-native-tls"] -curl-backend = ["curl"] -reqwest-native-tls = [ - "reqwest/native-tls", - "dep:reqwest", - "dep:env_proxy", - "dep:tokio-stream", -] -reqwest-rustls-tls = [ - "reqwest/rustls-tls-manual-roots-no-provider", - "dep:env_proxy", - "dep:reqwest", - "dep:rustls", - "dep:rustls-platform-verifier", - "dep:tokio-stream", -] - -[dependencies] -anyhow.workspace = true -curl = { version = "0.4.44", optional = true } -env_proxy = { version = "0.4.1", optional = true } -reqwest = { version = "0.12", default-features = false, features = ["blocking", "gzip", "socks", "stream"], optional = true } -rustls = { version = "0.23", optional = true, default-features = false, features = ["logging", "aws_lc_rs", "tls12"] } -rustls-platform-verifier = { version = "0.5", optional = true } -thiserror.workspace = true -tokio-stream = { workspace = true, optional = true } -url.workspace = true - -[dev-dependencies] -http-body-util = "0.1.0" -hyper = { version = "1.0", default-features = false, features = ["server", "http1"] } -hyper-util = { version = "0.1.1", features = ["tokio"] } -tempfile.workspace = true -tokio = { workspace = true, default-features = false, features = ["sync"] } - -[lints] -workspace = true diff --git a/download/src/lib.rs b/src/download/mod.rs similarity index 98% rename from download/src/lib.rs rename to src/download/mod.rs index 9266129058..143cdd8b10 100644 --- a/download/src/lib.rs +++ b/src/download/mod.rs @@ -3,11 +3,19 @@ use std::fs::remove_file; use std::path::Path; -use anyhow::Context; -pub use anyhow::Result; +#[cfg(any( + not(feature = "curl-backend"), + not(feature = "reqwest-rustls-tls"), + not(feature = "reqwest-native-tls") +))] +use anyhow::anyhow; +use anyhow::{Context, Result}; use thiserror::Error; use url::Url; +#[cfg(test)] +mod tests; + /// User agent header value for HTTP request. /// See: https://github.com/rust-lang/rustup/issues/2860. #[cfg(feature = "curl-backend")] @@ -488,8 +496,6 @@ pub enum DownloadError { HttpStatus(u32), #[error("file not found")] FileNotFound, - #[error("download backend '{0}' unavailable")] - BackendUnavailable(&'static str), #[error("{0}")] Message(String), #[error(transparent)] diff --git a/download/tests/all.rs b/src/download/tests.rs similarity index 99% rename from download/tests/all.rs rename to src/download/tests.rs index aea33ac229..edcfb39cb7 100644 --- a/download/tests/all.rs +++ b/src/download/tests.rs @@ -21,7 +21,7 @@ mod curl { use url::Url; use super::{serve_file, tmp_dir, write_file}; - use download::*; + use crate::download::{Backend, Event}; #[tokio::test] async fn partially_downloaded_file_gets_resumed_from_byte_offset() { @@ -107,7 +107,7 @@ mod reqwest { use url::Url; use super::{serve_file, tmp_dir, write_file}; - use download::{Backend, Event, TlsBackend}; + use crate::download::{Backend, Event, TlsBackend}; static SERIALISE_TESTS: LazyLock> = LazyLock::new(|| tokio::sync::Mutex::new(())); diff --git a/src/lib.rs b/src/lib.rs index 52d5786c60..08f855a425 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,6 +74,7 @@ mod command; mod config; mod diskio; pub mod dist; +mod download; pub mod env_var; pub mod errors; mod fallback_settings; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 52c76ac2e0..821d74c5f0 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -164,7 +164,7 @@ pub(crate) async fn download_file_with_resume( notify_handler: &dyn Fn(Notification<'_>), process: &Process, ) -> Result<()> { - use download::DownloadError as DEK; + use crate::download::DownloadError as DEK; match download_file_( url, path, @@ -213,8 +213,7 @@ async fn download_file_( process: &Process, ) -> Result<()> { #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - use download::TlsBackend; - use download::{Backend, Event}; + use crate::download::{TlsBackend, Backend, Event}; use sha2::Digest; use std::cell::RefCell; @@ -224,7 +223,7 @@ async fn download_file_( // This callback will write the download to disk and optionally // hash the contents, then forward the notification up the stack - let callback: &dyn Fn(Event<'_>) -> download::Result<()> = &|msg| { + let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| { if let Event::DownloadDataReceived(data) = msg { if let Some(h) = hasher.borrow_mut().as_mut() { h.update(data); From 586e32d1fca68226dc197823db9c2f995b599177 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 4 Apr 2025 13:24:57 +0200 Subject: [PATCH 3/5] Drop workspace indirection --- Cargo.toml | 75 +++++++++++++++++------------------------------------- 1 file changed, 23 insertions(+), 52 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ca5bf4395a..32fbbee2ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "rustup" -version.workspace = true -edition.workspace = true -license.workspace = true +version = "1.28.1" +edition = "2024" +license = "MIT OR Apache-2.0" description = "Manage multiple rust installations with ease" homepage = "https://github.com/rust-lang/rustup" keywords = ["rustup", "multirust", "install", "proxy"] @@ -41,7 +41,7 @@ test = ["dep:walkdir"] # Sorted by alphabetic order [dependencies] -anyhow.workspace = true +anyhow = "1.0.69" cfg-if = "1.0" chrono = { version = "0.4", default-features = false, features = ["std"] } clap = { version = "4", features = ["derive", "wrap_help"] } @@ -51,7 +51,7 @@ effective-limits = "0.5.5" enum-map = "2.5.0" env_proxy = { version = "0.4.1", optional = true } flate2 = "1" -fs_at.workspace = true +fs_at = "0.2.1" git-testament = "0.2" home = "0.5.4" itertools = "0.14" @@ -60,9 +60,9 @@ opener = "0.7.0" # `openssl` is used by `curl` or `reqwest` backend although it isn't imported by rustup: this # allows controlling the vendoring status without exposing the presence of the download crate. openssl = { version = "0.10", optional = true } -opentelemetry = { workspace = true, optional = true } -opentelemetry-otlp = { workspace = true, optional = true } -opentelemetry_sdk = { workspace = true, optional = true } +opentelemetry = { version = "0.29", optional = true } +opentelemetry-otlp = { version = "0.29", features = ["grpc-tonic"], optional = true } +opentelemetry_sdk = { version = "0.29", features = ["rt-tokio"], optional = true } pulldown-cmark = { version = "0.13", default-features = false } rand = "0.9" regex = "1" @@ -79,20 +79,20 @@ sha2 = "0.10" sharded-slab = "0.1.1" strsim = "0.11" tar = "0.4.26" -tempfile.workspace = true -termcolor.workspace = true -thiserror.workspace = true +tempfile = "3.8" +termcolor = "1.2" +thiserror = "2" threadpool = "1" -tokio.workspace = true -tokio-retry.workspace = true -tokio-stream.workspace = true +tokio = { version = "1.26.0", default-features = false, features = ["macros", "rt-multi-thread", "sync"] } +tokio-retry = "0.3.0" +tokio-stream = "0.1.14" toml = "0.8" -tracing.workspace = true -tracing-opentelemetry = { workspace = true, optional = true } -tracing-subscriber = { workspace = true, features = ["env-filter"] } -url.workspace = true +tracing = "0.1" +tracing-opentelemetry = { version = "0.30", optional = true } +tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +url = "2.4" wait-timeout = "0.2" -walkdir = { workspace = true, optional = true } +walkdir = { version = "2", optional = true } xz2 = "0.1.3" zstd = "0.13" @@ -128,46 +128,17 @@ enum-map = "2.5.0" http-body-util = "0.1.0" hyper = { version = "1.0", default-features = false, features = ["server", "http1"] } hyper-util = { version = "0.1.1", features = ["tokio"] } -platforms.workspace = true -proptest.workspace = true -tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +platforms = "3.4" +proptest = "1.1.0" trycmd = "0.15.0" [build-dependencies] -platforms.workspace = true - -[lints] -workspace = true - -[workspace.package] -version = "1.28.1" -edition = "2024" -license = "MIT OR Apache-2.0" - -[workspace.dependencies] -anyhow = "1.0.69" -fs_at = "0.2.1" -opentelemetry = "0.29" -opentelemetry-otlp = { version = "0.29", features = ["grpc-tonic"] } -opentelemetry_sdk = { version = "0.29", features = ["rt-tokio"] } platforms = "3.4" -proptest = "1.1.0" -tempfile = "3.8" -termcolor = "1.2" -thiserror = "2" -tokio = { version = "1.26.0", default-features = false, features = ["macros", "rt-multi-thread"] } -tokio-retry = { version = "0.3.0" } -tokio-stream = { version = "0.1.14" } -tracing = "0.1" -tracing-opentelemetry = "0.30" -tracing-subscriber = "0.3.16" -url = "2.4" -walkdir = "2" -[workspace.lints.rust] +[lints.rust] rust_2018_idioms = "deny" -[workspace.lints.clippy] +[lints.clippy] # `dbg!()` and `todo!()` clearly shouldn't make it to production: dbg_macro = "warn" todo = "warn" From 336a9b9fb83037b8328c02c5308281092f8dea48 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 4 Apr 2025 13:30:41 +0200 Subject: [PATCH 4/5] Simplify download API abstraction --- src/cli/self_update.rs | 5 +- src/cli/self_update/windows.rs | 3 +- src/dist/download.rs | 8 +- src/dist/manifestation/tests.rs | 3 +- src/download/mod.rs | 198 ++++++++++++++++++++++++++++++-- src/utils/mod.rs | 178 ---------------------------- 6 files changed, 201 insertions(+), 194 deletions(-) diff --git a/src/cli/self_update.rs b/src/cli/self_update.rs index bfb5cf486c..16dc43e769 100644 --- a/src/cli/self_update.rs +++ b/src/cli/self_update.rs @@ -48,6 +48,7 @@ use same_file::Handle; use serde::{Deserialize, Serialize}; use tracing::{error, info, trace, warn}; +use crate::download::download_file; use crate::{ DUP_TOOLS, TOOLS, cli::{ @@ -1134,7 +1135,7 @@ pub(crate) async fn prepare_update(process: &Process) -> Result> // Download new version info!("downloading self-update"); - utils::download_file(&download_url, &setup_path, None, &|_| (), process).await?; + download_file(&download_url, &setup_path, None, &|_| (), process).await?; // Mark as executable utils::make_executable(&setup_path)?; @@ -1153,7 +1154,7 @@ async fn get_available_rustup_version(process: &Process) -> Result { let release_file_url = format!("{update_root}/release-stable.toml"); let release_file_url = utils::parse_url(&release_file_url)?; let release_file = tempdir.path().join("release-stable.toml"); - utils::download_file(&release_file_url, &release_file, None, &|_| (), process).await?; + download_file(&release_file_url, &release_file, None, &|_| (), process).await?; let release_toml_str = utils::read_file("rustup release", &release_file)?; let release_toml = toml::from_str::(&release_toml_str) .context("unable to parse rustup release file")?; diff --git a/src/cli/self_update/windows.rs b/src/cli/self_update/windows.rs index fbd0d0caac..24caab6a38 100644 --- a/src/cli/self_update/windows.rs +++ b/src/cli/self_update/windows.rs @@ -22,6 +22,7 @@ use super::common; use super::{InstallOpts, install_bins, report_error}; use crate::cli::{download_tracker::DownloadTracker, markdown::md}; use crate::dist::TargetTriple; +use crate::download::download_file; use crate::process::{Process, terminalsource::ColorableTerminal}; use crate::utils::{self, Notification}; @@ -276,7 +277,7 @@ pub(crate) async fn try_install_msvc( download_tracker.lock().unwrap().download_finished(); info!("downloading Visual Studio installer"); - utils::download_file( + download_file( &visual_studio_url, &visual_studio, None, diff --git a/src/dist/download.rs b/src/dist/download.rs index 47c924b7bf..38a8e0f125 100644 --- a/src/dist/download.rs +++ b/src/dist/download.rs @@ -8,6 +8,8 @@ use url::Url; use crate::dist::notifications::*; use crate::dist::temp; +use crate::download::download_file; +use crate::download::download_file_with_resume; use crate::errors::*; use crate::process::Process; use crate::utils; @@ -73,7 +75,7 @@ impl<'a> DownloadCfg<'a> { let mut hasher = Sha256::new(); - if let Err(e) = utils::download_file_with_resume( + if let Err(e) = download_file_with_resume( url, &partial_file_path, Some(&mut hasher), @@ -134,7 +136,7 @@ impl<'a> DownloadCfg<'a> { let hash_url = utils::parse_url(&(url.to_owned() + ".sha256"))?; let hash_file = self.tmp_cx.new_file()?; - utils::download_file( + download_file( &hash_url, &hash_file, None, @@ -179,7 +181,7 @@ impl<'a> DownloadCfg<'a> { let file = self.tmp_cx.new_file_with_ext("", ext)?; let mut hasher = Sha256::new(); - utils::download_file( + download_file( &url, &file, Some(&mut hasher), diff --git a/src/dist/manifestation/tests.rs b/src/dist/manifestation/tests.rs index 15f0eeae61..fcc743ed68 100644 --- a/src/dist/manifestation/tests.rs +++ b/src/dist/manifestation/tests.rs @@ -23,6 +23,7 @@ use crate::{ prefix::InstallPrefix, temp, }, + download::download_file, errors::RustupError, process::TestProcess, test::{ @@ -495,7 +496,7 @@ impl TestContext { // Download the dist manifest and place it into the installation prefix let manifest_url = make_manifest_url(&self.url, &self.toolchain)?; let manifest_file = self.tmp_cx.new_file()?; - utils::download_file(&manifest_url, &manifest_file, None, &|_| {}, dl_cfg.process).await?; + download_file(&manifest_url, &manifest_file, None, &|_| {}, dl_cfg.process).await?; let manifest_str = utils::read_file("manifest", &manifest_file)?; let manifest = Manifest::parse(&manifest_str)?; diff --git a/src/download/mod.rs b/src/download/mod.rs index 143cdd8b10..c84e1e660d 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -10,12 +10,192 @@ use std::path::Path; ))] use anyhow::anyhow; use anyhow::{Context, Result}; +use sha2::Sha256; use thiserror::Error; +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] +use tracing::info; use url::Url; +use crate::{errors::RustupError, process::Process, utils::Notification}; + #[cfg(test)] mod tests; +pub(crate) async fn download_file( + url: &Url, + path: &Path, + hasher: Option<&mut Sha256>, + notify_handler: &dyn Fn(Notification<'_>), + process: &Process, +) -> Result<()> { + download_file_with_resume(url, path, hasher, false, ¬ify_handler, process).await +} + +pub(crate) async fn download_file_with_resume( + url: &Url, + path: &Path, + hasher: Option<&mut Sha256>, + resume_from_partial: bool, + notify_handler: &dyn Fn(Notification<'_>), + process: &Process, +) -> Result<()> { + use crate::download::DownloadError as DEK; + match download_file_( + url, + path, + hasher, + resume_from_partial, + notify_handler, + process, + ) + .await + { + Ok(_) => Ok(()), + Err(e) => { + if e.downcast_ref::().is_some() { + return Err(e); + } + let is_client_error = match e.downcast_ref::() { + // Specifically treat the bad partial range error as not our + // fault in case it was something odd which happened. + Some(DEK::HttpStatus(416)) => false, + Some(DEK::HttpStatus(400..=499)) | Some(DEK::FileNotFound) => true, + _ => false, + }; + Err(e).with_context(|| { + if is_client_error { + RustupError::DownloadNotExists { + url: url.clone(), + path: path.to_path_buf(), + } + } else { + RustupError::DownloadingFile { + url: url.clone(), + path: path.to_path_buf(), + } + } + }) + } + } +} + +async fn download_file_( + url: &Url, + path: &Path, + hasher: Option<&mut Sha256>, + resume_from_partial: bool, + notify_handler: &dyn Fn(Notification<'_>), + process: &Process, +) -> Result<()> { + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] + use crate::download::{Backend, Event, TlsBackend}; + use sha2::Digest; + use std::cell::RefCell; + + notify_handler(Notification::DownloadingFile(url, path)); + + let hasher = RefCell::new(hasher); + + // This callback will write the download to disk and optionally + // hash the contents, then forward the notification up the stack + let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| { + if let Event::DownloadDataReceived(data) = msg { + if let Some(h) = hasher.borrow_mut().as_mut() { + h.update(data); + } + } + + match msg { + Event::DownloadContentLengthReceived(len) => { + notify_handler(Notification::DownloadContentLengthReceived(len)); + } + Event::DownloadDataReceived(data) => { + notify_handler(Notification::DownloadDataReceived(data)); + } + Event::ResumingPartialDownload => { + notify_handler(Notification::ResumingPartialDownload); + } + } + + Ok(()) + }; + + // Download the file + + // Keep the curl env var around for a bit + let use_curl_backend = process.var_os("RUSTUP_USE_CURL").map(|it| it != "0"); + let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); + + let backend = match (use_curl_backend, use_rustls) { + // If environment specifies a backend that's unavailable, error out + #[cfg(not(feature = "reqwest-rustls-tls"))] + (_, Some(true)) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature" + )); + } + #[cfg(not(feature = "reqwest-native-tls"))] + (_, Some(false)) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature" + )); + } + #[cfg(not(feature = "curl-backend"))] + (Some(true), _) => { + return Err(anyhow!( + "RUSTUP_USE_CURL is set, but this rustup distribution was not built with the curl-backend feature" + )); + } + + // Positive selections, from least preferred to most preferred + #[cfg(feature = "curl-backend")] + (Some(true), None) => Backend::Curl, + #[cfg(feature = "reqwest-native-tls")] + (_, Some(false)) => { + if use_curl_backend == Some(true) { + info!( + "RUSTUP_USE_CURL is set and RUSTUP_USE_RUSTLS is set to off, using reqwest with native-tls" + ); + } + Backend::Reqwest(TlsBackend::NativeTls) + } + #[cfg(feature = "reqwest-rustls-tls")] + _ => { + if use_curl_backend == Some(true) { + info!( + "both RUSTUP_USE_CURL and RUSTUP_USE_RUSTLS are set, using reqwest with rustls" + ); + } + Backend::Reqwest(TlsBackend::Rustls) + } + + // Falling back if only one backend is available + #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] + _ => Backend::Reqwest(TlsBackend::NativeTls), + #[cfg(all( + not(feature = "reqwest-rustls-tls"), + not(feature = "reqwest-native-tls"), + feature = "curl-backend" + ))] + _ => Backend::Curl, + }; + + notify_handler(match backend { + #[cfg(feature = "curl-backend")] + Backend::Curl => Notification::UsingCurl, + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] + Backend::Reqwest(_) => Notification::UsingReqwest, + }); + + let res = backend + .download_to_path(url, path, resume_from_partial, Some(callback)) + .await; + + notify_handler(Notification::DownloadFinished); + + res +} + /// User agent header value for HTTP request. /// See: https://github.com/rust-lang/rustup/issues/2860. #[cfg(feature = "curl-backend")] @@ -33,7 +213,7 @@ const REQWEST_RUSTLS_TLS_USER_AGENT: &str = concat!("rustup/", env!("CARGO_PKG_VERSION"), " (reqwest; rustls)"); #[derive(Debug, Copy, Clone)] -pub enum Backend { +enum Backend { #[cfg(feature = "curl-backend")] Curl, #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] @@ -41,7 +221,7 @@ pub enum Backend { } impl Backend { - pub async fn download_to_path( + async fn download_to_path( self, url: &Url, path: &Path, @@ -175,7 +355,7 @@ impl Backend { #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] #[derive(Debug, Copy, Clone)] -pub enum TlsBackend { +enum TlsBackend { #[cfg(feature = "reqwest-rustls-tls")] Rustls, #[cfg(feature = "reqwest-native-tls")] @@ -202,7 +382,7 @@ impl TlsBackend { } #[derive(Debug, Copy, Clone)] -pub enum Event<'a> { +enum Event<'a> { ResumingPartialDownload, /// Received the Content-Length of the to-be downloaded data. DownloadContentLengthReceived(u64), @@ -215,7 +395,7 @@ type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>; /// Download via libcurl; encrypt with the native (or OpenSSl) TLS /// stack via libcurl #[cfg(feature = "curl-backend")] -pub mod curl { +mod curl { use std::cell::RefCell; use std::str; use std::time::Duration; @@ -226,7 +406,7 @@ pub mod curl { use super::{DownloadError, Event}; - pub fn download( + pub(super) fn download( url: &Url, resume_from: u64, callback: &dyn Fn(Event<'_>) -> Result<()>, @@ -327,7 +507,7 @@ pub mod curl { } #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] -pub mod reqwest_be { +mod reqwest_be { use std::io; #[cfg(feature = "reqwest-rustls-tls")] use std::sync::Arc; @@ -346,7 +526,7 @@ pub mod reqwest_be { use super::{DownloadError, Event}; - pub async fn download( + pub(super) async fn download( url: &Url, resume_from: u64, callback: &dyn Fn(Event<'_>) -> Result<()>, @@ -491,7 +671,7 @@ pub mod reqwest_be { } #[derive(Debug, Error)] -pub enum DownloadError { +enum DownloadError { #[error("http request returned an unsuccessful status code: {0}")] HttpStatus(u32), #[error("file not found")] diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 821d74c5f0..faecdcc407 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -11,9 +11,6 @@ use std::process::ExitStatus; use anyhow::{Context, Result, anyhow, bail}; use retry::delay::{Fibonacci, jitter}; use retry::{OperationResult, retry}; -use sha2::Sha256; -#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] -use tracing::info; use url::Url; use crate::errors::*; @@ -146,181 +143,6 @@ where }) } -pub async fn download_file( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - notify_handler: &dyn Fn(Notification<'_>), - process: &Process, -) -> Result<()> { - download_file_with_resume(url, path, hasher, false, ¬ify_handler, process).await -} - -pub(crate) async fn download_file_with_resume( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - resume_from_partial: bool, - notify_handler: &dyn Fn(Notification<'_>), - process: &Process, -) -> Result<()> { - use crate::download::DownloadError as DEK; - match download_file_( - url, - path, - hasher, - resume_from_partial, - notify_handler, - process, - ) - .await - { - Ok(_) => Ok(()), - Err(e) => { - if e.downcast_ref::().is_some() { - return Err(e); - } - let is_client_error = match e.downcast_ref::() { - // Specifically treat the bad partial range error as not our - // fault in case it was something odd which happened. - Some(DEK::HttpStatus(416)) => false, - Some(DEK::HttpStatus(400..=499)) | Some(DEK::FileNotFound) => true, - _ => false, - }; - Err(e).with_context(|| { - if is_client_error { - RustupError::DownloadNotExists { - url: url.clone(), - path: path.to_path_buf(), - } - } else { - RustupError::DownloadingFile { - url: url.clone(), - path: path.to_path_buf(), - } - } - }) - } - } -} - -async fn download_file_( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - resume_from_partial: bool, - notify_handler: &dyn Fn(Notification<'_>), - process: &Process, -) -> Result<()> { - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - use crate::download::{TlsBackend, Backend, Event}; - use sha2::Digest; - use std::cell::RefCell; - - notify_handler(Notification::DownloadingFile(url, path)); - - let hasher = RefCell::new(hasher); - - // This callback will write the download to disk and optionally - // hash the contents, then forward the notification up the stack - let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| { - if let Event::DownloadDataReceived(data) = msg { - if let Some(h) = hasher.borrow_mut().as_mut() { - h.update(data); - } - } - - match msg { - Event::DownloadContentLengthReceived(len) => { - notify_handler(Notification::DownloadContentLengthReceived(len)); - } - Event::DownloadDataReceived(data) => { - notify_handler(Notification::DownloadDataReceived(data)); - } - Event::ResumingPartialDownload => { - notify_handler(Notification::ResumingPartialDownload); - } - } - - Ok(()) - }; - - // Download the file - - // Keep the curl env var around for a bit - let use_curl_backend = process.var_os("RUSTUP_USE_CURL").map(|it| it != "0"); - let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); - - let backend = match (use_curl_backend, use_rustls) { - // If environment specifies a backend that's unavailable, error out - #[cfg(not(feature = "reqwest-rustls-tls"))] - (_, Some(true)) => { - return Err(anyhow!( - "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature" - )); - } - #[cfg(not(feature = "reqwest-native-tls"))] - (_, Some(false)) => { - return Err(anyhow!( - "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature" - )); - } - #[cfg(not(feature = "curl-backend"))] - (Some(true), _) => { - return Err(anyhow!( - "RUSTUP_USE_CURL is set, but this rustup distribution was not built with the curl-backend feature" - )); - } - - // Positive selections, from least preferred to most preferred - #[cfg(feature = "curl-backend")] - (Some(true), None) => Backend::Curl, - #[cfg(feature = "reqwest-native-tls")] - (_, Some(false)) => { - if use_curl_backend == Some(true) { - info!( - "RUSTUP_USE_CURL is set and RUSTUP_USE_RUSTLS is set to off, using reqwest with native-tls" - ); - } - Backend::Reqwest(TlsBackend::NativeTls) - } - #[cfg(feature = "reqwest-rustls-tls")] - _ => { - if use_curl_backend == Some(true) { - info!( - "both RUSTUP_USE_CURL and RUSTUP_USE_RUSTLS are set, using reqwest with rustls" - ); - } - Backend::Reqwest(TlsBackend::Rustls) - } - - // Falling back if only one backend is available - #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] - _ => Backend::Reqwest(TlsBackend::NativeTls), - #[cfg(all( - not(feature = "reqwest-rustls-tls"), - not(feature = "reqwest-native-tls"), - feature = "curl-backend" - ))] - _ => Backend::Curl, - }; - - notify_handler(match backend { - #[cfg(feature = "curl-backend")] - Backend::Curl => Notification::UsingCurl, - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - Backend::Reqwest(_) => Notification::UsingReqwest, - }); - - let res = backend - .download_to_path(url, path, resume_from_partial, Some(callback)) - .await; - - notify_handler(Notification::DownloadFinished); - - res -} - pub(crate) fn parse_url(url: &str) -> Result { Url::parse(url).with_context(|| format!("failed to parse url: {url}")) } From dea2bf476edb4d39e107e692d1d6489cfa837a17 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 4 Apr 2025 13:33:10 +0200 Subject: [PATCH 5/5] Warn about using curl --- src/download/mod.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/download/mod.rs b/src/download/mod.rs index c84e1e660d..82511a7710 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -14,6 +14,7 @@ use sha2::Sha256; use thiserror::Error; #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] use tracing::info; +use tracing::warn; use url::Url; use crate::{errors::RustupError, process::Process, utils::Notification}; @@ -124,8 +125,14 @@ async fn download_file_( // Keep the curl env var around for a bit let use_curl_backend = process.var_os("RUSTUP_USE_CURL").map(|it| it != "0"); - let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); + if use_curl_backend == Some(true) { + warn!( + "RUSTUP_USE_CURL is set; the curl backend is deprecated, please file an issue if the \ + default download backend does not work for your use case" + ); + } + let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); let backend = match (use_curl_backend, use_rustls) { // If environment specifies a backend that's unavailable, error out #[cfg(not(feature = "reqwest-rustls-tls"))]