From 2cc309590f99cbc4493f51e9a65e0700cf4f1aa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Gr=C3=BCnewald?= Date: Mon, 1 Jan 2024 21:50:57 +0100 Subject: [PATCH] Add request limiter --- Cargo.lock | 86 +++++++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 1 + src/client.rs | 37 +++++++++++++++++----- 3 files changed, 116 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 126ad50..1fe4710 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -236,6 +236,12 @@ version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +[[package]] +name = "hermit-abi" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" + [[package]] name = "http" version = "0.2.11" @@ -366,6 +372,16 @@ version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.20" @@ -422,6 +438,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.32.2" @@ -481,6 +507,29 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.48.5", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -533,6 +582,7 @@ dependencies = [ "serde", "serde_json", "thiserror", + "tokio", "url", ] @@ -617,6 +667,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "security-framework" version = "2.9.2" @@ -683,6 +739,15 @@ dependencies = [ "serde", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "slab" version = "0.4.9" @@ -692,6 +757,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "smallvec" +version = "1.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" + [[package]] name = "socket2" version = "0.5.5" @@ -792,11 +863,26 @@ dependencies = [ "bytes", "libc", "mio", + "num_cpus", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", + "tokio-macros", "windows-sys 0.48.0", ] +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio-native-tls" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index eda8684..5e83f73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ serde = { version = "1.0.193", features = ["derive", "std"] } serde_json = "1.0.108" thiserror = "1.0.52" url = "2.5.0" +tokio = { version = "1", features = ["full"] } [features] dump = [] \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index 0fce8c9..19af285 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,17 +4,20 @@ use anyhow::Result; use base64::{engine::GeneralPurpose, Engine}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use reqwest::Response; +use std::time::{Duration, Instant}; + +const MAX_REQ_PER_SEC: u8 = 2; #[derive(Clone, Debug)] -pub struct Client<'a> { - auth: Authentication<'a>, - useragent: &'a str, +pub struct Client { host: &'static str, http_client: reqwest::Client, + last_request_time: Instant, + request_counter: u8 } -impl<'a> Client<'a> { - pub fn new(auth: Authentication<'a>, useragent: &'a str) -> Result { +impl Client { + pub fn new(auth: Authentication, useragent: &str) -> Result { let mut header_map = HeaderMap::new(); header_map.append(USER_AGENT, HeaderValue::from_str(useragent)?); if let Authentication::Authorized { username, apikey } = auth { @@ -27,14 +30,30 @@ impl<'a> Client<'a> { .build()?; Ok(Client { - auth, - useragent, host: "https://e621.net", http_client, + last_request_time: Instant::now(), + request_counter: 0 }) } - fn get_authorization_value(username: &'a str, apikey: &'a str) -> String { + async fn request_limiter(&mut self) { + let wait_time = Instant::now() - self.last_request_time; + if Instant::now() - self.last_request_time > Duration::from_secs(1) { + self.last_request_time = Instant::now(); + self.request_counter = 0; + return; + } + + if self.request_counter >= MAX_REQ_PER_SEC { + tokio::time::sleep(wait_time).await; + self.last_request_time = Instant::now(); + self.request_counter = 0; + return; + } + } + + fn get_authorization_value(username: &str, apikey: &str) -> String { let base64_engine = GeneralPurpose::new(&base64::alphabet::STANDARD, Default::default()); base64_engine.encode(format!("{username}:{apikey}")) } @@ -64,6 +83,7 @@ impl<'a> Client<'a> { url.set_query(Some(&query_params.join("&"))); + self.request_counter = self.request_counter+1; Ok(self.http_client.get(url.as_str()).send().await?) } @@ -73,6 +93,7 @@ impl<'a> Client<'a> { tags: Option, page: Option, ) -> Result> { + self.request_limiter().await; let res = self.list_posts_raw(limit, tags, page).await?; let text = res.text().await?;