use anyhow::{anyhow, Result}; use base64::engine::general_purpose::STANDARD; use base64::Engine; use reqwest::{Client, Response, StatusCode}; use sha2::{Digest, Sha256}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::fs::{self, File}; use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use tokio_util::sync::CancellationToken; use tracing::{instrument, trace}; #[allow(dead_code)] type ProgressCallback = Arc; #[instrument(level = "trace", skip_all, err)] pub async fn download_model( url: &str, model_path: &Path, model_filename: &str, progress_callback: Option, cancel_token: Option, ) -> Result { let client = Client::new(); let mut response = make_request(&client, url, None).await?; let total_size_in_bytes = response.content_length().unwrap_or(0); let partial_path = model_path.join(format!("{}.part", model_filename)); let download_path = model_path.join(model_filename); let mut part_file = File::create(&partial_path).await?; let mut downloaded: u64 = 0; let debounce_duration = Duration::from_millis(100); let mut last_update = Instant::now() .checked_sub(debounce_duration) .unwrap_or(Instant::now()); while let Some(chunk) = response.chunk().await? { if let Some(cancel_token) = &cancel_token { if cancel_token.is_cancelled() { trace!("Download canceled by client"); fs::remove_file(&partial_path).await?; return Err(anyhow!("Download canceled")); } } part_file.write_all(&chunk).await?; downloaded += chunk.len() as u64; if let Some(progress_callback) = &progress_callback { let now = Instant::now(); if now.duration_since(last_update) >= debounce_duration { progress_callback(downloaded, total_size_in_bytes); last_update = now; } } } // Verify file integrity let header_sha256 = response .headers() .get("SHA256") .and_then(|value| value.to_str().ok()) .and_then(|value| STANDARD.decode(value).ok()); part_file.seek(tokio::io::SeekFrom::Start(0)).await?; let mut hasher = Sha256::new(); let block_size = 2_usize.pow(20); // 1 MB let mut buffer = vec![0; block_size]; while let Ok(bytes_read) = part_file.read(&mut buffer).await { if bytes_read == 0 { break; } hasher.update(&buffer[..bytes_read]); } let calculated_sha256 = hasher.finalize(); if let Some(header_sha256) = header_sha256 { if calculated_sha256.as_slice() != header_sha256.as_slice() { trace!( "Header Sha256: {:?}, calculated Sha256:{:?}", header_sha256, calculated_sha256 ); fs::remove_file(&partial_path).await?; return Err(anyhow!( "Sha256 mismatch: expected {:?}, got {:?}", header_sha256, calculated_sha256 )); } } fs::rename(&partial_path, &download_path).await?; Ok(download_path) } #[allow(dead_code)] async fn make_request( client: &Client, url: &str, offset: Option, ) -> Result { let mut request = client.get(url); if let Some(offset) = offset { println!( "\nDownload interrupted, resuming from byte position {}", offset ); request = request.header("Range", format!("bytes={}-", offset)); } let response = request.send().await?; if !(response.status().is_success() || response.status() == StatusCode::PARTIAL_CONTENT) { return Err(anyhow!(response.text().await?)); } Ok(response) } #[cfg(test)] mod test { use super::*; use std::env::temp_dir; #[tokio::test] async fn retrieve_gpt4all_model_test() { for url in [ // "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf", "https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-Q3_K_L.gguf?download=true", // "https://huggingface.co/MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3.Q4_K_M.gguf?download=true", ] { let temp_dir = temp_dir().join("download_llm"); if !temp_dir.exists() { fs::create_dir(&temp_dir).await.unwrap(); } let file_name = "llm_model.gguf"; let cancel_token = CancellationToken::new(); let token = cancel_token.clone(); tokio::spawn(async move { tokio::time::sleep(tokio::time::Duration::from_secs(120)).await; token.cancel(); }); let download_file = download_model( url, &temp_dir, file_name, Some(Arc::new(|a, b| { println!("{}/{}", a, b); })), Some(cancel_token), ).await.unwrap(); let file_path = temp_dir.join(file_name); assert_eq!(download_file, file_path); println!("File path: {:?}", file_path); assert!(file_path.exists()); std::fs::remove_file(file_path).unwrap(); } } }