diff --git a/Cargo.lock b/Cargo.lock index d4e9fe9..b860a96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,19 +82,20 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "async-compression" @@ -120,6 +121,7 @@ dependencies = [ "futures-io", "futures-timer", "pin-project-lite", + "tokio", ] [[package]] @@ -146,9 +148,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" dependencies = [ "proc-macro2", "quote", @@ -167,17 +169,6 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" -[[package]] -name = "atomic_enum" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e1aca718ea7b89985790c94aad72d77533063fe00bc497bb79a7c2dae6a661" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "autocfg" version = "1.4.0" @@ -266,9 +257,9 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" [[package]] name = "block-buffer" @@ -292,9 +283,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "4.0.1" +version = "4.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -302,9 +293,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "byteorder" @@ -320,9 +311,9 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cc" -version = "1.2.3" +version = "1.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27f657647bcff5394bf56c7317665bbf790a137a50eaaa5c6bfbb9e27a518f2d" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" dependencies = [ "jobserver", "libc", @@ -343,9 +334,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "clap" -version = "4.5.23" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -353,9 +344,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.23" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -365,9 +356,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ "heck", "proc-macro2", @@ -442,10 +433,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" [[package]] -name = "cpufeatures" -version = "0.2.16" +name = "core-foundation" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -461,18 +468,18 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crypto-common" @@ -548,9 +555,9 @@ dependencies = [ [[package]] name = "data-encoding" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" +checksum = "0e60eed09d8c01d3cee5b7d30acb059b76614c918fa0f992e0dd6eeb10daad6f" [[package]] name = "der" @@ -632,7 +639,6 @@ checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" dependencies = [ "pkcs8", "signature", - "zeroize", ] [[package]] @@ -669,9 +675,9 @@ dependencies = [ [[package]] name = "env_filter" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" dependencies = [ "log", "regex", @@ -679,9 +685,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" dependencies = [ "anstream", "anstyle", @@ -717,7 +723,7 @@ dependencies = [ "rustls-pki-types", "rustls-webpki", "send_wrapper", - "thiserror 2.0.6", + "thiserror 2.0.11", "tokio", "wasm-bindgen", "wasm-bindgen-futures", @@ -738,10 +744,10 @@ dependencies = [ "bytes", "cfg-if", "clap", + "console-subscriber", "ed25519-dalek", "env_logger", "event-listener", - "fastwebsockets", "futures-util", "hickory-resolver", "http-body-util", @@ -766,6 +772,7 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-util", + "tokio-websockets", "toml", "uuid", "vergen-git2", @@ -790,28 +797,28 @@ dependencies = [ [[package]] name = "event-listener" -version = "5.3.1" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", "pin-project-lite", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fastwebsockets" version = "0.8.0" source = "git+https://github.com/r58Playz/fastwebsockets#1064f64add235a9295633e88d8761011c7af51d5" dependencies = [ - "base64 0.21.7", "bytes", - "http-body-util", - "hyper", - "hyper-util", - "pin-project", "rand", - "sha1", "simdutf8", "thiserror 1.0.69", "tokio", @@ -852,6 +859,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -998,9 +1020,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "git2" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b903b73e45dc0c6c596f2d37eccece7c1c8bb6e4407b001096387c63d0d93724" +checksum = "3fda788993cc341f69012feba8bf45c0ba4f3291fcc08e214b4d5a7332d88aff" dependencies = [ "bitflags", "libc", @@ -1020,7 +1042,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.7.0", + "indexmap 2.7.1", "instant", "slab", "tokio", @@ -1151,9 +1173,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.9.5" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" [[package]] name = "httpdate" @@ -1169,9 +1191,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "1.5.1" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ "bytes", "futures-channel", @@ -1393,9 +1415,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -1427,9 +1449,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.10.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is_terminal_polyfill" @@ -1463,9 +1485,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ "once_cell", "wasm-bindgen", @@ -1479,15 +1501,15 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.168" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libgit2-sys" -version = "0.17.0+1.8.1" +version = "0.18.0+1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10472326a8a6477c3c20a64547b0059e4b0d086869eee31e6d7da728a8eb7224" +checksum = "e1a117465e7e1597e8febea8bb0c410f1c7fb93b1e1cddf34363f8390367ffec" dependencies = [ "cc", "libc", @@ -1497,9 +1519,9 @@ dependencies = [ [[package]] name = "libz-sys" -version = "1.1.20" +version = "1.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2d16453e800a8cf6dd2fc3eb4bc99b786a9b90c663b8559a5b1a041bf89e472" +checksum = "df9b68e50e6e0b26f672573834882eb57759f6db9b3be2ea3c35c91188bb4eaa" dependencies = [ "cc", "libc", @@ -1515,9 +1537,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "litemap" @@ -1537,9 +1559,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" dependencies = [ "serde", ] @@ -1594,9 +1616,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" dependencies = [ "adler2", ] @@ -1621,6 +1643,23 @@ dependencies = [ "getrandom", ] +[[package]] +name = "native-tls" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nix" version = "0.29.0" @@ -1658,6 +1697,27 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_enum" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "num_threads" version = "0.1.7" @@ -1669,9 +1729,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.5" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] @@ -1682,6 +1742,50 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "openssl" +version = "0.10.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5e534d133a060a3c19daec1eb3e98ec6f4685978834f2dbadfe2ec215bab64e" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "parking" version = "2.2.1" @@ -1734,18 +1838,18 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pin-project" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" +checksum = "1e2ec53ad785f4d35dac0adea7f7dc6f1bb277ad84a680c7afefeae05d1f5916" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" +checksum = "d56a66c0c55993aa927429d0f8a0abfd74f084e4d9c192cffed01e418d83eefb" dependencies = [ "proc-macro2", "quote", @@ -1754,9 +1858,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -1796,10 +1900,19 @@ dependencies = [ ] [[package]] -name = "proc-macro2" -version = "1.0.92" +name = "proc-macro-crate" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +dependencies = [ + "toml_edit", +] + +[[package]] +name = "proc-macro2" +version = "1.0.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] @@ -1855,9 +1968,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -1955,12 +2068,6 @@ dependencies = [ "quick-error", ] -[[package]] -name = "reusable-box-future" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e0e61cd21fbddd85fbd9367b775660a01d388c08a61c6d2824af480b0309bb9" - [[package]] name = "ring" version = "0.17.8" @@ -1999,9 +2106,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.42" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags", "errno", @@ -2013,9 +2120,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.20" +version = "0.23.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" +checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" dependencies = [ "once_cell", "ring", @@ -2036,9 +2143,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" dependencies = [ "web-time", ] @@ -2056,15 +2163,24 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" + +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] [[package]] name = "scopeguard" @@ -2073,10 +2189,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] -name = "semver" -version = "1.0.23" +name = "security-framework" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "send_wrapper" @@ -2089,18 +2228,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.216" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", @@ -2109,9 +2248,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.133" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", "memchr", @@ -2134,7 +2273,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.7.0", + "indexmap 2.7.1", "itoa", "ryu", "serde", @@ -2152,6 +2291,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.8" @@ -2217,16 +2362,14 @@ dependencies = [ "clap", "console-subscriber", "ed25519-dalek", - "fastwebsockets", "futures", - "http-body-util", "humantime", "hyper", - "hyper-util", "sha2", "simple_moving_average", "tikv-jemallocator", "tokio", + "tokio-websockets", "wisp-mux", ] @@ -2303,9 +2446,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.90" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -2329,6 +2472,20 @@ dependencies = [ "syn", ] +[[package]] +name = "tempfile" +version = "3.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" +dependencies = [ + "cfg-if", + "fastrand", + "getrandom", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -2340,11 +2497,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.6" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec2a1820ebd077e2b90c4df007bebf344cd394098a13c563957d0afc83ea47" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" dependencies = [ - "thiserror-impl 2.0.6", + "thiserror-impl 2.0.11", ] [[package]] @@ -2360,9 +2517,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.6" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d65750cab40f4ff1929fb1ba509e9914eb756131cef4210da8d5d700d26f6312" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", @@ -2455,9 +2612,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" dependencies = [ "tinyvec_macros", ] @@ -2470,9 +2627,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.42.0" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" dependencies = [ "backtrace", "bytes", @@ -2489,15 +2646,25 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.1" @@ -2519,6 +2686,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4bf6fecd69fcdede0ec680aaf474cdab988f9de6bc73d3758f0160e3b7025a" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.13" @@ -2533,6 +2712,26 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-websockets" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8db988691af99fe6f957c1faa50b2fbaed8cbeb25662ac73083481bc10a511c0" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-core", + "futures-sink", + "http", + "httparse", + "rand", + "sha1_smol", + "simdutf8", + "tokio", + "tokio-native-tls", + "tokio-util", +] + [[package]] name = "toml" version = "0.8.19" @@ -2560,7 +2759,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.7.0", + "indexmap 2.7.1", "serde", "serde_spanned", "toml_datetime", @@ -2696,6 +2895,20 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413083a99c579593656008130e29255e54dcaae495be556cc26888f211648c24" +dependencies = [ + "byteorder", + "bytes", + "log", + "rand", + "thiserror 2.0.11", + "utf-8", +] + [[package]] name = "typenum" version = "1.17.0" @@ -2704,9 +2917,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unsafe-libyaml" @@ -2757,18 +2970,18 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.11.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ "getrandom", ] [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "vcpkg" @@ -2778,9 +2991,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vergen" -version = "9.0.2" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31f25fc8f8f05df455c7941e87f093ad22522a9ff33d7a027774815acf6f0639" +checksum = "e0d2f179f8075b805a43a2a21728a46f0cc2921b3c58695b28fa8817e103cd9a" dependencies = [ "anyhow", "derive_builder", @@ -2791,9 +3004,9 @@ dependencies = [ [[package]] name = "vergen-git2" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e63e069d8749fead1e3bab7a9d79e8fb90516b2ec66fc2243a798ecdc1a31d7" +checksum = "d86bae87104cb2790cdee615c2bb54729804d307191732ab27b1c5357ea6ddc5" dependencies = [ "anyhow", "derive_builder", @@ -2806,9 +3019,9 @@ dependencies = [ [[package]] name = "vergen-lib" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0c767e6751c09fc85cde58722cf2f1007e80e4c8d5a4321fc90d83dc54ca147" +checksum = "9b07e6010c0f3e59fcb164e0163834597da68d1f864e2b8ca49f74de01e9c166" dependencies = [ "anyhow", "derive_builder", @@ -2838,20 +3051,21 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", @@ -2863,9 +3077,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.49" +version = "0.4.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" dependencies = [ "cfg-if", "js-sys", @@ -2876,9 +3090,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2886,9 +3100,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", @@ -2899,9 +3113,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "wasm-streams" @@ -2932,9 +3149,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" dependencies = [ "js-sys", "wasm-bindgen", @@ -3137,9 +3354,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "ad699df48212c6cc6eb4435f35500ac6fd3b9913324f938aea302022ce19d310" dependencies = [ "memchr", ] @@ -3156,23 +3373,23 @@ dependencies = [ [[package]] name = "wisp-mux" -version = "6.0.0" +version = "7.0.0" dependencies = [ "async-trait", - "atomic_enum", "bitflags", "bytes", "ed25519", - "event-listener", - "fastwebsockets", "flume", "futures", "getrandom", - "pin-project-lite", - "reusable-box-future", + "num_enum", + "pin-project", "rustc-hash", - "thiserror 2.0.6", + "slab", + "thiserror 2.0.11", "tokio", + "tokio-tungstenite", + "tokio-websockets", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index e3aa14a..b00f5f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = ["server", "client", "wisp", "simple-wisp-client"] [profile.release] lto = true debug = true +strip = false panic = "abort" codegen-units = 1 opt-level = 3 diff --git a/client/Cargo.toml b/client/Cargo.toml index 3dc3b64..78b5d2a 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -30,12 +30,12 @@ rustls-webpki = { version = "0.102.7", optional = true } send_wrapper = { version = "0.6.0", features = ["futures"] } thiserror = "2.0.3" tokio = "1.39.3" -wasm-bindgen = "0.2.93" +wasm-bindgen = "0.2.100" wasm-bindgen-futures = "0.4.43" wasm-streams = "0.4.0" web-sys = { version = "0.3.70", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "Url", "WebSocket"] } webpki-roots = "0.26.3" -wisp-mux = { version = "*", path = "../wisp", features = ["wasm", "generic_stream"], default-features = false } +wisp-mux = { version = "*", path = "../wisp", features = ["wasm"], default-features = false } [dependencies.getrandom] version = "*" diff --git a/client/build.sh b/client/build.sh index 336ba80..40db028 100755 --- a/client/build.sh +++ b/client/build.sh @@ -12,7 +12,7 @@ else CARGOFLAGS="" fi -WBG="wasm-bindgen 0.2.99" +WBG="wasm-bindgen 0.2.100" if [ "$(wasm-bindgen -V)" != "$WBG" ]; then echo "Incorrect wasm-bindgen version: '$(wasm-bindgen -V)' != '$WBG'" exit 1 diff --git a/client/src/io_stream.rs b/client/src/io_stream.rs index ddb6657..07279fb 100644 --- a/client/src/io_stream.rs +++ b/client/src/io_stream.rs @@ -1,6 +1,6 @@ use std::pin::Pin; -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use futures_util::{AsyncReadExt, AsyncWriteExt, Sink, SinkExt, Stream, TryStreamExt}; use js_sys::{Object, Uint8Array}; use wasm_bindgen::prelude::*; @@ -14,7 +14,7 @@ use crate::{ fn create_iostream( stream: Pin>>>, - sink: Pin>>, + sink: Pin>>, ) -> EpoxyIoStream { let read = ReadableStream::from_stream( stream @@ -27,7 +27,7 @@ fn create_iostream( convert_body(x) .await .map_err(|_| EpoxyError::InvalidPayload) - .map(|x| BytesMut::from(x.0.to_vec().as_slice())) + .map(|x| Bytes::from(x.0.to_vec())) }) .sink_map_err(Into::into), ) @@ -50,7 +50,7 @@ pub fn iostream_from_asyncrw(asyncrw: ProviderAsyncRW, buffer_size: usize) -> Ep pub fn iostream_from_stream(stream: ProviderUnencryptedStream) -> EpoxyIoStream { let (rx, tx) = stream.into_split(); create_iostream( - Box::pin(rx.map_ok(Bytes::from).map_err(EpoxyError::Io)), - Box::pin(tx.sink_map_err(EpoxyError::Io)), + Box::pin(rx.map_ok(Bytes::from).map_err(EpoxyError::Wisp)), + Box::pin(tx.sink_map_err(EpoxyError::Wisp)), ) } diff --git a/client/src/lib.rs b/client/src/lib.rs index ea23fb3..cbbdf23 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -1,13 +1,12 @@ #![feature(let_chains, impl_trait_in_assoc_type)] -use std::{error::Error, pin::Pin, str::FromStr, sync::Arc}; +use std::{error::Error, str::FromStr, sync::Arc}; #[cfg(feature = "full")] use async_compression::futures::bufread as async_comp; -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use cfg_if::cfg_if; -use futures_util::future::Either; -use futures_util::{Stream, StreamExt, TryStreamExt}; +use futures_util::{future::Either, StreamExt, TryStreamExt}; use http::{ header::{ InvalidHeaderName, InvalidHeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, @@ -24,7 +23,10 @@ use hyper_util_wasm::client::legacy::Client; use io_stream::{iostream_from_asyncrw, iostream_from_stream}; use js_sys::{Array, ArrayBuffer, Function, Object, Promise, Uint8Array}; use send_wrapper::SendWrapper; -use stream_provider::{ProviderWispTransportGenerator, StreamProvider, StreamProviderService}; +use stream_provider::{ + ProviderWispTransportGenerator, ProviderWispTransportRead, ProviderWispTransportWrite, + StreamProvider, StreamProviderService, +}; use thiserror::Error; use utils::{ asyncread_to_readablestream, convert_streaming_body, entries_of_object, from_entries, @@ -36,11 +38,9 @@ use wasm_bindgen_futures::JsFuture; use web_sys::{ResponseInit, Url, WritableStream}; #[cfg(feature = "full")] use websocket::EpoxyWebSocket; -use wisp_mux::StreamType; use wisp_mux::{ - generic::GenericWebSocketRead, - ws::{EitherWebSocketRead, EitherWebSocketWrite}, - CloseReason, + packet::{CloseReason, StreamType}, + WispError, }; use ws_wrapper::WebSocketWrapper; @@ -341,29 +341,31 @@ fn create_wisp_transport(function: Function) -> ProviderWispTransportGenerator { } .into(); - let read = GenericWebSocketRead::new(Box::pin(SendWrapper::new( + let read = Box::pin(SendWrapper::new( wasm_streams::ReadableStream::from_raw(object_get(&transport, "read").into()) - .into_stream() + .try_into_stream() + .map_err(|x| EpoxyError::wisp_transport(x.0.into()))? .map(|x| { - let pkt = x.map_err(EpoxyError::wisp_transport)?; + let pkt = x + .map_err(EpoxyError::wisp_transport) + .map_err(|x| WispError::WsImplError(Box::new(x)))?; let arr: ArrayBuffer = pkt.dyn_into().map_err(|x| { - EpoxyError::InvalidWispTransportPacket(format!("{x:?}")) + WispError::WsImplError(Box::new( + EpoxyError::InvalidWispTransportPacket(format!("{x:?}")), + )) })?; - Ok::(BytesMut::from( - Uint8Array::new(&arr).to_vec().as_slice(), - )) + Ok::(Bytes::from(Uint8Array::new(&arr).to_vec())) }), - )) - as Pin> + Send>>); - let write: WritableStream = object_get(&transport, "write").into(); - let write = WispTransportWrite { - inner: SendWrapper::new(write.get_writer().map_err(EpoxyError::wisp_transport)?), - }; + )) as ProviderWispTransportRead; - Ok(( - EitherWebSocketRead::Right(read), - EitherWebSocketWrite::Right(write), - )) + let write: WritableStream = object_get(&transport, "write").into(); + let write = Box::pin(WispTransportWrite( + wasm_streams::WritableStream::from_raw(write) + .try_into_sink() + .map_err(|x| EpoxyError::wisp_transport(x.0.into()))?, + )) as ProviderWispTransportWrite; + + Ok((read, write)) })) }) } @@ -419,10 +421,7 @@ impl EpoxyClient { )); } } - Ok(( - EitherWebSocketRead::Left(read), - EitherWebSocketWrite::Left(write), - )) + Ok((read.into_read(), write.into_write())) }) }), &options, diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 1d637e8..5690c85 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -1,6 +1,5 @@ use std::{io::ErrorKind, pin::Pin, sync::Arc, task::Poll}; -use bytes::BytesMut; use cfg_if::cfg_if; use futures_rustls::{ rustls::{ClientConfig, RootCertStore}, @@ -9,38 +8,33 @@ use futures_rustls::{ use futures_util::{ future::Either, lock::{Mutex, MutexGuard}, - AsyncRead, AsyncWrite, Future, Stream, + AsyncRead, AsyncWrite, Future, }; use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection}; use pin_project_lite::pin_project; +use send_wrapper::SendWrapper; use wasm_bindgen_futures::spawn_local; use webpki_roots::TLS_SERVER_ROOTS; use wisp_mux::{ extensions::{udp::UdpProtocolExtensionBuilder, AnyProtocolExtensionBuilder}, - generic::GenericWebSocketRead, - ws::{EitherWebSocketRead, EitherWebSocketWrite}, - ClientMux, MuxStreamAsyncRW, MuxStreamIo, StreamType, WispV2Handshake, + packet::StreamType, + stream::{MuxStream, MuxStreamAsyncRW}, + ws::{WebSocketRead, WebSocketWrite}, + ClientMux, WispV2Handshake, }; use crate::{ console_error, console_log, - utils::{IgnoreCloseNotify, NoCertificateVerification, WispTransportWrite}, - ws_wrapper::{WebSocketReader, WebSocketWrapper}, + utils::{IgnoreCloseNotify, NoCertificateVerification}, EpoxyClientOptions, EpoxyError, }; -pub type ProviderUnencryptedStream = MuxStreamIo; -pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; +pub type ProviderUnencryptedStream = MuxStream; +pub type ProviderUnencryptedAsyncRW = MuxStreamAsyncRW; pub type ProviderTlsAsyncRW = IgnoreCloseNotify; pub type ProviderAsyncRW = Either; -pub type ProviderWispTransportRead = EitherWebSocketRead< - WebSocketReader, - GenericWebSocketRead< - Pin> + Send>>, - EpoxyError, - >, ->; -pub type ProviderWispTransportWrite = EitherWebSocketWrite; +pub type ProviderWispTransportRead = Pin>; +pub type ProviderWispTransportWrite = Pin>; pub type ProviderWispTransportGenerator = Box< dyn Fn( bool, @@ -137,7 +131,7 @@ impl StreamProvider { let (read, write) = (self.wisp_generator)(self.wisp_v2).await?; - let client = ClientMux::create(read, write, extensions).await?; + let client = ClientMux::new(read, write, extensions).await?; let (mux, fut) = if self.udp_extension { client.with_udp_extension_required().await? } else { @@ -172,8 +166,8 @@ impl StreamProvider { Box::pin(async { let locked = self.current_client.lock().await; if let Some(mux) = locked.as_ref() { - let stream = mux.client_new_stream(stream_type, host, port).await?; - Ok(stream.into_io()) + let stream = mux.new_stream(stream_type, host, port).await?; + Ok(stream) } else { self.create_client(locked).await?; self.get_stream(stream_type, host, port).await @@ -191,7 +185,7 @@ impl StreamProvider { Ok(self .get_stream(stream_type, host, port) .await? - .into_asyncrw()) + .into_async_rw()) } pub async fn get_tls_stream( @@ -316,34 +310,37 @@ impl Connection for HyperIo { #[derive(Clone)] pub struct StreamProviderService(pub Arc); -impl ConnectSvc for StreamProviderService { - type Connection = HyperIo; - type Error = EpoxyError; - type Future = Pin>>>; - - fn connect(self, req: hyper::Uri) -> Self::Future { - let provider = self.0.clone(); - Box::pin(async move { - let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme(None))?; - let host = req.host().ok_or(EpoxyError::NoUrlHost)?.to_string(); - let port = req.port_u16().map_or_else( - || match scheme { - "https" | "wss" => Ok(443), - "http" | "ws" => Ok(80), - _ => Err(EpoxyError::NoUrlPort), - }, - Ok, - )?; - Ok(HyperIo { - inner: match scheme { - "https" => Either::Left(provider.get_tls_stream(host, port, true).await?), - "wss" => Either::Left(provider.get_tls_stream(host, port, false).await?), - "http" | "ws" => { - Either::Right(provider.get_asyncread(StreamType::Tcp, host, port).await?) - } - _ => return Err(EpoxyError::InvalidUrlScheme(Some(scheme.to_string()))), - }, - }) +impl StreamProviderService { + async fn connect(self, req: hyper::Uri) -> Result { + let scheme = req.scheme_str().ok_or(EpoxyError::InvalidUrlScheme(None))?; + let host = req.host().ok_or(EpoxyError::NoUrlHost)?.to_string(); + let port = req.port_u16().map_or_else( + || match scheme { + "https" | "wss" => Ok(443), + "http" | "ws" => Ok(80), + _ => Err(EpoxyError::NoUrlPort), + }, + Ok, + )?; + Ok(HyperIo { + inner: match scheme { + "https" => Either::Left(self.0.get_tls_stream(host, port, true).await?), + "wss" => Either::Left(self.0.get_tls_stream(host, port, false).await?), + "http" | "ws" => { + Either::Right(self.0.get_asyncread(StreamType::Tcp, host, port).await?) + } + _ => return Err(EpoxyError::InvalidUrlScheme(Some(scheme.to_string()))), + }, }) } } + +impl ConnectSvc for StreamProviderService { + type Connection = HyperIo; + type Error = EpoxyError; + type Future = impl Future> + Send; + + fn connect(self, req: hyper::Uri) -> Self::Future { + SendWrapper::new(Box::pin(self.connect(req))) + } +} diff --git a/client/src/utils/mod.rs b/client/src/utils/mod.rs index 1c6fde6..776c3fb 100644 --- a/client/src/utils/mod.rs +++ b/client/src/utils/mod.rs @@ -1,7 +1,10 @@ mod js; mod rustls; pub use js::*; +use js_sys::Uint8Array; pub use rustls::*; +use wasm_streams::writable::IntoSink; +use wisp_mux::{ws::Payload, WispError}; use std::{ pin::Pin, @@ -9,19 +12,11 @@ use std::{ }; use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; -use futures_util::{ready, AsyncRead, Future, Stream}; +use futures_util::{ready, AsyncRead, Future, Sink, SinkExt, Stream}; use http::{HeaderValue, Uri}; use hyper::rt::Executor; -use js_sys::Uint8Array; use pin_project_lite::pin_project; -use send_wrapper::SendWrapper; use wasm_bindgen::prelude::*; -use wasm_bindgen_futures::JsFuture; -use web_sys::WritableStreamDefaultWriter; -use wisp_mux::{ - ws::{Frame, WebSocketWrite}, - WispError, -}; use crate::EpoxyError; @@ -131,8 +126,7 @@ pub fn poll_read_buf( let n = { let dst = buf.chunk_mut(); - let dst = - unsafe { &mut *(std::ptr::from_mut::(dst) as *mut [u8]) }; + let dst = unsafe { &mut *(std::ptr::from_mut::(dst) as *mut [u8]) }; ready!(io.poll_read(cx, dst)?) }; @@ -174,26 +168,32 @@ impl Stream for ReaderStream { } } -pub struct WispTransportWrite { - pub inner: SendWrapper, -} +pub struct WispTransportWrite(pub IntoSink<'static>); +unsafe impl Send for WispTransportWrite {} -impl WebSocketWrite for WispTransportWrite { - async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { - SendWrapper::new(async { - let chunk = Uint8Array::from(frame.payload.as_ref()).into(); - JsFuture::from(self.inner.write_with_chunk(&chunk)) - .await - .map(|_| ()) - .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) - }) - .await +impl Sink for WispTransportWrite { + type Error = WispError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0 + .poll_ready_unpin(cx) + .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) } - async fn wisp_close(&mut self) -> Result<(), WispError> { - SendWrapper::new(JsFuture::from(self.inner.abort())) - .await - .map(|_| ()) + fn start_send(mut self: Pin<&mut Self>, item: Payload) -> Result<(), Self::Error> { + self.0 + .start_send_unpin(Uint8Array::from(item.as_ref()).into()) + .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0 + .poll_flush_unpin(cx) + .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) + } + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0 + .poll_close_unpin(cx) .map_err(|x| WispError::WsImplError(Box::new(EpoxyError::wisp_transport(x)))) } } diff --git a/client/src/ws_wrapper.rs b/client/src/ws_wrapper.rs index 11a6dc7..987ddc8 100644 --- a/client/src/ws_wrapper.rs +++ b/client/src/ws_wrapper.rs @@ -3,7 +3,6 @@ use std::sync::{ Arc, }; -use bytes::BytesMut; use event_listener::Event; use flume::Receiver; use futures_util::FutureExt; @@ -13,11 +12,14 @@ use thiserror::Error; use wasm_bindgen::{closure::Closure, JsCast, JsValue}; use web_sys::{BinaryType, MessageEvent, WebSocket}; use wisp_mux::{ - ws::{Frame, LockingWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + ws::{async_iterator_transport_read, async_iterator_transport_write, Payload}, WispError, }; -use crate::EpoxyError; +use crate::{ + stream_provider::{ProviderWispTransportRead, ProviderWispTransportWrite}, + EpoxyError, +}; #[derive(Error, Debug)] pub enum WebSocketError { @@ -36,13 +38,12 @@ impl From for WispError { } pub enum WebSocketMessage { - Closed, Error(WebSocketError), Message(Vec), } pub struct WebSocketWrapper { - pub inner: SendWrapper, + pub inner: Arc>, open_event: Arc, error_event: Arc, close_event: Arc, @@ -65,26 +66,27 @@ pub struct WebSocketReader { close_event: Arc, } -impl WebSocketRead for WebSocketReader { - async fn wisp_read_frame( - &mut self, - _: &dyn LockingWebSocketWrite, - ) -> Result, WispError> { - use WebSocketMessage as M; - if self.closed.load(Ordering::Acquire) { - return Err(WispError::WsImplSocketClosed); - } - let res = futures_util::select! { - data = self.read_rx.recv_async() => data.ok(), - () = self.close_event.listen().fuse() => Some(M::Closed), - }; - match res.ok_or(WispError::WsImplSocketClosed)? { - M::Message(bin) => Ok(Frame::binary(Payload::Bytes(BytesMut::from( - bin.as_slice(), - )))), - M::Error(x) => Err(x.into()), - M::Closed => Err(WispError::WsImplSocketClosed), - } +impl WebSocketReader { + pub fn into_read(self) -> ProviderWispTransportRead { + Box::pin(async_iterator_transport_read(self, |this| { + Box::pin(async { + use WebSocketMessage as M; + if this.closed.load(Ordering::Acquire) { + return Err(WispError::WsImplSocketClosed); + } + + let res = futures_util::select! { + data = this.read_rx.recv_async() => data.ok(), + () = this.close_event.listen().fuse() => None + }; + + match res { + Some(M::Message(x)) => Ok(Some((Payload::from(x), this))), + Some(M::Error(x)) => Err(x.into()), + None => Ok(None), + } + }) + })) } } @@ -153,7 +155,7 @@ impl WebSocketWrapper { Ok(( Self { - inner: SendWrapper::new(ws), + inner: Arc::new(SendWrapper::new(ws)), open_event, error_event, close_event: close_event.clone(), @@ -180,42 +182,35 @@ impl WebSocketWrapper { () = self.error_event.listen().fuse() => false, } } -} -impl WebSocketWrite for WebSocketWrapper { - async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { - use wisp_mux::ws::OpCode::{Binary, Close, Text}; - if self.closed.load(Ordering::Acquire) { - return Err(WispError::WsImplSocketClosed); - } - match frame.opcode { - Binary | Text => self - .inner - .send_with_u8_array(&frame.payload) - .map_err(|x| WebSocketError::SendFailed(format!("{x:?}")).into()), - Close => { - let _ = self.inner.close(); - Ok(()) - } - _ => Err(WispError::WsImplNotSupported), - } - } + pub fn into_write(self) -> ProviderWispTransportWrite { + let ws = self.inner.clone(); + let closed = self.closed.clone(); + let close_event = self.close_event.clone(); + Box::pin(async_iterator_transport_write( + self, + |this, item| { + Box::pin(async move { + this.inner + .send_with_u8_array(&item) + .map_err(|x| WebSocketError::SendFailed(format!("{x:?}").into()))?; + Ok(this) + }) + }, + (ws, closed, close_event), + |(ws, closed, close_event)| { + Box::pin(async move { + ws.set_onopen(None); + ws.set_onclose(None); + ws.set_onerror(None); + ws.set_onmessage(None); + closed.store(true, Ordering::Release); + close_event.notify(usize::MAX); - async fn wisp_close(&mut self) -> Result<(), WispError> { - self.inner - .close() - .map_err(|x| WebSocketError::CloseFailed(format!("{x:?}")).into()) - } -} - -impl Drop for WebSocketWrapper { - fn drop(&mut self) { - self.inner.set_onopen(None); - self.inner.set_onclose(None); - self.inner.set_onerror(None); - self.inner.set_onmessage(None); - self.closed.store(true, Ordering::Release); - self.close_event.notify(usize::MAX); - let _ = self.inner.close(); + ws.close() + .map_err(|x| WebSocketError::CloseFailed(format!("{:?}", x)).into()) + }) + }, + )) } } diff --git a/server/Cargo.toml b/server/Cargo.toml index b3094e6..2f61b63 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -8,16 +8,16 @@ workspace = true [dependencies] anyhow = "1.0.86" -async-speed-limit = { version = "0.4.2", optional = true } +async-speed-limit = { version = "0.4.2", optional = true, features = ["tokio"] } async-trait = "0.1.81" base64 = "0.22.1" bytes = "1.7.1" cfg-if = "1.0.0" clap = { version = "4.5.16", features = ["cargo", "derive"] } +console-subscriber = { version = "0.4.1", optional = true } ed25519-dalek = { version = "2.1.1", features = ["pem"] } env_logger = "0.11.5" event-listener = "5.3.1" -fastwebsockets = { version = "0.8.0", features = ["unstable-split"] } futures-util = "0.3.30" hickory-resolver = "0.24.1" http-body-util = "0.1.2" @@ -39,12 +39,13 @@ sha2 = "0.10.8" shell-words = { version = "1.1.0", optional = true } tikv-jemalloc-ctl = { version = "0.6.0", features = ["stats", "use_std"] } tikv-jemallocator = "0.6.0" -tokio = { version = "1.39.3", features = ["full"] } +tokio = { version = "1.43.0", features = ["full"] } tokio-rustls = { version = "0.26.0", features = ["ring", "tls12"], default-features = false } tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] } +tokio-websockets = { version = "0.11.1", features = ["server", "simd", "sha1_smol"] } toml = { version = "0.8.19", optional = true } uuid = { version = "1.10.0", features = ["v4"] } -wisp-mux = { version = "*", path = "../wisp", features = ["fastwebsockets", "generic_stream", "certificate"] } +wisp-mux = { version = "*", path = "../wisp", features = ["tokio-websockets", "certificate"] } [features] default = ["toml"] @@ -54,6 +55,7 @@ toml = ["dep:toml"] twisp = ["dep:pty-process", "dep:libc", "dep:shell-words"] speed-limit = ["dep:async-speed-limit"] +tokio-console = ["dep:console-subscriber"] [build-dependencies] vergen-git2 = { version = "1.0.0", features = ["rustc"] } diff --git a/server/flamegraph.svg b/server/flamegraph.svg index bf7ea3d..75be8f6 100644 --- a/server/flamegraph.svg +++ b/server/flamegraph.svg @@ -1,4 +1,4 @@ - \ No newline at end of file diff --git a/server/src/config.rs b/server/src/config.rs index d78727d..c03e994 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -324,7 +324,7 @@ impl Default for ServerConfig { bind: (SocketType::default(), "127.0.0.1:4000".to_string()), transport: SocketTransport::default(), resolve_ipv6: false, - tcp_nodelay: false, + tcp_nodelay: true, file_raw_mode: false, tls_keypair: None, @@ -432,8 +432,8 @@ impl WispConfig { impl Default for StreamConfig { fn default() -> Self { Self { - tcp_nodelay: false, - buffer_size: 16384, + tcp_nodelay: true, + buffer_size: 128 * 1024, allow_udp: true, allow_wsproxy_udp: false, diff --git a/server/src/handle/wisp/mod.rs b/server/src/handle/wisp/mod.rs index c3a2527..7d0d2b8 100644 --- a/server/src/handle/wisp/mod.rs +++ b/server/src/handle/wisp/mod.rs @@ -6,81 +6,61 @@ pub mod wispnet; use std::{sync::Arc, time::Duration}; use anyhow::Context; -use bytes::BytesMut; use cfg_if::cfg_if; use event_listener::Event; -use futures_util::FutureExt; +use futures_util::{future::Either, FutureExt, SinkExt, StreamExt}; use log::{debug, trace}; use tokio::{ - io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, - net::tcp::{OwnedReadHalf, OwnedWriteHalf}, + io::{AsyncWriteExt, BufReader}, + net::TcpStream, select, task::JoinSet, time::interval, }; -use tokio_util::compat::FuturesAsyncReadCompatExt; +use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt}; use uuid::Uuid; use wisp_mux::{ - ws::Payload, CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, + packet::{CloseReason, ConnectPacket}, + stream::MuxStream, ServerMux, }; use wispnet::route_wispnet; use crate::{ - route::{WispResult, WispStreamWrite}, + route::{WispResult, WispStreamWrite, WispWsStreamWrite}, stream::{ClientStream, ResolvedPacket}, CLIENTS, CONFIG, }; -async fn copy_read_fast( - muxrx: MuxStreamAsyncRead, - mut tcptx: OwnedWriteHalf, - #[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter< +async fn copy_fast( + mux: MuxStream, + tcp: TcpStream, + #[cfg(feature = "speed-limit")] read_limit: async_speed_limit::Limiter< + async_speed_limit::clock::StandardClock, + >, + #[cfg(feature = "speed-limit")] write_limit: async_speed_limit::Limiter< async_speed_limit::clock::StandardClock, >, ) -> std::io::Result<()> { + let (muxrx, muxtx) = mux.into_async_rw().into_split(); let mut muxrx = muxrx.compat(); - loop { - let buf = muxrx.fill_buf().await?; - if buf.is_empty() { - tcptx.flush().await?; - return Ok(()); - } + let mut muxtx = muxtx.compat_write(); - #[cfg(feature = "speed-limit")] - limiter.consume(buf.len()).await; + let (tcprx, mut tcptx) = tcp.into_split(); - let i = tcptx.write(buf).await?; - if i == 0 { - return Err(std::io::ErrorKind::WriteZero.into()); - } + #[cfg(feature = "speed-limit")] + let tcprx = read_limit.limit(tcprx); + #[cfg(feature = "speed-limit")] + let mut tcptx = write_limit.limit(tcptx); - muxrx.consume(i); - } -} - -async fn copy_write_fast( - muxtx: MuxStreamWrite, - tcprx: OwnedReadHalf, - #[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter< - async_speed_limit::clock::StandardClock, - >, -) -> anyhow::Result<()> { let mut tcprx = BufReader::with_capacity(CONFIG.stream.buffer_size, tcprx); - loop { - let buf = tcprx.fill_buf().await?; - let len = buf.len(); - if len == 0 { - return Ok(()); - } + select! { + x = tokio::io::copy_buf(&mut muxrx, &mut tcptx) => x?, + x = tokio::io::copy(&mut tcprx, &mut muxtx) => x?, + }; - #[cfg(feature = "speed-limit")] - limiter.consume(buf.len()).await; - - muxtx.write(&buf).await?; - tcprx.consume(len); - } + Ok(()) } async fn resolve_stream( @@ -147,13 +127,15 @@ async fn forward_stream( let closer = muxstream.get_close_handle(); let ret: anyhow::Result<()> = async { - let (muxread, muxwrite) = muxstream.into_split(); - let muxread = muxread.into_stream().into_asyncread(); - let (tcpread, tcpwrite) = stream.into_split(); - select! { - x = copy_read_fast(muxread, tcpwrite, #[cfg(feature = "speed-limit")] write_limit) => x?, - x = copy_write_fast(muxwrite, tcpread, #[cfg(feature = "speed-limit")] read_limit) => x?, - } + copy_fast( + muxstream, + stream, + #[cfg(feature = "speed-limit")] + read_limit, + #[cfg(feature = "speed-limit")] + write_limit, + ) + .await?; Ok(()) } .await; @@ -169,6 +151,8 @@ async fn forward_stream( } ClientStream::Udp(stream) => { let closer = muxstream.get_close_handle(); + let (mut read, write) = muxstream.into_split(); + let mut write = write.into_async_write().compat_write(); let ret: anyhow::Result<()> = async move { let mut data = vec![0u8; 65507]; @@ -176,10 +160,10 @@ async fn forward_stream( select! { size = stream.recv(&mut data) => { let size = size?; - muxstream.write(&data[..size]).await?; + write.write_all(&data[..size]).await?; } - data = muxstream.read() => { - if let Some(data) = data? { + data = read.next() => { + if let Some(data) = data.transpose()? { stream.send(&data).await?; } else { break Ok(()); @@ -202,8 +186,8 @@ async fn forward_stream( #[cfg(feature = "twisp")] ClientStream::Pty(cmd, pty) => { let closer = muxstream.get_close_handle(); - let id = muxstream.stream_id; - let (mut rx, mut tx) = muxstream.into_io().into_asyncrw().into_split(); + let id = muxstream.get_stream_id(); + let (mut rx, mut tx) = muxstream.into_async_rw().into_split(); match twisp::handle_twisp(id, &mut rx, &mut tx, twisp_map.clone(), pty, cmd).await { Ok(()) => { @@ -335,7 +319,7 @@ pub async fn handle_wisp(stream: WispResult, is_v2: bool, id: String) -> anyhow: .build(); let (mux, fut) = Box::pin( - Box::pin(ServerMux::create( + Box::pin(ServerMux::new( read, write, buffer_size, @@ -351,11 +335,8 @@ pub async fn handle_wisp(stream: WispResult, is_v2: bool, id: String) -> anyhow: debug!( "new wisp client id {:?} connected with extensions {:?}, downgraded {:?}", id, - mux.supported_extensions - .iter() - .map(|x| x.get_id()) - .collect::>(), - mux.downgraded + mux.get_extension_ids(), + mux.was_downgraded() ); let mut set: JoinSet<()> = JoinSet::new(); @@ -369,11 +350,19 @@ pub async fn handle_wisp(stream: WispResult, is_v2: bool, id: String) -> anyhow: let ping_id = id.clone(); set.spawn(async move { let mut interval = interval(Duration::from_secs(30)); - while ping_mux - .send_ping(Payload::Bytes(BytesMut::new())) - .await - .is_ok() - { + let send_ping = || async { + let mut locked = ping_mux.lock_ws().await?; + if let Either::Left(ws) = &mut *locked { + >::send( + ws, + tokio_websockets::Message::ping(&[] as &[u8]), + ) + .await?; + } + anyhow::Ok(()) + }; + + while (send_ping)().await.is_ok() { trace!("sent ping to wisp client id {:?}", ping_id); select! { _ = interval.tick() => (), @@ -382,7 +371,7 @@ pub async fn handle_wisp(stream: WispResult, is_v2: bool, id: String) -> anyhow: } }); - while let Some((connect, stream)) = mux.server_new_stream().await { + while let Some((connect, stream)) = mux.wait_for_stream().await { set.spawn(handle_stream( connect, stream, diff --git a/server/src/handle/wisp/twisp.rs b/server/src/handle/wisp/twisp.rs index cf2905d..6f93531 100644 --- a/server/src/handle/wisp/twisp.rs +++ b/server/src/handle/wisp/twisp.rs @@ -14,10 +14,13 @@ use wisp_mux::{ AnyProtocolExtension, AnyProtocolExtensionBuilder, ProtocolExtension, ProtocolExtensionBuilder, }, - ws::{DynWebSocketRead, LockingWebSocketWrite}, - MuxStreamAsyncRead, MuxStreamAsyncWrite, WispError, + stream::{MuxStreamAsyncRead, MuxStreamAsyncWrite}, + ws::{WebSocketRead, WebSocketWrite}, + WispError, }; +use crate::route::WispStreamWrite; + pub type TwispMap = Arc>>; pub const STREAM_TYPE: u8 = 0x03; @@ -50,8 +53,8 @@ impl ProtocolExtension for TWispServerProtocolExtension { async fn handle_handshake( &mut self, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, + _: &mut dyn WebSocketRead, + _: &mut dyn WebSocketWrite, ) -> std::result::Result<(), WispError> { Ok(()) } @@ -60,8 +63,8 @@ impl ProtocolExtension for TWispServerProtocolExtension { &mut self, packet_type: u8, mut packet: Bytes, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, + _: &mut dyn WebSocketRead, + _: &mut dyn WebSocketWrite, ) -> std::result::Result<(), WispError> { if packet_type == 0xF0 { if packet.remaining() < 4 + 2 + 2 { @@ -126,8 +129,8 @@ pub fn new_ext(map: TwispMap) -> AnyProtocolExtensionBuilder { pub async fn handle_twisp( id: u32, - streamrx: &mut MuxStreamAsyncRead, - streamtx: &mut MuxStreamAsyncWrite, + streamrx: &mut MuxStreamAsyncRead, + streamtx: &mut MuxStreamAsyncWrite, map: TwispMap, mut pty: Pty, mut cmd: Child, diff --git a/server/src/handle/wisp/wispnet.rs b/server/src/handle/wisp/wispnet.rs index 50aeae0..639c56b 100644 --- a/server/src/handle/wisp/wispnet.rs +++ b/server/src/handle/wisp/wispnet.rs @@ -6,17 +6,19 @@ use std::{ use anyhow::{Context, Result}; use async_trait::async_trait; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures_util::{SinkExt, StreamExt}; use lazy_static::lazy_static; use log::debug; use tokio::{select, sync::Mutex}; use uuid::Uuid; use wisp_mux::{ extensions::{ - AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder, ProtocolExtensionVecExt, + AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder, ProtocolExtensionListExt, }, - ws::{DynWebSocketRead, Frame, LockingWebSocketWrite, Payload}, - ClientMux, CloseReason, ConnectPacket, MuxStream, MuxStreamRead, MuxStreamWrite, Role, - WispError, WispV2Handshake, + packet::{CloseReason, ConnectPacket}, + stream::{MuxStream, MuxStreamRead, MuxStreamWrite}, + ws::{WebSocketRead, WebSocketWrite}, + ClientMux, Role, WispError, WispV2Handshake, }; use crate::{ @@ -96,8 +98,8 @@ impl ProtocolExtension for WispnetServerProtocolExtension { async fn handle_handshake( &mut self, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, + _: &mut dyn WebSocketRead, + _: &mut dyn WebSocketWrite, ) -> Result<(), WispError> { Ok(()) } @@ -106,15 +108,16 @@ impl ProtocolExtension for WispnetServerProtocolExtension { &mut self, packet_type: u8, mut packet: Bytes, - _: &mut DynWebSocketRead, - write: &dyn LockingWebSocketWrite, + _: &mut dyn WebSocketRead, + write: &mut dyn WebSocketWrite, ) -> Result<(), WispError> { if packet_type == Self::ID { if packet.remaining() < 4 { return Err(WispError::PacketTooSmall); } - if packet.get_u32_le() != 0 { - return Err(WispError::InvalidStreamId); + let id = packet.get_u32_le(); + if id != 0 { + return Err(WispError::InvalidStreamId(id)); } let mut out = BytesMut::new(); @@ -129,9 +132,7 @@ impl ProtocolExtension for WispnetServerProtocolExtension { } drop(locked); - write - .wisp_write_frame(Frame::binary(Payload::Bytes(out))) - .await?; + write.send(out.into()).await?; } Ok(()) } @@ -145,11 +146,7 @@ pub async fn route_wispnet(server: u32, packet: ConnectPacket) -> Result Result, - tx: MuxStreamWrite, + mut rx: MuxStreamRead, + mut tx: MuxStreamWrite, #[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter< async_speed_limit::clock::StandardClock, >, ) -> Result<()> { - while let Some(data) = rx.read().await? { + while let Some(data) = rx.next().await { + let data = data?; + #[cfg(feature = "speed-limit")] limiter.consume(data.len()).await; - tx.write_payload(Payload::Borrowed(data.as_ref())).await?; + tx.send(data).await?; } Ok(()) } @@ -219,7 +218,7 @@ pub async fn handle_wispnet(stream: WispResult, id: String) -> Result<()> { let extensions = vec![WispnetServerProtocolExtensionBuilder(net_id).into()]; let (mux, fut) = Box::pin( - ClientMux::create(read, write, Some(WispV2Handshake::new(extensions))) + ClientMux::new(read, write, Some(WispV2Handshake::new(extensions))) .await .context("failed to create client multiplexor")? .with_required_extensions(&[WispnetServerProtocolExtension::ID]), @@ -228,7 +227,7 @@ pub async fn handle_wispnet(stream: WispResult, id: String) -> Result<()> { .context("wispnet client did not have wispnet extension")?; let is_private = mux - .supported_extensions + .get_extensions() .find_extension::() .context("failed to find wispnet extension")? .1; diff --git a/server/src/handle/wsproxy.rs b/server/src/handle/wsproxy.rs index be48232..171af4f 100644 --- a/server/src/handle/wsproxy.rs +++ b/server/src/handle/wsproxy.rs @@ -1,13 +1,14 @@ use std::str::FromStr; -use fastwebsockets::CloseCode; +use futures_util::{SinkExt, StreamExt}; use log::debug; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, select, }; +use tokio_websockets::CloseCode; use uuid::Uuid; -use wisp_mux::{ws::Payload, CloseReason, ConnectPacket, StreamType}; +use wisp_mux::packet::{CloseReason, ConnectPacket, StreamType}; use crate::{ handle::wisp::wispnet::route_wispnet, @@ -25,13 +26,17 @@ pub async fn handle_wsproxy( udp: bool, ) -> anyhow::Result<()> { if udp && !CONFIG.stream.allow_wsproxy_udp { - let _ = ws.close(CloseCode::Error.into(), b"udp is blocked").await; + let _ = ws + .close(CloseCode::POLICY_VIOLATION.into(), "udp is blocked") + .await; return Ok(()); } let vec: Vec<&str> = path.split('/').last().unwrap().split(':').collect(); let Ok(port) = FromStr::from_str(vec[1]) else { - let _ = ws.close(CloseCode::Error.into(), b"invalid port").await; + let _ = ws + .close(CloseCode::POLICY_VIOLATION.into(), "invalid port") + .await; return Ok(()); }; let connect = ConnectPacket { @@ -40,15 +45,18 @@ pub async fn handle_wsproxy( } else { StreamType::Tcp }, - destination_hostname: vec[0].to_string(), - destination_port: port, + host: vec[0].to_string(), + port, }; let requested_stream = connect.clone(); let Ok(resolved) = ClientStream::resolve(connect).await else { let _ = ws - .close(CloseCode::Error.into(), b"failed to resolve host") + .close( + CloseCode::INTERNAL_SERVER_ERROR.into(), + "failed to resolve host", + ) .await; return Ok(()); }; @@ -57,7 +65,10 @@ pub async fn handle_wsproxy( let resolved = connect.clone(); let Ok(stream) = ClientStream::connect(connect).await else { let _ = ws - .close(CloseCode::Error.into(), b"failed to connect to host") + .close( + CloseCode::INTERNAL_SERVER_ERROR.into(), + "failed to connect to host", + ) .await; return Ok(()); }; @@ -67,7 +78,10 @@ pub async fn handle_wsproxy( let resolved = connect.clone(); let Ok(stream) = route_wispnet(server, connect).await else { let _ = ws - .close(CloseCode::Error.into(), b"failed to connect to host") + .close( + CloseCode::INTERNAL_SERVER_ERROR.into(), + "failed to connect to host", + ) .await; return Ok(()); }; @@ -76,21 +90,23 @@ pub async fn handle_wsproxy( ResolvedPacket::NoResolvedAddrs => { let _ = ws .close( - CloseCode::Error.into(), - b"host did not resolve to any addrs", + CloseCode::INTERNAL_SERVER_ERROR.into(), + "host did not resolve to any addrs", ) .await; return Ok(()); } ResolvedPacket::Blocked => { - let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await; + let _ = ws + .close(CloseCode::POLICY_VIOLATION.into(), "host is blocked") + .await; return Ok(()); } ResolvedPacket::Invalid => { let _ = ws .close( - CloseCode::Error.into(), - b"invalid host/port/type combination", + CloseCode::POLICY_VIOLATION.into(), + "invalid host/port/type combination", ) .await; return Ok(()); @@ -119,19 +135,20 @@ pub async fn handle_wsproxy( loop { select! { x = ws.read() => { - match x? { - WebSocketFrame::Data(data) => { + match x.transpose()? { + Some(WebSocketFrame::Data(data)) => { stream.write_all(&data).await?; } - WebSocketFrame::Close => { + Some(WebSocketFrame::Close) => { stream.shutdown().await?; } - WebSocketFrame::Ignore => {} + Some(WebSocketFrame::Ignore) => {} + None => break Ok(()), } } x = stream.fill_buf() => { let x = x?; - ws.write(x).await?; + ws.write(x.to_vec()).await?; let len = x.len(); stream.consume(len); } @@ -141,11 +158,11 @@ pub async fn handle_wsproxy( .await; match ret { Ok(()) => { - let _ = ws.close(CloseCode::Normal.into(), b"").await; + let _ = ws.close(CloseCode::NORMAL_CLOSURE.into(), "").await; } Err(x) => { let _ = ws - .close(CloseCode::Normal.into(), x.to_string().as_bytes()) + .close(CloseCode::NORMAL_CLOSURE.into(), &x.to_string()) .await; } } @@ -156,15 +173,16 @@ pub async fn handle_wsproxy( loop { select! { x = ws.read() => { - match x? { - WebSocketFrame::Data(data) => { + match x.transpose()? { + Some(WebSocketFrame::Data(data)) => { stream.send(&data).await?; } - WebSocketFrame::Close | WebSocketFrame::Ignore => {} + Some(WebSocketFrame::Close | WebSocketFrame::Ignore) => {} + None => break Ok(()), } } size = stream.recv(&mut data) => { - ws.write(&data[..size?]).await?; + ws.write(data[..size?].to_vec()).await?; } } } @@ -172,11 +190,11 @@ pub async fn handle_wsproxy( .await; match ret { Ok(()) => { - let _ = ws.close(CloseCode::Normal.into(), b"").await; + let _ = ws.close(CloseCode::NORMAL_CLOSURE.into(), "").await; } Err(x) => { let _ = ws - .close(CloseCode::Normal.into(), x.to_string().as_bytes()) + .close(CloseCode::NORMAL_CLOSURE.into(), &x.to_string()) .await; } } @@ -184,10 +202,10 @@ pub async fn handle_wsproxy( #[cfg(feature = "twisp")] ClientStream::Pty(_, _) => { let _ = ws - .close(CloseCode::Error.into(), b"twisp is not supported") + .close(CloseCode::POLICY_VIOLATION, "twisp is not supported") .await; } - ClientStream::Wispnet(stream, mux_id) => { + ClientStream::Wispnet(mut stream, mux_id) => { if let Some(client) = CLIENTS.lock().await.get(&mux_id) { client .0 @@ -200,21 +218,22 @@ pub async fn handle_wsproxy( loop { select! { x = ws.read() => { - match x? { - WebSocketFrame::Data(data) => { - stream.write_payload(Payload::Bytes(data)).await?; + match x.transpose()? { + Some(WebSocketFrame::Data(data)) => { + stream.send(data.into()).await?; } - WebSocketFrame::Close => { + Some(WebSocketFrame::Close) => { stream.close(CloseReason::Voluntary).await?; } - WebSocketFrame::Ignore => {} + Some(WebSocketFrame::Ignore) => {} + None => break, } } - x = stream.read() => { - let Some(x) = x? else { + x = stream.next() => { + let Some(x) = x else { break; }; - ws.write(&x).await?; + ws.write(x?).await?; } } } @@ -228,11 +247,11 @@ pub async fn handle_wsproxy( match ret { Ok(()) => { - let _ = ws.close(CloseCode::Normal.into(), b"").await; + let _ = ws.close(CloseCode::NORMAL_CLOSURE.into(), "").await; } Err(x) => { let _ = ws - .close(CloseCode::Normal.into(), x.to_string().as_bytes()) + .close(CloseCode::NORMAL_CLOSURE.into(), &x.to_string()) .await; } } @@ -240,17 +259,21 @@ pub async fn handle_wsproxy( ClientStream::NoResolvedAddrs => { let _ = ws .close( - CloseCode::Error.into(), - b"host did not resolve to any addrs", + CloseCode::INTERNAL_SERVER_ERROR.into(), + "host did not resolve to any addrs", ) .await; return Ok(()); } ClientStream::Blocked => { - let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await; + let _ = ws + .close(CloseCode::POLICY_VIOLATION.into(), "host is blocked") + .await; } ClientStream::Invalid => { - let _ = ws.close(CloseCode::Error.into(), b"host is invalid").await; + let _ = ws + .close(CloseCode::POLICY_VIOLATION.into(), "host is invalid") + .await; } } diff --git a/server/src/main.rs b/server/src/main.rs index 1a5e97c..5f0a098 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -24,7 +24,7 @@ use tokio::{ sync::Mutex, }; use uuid::Uuid; -use wisp_mux::ConnectPacket; +use wisp_mux::packet::ConnectPacket; pub mod config; #[doc(hidden)] @@ -41,6 +41,8 @@ mod stream; mod upgrade; #[doc(hidden)] mod util_chain; +#[doc(hidden)] +mod util_map_err; #[doc(hidden)] type Client = (Mutex>, String); diff --git a/server/src/route.rs b/server/src/route.rs index 5216d86..5395810 100644 --- a/server/src/route.rs +++ b/server/src/route.rs @@ -2,7 +2,7 @@ use std::{fmt::Display, future::Future, io::Cursor}; use anyhow::Context; use bytes::Bytes; -use fastwebsockets::{FragmentCollector, Role, WebSocket, WebSocketRead, WebSocketWrite}; +use futures_util::future::Either; use http_body_util::Full; use hyper::{ body::Incoming, header::SEC_WEBSOCKET_PROTOCOL, server::conn::http1::Builder, @@ -11,9 +11,9 @@ use hyper::{ use hyper_util::rt::TokioIo; use log::{debug, error, trace}; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; -use wisp_mux::{ - generic::{GenericWebSocketRead, GenericWebSocketWrite}, - ws::{EitherWebSocketRead, EitherWebSocketWrite}, +use tokio_websockets::Limits; +use wisp_mux::ws::{ + TokioWebsocketsTransport, WebSocketExt, WebSocketSplitRead, WebSocketSplitWrite, }; use crate::{ @@ -23,17 +23,18 @@ use crate::{ stream::WebSocketStreamWrapper, upgrade::{is_upgrade_request, upgrade}, util_chain::{chain, Chain}, + util_map_err::MapErr, CONFIG, }; -pub type WispStreamRead = EitherWebSocketRead< - WebSocketRead, ServerStreamRead>>, - GenericWebSocketRead, std::io::Error>, ->; -pub type WispStreamWrite = EitherWebSocketWrite< - WebSocketWrite, - GenericWebSocketWrite, std::io::Error>, +pub type WispStreamRead = Either< + WebSocketSplitRead, ServerStream>>>, + MapErr>, >; +pub type WispWsStreamWrite = + WebSocketSplitWrite, ServerStream>>>; +pub type WispStreamWrite = + Either>>; pub type WispResult = (WispStreamRead, WispStreamWrite); pub enum ServerRouteResult { @@ -216,38 +217,30 @@ pub async fn route( |fut, res, maybe_ip| async move { let ws = fut.await.context("failed to await upgrade future")?; - let mut ws = - WebSocket::after_handshake(TokioIo::new(ws), Role::Server); - ws.set_max_message_size(CONFIG.server.max_message_size); - ws.set_auto_pong(false); - match res { HttpUpgradeResult::Wisp { has_ws_protocol, is_wispnet, } => { - let (read, write) = ws.split(|x| { - let parts = x - .into_inner() - .downcast::>() - .unwrap(); - let (r, w) = parts.io.into_inner().split(); - (chain(Cursor::new(parts.read_buf), r), w) - }); + let ws = ws.downcast::>().unwrap(); + let ws = + chain(Cursor::new(ws.read_buf), ws.io.into_inner()); + + let ws = tokio_websockets::ServerBuilder::new() + .limits(Limits::default().max_payload_len(Some( + CONFIG.server.max_message_size, + ))) + .serve(ws); + let (read, write) = + TokioWebsocketsTransport(ws).split_fast(); let result = if is_wispnet { ServerRouteResult::Wispnet { - stream: ( - EitherWebSocketRead::Left(read), - EitherWebSocketWrite::Left(write), - ), + stream: (Either::Left(read), Either::Left(write)), } } else { ServerRouteResult::Wisp { - stream: ( - EitherWebSocketRead::Left(read), - EitherWebSocketWrite::Left(write), - ), + stream: (Either::Left(read), Either::Left(write)), has_ws_protocol, } }; @@ -255,7 +248,12 @@ pub async fn route( (callback)(result, maybe_ip); } HttpUpgradeResult::WsProxy { path, udp } => { - let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); + let ws = tokio_websockets::ServerBuilder::new() + .limits(Limits::default().max_payload_len(Some( + CONFIG.server.max_message_size, + ))) + .serve(TokioIo::new(ws)); + let ws = WebSocketStreamWrapper(ws); (callback)( ServerRouteResult::WsProxy { stream: ws, @@ -282,15 +280,12 @@ pub async fn route( .new_codec(); let (read, write) = stream.split(); - let read = GenericWebSocketRead::new(FramedRead::new(read, codec.clone())); - let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec)); + let read = MapErr(FramedRead::new(read, codec.clone())); + let write = MapErr(FramedWrite::new(write, codec)); (callback)( ServerRouteResult::Wisp { - stream: ( - EitherWebSocketRead::Right(read), - EitherWebSocketWrite::Right(write), - ), + stream: (Either::Right(read), Either::Right(write)), has_ws_protocol: true, }, None, diff --git a/server/src/stats.rs b/server/src/stats.rs index a536255..10b5228 100644 --- a/server/src/stats.rs +++ b/server/src/stats.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use serde::Serialize; -use wisp_mux::{ConnectPacket, StreamType}; +use wisp_mux::packet::{ConnectPacket, StreamType}; use crate::{CLIENTS, CONFIG}; @@ -10,8 +10,8 @@ fn format_stream_type(stream_type: StreamType) -> &'static str { StreamType::Tcp => "tcp", StreamType::Udp => "udp", #[cfg(feature = "twisp")] - StreamType::Unknown(crate::handle::wisp::twisp::STREAM_TYPE) => "twisp", - StreamType::Unknown(_) => unreachable!(), + StreamType::Other(crate::handle::wisp::twisp::STREAM_TYPE) => "twisp", + StreamType::Other(_) => unreachable!(), } } @@ -36,14 +36,8 @@ impl From<(ConnectPacket, ConnectPacket)> for StreamStats { fn from(value: (ConnectPacket, ConnectPacket)) -> Self { Self { stream_type: format_stream_type(value.0.stream_type).to_string(), - requested: format!( - "{}:{}", - value.0.destination_hostname, value.0.destination_port - ), - resolved: format!( - "{}:{}", - value.1.destination_hostname, value.1.destination_port - ), + requested: format!("{}:{}", value.0.host, value.0.port), + resolved: format!("{}:{}", value.1.host, value.1.port), } } } diff --git a/server/src/stream.rs b/server/src/stream.rs index 07a1787..68fa10d 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -7,13 +7,17 @@ use anyhow::Context; use base64::{prelude::BASE64_STANDARD, Engine}; use bytes::BytesMut; use cfg_if::cfg_if; -use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError}; +use futures_util::{SinkExt, StreamExt}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use log::debug; use regex::RegexSet; use tokio::net::{TcpStream, UdpSocket}; -use wisp_mux::{ConnectPacket, MuxStream, StreamType}; +use tokio_websockets::{CloseCode, Message, Payload, WebSocketStream}; +use wisp_mux::{ + packet::{ConnectPacket, StreamType}, + stream::MuxStream, +}; use crate::{route::WispStreamWrite, CONFIG, RESOLVER}; @@ -25,7 +29,7 @@ fn allowed_set(stream_type: StreamType) -> &'static RegexSet { match stream_type { StreamType::Tcp => CONFIG.stream.allowed_tcp_hosts(), StreamType::Udp => CONFIG.stream.allowed_udp_hosts(), - StreamType::Unknown(_) => unreachable!(), + StreamType::Other(_) => unreachable!(), } } @@ -33,7 +37,7 @@ fn blocked_set(stream_type: StreamType) -> &'static RegexSet { match stream_type { StreamType::Tcp => CONFIG.stream.blocked_tcp_hosts(), StreamType::Udp => CONFIG.stream.blocked_udp_hosts(), - StreamType::Unknown(_) => unreachable!(), + StreamType::Other(_) => unreachable!(), } } @@ -118,8 +122,8 @@ pub enum ResolvedPacket { impl ClientStream { pub async fn resolve(packet: ConnectPacket) -> anyhow::Result { - if CONFIG.wisp.has_wispnet() && packet.destination_hostname.ends_with(".wisp") { - if let Some(wispnet_server) = packet.destination_hostname.split(".wisp").next() { + if CONFIG.wisp.has_wispnet() && packet.host.ends_with(".wisp") { + if let Some(wispnet_server) = packet.host.split(".wisp").next() { debug!("routing {:?} through wispnet", packet); let decoded = BASE64_STANDARD .decode(wispnet_server) @@ -134,14 +138,14 @@ impl ClientStream { cfg_if! { if #[cfg(feature = "twisp")] { - if let StreamType::Unknown(ty) = packet.stream_type { + if let StreamType::Other(ty) = packet.stream_type { if ty == crate::handle::wisp::twisp::STREAM_TYPE && CONFIG.stream.allow_twisp && CONFIG.wisp.wisp_v2 { return Ok(ResolvedPacket::Valid(packet)); } return Ok(ResolvedPacket::Invalid); } } else { - if matches!(packet.stream_type, StreamType::Unknown(_)) { + if matches!(packet.stream_type, StreamType::Other(_)) { return Ok(ResolvedPacket::Invalid); } } @@ -155,17 +159,17 @@ impl ClientStream { .stream .blocked_ports() .iter() - .any(|x| x.contains(&packet.destination_port)) + .any(|x| x.contains(&packet.port)) && !CONFIG .stream .allowed_ports() .iter() - .any(|x| x.contains(&packet.destination_port)) + .any(|x| x.contains(&packet.port)) { return Ok(ResolvedPacket::Blocked); } - if let Ok(addr) = IpAddr::from_str(&packet.destination_hostname) { + if let Ok(addr) = IpAddr::from_str(&packet.host) { if !CONFIG.stream.allow_direct_ip { return Ok(ResolvedPacket::Blocked); } @@ -186,7 +190,7 @@ impl ClientStream { } if match_addr( - &packet.destination_hostname, + &packet.host, allowed_set(packet.stream_type), blocked_set(packet.stream_type), ) { @@ -195,23 +199,23 @@ impl ClientStream { // allow stream type whitelists through if match_addr( - &packet.destination_hostname, + &packet.host, CONFIG.stream.allowed_hosts(), CONFIG.stream.blocked_hosts(), - ) && !allowed_set(packet.stream_type).is_match(&packet.destination_hostname) + ) && !allowed_set(packet.stream_type).is_match(&packet.host) { return Ok(ResolvedPacket::Blocked); } let packet = RESOLVER - .resolve(packet.destination_hostname) + .resolve(packet.host) .await .context("failed to resolve hostname")? .filter(|x| CONFIG.server.resolve_ipv6 || x.is_ipv4()) .map(|x| ConnectPacket { stream_type: packet.stream_type, - destination_hostname: x.to_string(), - destination_port: packet.destination_port, + host: x.to_string(), + port: packet.port, }) .next(); @@ -221,13 +225,11 @@ impl ClientStream { pub async fn connect(packet: ConnectPacket) -> anyhow::Result { match packet.stream_type { StreamType::Tcp => { - let ipaddr = IpAddr::from_str(&packet.destination_hostname) - .context("failed to parse hostname as ipaddr")?; - let stream = TcpStream::connect(SocketAddr::new(ipaddr, packet.destination_port)) + let ipaddr = + IpAddr::from_str(&packet.host).context("failed to parse hostname as ipaddr")?; + let stream = TcpStream::connect(SocketAddr::new(ipaddr, packet.port)) .await - .with_context(|| { - format!("failed to connect to host {}", packet.destination_hostname) - })?; + .with_context(|| format!("failed to connect to host {}", packet.host))?; if CONFIG.stream.tcp_nodelay { stream @@ -242,8 +244,8 @@ impl ClientStream { return Ok(ClientStream::Blocked); } - let ipaddr = IpAddr::from_str(&packet.destination_hostname) - .context("failed to parse hostname as ipaddr")?; + let ipaddr = + IpAddr::from_str(&packet.host).context("failed to parse hostname as ipaddr")?; let bind_addr = if ipaddr.is_ipv4() { SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0) @@ -253,23 +255,20 @@ impl ClientStream { let stream = UdpSocket::bind(bind_addr).await?; - stream - .connect(SocketAddr::new(ipaddr, packet.destination_port)) - .await?; + stream.connect(SocketAddr::new(ipaddr, packet.port)).await?; Ok(ClientStream::Udp(stream)) } #[cfg(feature = "twisp")] - StreamType::Unknown(crate::handle::wisp::twisp::STREAM_TYPE) => { + StreamType::Other(crate::handle::wisp::twisp::STREAM_TYPE) => { if !CONFIG.stream.allow_twisp { return Ok(ClientStream::Blocked); } - let cmdline: Vec = - shell_words::split(&packet.destination_hostname)? - .into_iter() - .map(Into::into) - .collect(); + let cmdline: Vec = shell_words::split(&packet.host)? + .into_iter() + .map(Into::into) + .collect(); let pty = pty_process::Pty::new()?; let cmd = pty_process::Command::new(&cmdline[0]) @@ -278,7 +277,7 @@ impl ClientStream { Ok(ClientStream::Pty(cmd, pty)) } - StreamType::Unknown(_) => Ok(ClientStream::Invalid), + StreamType::Other(_) => Ok(ClientStream::Invalid), } } } @@ -289,25 +288,31 @@ pub enum WebSocketFrame { Ignore, } -pub struct WebSocketStreamWrapper(pub FragmentCollector>); +pub struct WebSocketStreamWrapper(pub WebSocketStream>); impl WebSocketStreamWrapper { - pub async fn read(&mut self) -> Result { - let frame = self.0.read_frame().await?; - Ok(match frame.opcode { - OpCode::Text | OpCode::Binary => WebSocketFrame::Data(frame.payload.into()), - OpCode::Close => WebSocketFrame::Close, - _ => WebSocketFrame::Ignore, - }) + pub async fn read(&mut self) -> Option> { + let frame = self.0.next().await?; + match frame { + Ok(frame) if frame.is_binary() || frame.is_text() => { + Some(Ok(WebSocketFrame::Data(frame.into_payload().into()))) + } + Ok(frame) if frame.is_close() => Some(Ok(WebSocketFrame::Close)), + Ok(_) => Some(Ok(WebSocketFrame::Ignore)), + Err(err) => Some(Err(err)), + } } - pub async fn write(&mut self, data: &[u8]) -> Result<(), WebSocketError> { - self.0 - .write_frame(Frame::binary(Payload::Borrowed(data))) - .await + pub async fn write(&mut self, data: impl Into) -> Result<(), tokio_websockets::Error> { + self.0.send(Message::binary(data)).await } - pub async fn close(&mut self, code: u16, reason: &[u8]) -> Result<(), WebSocketError> { - self.0.write_frame(Frame::close(code, reason)).await + pub async fn close( + &mut self, + code: CloseCode, + reason: &str, + ) -> Result<(), tokio_websockets::Error> { + self.0.send(Message::close(Some(code), reason)).await?; + self.0.close().await } } diff --git a/server/src/util_chain.rs b/server/src/util_chain.rs index 98003eb..133be60 100644 --- a/server/src/util_chain.rs +++ b/server/src/util_chain.rs @@ -9,7 +9,7 @@ use std::{ use futures_util::ready; use pin_project_lite::pin_project; -use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; pin_project! { pub struct Chain { @@ -99,3 +99,35 @@ where } } } +impl AsyncWrite for Chain +where + U: AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().second.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.project().second.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().second.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().second.poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.second.is_write_vectored() + } +} diff --git a/server/src/util_map_err.rs b/server/src/util_map_err.rs new file mode 100644 index 0000000..b23a640 --- /dev/null +++ b/server/src/util_map_err.rs @@ -0,0 +1,56 @@ +use bytes::BytesMut; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; +use wisp_mux::{ws::Payload, WispError}; + +pub struct MapErr(pub T); + +impl> + Unpin> Stream for MapErr { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0 + .poll_next_unpin(cx) + .map_err(|x| WispError::WsImplError(Box::new(x))) + .map_ok(Into::into) + } +} + +impl + Unpin> Sink for MapErr { + type Error = WispError; + + fn poll_ready( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0 + .poll_ready_unpin(cx) + .map_err(|x| WispError::WsImplError(Box::new(x))) + } + + fn start_send(mut self: std::pin::Pin<&mut Self>, item: Payload) -> Result<(), Self::Error> { + self.0 + .start_send_unpin(item) + .map_err(|x| WispError::WsImplError(Box::new(x))) + } + + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0 + .poll_close_unpin(cx) + .map_err(|x| WispError::WsImplError(Box::new(x))) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0 + .poll_flush_unpin(cx) + .map_err(|x| WispError::WsImplError(Box::new(x))) + } +} diff --git a/simple-wisp-client/Cargo.toml b/simple-wisp-client/Cargo.toml index 8a8f9ab..116f303 100644 --- a/simple-wisp-client/Cargo.toml +++ b/simple-wisp-client/Cargo.toml @@ -12,18 +12,15 @@ bytes = "1.7.1" clap = { version = "4.5.16", features = ["cargo", "derive"] } console-subscriber = { version = "0.4.0", optional = true } ed25519-dalek = { version = "2.1.1", features = ["pem"] } -fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] } futures = "0.3.30" -http-body-util = "0.1.2" humantime = "2.1.0" hyper = { version = "1.4.1", features = ["http1", "client"] } -hyper-util = { version = "0.1.7", features = ["tokio"] } sha2 = "0.10.8" simple_moving_average = "1.0.2" tikv-jemallocator = "0.6.0" -tokio = { version = "1.39.3", features = ["full"] } -wisp-mux = { path = "../wisp", features = ["fastwebsockets"]} +tokio = { version = "1.43.0", features = ["full"] } +tokio-websockets = { version = "0.11.1", features = ["client", "simd", "sha1_smol", "rand", "native-tls"] } +wisp-mux = { path = "../wisp", features = ["tokio-websockets"]} [features] tokio-console = ["tokio/tracing", "dep:console-subscriber"] - diff --git a/simple-wisp-client/flamegraph.svg b/simple-wisp-client/flamegraph.svg index 326ccae..f806057 100644 --- a/simple-wisp-client/flamegraph.svg +++ b/simple-wisp-client/flamegraph.svg @@ -1,4 +1,4 @@ - \ No newline at end of file diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index e4e32e6..8697ab7 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -2,31 +2,25 @@ use atomic_counter::{AtomicCounter, RelaxedCounter}; use bytes::Bytes; use clap::Parser; use ed25519_dalek::pkcs8::DecodePrivateKey; -use fastwebsockets::{handshake, WebSocketWrite}; -use futures::{future::select_all, FutureExt, TryFutureExt}; -use http_body_util::Empty; +use futures::{future::select_all, FutureExt, SinkExt}; use humantime::format_duration; -use hyper::{ - header::{CONNECTION, UPGRADE}, - Request, Uri, -}; -use hyper_util::rt::TokioIo; +use hyper::Uri; use sha2::{Digest, Sha256}; use simple_moving_average::{SingleSumSMA, SMA}; use std::{ error::Error, future::Future, - io::{stdout, Cursor, IsTerminal, Write}, + io::{stdout, IsTerminal, Write}, net::SocketAddr, path::PathBuf, pin::Pin, - process::{abort, exit}, - sync::Arc, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, time::{Duration, Instant}, }; use tokio::{ - io::AsyncReadExt, - net::{tcp::OwnedWriteHalf, TcpStream}, select, signal::unix::{signal, SignalKind}, time::{interval, sleep}, @@ -37,44 +31,16 @@ use wisp_mux::{ motd::{MotdProtocolExtension, MotdProtocolExtensionBuilder}, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, - AnyProtocolExtensionBuilder, + AnyProtocolExtensionBuilder, ProtocolExtensionListExt, }, - ClientMux, StreamType, WispError, WispV2Handshake, + packet::StreamType, + ws::{TokioWebsocketsTransport, WebSocketWrite, WebSocketExt}, + ClientMux, WispError, WispV2Handshake, }; #[global_allocator] static JEMALLOCATOR: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; -#[derive(Debug)] -enum WispClientError { - InvalidUriScheme, - UriHasNoHost, -} - -impl std::fmt::Display for WispClientError { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { - use WispClientError as E; - match self { - E::InvalidUriScheme => write!(fmt, "Invalid URI scheme"), - E::UriHasNoHost => write!(fmt, "URI has no host"), - } - } -} - -impl Error for WispClientError {} - -struct SpawnExecutor; - -impl hyper::rt::Executor for SpawnExecutor -where - Fut: Future + Send + 'static, - Fut::Output: Send + 'static, -{ - fn execute(&self, fut: Fut) { - tokio::task::spawn(fut); - } -} - #[derive(Parser)] #[command(version = clap::crate_version!())] struct Cli { @@ -132,19 +98,11 @@ async fn create_mux( opts: &Cli, ) -> Result< ( - ClientMux>, + ClientMux, impl Future> + Send, ), Box, > { - if opts.wisp.scheme_str().unwrap_or_default() != "ws" { - Err(Box::new(WispClientError::InvalidUriScheme))?; - } - - let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?; - let addr_port = opts.wisp.port_u16().unwrap_or(80); - let addr_path = opts.wisp.path(); - let auth = opts.auth.as_ref().map(|auth| { let split: Vec<_> = auth.split(':').collect(); let username = split[0].to_string(); @@ -157,27 +115,13 @@ async fn create_mux( opts.wisp, opts.packet_size, opts.tcp, opts.streams, ); - let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; - let req = Request::builder() - .method("GET") - .uri(addr_path) - .header("Host", addr) - .header(UPGRADE, "websocket") - .header(CONNECTION, "upgrade") - .header( - "Sec-WebSocket-Key", - fastwebsockets::handshake::generate_key(), - ) - .header("Sec-WebSocket-Version", "13") - .body(Empty::::new())?; - - let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; - - let (rx, tx) = ws.split(|x| { - let parts = x.into_inner().downcast::>().unwrap(); - let (r, w) = parts.io.into_inner().into_split(); - (Cursor::new(parts.read_buf).chain(r), w) - }); + let (rx, tx) = TokioWebsocketsTransport( + tokio_websockets::ClientBuilder::from_uri(opts.wisp.clone()) + .connect() + .await? + .0, + ) + .split_fast(); let mut extensions: Vec = Vec::new(); let mut extension_ids: Vec = Vec::new(); @@ -204,12 +148,12 @@ async fn create_mux( } let (mux, fut) = if opts.wisp_v2 { - ClientMux::create(rx, tx, Some(WispV2Handshake::new(extensions))) + ClientMux::new(rx, tx, Some(WispV2Handshake::new(extensions))) .await? .with_required_extensions(extension_ids.as_slice()) .await? } else { - ClientMux::create(rx, tx, None) + ClientMux::new(rx, tx, None) .await? .with_no_required_extensions() }; @@ -228,14 +172,13 @@ async fn real_main() -> Result<(), Box> { let (mux, fut) = create_mux(&opts).await?; let motd_extension = mux - .supported_extensions - .iter() - .find_map(|x| x.downcast_ref::()); + .get_extensions() + .find_extension::(); println!( "connected and created ClientMux, was downgraded {}, extensions supported {:?}, motd {:?}\n\n", - mux.downgraded, - mux.supported_extensions + mux.was_downgraded(), + mux.get_extensions() .iter() .map(|x| x.get_id()) .collect::>(), @@ -244,40 +187,32 @@ async fn real_main() -> Result<(), Box> { let mut threads = Vec::with_capacity((opts.streams * 2) + 3); - threads.push(Box::pin( - tokio::spawn(fut) - .map_err(|x| WispError::Other(Box::new(x))) - .map(|x| x.and_then(|x| x)), - ) + threads.push(Box::pin(tokio::spawn(fut).map(|x| x.unwrap())) as Pin> + Send>>); - let payload = vec![0; 1024 * opts.packet_size]; + let payload = Bytes::from(vec![0; 1024 * opts.packet_size]); let cnt = Arc::new(RelaxedCounter::new(0)); + let top = Arc::new(AtomicUsize::new(0)); let start_time = Instant::now(); for _ in 0..opts.streams { - let (cr, cw) = mux - .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) + let (_, mut cw) = mux + .new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) .await? .into_split(); let cnt = cnt.clone(); let payload = payload.clone(); - threads.push(Box::pin(async move { - while let Ok(()) = cw.write(&payload).await { - cnt.inc(); - } - #[allow(unreachable_code)] - Ok::<(), WispError>(()) - })); threads.push(Box::pin(async move { loop { - let _ = cr.read().await; + cw.feed(payload.clone()).await?; + cnt.inc(); } })); } let cnt_avg = cnt.clone(); + let top_avg = top.clone(); threads.push(Box::pin(async move { let mut interval = interval(Duration::from_millis(100)); let mut avg: SingleSumSMA = SingleSumSMA::new(); @@ -303,15 +238,18 @@ async fn real_main() -> Result<(), Box> { } stdout().flush().unwrap(); avg.add_sample(now - last_time); + + let _ = top_avg.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |old| { + (old < now - last_time).then(|| now - last_time) + }); + last_time = now; } })); threads.push(Box::pin(async move { - let mut interrupt = - signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?; - let mut terminate = - signal(SignalKind::terminate()).map_err(|x| WispError::Other(Box::new(x)))?; + let mut interrupt = signal(SignalKind::interrupt()).unwrap(); + let mut terminate = signal(SignalKind::terminate()).unwrap(); select! { _ = interrupt.recv() => (), _ = terminate.recv() => (), @@ -330,10 +268,7 @@ async fn real_main() -> Result<(), Box> { let duration_since = Instant::now().duration_since(start_time); - if let Err(err) = out.0? { - println!("\n\nerr: {:?}", err); - exit(1); - } + dbg!(out.0)??; out.2.into_iter().for_each(|x| x.abort()); @@ -348,8 +283,15 @@ async fn real_main() -> Result<(), Box> { format_duration(duration_since), (cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(), ); + let top = top.load(Ordering::Relaxed); + println!( + "top: {} packets of &[0; 1024 * {}] ({} KiB) sent in 100ms ({} KiB/s)", + top, + opts.packet_size, + top * opts.packet_size, + top * opts.packet_size * 10 + ); } - // force everything to die - abort() + Ok(()) } diff --git a/wisp/.gitignore b/wisp/.gitignore deleted file mode 100644 index ea8c4bf..0000000 --- a/wisp/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/target diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index cbc60c0..e5909c6 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "wisp-mux" -version = "6.0.0" -license = "LGPL-3.0-only" +version = "7.0.0" +license = "MIT" description = "A library for easily creating Wisp servers and clients." homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" repository = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" @@ -14,28 +14,31 @@ categories = ["network-programming", "asynchronous", "web-programming::websocket workspace = true [dependencies] -async-trait = "0.1.81" -atomic_enum = "0.3.0" -bitflags = { version = "2.6.0", optional = true, features = ["std"] } -bytes = "1.7.1" -ed25519 = { version = "2.2.3", optional = true, features = ["pem", "zeroize"] } -event-listener = "5.3.1" -fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true } -flume = "0.11.0" -futures = "0.3.30" -getrandom = { version = "0.2.15", features = ["std"], optional = true } -pin-project-lite = "0.2.14" -reusable-box-future = "0.2.0" +async-trait = "0.1.85" +bitflags = { version = "2.6.0", optional = true } +bytes = "1.9.0" +ed25519 = { version = "2.2.3", optional = true, features = ["std", "alloc"] } +flume = "0.11.1" +futures = { version = "0.3.31", default-features = false, features = ["std", "async-await"] } +getrandom = { version = "0.2.15", optional = true } +num_enum = "0.7.3" +pin-project = "1.1.8" rustc-hash = "2.1.0" -thiserror = "2.0.3" -tokio = { version = "1.39.3", optional = true, default-features = false } +slab = "0.4.9" +thiserror = "2.0.9" +tokio = { version = "1.42.0", optional = true } +tokio-tungstenite = { version = "0.26.1", features = ["stream"], optional = true, default-features = false } +tokio-websockets = { version = "0.11.1", optional = true } [features] -default = ["generic_stream", "certificate"] -fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] -generic_stream = [] +default = ["certificate"] +certificate = ["dep:getrandom", "dep:ed25519", "dep:bitflags"] wasm = ["getrandom/js"] -certificate = ["dep:ed25519", "dep:bitflags", "dep:getrandom"] +tokio-websockets = ["dep:tokio-websockets", "dep:tokio"] +tokio-tungstenite = ["dep:tokio-tungstenite", "dep:tokio"] + +[dev-dependencies] +tokio = { version = "1.42.0", features = ["macros", "rt", "time"] } [package.metadata.docs.rs] all-features = true diff --git a/wisp/LICENSE b/wisp/LICENSE deleted file mode 100644 index 7b6bec5..0000000 --- a/wisp/LICENSE +++ /dev/null @@ -1,841 +0,0 @@ - GNU LESSER GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 - - Copyright (C) 2007 Free Software Foundation, Inc. - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - - This version of the GNU Lesser General Public License incorporates -the terms and conditions of version 3 of the GNU General Public -License, supplemented by the additional permissions listed below. - - 0. Additional Definitions. - - As used herein, "this License" refers to version 3 of the GNU Lesser -General Public License, and the "GNU GPL" refers to version 3 of the GNU -General Public License. - - "The Library" refers to a covered work governed by this License, -other than an Application or a Combined Work as defined below. - - An "Application" is any work that makes use of an interface provided -by the Library, but which is not otherwise based on the Library. -Defining a subclass of a class defined by the Library is deemed a mode -of using an interface provided by the Library. - - A "Combined Work" is a work produced by combining or linking an -Application with the Library. The particular version of the Library -with which the Combined Work was made is also called the "Linked -Version". - - The "Minimal Corresponding Source" for a Combined Work means the -Corresponding Source for the Combined Work, excluding any source code -for portions of the Combined Work that, considered in isolation, are -based on the Application, and not on the Linked Version. - - The "Corresponding Application Code" for a Combined Work means the -object code and/or source code for the Application, including any data -and utility programs needed for reproducing the Combined Work from the -Application, but excluding the System Libraries of the Combined Work. - - 1. Exception to Section 3 of the GNU GPL. - - You may convey a covered work under sections 3 and 4 of this License -without being bound by section 3 of the GNU GPL. - - 2. Conveying Modified Versions. - - If you modify a copy of the Library, and, in your modifications, a -facility refers to a function or data to be supplied by an Application -that uses the facility (other than as an argument passed when the -facility is invoked), then you may convey a copy of the modified -version: - - a) under this License, provided that you make a good faith effort to - ensure that, in the event an Application does not supply the - function or data, the facility still operates, and performs - whatever part of its purpose remains meaningful, or - - b) under the GNU GPL, with none of the additional permissions of - this License applicable to that copy. - - 3. Object Code Incorporating Material from Library Header Files. - - The object code form of an Application may incorporate material from -a header file that is part of the Library. You may convey such object -code under terms of your choice, provided that, if the incorporated -material is not limited to numerical parameters, data structure -layouts and accessors, or small macros, inline functions and templates -(ten or fewer lines in length), you do both of the following: - - a) Give prominent notice with each copy of the object code that the - Library is used in it and that the Library and its use are - covered by this License. - - b) Accompany the object code with a copy of the GNU GPL and this license - document. - - 4. Combined Works. - - You may convey a Combined Work under terms of your choice that, -taken together, effectively do not restrict modification of the -portions of the Library contained in the Combined Work and reverse -engineering for debugging such modifications, if you also do each of -the following: - - a) Give prominent notice with each copy of the Combined Work that - the Library is used in it and that the Library and its use are - covered by this License. - - b) Accompany the Combined Work with a copy of the GNU GPL and this license - document. - - c) For a Combined Work that displays copyright notices during - execution, include the copyright notice for the Library among - these notices, as well as a reference directing the user to the - copies of the GNU GPL and this license document. - - d) Do one of the following: - - 0) Convey the Minimal Corresponding Source under the terms of this - License, and the Corresponding Application Code in a form - suitable for, and under terms that permit, the user to - recombine or relink the Application with a modified version of - the Linked Version to produce a modified Combined Work, in the - manner specified by section 6 of the GNU GPL for conveying - Corresponding Source. - - 1) Use a suitable shared library mechanism for linking with the - Library. A suitable mechanism is one that (a) uses at run time - a copy of the Library already present on the user's computer - system, and (b) will operate properly with a modified version - of the Library that is interface-compatible with the Linked - Version. - - e) Provide Installation Information, but only if you would otherwise - be required to provide such information under section 6 of the - GNU GPL, and only to the extent that such information is - necessary to install and execute a modified version of the - Combined Work produced by recombining or relinking the - Application with a modified version of the Linked Version. (If - you use option 4d0, the Installation Information must accompany - the Minimal Corresponding Source and Corresponding Application - Code. If you use option 4d1, you must provide the Installation - Information in the manner specified by section 6 of the GNU GPL - for conveying Corresponding Source.) - - 5. Combined Libraries. - - You may place library facilities that are a work based on the -Library side by side in a single library together with other library -facilities that are not Applications and are not covered by this -License, and convey such a combined library under terms of your -choice, if you do both of the following: - - a) Accompany the combined library with a copy of the same work based - on the Library, uncombined with any other library facilities, - conveyed under the terms of this License. - - b) Give prominent notice with the combined library that part of it - is a work based on the Library, and explaining where to find the - accompanying uncombined form of the same work. - - 6. Revised Versions of the GNU Lesser General Public License. - - The Free Software Foundation may publish revised and/or new versions -of the GNU Lesser General Public License from time to time. Such new -versions will be similar in spirit to the present version, but may -differ in detail to address new problems or concerns. - - Each version is given a distinguishing version number. If the -Library as you received it specifies that a certain numbered version -of the GNU Lesser General Public License "or any later version" -applies to it, you have the option of following the terms and -conditions either of that published version or of any later version -published by the Free Software Foundation. If the Library as you -received it does not specify a version number of the GNU Lesser -General Public License, you may choose any version of the GNU Lesser -General Public License ever published by the Free Software Foundation. - - If the Library as you received it specifies that a proxy can decide -whether future versions of the GNU Lesser General Public License shall -apply, that proxy's public statement of acceptance of any version is -permanent authorization for you to choose that version for the -Library. - - GNU GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 - - Copyright (C) 2007 Free Software Foundation, Inc. - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - Preamble - - The GNU General Public License is a free, copyleft license for -software and other kinds of works. - - The licenses for most software and other practical works are designed -to take away your freedom to share and change the works. By contrast, -the GNU General Public License is intended to guarantee your freedom to -share and change all versions of a program--to make sure it remains free -software for all its users. We, the Free Software Foundation, use the -GNU General Public License for most of our software; it applies also to -any other work released this way by its authors. You can apply it to -your programs, too. - - When we speak of free software, we are referring to freedom, not -price. Our General Public Licenses are designed to make sure that you -have the freedom to distribute copies of free software (and charge for -them if you wish), that you receive source code or can get it if you -want it, that you can change the software or use pieces of it in new -free programs, and that you know you can do these things. - - To protect your rights, we need to prevent others from denying you -these rights or asking you to surrender the rights. Therefore, you have -certain responsibilities if you distribute copies of the software, or if -you modify it: responsibilities to respect the freedom of others. - - For example, if you distribute copies of such a program, whether -gratis or for a fee, you must pass on to the recipients the same -freedoms that you received. You must make sure that they, too, receive -or can get the source code. And you must show them these terms so they -know their rights. - - Developers that use the GNU GPL protect your rights with two steps: -(1) assert copyright on the software, and (2) offer you this License -giving you legal permission to copy, distribute and/or modify it. - - For the developers' and authors' protection, the GPL clearly explains -that there is no warranty for this free software. For both users' and -authors' sake, the GPL requires that modified versions be marked as -changed, so that their problems will not be attributed erroneously to -authors of previous versions. - - Some devices are designed to deny users access to install or run -modified versions of the software inside them, although the manufacturer -can do so. This is fundamentally incompatible with the aim of -protecting users' freedom to change the software. The systematic -pattern of such abuse occurs in the area of products for individuals to -use, which is precisely where it is most unacceptable. Therefore, we -have designed this version of the GPL to prohibit the practice for those -products. If such problems arise substantially in other domains, we -stand ready to extend this provision to those domains in future versions -of the GPL, as needed to protect the freedom of users. - - Finally, every program is threatened constantly by software patents. -States should not allow patents to restrict development and use of -software on general-purpose computers, but in those that do, we wish to -avoid the special danger that patents applied to a free program could -make it effectively proprietary. To prevent this, the GPL assures that -patents cannot be used to render the program non-free. - - The precise terms and conditions for copying, distribution and -modification follow. - - TERMS AND CONDITIONS - - 0. Definitions. - - "This License" refers to version 3 of the GNU General Public License. - - "Copyright" also means copyright-like laws that apply to other kinds of -works, such as semiconductor masks. - - "The Program" refers to any copyrightable work licensed under this -License. Each licensee is addressed as "you". "Licensees" and -"recipients" may be individuals or organizations. - - To "modify" a work means to copy from or adapt all or part of the work -in a fashion requiring copyright permission, other than the making of an -exact copy. The resulting work is called a "modified version" of the -earlier work or a work "based on" the earlier work. - - A "covered work" means either the unmodified Program or a work based -on the Program. - - To "propagate" a work means to do anything with it that, without -permission, would make you directly or secondarily liable for -infringement under applicable copyright law, except executing it on a -computer or modifying a private copy. Propagation includes copying, -distribution (with or without modification), making available to the -public, and in some countries other activities as well. - - To "convey" a work means any kind of propagation that enables other -parties to make or receive copies. Mere interaction with a user through -a computer network, with no transfer of a copy, is not conveying. - - An interactive user interface displays "Appropriate Legal Notices" -to the extent that it includes a convenient and prominently visible -feature that (1) displays an appropriate copyright notice, and (2) -tells the user that there is no warranty for the work (except to the -extent that warranties are provided), that licensees may convey the -work under this License, and how to view a copy of this License. If -the interface presents a list of user commands or options, such as a -menu, a prominent item in the list meets this criterion. - - 1. Source Code. - - The "source code" for a work means the preferred form of the work -for making modifications to it. "Object code" means any non-source -form of a work. - - A "Standard Interface" means an interface that either is an official -standard defined by a recognized standards body, or, in the case of -interfaces specified for a particular programming language, one that -is widely used among developers working in that language. - - The "System Libraries" of an executable work include anything, other -than the work as a whole, that (a) is included in the normal form of -packaging a Major Component, but which is not part of that Major -Component, and (b) serves only to enable use of the work with that -Major Component, or to implement a Standard Interface for which an -implementation is available to the public in source code form. A -"Major Component", in this context, means a major essential component -(kernel, window system, and so on) of the specific operating system -(if any) on which the executable work runs, or a compiler used to -produce the work, or an object code interpreter used to run it. - - The "Corresponding Source" for a work in object code form means all -the source code needed to generate, install, and (for an executable -work) run the object code and to modify the work, including scripts to -control those activities. However, it does not include the work's -System Libraries, or general-purpose tools or generally available free -programs which are used unmodified in performing those activities but -which are not part of the work. For example, Corresponding Source -includes interface definition files associated with source files for -the work, and the source code for shared libraries and dynamically -linked subprograms that the work is specifically designed to require, -such as by intimate data communication or control flow between those -subprograms and other parts of the work. - - The Corresponding Source need not include anything that users -can regenerate automatically from other parts of the Corresponding -Source. - - The Corresponding Source for a work in source code form is that -same work. - - 2. Basic Permissions. - - All rights granted under this License are granted for the term of -copyright on the Program, and are irrevocable provided the stated -conditions are met. This License explicitly affirms your unlimited -permission to run the unmodified Program. The output from running a -covered work is covered by this License only if the output, given its -content, constitutes a covered work. This License acknowledges your -rights of fair use or other equivalent, as provided by copyright law. - - You may make, run and propagate covered works that you do not -convey, without conditions so long as your license otherwise remains -in force. You may convey covered works to others for the sole purpose -of having them make modifications exclusively for you, or provide you -with facilities for running those works, provided that you comply with -the terms of this License in conveying all material for which you do -not control copyright. Those thus making or running the covered works -for you must do so exclusively on your behalf, under your direction -and control, on terms that prohibit them from making any copies of -your copyrighted material outside their relationship with you. - - Conveying under any other circumstances is permitted solely under -the conditions stated below. Sublicensing is not allowed; section 10 -makes it unnecessary. - - 3. Protecting Users' Legal Rights From Anti-Circumvention Law. - - No covered work shall be deemed part of an effective technological -measure under any applicable law fulfilling obligations under article -11 of the WIPO copyright treaty adopted on 20 December 1996, or -similar laws prohibiting or restricting circumvention of such -measures. - - When you convey a covered work, you waive any legal power to forbid -circumvention of technological measures to the extent such circumvention -is effected by exercising rights under this License with respect to -the covered work, and you disclaim any intention to limit operation or -modification of the work as a means of enforcing, against the work's -users, your or third parties' legal rights to forbid circumvention of -technological measures. - - 4. Conveying Verbatim Copies. - - You may convey verbatim copies of the Program's source code as you -receive it, in any medium, provided that you conspicuously and -appropriately publish on each copy an appropriate copyright notice; -keep intact all notices stating that this License and any -non-permissive terms added in accord with section 7 apply to the code; -keep intact all notices of the absence of any warranty; and give all -recipients a copy of this License along with the Program. - - You may charge any price or no price for each copy that you convey, -and you may offer support or warranty protection for a fee. - - 5. Conveying Modified Source Versions. - - You may convey a work based on the Program, or the modifications to -produce it from the Program, in the form of source code under the -terms of section 4, provided that you also meet all of these conditions: - - a) The work must carry prominent notices stating that you modified - it, and giving a relevant date. - - b) The work must carry prominent notices stating that it is - released under this License and any conditions added under section - 7. This requirement modifies the requirement in section 4 to - "keep intact all notices". - - c) You must license the entire work, as a whole, under this - License to anyone who comes into possession of a copy. This - License will therefore apply, along with any applicable section 7 - additional terms, to the whole of the work, and all its parts, - regardless of how they are packaged. This License gives no - permission to license the work in any other way, but it does not - invalidate such permission if you have separately received it. - - d) If the work has interactive user interfaces, each must display - Appropriate Legal Notices; however, if the Program has interactive - interfaces that do not display Appropriate Legal Notices, your - work need not make them do so. - - A compilation of a covered work with other separate and independent -works, which are not by their nature extensions of the covered work, -and which are not combined with it such as to form a larger program, -in or on a volume of a storage or distribution medium, is called an -"aggregate" if the compilation and its resulting copyright are not -used to limit the access or legal rights of the compilation's users -beyond what the individual works permit. Inclusion of a covered work -in an aggregate does not cause this License to apply to the other -parts of the aggregate. - - 6. Conveying Non-Source Forms. - - You may convey a covered work in object code form under the terms -of sections 4 and 5, provided that you also convey the -machine-readable Corresponding Source under the terms of this License, -in one of these ways: - - a) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by the - Corresponding Source fixed on a durable physical medium - customarily used for software interchange. - - b) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by a - written offer, valid for at least three years and valid for as - long as you offer spare parts or customer support for that product - model, to give anyone who possesses the object code either (1) a - copy of the Corresponding Source for all the software in the - product that is covered by this License, on a durable physical - medium customarily used for software interchange, for a price no - more than your reasonable cost of physically performing this - conveying of source, or (2) access to copy the - Corresponding Source from a network server at no charge. - - c) Convey individual copies of the object code with a copy of the - written offer to provide the Corresponding Source. This - alternative is allowed only occasionally and noncommercially, and - only if you received the object code with such an offer, in accord - with subsection 6b. - - d) Convey the object code by offering access from a designated - place (gratis or for a charge), and offer equivalent access to the - Corresponding Source in the same way through the same place at no - further charge. You need not require recipients to copy the - Corresponding Source along with the object code. If the place to - copy the object code is a network server, the Corresponding Source - may be on a different server (operated by you or a third party) - that supports equivalent copying facilities, provided you maintain - clear directions next to the object code saying where to find the - Corresponding Source. Regardless of what server hosts the - Corresponding Source, you remain obligated to ensure that it is - available for as long as needed to satisfy these requirements. - - e) Convey the object code using peer-to-peer transmission, provided - you inform other peers where the object code and Corresponding - Source of the work are being offered to the general public at no - charge under subsection 6d. - - A separable portion of the object code, whose source code is excluded -from the Corresponding Source as a System Library, need not be -included in conveying the object code work. - - A "User Product" is either (1) a "consumer product", which means any -tangible personal property which is normally used for personal, family, -or household purposes, or (2) anything designed or sold for incorporation -into a dwelling. In determining whether a product is a consumer product, -doubtful cases shall be resolved in favor of coverage. For a particular -product received by a particular user, "normally used" refers to a -typical or common use of that class of product, regardless of the status -of the particular user or of the way in which the particular user -actually uses, or expects or is expected to use, the product. A product -is a consumer product regardless of whether the product has substantial -commercial, industrial or non-consumer uses, unless such uses represent -the only significant mode of use of the product. - - "Installation Information" for a User Product means any methods, -procedures, authorization keys, or other information required to install -and execute modified versions of a covered work in that User Product from -a modified version of its Corresponding Source. The information must -suffice to ensure that the continued functioning of the modified object -code is in no case prevented or interfered with solely because -modification has been made. - - If you convey an object code work under this section in, or with, or -specifically for use in, a User Product, and the conveying occurs as -part of a transaction in which the right of possession and use of the -User Product is transferred to the recipient in perpetuity or for a -fixed term (regardless of how the transaction is characterized), the -Corresponding Source conveyed under this section must be accompanied -by the Installation Information. But this requirement does not apply -if neither you nor any third party retains the ability to install -modified object code on the User Product (for example, the work has -been installed in ROM). - - The requirement to provide Installation Information does not include a -requirement to continue to provide support service, warranty, or updates -for a work that has been modified or installed by the recipient, or for -the User Product in which it has been modified or installed. Access to a -network may be denied when the modification itself materially and -adversely affects the operation of the network or violates the rules and -protocols for communication across the network. - - Corresponding Source conveyed, and Installation Information provided, -in accord with this section must be in a format that is publicly -documented (and with an implementation available to the public in -source code form), and must require no special password or key for -unpacking, reading or copying. - - 7. Additional Terms. - - "Additional permissions" are terms that supplement the terms of this -License by making exceptions from one or more of its conditions. -Additional permissions that are applicable to the entire Program shall -be treated as though they were included in this License, to the extent -that they are valid under applicable law. If additional permissions -apply only to part of the Program, that part may be used separately -under those permissions, but the entire Program remains governed by -this License without regard to the additional permissions. - - When you convey a copy of a covered work, you may at your option -remove any additional permissions from that copy, or from any part of -it. (Additional permissions may be written to require their own -removal in certain cases when you modify the work.) You may place -additional permissions on material, added by you to a covered work, -for which you have or can give appropriate copyright permission. - - Notwithstanding any other provision of this License, for material you -add to a covered work, you may (if authorized by the copyright holders of -that material) supplement the terms of this License with terms: - - a) Disclaiming warranty or limiting liability differently from the - terms of sections 15 and 16 of this License; or - - b) Requiring preservation of specified reasonable legal notices or - author attributions in that material or in the Appropriate Legal - Notices displayed by works containing it; or - - c) Prohibiting misrepresentation of the origin of that material, or - requiring that modified versions of such material be marked in - reasonable ways as different from the original version; or - - d) Limiting the use for publicity purposes of names of licensors or - authors of the material; or - - e) Declining to grant rights under trademark law for use of some - trade names, trademarks, or service marks; or - - f) Requiring indemnification of licensors and authors of that - material by anyone who conveys the material (or modified versions of - it) with contractual assumptions of liability to the recipient, for - any liability that these contractual assumptions directly impose on - those licensors and authors. - - All other non-permissive additional terms are considered "further -restrictions" within the meaning of section 10. If the Program as you -received it, or any part of it, contains a notice stating that it is -governed by this License along with a term that is a further -restriction, you may remove that term. If a license document contains -a further restriction but permits relicensing or conveying under this -License, you may add to a covered work material governed by the terms -of that license document, provided that the further restriction does -not survive such relicensing or conveying. - - If you add terms to a covered work in accord with this section, you -must place, in the relevant source files, a statement of the -additional terms that apply to those files, or a notice indicating -where to find the applicable terms. - - Additional terms, permissive or non-permissive, may be stated in the -form of a separately written license, or stated as exceptions; -the above requirements apply either way. - - 8. Termination. - - You may not propagate or modify a covered work except as expressly -provided under this License. Any attempt otherwise to propagate or -modify it is void, and will automatically terminate your rights under -this License (including any patent licenses granted under the third -paragraph of section 11). - - However, if you cease all violation of this License, then your -license from a particular copyright holder is reinstated (a) -provisionally, unless and until the copyright holder explicitly and -finally terminates your license, and (b) permanently, if the copyright -holder fails to notify you of the violation by some reasonable means -prior to 60 days after the cessation. - - Moreover, your license from a particular copyright holder is -reinstated permanently if the copyright holder notifies you of the -violation by some reasonable means, this is the first time you have -received notice of violation of this License (for any work) from that -copyright holder, and you cure the violation prior to 30 days after -your receipt of the notice. - - Termination of your rights under this section does not terminate the -licenses of parties who have received copies or rights from you under -this License. If your rights have been terminated and not permanently -reinstated, you do not qualify to receive new licenses for the same -material under section 10. - - 9. Acceptance Not Required for Having Copies. - - You are not required to accept this License in order to receive or -run a copy of the Program. Ancillary propagation of a covered work -occurring solely as a consequence of using peer-to-peer transmission -to receive a copy likewise does not require acceptance. However, -nothing other than this License grants you permission to propagate or -modify any covered work. These actions infringe copyright if you do -not accept this License. Therefore, by modifying or propagating a -covered work, you indicate your acceptance of this License to do so. - - 10. Automatic Licensing of Downstream Recipients. - - Each time you convey a covered work, the recipient automatically -receives a license from the original licensors, to run, modify and -propagate that work, subject to this License. You are not responsible -for enforcing compliance by third parties with this License. - - An "entity transaction" is a transaction transferring control of an -organization, or substantially all assets of one, or subdividing an -organization, or merging organizations. If propagation of a covered -work results from an entity transaction, each party to that -transaction who receives a copy of the work also receives whatever -licenses to the work the party's predecessor in interest had or could -give under the previous paragraph, plus a right to possession of the -Corresponding Source of the work from the predecessor in interest, if -the predecessor has it or can get it with reasonable efforts. - - You may not impose any further restrictions on the exercise of the -rights granted or affirmed under this License. For example, you may -not impose a license fee, royalty, or other charge for exercise of -rights granted under this License, and you may not initiate litigation -(including a cross-claim or counterclaim in a lawsuit) alleging that -any patent claim is infringed by making, using, selling, offering for -sale, or importing the Program or any portion of it. - - 11. Patents. - - A "contributor" is a copyright holder who authorizes use under this -License of the Program or a work on which the Program is based. The -work thus licensed is called the contributor's "contributor version". - - A contributor's "essential patent claims" are all patent claims -owned or controlled by the contributor, whether already acquired or -hereafter acquired, that would be infringed by some manner, permitted -by this License, of making, using, or selling its contributor version, -but do not include claims that would be infringed only as a -consequence of further modification of the contributor version. For -purposes of this definition, "control" includes the right to grant -patent sublicenses in a manner consistent with the requirements of -this License. - - Each contributor grants you a non-exclusive, worldwide, royalty-free -patent license under the contributor's essential patent claims, to -make, use, sell, offer for sale, import and otherwise run, modify and -propagate the contents of its contributor version. - - In the following three paragraphs, a "patent license" is any express -agreement or commitment, however denominated, not to enforce a patent -(such as an express permission to practice a patent or covenant not to -sue for patent infringement). To "grant" such a patent license to a -party means to make such an agreement or commitment not to enforce a -patent against the party. - - If you convey a covered work, knowingly relying on a patent license, -and the Corresponding Source of the work is not available for anyone -to copy, free of charge and under the terms of this License, through a -publicly available network server or other readily accessible means, -then you must either (1) cause the Corresponding Source to be so -available, or (2) arrange to deprive yourself of the benefit of the -patent license for this particular work, or (3) arrange, in a manner -consistent with the requirements of this License, to extend the patent -license to downstream recipients. "Knowingly relying" means you have -actual knowledge that, but for the patent license, your conveying the -covered work in a country, or your recipient's use of the covered work -in a country, would infringe one or more identifiable patents in that -country that you have reason to believe are valid. - - If, pursuant to or in connection with a single transaction or -arrangement, you convey, or propagate by procuring conveyance of, a -covered work, and grant a patent license to some of the parties -receiving the covered work authorizing them to use, propagate, modify -or convey a specific copy of the covered work, then the patent license -you grant is automatically extended to all recipients of the covered -work and works based on it. - - A patent license is "discriminatory" if it does not include within -the scope of its coverage, prohibits the exercise of, or is -conditioned on the non-exercise of one or more of the rights that are -specifically granted under this License. You may not convey a covered -work if you are a party to an arrangement with a third party that is -in the business of distributing software, under which you make payment -to the third party based on the extent of your activity of conveying -the work, and under which the third party grants, to any of the -parties who would receive the covered work from you, a discriminatory -patent license (a) in connection with copies of the covered work -conveyed by you (or copies made from those copies), or (b) primarily -for and in connection with specific products or compilations that -contain the covered work, unless you entered into that arrangement, -or that patent license was granted, prior to 28 March 2007. - - Nothing in this License shall be construed as excluding or limiting -any implied license or other defenses to infringement that may -otherwise be available to you under applicable patent law. - - 12. No Surrender of Others' Freedom. - - If conditions are imposed on you (whether by court order, agreement or -otherwise) that contradict the conditions of this License, they do not -excuse you from the conditions of this License. If you cannot convey a -covered work so as to satisfy simultaneously your obligations under this -License and any other pertinent obligations, then as a consequence you may -not convey it at all. For example, if you agree to terms that obligate you -to collect a royalty for further conveying from those to whom you convey -the Program, the only way you could satisfy both those terms and this -License would be to refrain entirely from conveying the Program. - - 13. Use with the GNU Affero General Public License. - - Notwithstanding any other provision of this License, you have -permission to link or combine any covered work with a work licensed -under version 3 of the GNU Affero General Public License into a single -combined work, and to convey the resulting work. The terms of this -License will continue to apply to the part which is the covered work, -but the special requirements of the GNU Affero General Public License, -section 13, concerning interaction through a network will apply to the -combination as such. - - 14. Revised Versions of this License. - - The Free Software Foundation may publish revised and/or new versions of -the GNU General Public License from time to time. Such new versions will -be similar in spirit to the present version, but may differ in detail to -address new problems or concerns. - - Each version is given a distinguishing version number. If the -Program specifies that a certain numbered version of the GNU General -Public License "or any later version" applies to it, you have the -option of following the terms and conditions either of that numbered -version or of any later version published by the Free Software -Foundation. If the Program does not specify a version number of the -GNU General Public License, you may choose any version ever published -by the Free Software Foundation. - - If the Program specifies that a proxy can decide which future -versions of the GNU General Public License can be used, that proxy's -public statement of acceptance of a version permanently authorizes you -to choose that version for the Program. - - Later license versions may give you additional or different -permissions. However, no additional obligations are imposed on any -author or copyright holder as a result of your choosing to follow a -later version. - - 15. Disclaimer of Warranty. - - THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY -APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT -HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY -OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, -THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM -IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF -ALL NECESSARY SERVICING, REPAIR OR CORRECTION. - - 16. Limitation of Liability. - - IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING -WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS -THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY -GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE -USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF -DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD -PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), -EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF -SUCH DAMAGES. - - 17. Interpretation of Sections 15 and 16. - - If the disclaimer of warranty and limitation of liability provided -above cannot be given local legal effect according to their terms, -reviewing courts shall apply local law that most closely approximates -an absolute waiver of all civil liability in connection with the -Program, unless a warranty or assumption of liability accompanies a -copy of the Program in return for a fee. - - END OF TERMS AND CONDITIONS - - How to Apply These Terms to Your New Programs - - If you develop a new program, and you want it to be of the greatest -possible use to the public, the best way to achieve this is to make it -free software which everyone can redistribute and change under these terms. - - To do so, attach the following notices to the program. It is safest -to attach them to the start of each source file to most effectively -state the exclusion of warranty; and each file should have at least -the "copyright" line and a pointer to where the full notice is found. - - - Copyright (C) - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . - -Also add information on how to contact you by electronic and paper mail. - - If the program does terminal interaction, make it output a short -notice like this when it starts in an interactive mode: - - Copyright (C) - This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. - This is free software, and you are welcome to redistribute it - under certain conditions; type `show c' for details. - -The hypothetical commands `show w' and `show c' should show the appropriate -parts of the General Public License. Of course, your program's commands -might be different; for a GUI interface, you would use an "about box". - - You should also get your employer (if you work as a programmer) or school, -if any, to sign a "copyright disclaimer" for the program, if necessary. -For more information on this, and how to apply and follow the GNU GPL, see -. - - The GNU General Public License does not permit incorporating your program -into proprietary programs. If your program is a subroutine library, you -may consider it more useful to permit linking proprietary applications with -the library. If this is what you want to do, use the GNU Lesser General -Public License instead of this License. But first, please read -. - diff --git a/wisp/README.md b/wisp/README.md deleted file mode 100644 index ee6a62b..0000000 --- a/wisp/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# wisp-mux -A library for easily creating [Wisp](https://github.com/MercuryWorkshop/wisp-protocol) servers and clients. diff --git a/wisp/src/extensions/cert.rs b/wisp/src/extensions/cert.rs index 510c1f5..877ba8e 100644 --- a/wisp/src/extensions/cert.rs +++ b/wisp/src/extensions/cert.rs @@ -10,10 +10,7 @@ use ed25519::{ Signature, }; -use crate::{ - ws::{DynWebSocketRead, LockingWebSocketWrite}, - Role, WispError, -}; +use crate::{Role, WispError}; use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; @@ -145,13 +142,6 @@ impl ProtocolExtension for CertAuthProtocolExtension { Self::ID } - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } - fn get_congestion_stream_types(&self) -> &'static [u8] { - &[] - } - fn encode(&self) -> Bytes { match self { Self::Server { @@ -180,24 +170,6 @@ impl ProtocolExtension for CertAuthProtocolExtension { } } - async fn handle_handshake( - &mut self, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - async fn handle_packet( - &mut self, - _: u8, - _: Bytes, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - fn box_clone(&self) -> Box { Box::new(self.clone()) } diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index c4f92a1..7d46f46 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -12,14 +12,18 @@ use std::{ }; use async_trait::async_trait; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{BufMut, Bytes}; use crate::{ - ws::{DynWebSocketRead, LockingWebSocketWrite}, + ws::{PayloadMut, WebSocketRead, WebSocketWrite}, Role, WispError, }; -/// Type-erased protocol extension that implements Clone. +mod private { + pub struct Sealed; +} + +/// Type-erased protocol extension. #[derive(Debug)] pub struct AnyProtocolExtension(Box); @@ -64,14 +68,12 @@ impl Clone for AnyProtocolExtension { } } -impl From for Bytes { - fn from(value: AnyProtocolExtension) -> Self { - let mut bytes = BytesMut::with_capacity(5); - let payload = value.encode(); - bytes.put_u8(value.get_id()); - bytes.put_u32_le(payload.len() as u32); - bytes.extend(payload); - bytes.freeze() +impl AnyProtocolExtension { + pub(crate) fn encode_into(&self, packet: &mut PayloadMut) { + let payload = self.encode(); + packet.put_u8(self.get_id()); + packet.put_u32_le(payload.len() as u32); + packet.extend(payload); } } @@ -92,11 +94,15 @@ pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static { /// Get the protocol extension's supported packets. /// /// Used to decide whether to call the protocol extension's packet handler. - fn get_supported_packets(&self) -> &'static [u8]; + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } /// Get stream types that should be treated as TCP. /// /// Used to decide whether to handle congestion control for that stream type. - fn get_congestion_stream_types(&self) -> &'static [u8]; + fn get_congestion_stream_types(&self) -> &'static [u8] { + &[] + } /// Encode self into Bytes. fn encode(&self) -> Bytes; @@ -106,24 +112,31 @@ pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static { /// This should be used to send or receive data before any streams are created. async fn handle_handshake( &mut self, - read: &mut DynWebSocketRead, - write: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError>; + read: &mut dyn WebSocketRead, + write: &mut dyn WebSocketWrite, + ) -> Result<(), WispError> { + let _ = (read, write); + Ok(()) + } /// Handle receiving a packet. async fn handle_packet( &mut self, packet_type: u8, packet: Bytes, - read: &mut DynWebSocketRead, - write: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError>; + read: &mut dyn WebSocketRead, + write: &mut dyn WebSocketWrite, + ) -> Result<(), WispError> { + let _ = (packet_type, packet, read, write); + Ok(()) + } /// Clone the protocol extension. fn box_clone(&self) -> Box; + #[doc(hidden)] /// Do not override. - fn __internal_type_id(&self) -> TypeId { + fn __internal_type_id(&self, _: private::Sealed) -> TypeId { TypeId::of::() } } @@ -131,7 +144,7 @@ pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static { impl dyn ProtocolExtension { fn __is(&self) -> bool { let t = TypeId::of::(); - self.__internal_type_id() == t + self.__internal_type_id(private::Sealed) == t } fn __downcast(self: Box) -> Result, Box> { @@ -183,8 +196,9 @@ pub trait ProtocolExtensionBuilder: Sync + Send + 'static { /// This is called first on the server and second on the client. fn build_to_extension(&mut self, role: Role) -> Result; + #[doc(hidden)] /// Do not override. - fn __internal_type_id(&self) -> TypeId { + fn __internal_type_id(&self, _sealed: private::Sealed) -> TypeId { TypeId::of::() } } @@ -192,7 +206,7 @@ pub trait ProtocolExtensionBuilder: Sync + Send + 'static { impl dyn ProtocolExtensionBuilder { fn __is(&self) -> bool { let t = TypeId::of::(); - self.__internal_type_id() == t + self.__internal_type_id(private::Sealed) == t } fn __downcast(self: Box) -> Result, Box> { @@ -267,49 +281,78 @@ impl From for AnyProtocolExtensionBuilder { } } -/// Helper functions for `Vec` -pub trait ProtocolExtensionBuilderVecExt { +/// Helper functions for `[AnyProtocolExtensionBuilder]` +pub trait ProtocolExtensionBuilderListExt { /// Returns a reference to the protocol extension builder specified, if it was found. fn find_extension(&self) -> Option<&T>; /// Returns a mutable reference to the protocol extension builder specified, if it was found. fn find_extension_mut(&mut self) -> Option<&mut T>; +} +/// Helper functions for `Vec` +pub trait ProtocolExtensionBuilderVecExt { /// Removes any instances of the protocol extension builder specified, if it was found. fn remove_extension(&mut self); } -impl ProtocolExtensionBuilderVecExt for Vec { +impl ProtocolExtensionBuilderListExt for [AnyProtocolExtensionBuilder] { fn find_extension(&self) -> Option<&T> { self.iter().find_map(|x| x.downcast_ref::()) } fn find_extension_mut(&mut self) -> Option<&mut T> { self.iter_mut().find_map(|x| x.downcast_mut::()) } +} +impl ProtocolExtensionBuilderListExt for Vec { + fn find_extension(&self) -> Option<&T> { + self.as_slice().find_extension() + } + fn find_extension_mut(&mut self) -> Option<&mut T> { + self.as_mut_slice().find_extension_mut() + } +} + +impl ProtocolExtensionBuilderVecExt for Vec { fn remove_extension(&mut self) { self.retain(|x| x.downcast_ref::().is_none()); } } -/// Helper functions for `Vec` -pub trait ProtocolExtensionVecExt { +/// Helper functions for `[AnyProtocolExtension]` +pub trait ProtocolExtensionListExt { /// Returns a reference to the protocol extension specified, if it was found. fn find_extension(&self) -> Option<&T>; /// Returns a mutable reference to the protocol extension specified, if it was found. fn find_extension_mut(&mut self) -> Option<&mut T>; +} +/// Helper functions for `Vec` +pub trait ProtocolExtensionVecExt { /// Removes any instances of the protocol extension specified, if it was found. fn remove_extension(&mut self); } -impl ProtocolExtensionVecExt for Vec { +impl ProtocolExtensionListExt for [AnyProtocolExtension] { fn find_extension(&self) -> Option<&T> { self.iter().find_map(|x| x.downcast_ref::()) } fn find_extension_mut(&mut self) -> Option<&mut T> { self.iter_mut().find_map(|x| x.downcast_mut::()) } +} +impl ProtocolExtensionListExt for Vec { + fn find_extension(&self) -> Option<&T> { + self.as_slice().find_extension() + } + + fn find_extension_mut(&mut self) -> Option<&mut T> { + self.as_mut_slice().find_extension_mut() + } +} + +impl ProtocolExtensionVecExt for Vec { fn remove_extension(&mut self) { self.retain(|x| x.downcast_ref::().is_none()); } diff --git a/wisp/src/extensions/motd.rs b/wisp/src/extensions/motd.rs index f718cad..88ff571 100644 --- a/wisp/src/extensions/motd.rs +++ b/wisp/src/extensions/motd.rs @@ -5,10 +5,7 @@ use async_trait::async_trait; use bytes::Bytes; -use crate::{ - ws::{DynWebSocketRead, LockingWebSocketWrite}, - Role, WispError, -}; +use crate::{Role, WispError}; use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; @@ -31,14 +28,6 @@ impl ProtocolExtension for MotdProtocolExtension { Self::ID } - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } - - fn get_congestion_stream_types(&self) -> &'static [u8] { - &[] - } - fn encode(&self) -> Bytes { match self.role { Role::Server => Bytes::from(self.motd.as_bytes().to_vec()), @@ -46,24 +35,6 @@ impl ProtocolExtension for MotdProtocolExtension { } } - async fn handle_handshake( - &mut self, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - async fn handle_packet( - &mut self, - _: u8, - _: Bytes, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - fn box_clone(&self) -> Box { Box::new(self.clone()) } diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs index b2da387..c97cf1c 100644 --- a/wisp/src/extensions/password.rs +++ b/wisp/src/extensions/password.rs @@ -8,10 +8,7 @@ use std::collections::HashMap; use async_trait::async_trait; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use crate::{ - ws::{DynWebSocketRead, LockingWebSocketWrite}, - Role, WispError, -}; +use crate::{Role, WispError}; use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; @@ -60,18 +57,6 @@ impl ProtocolExtension for PasswordProtocolExtension { PASSWORD_PROTOCOL_EXTENSION_ID } - fn box_clone(&self) -> Box { - Box::new(self.clone()) - } - - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } - - fn get_congestion_stream_types(&self) -> &'static [u8] { - &[] - } - fn encode(&self) -> Bytes { match self { Self::ServerBeforeClientInfo { required } => { @@ -92,22 +77,8 @@ impl ProtocolExtension for PasswordProtocolExtension { } } - async fn handle_handshake( - &mut self, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - async fn handle_packet( - &mut self, - _: u8, - _: Bytes, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError> { - Err(WispError::ExtensionImplNotSupported) + fn box_clone(&self) -> Box { + Box::new(self.clone()) } } diff --git a/wisp/src/extensions/udp.rs b/wisp/src/extensions/udp.rs index 50cc445..e68ca63 100644 --- a/wisp/src/extensions/udp.rs +++ b/wisp/src/extensions/udp.rs @@ -4,10 +4,7 @@ use async_trait::async_trait; use bytes::Bytes; -use crate::{ - ws::{DynWebSocketRead, LockingWebSocketWrite}, - WispError, -}; +use crate::WispError; use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; @@ -26,36 +23,10 @@ impl ProtocolExtension for UdpProtocolExtension { Self::ID } - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } - - fn get_congestion_stream_types(&self) -> &'static [u8] { - &[] - } - fn encode(&self) -> Bytes { Bytes::new() } - async fn handle_handshake( - &mut self, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - async fn handle_packet( - &mut self, - _: u8, - _: Bytes, - _: &mut DynWebSocketRead, - _: &dyn LockingWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - fn box_clone(&self) -> Box { Box::new(Self) } diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs deleted file mode 100644 index f7a3f68..0000000 --- a/wisp/src/fastwebsockets.rs +++ /dev/null @@ -1,221 +0,0 @@ -//! `WebSocketRead` + `WebSocketWrite` implementation for the fastwebsockets library. - -use bytes::BytesMut; -use fastwebsockets::{ - CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketRead, - WebSocketWrite, -}; -use tokio::io::{AsyncRead, AsyncWrite}; - -use crate::{ws::LockingWebSocketWrite, WispError}; - -fn match_payload(payload: Payload<'_>) -> crate::ws::Payload<'_> { - match payload { - Payload::Bytes(x) => crate::ws::Payload::Bytes(x), - Payload::Owned(x) => crate::ws::Payload::Bytes(BytesMut::from(&*x)), - Payload::BorrowedMut(x) => crate::ws::Payload::Borrowed(&*x), - Payload::Borrowed(x) => crate::ws::Payload::Borrowed(x), - } -} - -fn match_payload_reverse(payload: crate::ws::Payload<'_>) -> Payload<'_> { - match payload { - crate::ws::Payload::Bytes(x) => Payload::Bytes(x), - crate::ws::Payload::Borrowed(x) => Payload::Borrowed(x), - } -} - -fn payload_to_bytesmut(payload: Payload<'_>) -> BytesMut { - match payload { - Payload::Borrowed(borrowed) => BytesMut::from(borrowed), - Payload::BorrowedMut(borrowed_mut) => BytesMut::from(&*borrowed_mut), - Payload::Owned(owned) => BytesMut::from(owned.as_slice()), - Payload::Bytes(b) => b, - } -} - -impl From for crate::ws::OpCode { - fn from(opcode: OpCode) -> Self { - use OpCode as O; - match opcode { - O::Continuation => { - unreachable!("continuation should never be recieved when using a fragmentcollector") - } - O::Text => Self::Text, - O::Binary => Self::Binary, - O::Close => Self::Close, - O::Ping => Self::Ping, - O::Pong => Self::Pong, - } - } -} - -impl<'a> From> for crate::ws::Frame<'a> { - fn from(frame: Frame<'a>) -> Self { - Self { - finished: frame.fin, - opcode: frame.opcode.into(), - payload: match_payload(frame.payload), - } - } -} - -impl<'a> From> for Frame<'a> { - fn from(frame: crate::ws::Frame<'a>) -> Self { - use crate::ws::OpCode as O; - let payload = match_payload_reverse(frame.payload); - match frame.opcode { - O::Text => Self::text(payload), - O::Binary => Self::binary(payload), - O::Close => Self::close_raw(payload), - O::Ping => Self::new(true, OpCode::Ping, None, payload), - O::Pong => Self::pong(payload), - } - } -} - -impl From for crate::WispError { - fn from(err: WebSocketError) -> Self { - if let WebSocketError::ConnectionClosed = err { - Self::WsImplSocketClosed - } else { - Self::WsImplError(Box::new(err)) - } - } -} - -impl crate::ws::WebSocketRead for FragmentCollectorRead { - async fn wisp_read_frame( - &mut self, - tx: &dyn LockingWebSocketWrite, - ) -> Result, WispError> { - Ok(self - .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) - .await? - .into()) - } -} - -impl crate::ws::WebSocketRead for WebSocketRead { - async fn wisp_read_frame( - &mut self, - tx: &dyn LockingWebSocketWrite, - ) -> Result, WispError> { - let mut frame = self - .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) - .await?; - - if frame.opcode == OpCode::Continuation { - return Err(WispError::WsImplError(Box::new( - WebSocketError::InvalidContinuationFrame, - ))); - } - - let mut buf = payload_to_bytesmut(frame.payload); - let opcode = frame.opcode; - - while !frame.fin { - frame = self - .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) - .await?; - - if frame.opcode != OpCode::Continuation { - return Err(WispError::WsImplError(Box::new( - WebSocketError::InvalidContinuationFrame, - ))); - } - - buf.extend_from_slice(&frame.payload); - } - - Ok(crate::ws::Frame { - opcode: opcode.into(), - payload: crate::ws::Payload::Bytes(buf), - finished: frame.fin, - }) - } - - async fn wisp_read_split( - &mut self, - tx: &dyn LockingWebSocketWrite, - ) -> Result<(crate::ws::Frame<'static>, Option>), WispError> { - let mut frame_cnt = 1; - let mut frame = self - .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) - .await?; - let mut extra_frame = None; - - if frame.opcode == OpCode::Continuation { - return Err(WispError::WsImplError(Box::new( - WebSocketError::InvalidContinuationFrame, - ))); - } - - let mut buf = payload_to_bytesmut(frame.payload); - let opcode = frame.opcode; - - while !frame.fin { - frame = self - .read_frame(&mut |frame| async { tx.wisp_write_frame(frame.into()).await }) - .await?; - - if frame.opcode != OpCode::Continuation { - return Err(WispError::WsImplError(Box::new( - WebSocketError::InvalidContinuationFrame, - ))); - } - if frame_cnt == 1 { - let payload = payload_to_bytesmut(frame.payload); - extra_frame = Some(crate::ws::Frame { - opcode: opcode.into(), - payload: crate::ws::Payload::Bytes(payload), - finished: true, - }); - } else if frame_cnt == 2 { - let extra_payload = extra_frame.take().unwrap().payload; - buf.extend_from_slice(&extra_payload); - buf.extend_from_slice(&frame.payload); - } else { - buf.extend_from_slice(&frame.payload); - } - frame_cnt += 1; - } - - Ok(( - crate::ws::Frame { - opcode: opcode.into(), - payload: crate::ws::Payload::Bytes(buf), - finished: frame.fin, - }, - extra_frame, - )) - } -} - -impl crate::ws::WebSocketWrite for WebSocketWrite { - async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> { - self.write_frame(frame.into()).await.map_err(Into::into) - } - - async fn wisp_write_split( - &mut self, - header: crate::ws::Frame<'_>, - body: crate::ws::Frame<'_>, - ) -> Result<(), WispError> { - let mut header = Frame::from(header); - header.fin = false; - self.write_frame(header).await?; - - let mut body = Frame::from(body); - body.opcode = OpCode::Continuation; - self.write_frame(body).await?; - - Ok(()) - } - - async fn wisp_close(&mut self) -> Result<(), WispError> { - self.write_frame(Frame::close(CloseCode::Normal.into(), b"")) - .await - .map_err(Into::into) - } -} diff --git a/wisp/src/generic.rs b/wisp/src/generic.rs deleted file mode 100644 index 316ea70..0000000 --- a/wisp/src/generic.rs +++ /dev/null @@ -1,88 +0,0 @@ -//! `WebSocketRead` and `WebSocketWrite` implementation for generic `Stream`s and `Sink`s. - -use bytes::{Bytes, BytesMut}; -use futures::{Sink, SinkExt, Stream, StreamExt}; -use std::error::Error; - -use crate::{ - ws::{Frame, LockingWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite}, - WispError, -}; - -/// `WebSocketRead` implementation for generic `Stream`s. -pub struct GenericWebSocketRead< - T: Stream> + Send + Unpin, - E: Error + Sync + Send + 'static, ->(T); - -impl> + Send + Unpin, E: Error + Sync + Send + 'static> - GenericWebSocketRead -{ - /// Create a new wrapper `WebSocketRead` implementation. - pub fn new(stream: T) -> Self { - Self(stream) - } - - /// Get the inner `Stream` from the wrapper. - pub fn into_inner(self) -> T { - self.0 - } -} - -impl> + Send + Unpin, E: Error + Sync + Send + 'static> - WebSocketRead for GenericWebSocketRead -{ - async fn wisp_read_frame( - &mut self, - _tx: &dyn LockingWebSocketWrite, - ) -> Result, WispError> { - match self.0.next().await { - Some(data) => Ok(Frame::binary(Payload::Bytes( - data.map_err(|x| WispError::WsImplError(Box::new(x)))?, - ))), - None => Ok(Frame::close(Payload::Bytes(BytesMut::new()))), - } - } -} - -/// `WebSocketWrite` implementation for generic `Sink`s. -pub struct GenericWebSocketWrite< - T: Sink + Send + Unpin, - E: Error + Sync + Send + 'static, ->(T); - -impl + Send + Unpin, E: Error + Sync + Send + 'static> - GenericWebSocketWrite -{ - /// Create a new wrapper `WebSocketWrite` implementation. - pub fn new(stream: T) -> Self { - Self(stream) - } - - /// Get the inner `Sink` from the wrapper. - pub fn into_inner(self) -> T { - self.0 - } -} - -impl + Send + Unpin, E: Error + Sync + Send + 'static> WebSocketWrite - for GenericWebSocketWrite -{ - async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { - if frame.opcode == OpCode::Binary { - self.0 - .send(BytesMut::from(frame.payload).freeze()) - .await - .map_err(|x| WispError::WsImplError(Box::new(x))) - } else { - Ok(()) - } - } - - async fn wisp_close(&mut self) -> Result<(), WispError> { - self.0 - .close() - .await - .map_err(|x| WispError::WsImplError(Box::new(x))) - } -} diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index f48b557..64800a9 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,88 +1,51 @@ -#![cfg_attr(docsrs, feature(doc_cfg))] -#![deny(missing_docs, clippy::todo)] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] -//! A library for easily creating [Wisp] clients and servers. -//! -//! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol +use std::error::Error; + +use packet::WispVersion; +use thiserror::Error as ErrorDerive; pub mod extensions; -#[cfg(feature = "fastwebsockets")] -#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))] -mod fastwebsockets; -#[cfg(feature = "generic_stream")] -#[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))] -pub mod generic; +mod locked_sink; mod mux; -mod packet; -mod stream; +pub mod packet; +pub mod stream; pub mod ws; -pub use crate::{mux::*, packet::*, stream::*}; +pub use mux::*; -use thiserror::Error; +use locked_sink::LockedWebSocketWrite; +pub use locked_sink::{LockedSinkGuard, LockedWebSocketWriteGuard}; -/// Wisp version supported by this crate. pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 }; -/// The role of the multiplexor. -#[derive(Debug, PartialEq, Copy, Clone)] -pub enum Role { - /// Client side, can create new streams. - Client, - /// Server side, can listen for streams created by the client. - Server, -} - -/// Errors the Wisp implementation can return. -#[derive(Error, Debug)] +#[derive(Debug, ErrorDerive)] pub enum WispError { - /// The packet received did not have enough data. + /// Stream ID was invalid. + #[error("Invalid stream ID: {0}")] + InvalidStreamId(u32), + /// Packet type was invalid. + #[error("Invalid packet type: {0:#02X}")] + InvalidPacketType(u8), + /// Packet was too small. #[error("Packet too small")] PacketTooSmall, - /// The packet received had an invalid type. - #[error("Invalid packet type")] - InvalidPacketType, - /// The stream had an invalid ID. - #[error("Invalid steam ID")] - InvalidStreamId, - /// The close packet had an invalid reason. - #[error("Invalid close reason")] - InvalidCloseReason, - /// The max stream count was reached. - #[error("Maximum stream count reached")] - MaxStreamCountReached, /// The Wisp protocol version was incompatible. #[error("Incompatible Wisp protocol version: found {0} but needed {1}")] IncompatibleProtocolVersion(WispVersion, WispVersion), - /// The stream had already been closed. + + /// The stream was closed already. #[error("Stream already closed")] StreamAlreadyClosed, + /// The max stream count was reached. + #[error("Maximum stream count reached")] + MaxStreamCountReached, - /// The websocket frame received had an invalid type. - #[error("Invalid websocket frame type: {0:?}")] - WsFrameInvalidType(ws::OpCode), - /// The websocket frame received was not finished. - #[error("Unfinished websocket frame")] - WsFrameNotFinished, - /// Error specific to the websocket implementation. - #[error("Websocket implementation error: {0:?}")] - WsImplError(Box), - /// The websocket implementation socket closed. - #[error("Websocket implementation error: socket closed")] - WsImplSocketClosed, - /// The websocket implementation did not support the action. - #[error("Websocket implementation error: not supported")] - WsImplNotSupported, - - /// The string was invalid UTF-8. - #[error("UTF-8 error: {0}")] - Utf8Error(#[from] std::str::Utf8Error), - /// The integer failed to convert. + /// Failed to parse bytes as UTF-8. + #[error("Invalid UTF-8: {0}")] + Utf8(#[from] std::str::Utf8Error), #[error("Integer conversion error: {0}")] TryFromIntError(#[from] std::num::TryFromIntError), - /// Other error. - #[error("Other: {0:?}")] - Other(Box), /// Failed to send message to multiplexor task. #[error("Failed to send multiplexor message")] @@ -97,6 +60,13 @@ pub enum WispError { #[error("Multiplexor task already started")] MuxTaskStarted, + /// Error specific to the websocket implementation. + #[error("Websocket implementation error: {0}")] + WsImplError(Box), + /// Websocket implementation: websocket closed + #[error("Websocket implementation error: websocket closed")] + WsImplSocketClosed, + /// Error specific to the protocol extension implementation. #[error("Protocol extension implementation error: {0:?}")] ExtensionImplError(Box), @@ -124,3 +94,15 @@ pub enum WispError { #[error("Password protocol extension: No signing key provided")] CertAuthExtensionNoKey, } + +impl From for WispError { + fn from(value: std::string::FromUtf8Error) -> Self { + Self::Utf8(value.utf8_error()) + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum Role { + Server, + Client, +} diff --git a/wisp/src/locked_sink.rs b/wisp/src/locked_sink.rs new file mode 100644 index 0000000..ce27528 --- /dev/null +++ b/wisp/src/locked_sink.rs @@ -0,0 +1,330 @@ +//! unfair async mutex that doesn't have guards by default + +use std::{ + cell::UnsafeCell, + future::poll_fn, + marker::PhantomData, + ops::{Deref, DerefMut}, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, MutexGuard, + }, + task::{Context, Poll, Waker}, +}; + +use futures::Sink; +use slab::Slab; + +use crate::ws::{Payload, WebSocketWrite}; + +// it would be nice to have type_alias_bounds but oh well +#[expect(type_alias_bounds)] +pub(crate) type LockedWebSocketWrite = LockedSink; +#[expect(type_alias_bounds)] +pub type LockedWebSocketWriteGuard = LockedSinkGuard; + +pub(crate) enum Waiter { + Sleeping(Waker), + Woken, +} + +impl Waiter { + pub fn new(cx: &mut Context<'_>) -> Self { + Self::Sleeping(cx.waker().clone()) + } + + pub fn register(&mut self, cx: &mut Context<'_>) { + match self { + Self::Sleeping(x) => x.clone_from(cx.waker()), + Self::Woken => *self = Self::Sleeping(cx.waker().clone()), + } + } + + pub fn wake(&mut self) -> Option { + match std::mem::replace(self, Self::Woken) { + Self::Sleeping(x) => Some(x), + Self::Woken => None, + } + } +} + +struct WakerList { + inner: Slab, +} + +impl WakerList { + pub fn new() -> Self { + Self { inner: Slab::new() } + } + + pub fn add(&mut self, cx: &mut Context<'_>) -> usize { + self.inner.insert(Waiter::new(cx)) + } + + pub fn update(&mut self, key: usize, cx: &mut Context<'_>) { + self.inner + .get_mut(key) + .expect("task should never have invalid key") + .register(cx); + } + + pub fn remove(&mut self, key: usize) { + self.inner.remove(key); + } + + pub fn get_next(&mut self) -> Option { + self.inner.iter_mut().find_map(|x| x.1.wake()) + } +} + +enum LockStatus { + /// was locked, you are now in the list + Joined(usize), + /// was locked, you were already in the list + Waiting, + /// was unlocked, lock is yours now + Unlocked, +} + +struct SinkState, I> { + sink: UnsafeCell, + locked: AtomicBool, + waiters: Mutex, + + phantom: PhantomData, +} + +unsafe impl + Send, I> Send for SinkState {} +unsafe impl, I> Sync for SinkState {} + +impl, I> SinkState { + pub fn new(sink: S) -> Self { + Self { + sink: UnsafeCell::new(sink), + locked: AtomicBool::new(false), + waiters: Mutex::new(WakerList::new()), + + phantom: PhantomData, + } + } + + fn lock_waiters(&self) -> MutexGuard<'_, WakerList> { + self.waiters.lock().expect("waiters mutex was poisoned") + } + + /// caller must make sure they are the ones locking the sink + #[expect(clippy::mut_from_ref)] + pub unsafe fn get_unpin(&self) -> &mut S { + // SAFETY: we are locked + unsafe { &mut *self.sink.get() } + } + + /// caller must make sure they are the ones locking the sink + pub unsafe fn get(&self) -> Pin<&mut S> { + // SAFETY: we are locked + let inner = unsafe { self.get_unpin() }; + // SAFETY: we never touch the UnsafeCell + unsafe { Pin::new_unchecked(inner) } + } + + pub fn lock(&self, key: Option, cx: &mut Context<'_>) -> LockStatus { + let old_state = self.locked.swap(true, Ordering::AcqRel); + match (key, old_state) { + (Some(key), true) => { + self.lock_waiters().update(key, cx); + LockStatus::Waiting + } + (None, true) => { + let pos = self.lock_waiters().add(cx); + LockStatus::Joined(pos) + } + (_, false) => LockStatus::Unlocked, + } + } + + pub fn unlock(&self) { + let mut locked = self.lock_waiters(); + self.locked.store(false, Ordering::Release); + if let Some(next) = locked.get_next() { + drop(locked); + + next.wake(); + } + } + + pub fn remove(&self, key: usize) { + self.lock_waiters().remove(key); + } +} + +pub(crate) struct LockedSink, I> { + inner: Arc>, + + pos: Option, + locked: bool, +} + +impl, I> Clone for LockedSink { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + + pos: None, + locked: false, + } + } +} + +impl, I> Drop for LockedSink { + fn drop(&mut self) { + self.unlock(); + } +} + +impl, I> LockedSink { + pub fn new(sink: S) -> Self { + Self { + inner: Arc::new(SinkState::new(sink)), + + pos: None, + locked: false, + } + } + + pub fn poll_lock(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if self.locked { + Poll::Ready(()) + } else { + match self.inner.lock(self.pos, cx) { + LockStatus::Joined(pos) => { + self.pos = Some(pos); + + // make sure we haven't raced an unlock + if matches!(self.inner.lock(self.pos, cx), LockStatus::Unlocked) { + if let Some(pos) = self.pos.take() { + self.inner.remove(pos); + } + self.locked = true; + return Poll::Ready(()); + } + + Poll::Pending + } + LockStatus::Waiting => { + // make sure we haven't raced an unlock + if matches!(self.inner.lock(self.pos, cx), LockStatus::Unlocked) { + if let Some(pos) = self.pos.take() { + self.inner.remove(pos); + } + self.locked = true; + return Poll::Ready(()); + } + + Poll::Pending + } + LockStatus::Unlocked => { + if let Some(pos) = self.pos.take() { + self.inner.remove(pos); + } + self.locked = true; + Poll::Ready(()) + } + } + } + } + pub async fn lock(&mut self) { + poll_fn(|cx| self.poll_lock(cx)).await; + } + + pub fn unlock(&mut self) { + if self.locked { + self.locked = false; + self.inner.unlock(); + } + } + + pub fn get(&self) -> Pin<&mut S> { + debug_assert!(self.locked); + // SAFETY: we are locked + unsafe { self.inner.get() } + } + pub fn get_handle(&mut self) -> LockedSinkHandle { + debug_assert!(self.locked); + self.locked = false; + + LockedSinkHandle { + inner: self.inner.clone(), + } + } + pub fn get_guard(&mut self) -> LockedSinkGuard { + debug_assert!(self.locked); + self.locked = false; + + LockedSinkGuard { + inner: self.inner.clone(), + } + } +} + +// always locked sink "guard" of lockedsink +pub(crate) struct LockedSinkHandle, I> { + inner: Arc>, +} + +impl, I> Sink for LockedSinkHandle { + type Error = S::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unsafe { self.inner.get() }.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + unsafe { self.inner.get() }.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unsafe { self.inner.get() }.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + unsafe { self.inner.get() }.poll_close(cx) + } +} + +impl, I> Drop for LockedSinkHandle { + fn drop(&mut self) { + self.inner.unlock(); + } +} + +// always locked "guard" of lockedsink +pub struct LockedSinkGuard, I> { + inner: Arc>, +} + +impl, I> Deref for LockedSinkGuard { + type Target = S; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.inner.get_unpin() } + } +} + +impl + Unpin, I> DerefMut for LockedSinkGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { self.inner.get_unpin() } + } +} + +impl, I> LockedSinkGuard { + pub fn deref_pin(&mut self) -> Pin<&mut S> { + unsafe { self.inner.get() } + } +} + +impl, I> Drop for LockedSinkGuard { + fn drop(&mut self) { + self.inner.unlock(); + } +} diff --git a/wisp/src/mux/client.rs b/wisp/src/mux/client.rs index 49d7be1..a298551 100644 --- a/wisp/src/mux/client.rs +++ b/wisp/src/mux/client.rs @@ -1,159 +1,157 @@ -use std::{ - future::Future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; - -use flume as mpsc; use futures::channel::oneshot; use crate::{ - extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, + extensions::udp::UdpProtocolExtension, mux::send_info_packet, - ws::{DynWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, - CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType, - WispError, + packet::{ConnectPacket, ContinuePacket, MaybeInfoPacket, Packet, StreamType}, + stream::MuxStream, + ws::{WebSocketRead, WebSocketReadExt, WebSocketWrite}, + LockedWebSocketWrite, Role, WispError, }; use super::{ - get_supported_extensions, - inner::{MuxInner, WsEvent}, - validate_continue_packet, Multiplexor, MuxResult, WispHandshakeResult, WispHandshakeResultKind, - WispV2Handshake, + get_supported_extensions, handle_handshake, + inner::{FlowControl, MultiplexorActor, StreamMap, WsEvent}, + validate_continue_packet, Multiplexor, MultiplexorImpl, MuxResult, WispHandshakeResult, + WispHandshakeResultKind, WispV2Handshake, }; -async fn handshake( - rx: &mut R, - tx: &LockedWebSocketWrite, - v2_info: Option, -) -> Result<(WispHandshakeResult, u32), WispError> { - if let Some(WispV2Handshake { - mut builders, - closure, - }) = v2_info - { - let packet = - Packet::maybe_parse_info(rx.wisp_read_frame(tx).await?, Role::Client, &mut builders)?; +pub(crate) struct ClientActor; - if let PacketType::Info(info) = packet.packet_type { - // v2 server - let buffer_size = validate_continue_packet(&rx.wisp_read_frame(tx).await?.try_into()?)?; +impl MultiplexorActor for ClientActor { + fn handle_connect_packet( + &mut self, + _: crate::stream::MuxStream, + _: crate::packet::ConnectPacket, + ) -> Result<(), WispError> { + Err(WispError::InvalidPacketType(0x01)) + } - (closure)(&mut builders).await?; - send_info_packet(tx, &mut builders).await?; - - let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders); - - for extension in &mut supported_extensions { - extension - .handle_handshake(DynWebSocketRead::from_mut(rx), tx) - .await?; + fn handle_continue_packet( + &mut self, + id: u32, + pkt: ContinuePacket, + streams: &mut StreamMap, + ) -> Result<(), WispError> { + if let Some(stream) = streams.get(&id) { + if stream.info.flow_status == FlowControl::EnabledTrackAmount { + stream.info.flow_set(pkt.buffer_remaining); + stream.info.flow_wake(); } - - Ok(( - WispHandshakeResult { - kind: WispHandshakeResultKind::V2 { - extensions: supported_extensions, - }, - downgraded: false, - }, - buffer_size, - )) - } else { - // downgrade to v1 - let buffer_size = validate_continue_packet(&packet)?; - - Ok(( - WispHandshakeResult { - kind: WispHandshakeResultKind::V1 { frame: None }, - downgraded: true, - }, - buffer_size, - )) } - } else { - // user asked for a v1 client - let buffer_size = validate_continue_packet(&rx.wisp_read_frame(tx).await?.try_into()?)?; - Ok(( - WispHandshakeResult { - kind: WispHandshakeResultKind::V1 { frame: None }, - downgraded: false, - }, - buffer_size, - )) + Ok(()) + } + + fn get_flow_control(ty: StreamType, flow_stream_types: &[u8]) -> FlowControl { + if flow_stream_types.contains(&ty.into()) { + FlowControl::EnabledTrackAmount + } else { + FlowControl::Disabled + } } } -/// Client side multiplexor. -pub struct ClientMux { - /// Whether the connection was downgraded to Wisp v1. - /// - /// If this variable is true you must assume no extensions are supported. - pub downgraded: bool, - /// Extensions that are supported by both sides. - pub supported_extensions: Vec, - actor_tx: mpsc::Sender>, - tx: LockedWebSocketWrite, - actor_exited: Arc, +pub struct ClientImpl; + +impl MultiplexorImpl for ClientImpl { + type Actor = ClientActor; + + async fn handshake( + &mut self, + rx: &mut R, + tx: &mut LockedWebSocketWrite, + v2: Option, + ) -> Result { + if let Some(WispV2Handshake { + mut builders, + closure, + }) = v2 + { + let packet = + MaybeInfoPacket::decode(rx.next_erroring().await?, &mut builders, Role::Client)?; + + match packet { + MaybeInfoPacket::Info(info) => { + // v2 server + let buffer_size = + validate_continue_packet(&Packet::decode(rx.next_erroring().await?)?)?; + + (closure)(&mut builders).await?; + send_info_packet(tx, &mut builders, Role::Client).await?; + + let mut supported_extensions = + get_supported_extensions(info.extensions, &mut builders); + + handle_handshake(rx, tx, &mut supported_extensions).await?; + + Ok(WispHandshakeResult { + kind: WispHandshakeResultKind::V2 { + extensions: supported_extensions, + }, + downgraded: false, + buffer_size, + }) + } + MaybeInfoPacket::Packet(packet) => { + // downgrade to v1 + let buffer_size = validate_continue_packet(&packet)?; + + Ok(WispHandshakeResult { + kind: WispHandshakeResultKind::V1 { packet: None }, + downgraded: true, + buffer_size, + }) + } + } + } else { + // user asked for a v1 client + let buffer_size = + validate_continue_packet(&Packet::decode(rx.next_erroring().await?)?)?; + + Ok(WispHandshakeResult { + kind: WispHandshakeResultKind::V1 { packet: None }, + downgraded: false, + buffer_size, + }) + } + } + + async fn handle_error( + &mut self, + err: WispError, + _: &mut LockedWebSocketWrite, + ) -> Result { + Ok(err) + } } -impl ClientMux { +impl Multiplexor { /// Create a new client side multiplexor. /// - /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// If `wisp_v2` is None a Wisp v1 connection is created, otherwise a Wisp v2 connection is created. /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. - pub async fn create( - mut rx: R, + #[expect(clippy::new_ret_no_self)] + pub async fn new( + rx: R, tx: W, wisp_v2: Option, - ) -> Result< - MuxResult, impl Future> + Send>, - WispError, - > - where - R: WebSocketRead + 'static, - { - let tx = LockedWebSocketWrite::new(tx); - - let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?; - let (extensions, extra_packet) = handshake_result.kind.into_parts(); - - let mux_inner = MuxInner::new_client( - rx, - extra_packet, - tx.clone(), - extensions.clone(), - buffer_size, - ); - - Ok(MuxResult( - Self { - actor_tx: mux_inner.actor_tx, - actor_exited: mux_inner.actor_exited, - - tx, - - downgraded: handshake_result.downgraded, - supported_extensions: extensions, - }, - mux_inner.mux.into_future(), - )) + ) -> Result, WispError> { + Self::create(rx, tx, wisp_v2, ClientImpl, ClientActor).await } /// Create a new stream, multiplexed through Wisp. - pub async fn client_new_stream( + pub async fn new_stream( &self, stream_type: StreamType, host: String, port: u16, ) -> Result, WispError> { - if self.actor_exited.load(Ordering::Acquire) { + if self.actor_tx.is_disconnected() { return Err(WispError::MuxTaskEnded); } + if stream_type == StreamType::Udp && !self .supported_extensions @@ -164,74 +162,19 @@ impl ClientMux { UdpProtocolExtension::ID, ])); } + let (tx, rx) = oneshot::channel(); self.actor_tx - .send_async(WsEvent::CreateStream(stream_type, host, port, tx)) + .send_async(WsEvent::CreateStream( + ConnectPacket { + stream_type, + host, + port, + }, + tx, + )) .await .map_err(|_| WispError::MuxMessageFailedToSend)?; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? } - - /// Send a ping to the server. - pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> { - if self.actor_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - let (tx, rx) = oneshot::channel(); - self.actor_tx - .send_async(WsEvent::SendPing(payload, tx)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? - } - - async fn close_internal(&self, reason: Option) -> Result<(), WispError> { - if self.actor_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - self.actor_tx - .send_async(WsEvent::EndFut(reason)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend) - } - - /// Close all streams. - /// - /// Also terminates the multiplexor future. - pub async fn close(&self) -> Result<(), WispError> { - self.close_internal(None).await - } - - /// Close all streams and send a close reason on stream ID 0. - /// - /// Also terminates the multiplexor future. - pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> { - self.close_internal(Some(reason)).await - } - - /// Get a protocol extension stream for sending packets with stream id 0. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - MuxProtocolExtensionStream { - stream_id: 0, - tx: self.tx.clone(), - is_closed: self.actor_exited.clone(), - } - } -} - -impl Drop for ClientMux { - fn drop(&mut self) { - let _ = self.actor_tx.send(WsEvent::EndFut(None)); - } -} - -impl Multiplexor for ClientMux { - fn has_extension(&self, extension_id: u8) -> bool { - self.supported_extensions - .iter() - .any(|x| x.get_id() == extension_id) - } - async fn exit(&self, reason: CloseReason) -> Result<(), WispError> { - self.close_with_reason(reason).await - } } diff --git a/wisp/src/mux/inner.rs b/wisp/src/mux/inner.rs index 6a6a794..43e679c 100644 --- a/wisp/src/mux/inner.rs +++ b/wisp/src/mux/inner.rs @@ -1,493 +1,427 @@ -use std::{collections::HashMap, sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, -}}; +use std::{ + pin::pin, + sync::{ + atomic::{AtomicU32, AtomicU8, Ordering}, + Arc, Mutex, + }, + task::Context, +}; + +use futures::{ + channel::oneshot, + stream::{select, unfold}, + SinkExt, StreamExt, +}; +use rustc_hash::FxHashMap; use crate::{ extensions::AnyProtocolExtension, - ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite}, - AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType, - Role, StreamType, WispError, + locked_sink::Waiter, + packet::{ + ClosePacket, CloseReason, ConnectPacket, ContinuePacket, MaybeExtensionPacket, Packet, + PacketType, StreamType, + }, + stream::MuxStream, + ws::{Payload, WebSocketRead, WebSocketWrite}, + LockedWebSocketWrite, WispError, }; -use bytes::BytesMut; -use event_listener::Event; -use flume as mpsc; -use futures::{channel::oneshot, select, stream::unfold, FutureExt, StreamExt}; -use rustc_hash::FxHashMap; -pub(crate) enum WsEvent { - Close(Packet<'static>, oneshot::Sender>), +pub(crate) enum WsEvent { + Close(u32, ClosePacket, oneshot::Sender>), CreateStream( - StreamType, - String, - u16, + ConnectPacket, oneshot::Sender, WispError>>, ), - SendPing(Payload<'static>, oneshot::Sender>), - SendPong(Payload<'static>), - WispMessage(Option>, Option>), + WispMessage(Packet<'static>), EndFut(Option), - Noop, } -struct MuxMapValue { - stream: mpsc::Sender>, - stream_type: StreamType, +pub(crate) type StreamMap = FxHashMap; - should_flow_control: bool, - flow_control: Arc, - flow_control_event: Arc, - - is_closed: Arc, - close_reason: Arc, - is_closed_event: Arc, +pub(crate) struct StreamMapValue { + pub info: Arc, + pub stream: flume::Sender, } -pub(crate) struct MuxInner { - // gets taken by the mux task - rx: Option, - // gets taken by the mux task - maybe_downgrade_packet: Option>, +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub(crate) enum FlowControl { + /// flow control completely disabled + Disabled, + /// flow control enabled + /// - incoming: do not send buffer updates and no buffer + /// - outgoing: track sent amount and wait + EnabledTrackAmount, + /// flow control enabled + /// - incoming: send buffer updates and force buffer + /// - outgoing: do not track sent amount and do not wait + EnabledSendMessages, +} +pub(crate) struct StreamInfo { + pub id: u32, + + pub flow_status: FlowControl, + pub target_flow_control: u32, + flow_control: AtomicU32, + close_reason: AtomicU8, + flow_waker: Mutex, +} + +impl StreamInfo { + pub fn new(id: u32, flow_status: FlowControl, buffer_size: u32) -> Self { + debug_assert_ne!(id, 0); + + // 90% + #[expect(clippy::cast_possible_truncation)] + let target = ((u64::from(buffer_size) * 90) / 100) as u32; + + Self { + id, + + flow_status, + target_flow_control: target, + flow_control: AtomicU32::new(buffer_size), + flow_waker: Mutex::new(Waiter::Woken), + close_reason: AtomicU8::new(CloseReason::Unknown.into()), + } + } + + pub fn flow_set(&self, amt: u32) { + self.flow_control.store(amt, Ordering::Relaxed); + } + pub fn flow_add(&self, amt: u32) -> u32 { + let new = self + .flow_control + .load(Ordering::Relaxed) + .saturating_add(amt); + self.flow_control.store(new, Ordering::Relaxed); + new + } + pub fn flow_sub(&self, amt: u32) -> u32 { + let new = self + .flow_control + .load(Ordering::Relaxed) + .saturating_sub(amt); + self.flow_control.store(new, Ordering::Relaxed); + new + } + pub fn flow_dec(&self) { + self.flow_sub(1); + } + pub fn flow_empty(&self) -> bool { + self.flow_control.load(Ordering::Relaxed) == 0 + } + + pub fn flow_register(&self, cx: &mut Context<'_>) { + self.flow_waker + .lock() + .expect("flow_waker was poisoned") + .register(cx); + } + pub fn flow_wake(&self) { + let mut waiter = self.flow_waker.lock().expect("flow_waker was poisoned"); + if let Some(waker) = waiter.wake() { + drop(waiter); + + waker.wake(); + } + } + + pub fn get_reason(&self) -> CloseReason { + self.close_reason.load(Ordering::Relaxed).into() + } + pub fn set_reason(&self, reason: CloseReason) { + self.close_reason.store(reason.into(), Ordering::Relaxed); + } +} + +pub(crate) trait MultiplexorActor: Send { + fn handle_connect_packet( + &mut self, + stream: MuxStream, + pkt: ConnectPacket, + ) -> Result<(), WispError>; + + fn handle_data_packet( + &mut self, + id: u32, + pkt: Payload, + streams: &mut StreamMap, + ) -> Result<(), WispError> { + if let Some(stream) = streams.get(&id) { + let _ = stream.stream.try_send(pkt); + } + Ok(()) + } + + fn handle_continue_packet( + &mut self, + id: u32, + pkt: ContinuePacket, + streams: &mut StreamMap, + ) -> Result<(), WispError>; + + fn get_flow_control(ty: StreamType, flow_stream_types: &[u8]) -> FlowControl; +} + +struct MuxStart { + rx: R, + downgrade: Option>, + extensions: Vec, + actor_rx: flume::Receiver>, +} + +pub(crate) struct MuxInner> { + start: Option>, tx: LockedWebSocketWrite, - // gets taken by the mux task - extensions: Option>, - tcp_extensions: Vec, - role: Role, + flow_stream_types: Box<[u8]>, - // gets taken by the mux task - actor_rx: Option>>, - actor_tx: mpsc::Sender>, - fut_exited: Arc, - - stream_map: FxHashMap, + mux: M, + streams: StreamMap, + current_id: u32, buffer_size: u32, - target_buffer_size: u32, - server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>, + actor_tx: flume::Sender>, } -pub(crate) struct MuxInnerResult { - pub mux: MuxInner, - pub actor_exited: Arc, - pub actor_tx: mpsc::Sender>, +pub(crate) struct MuxInnerResult> { + pub mux: MuxInner, + pub actor_tx: flume::Sender>, } -impl MuxInner { - fn get_tcp_extensions(extensions: &[AnyProtocolExtension]) -> Vec { - extensions +impl> MuxInner { + #[expect(clippy::new_ret_no_self)] + pub fn new( + rx: R, + tx: LockedWebSocketWrite, + mux: M, + downgrade: Option>, + extensions: Vec, + buffer_size: u32, + ) -> MuxInnerResult { + let (actor_tx, actor_rx) = flume::unbounded(); + + let flow_extensions = extensions .iter() .flat_map(|x| x.get_congestion_stream_types()) .copied() .chain(std::iter::once(StreamType::Tcp.into())) - .collect() - } - - #[expect(clippy::type_complexity)] - pub fn new_server( - rx: R, - maybe_downgrade_packet: Option>, - tx: LockedWebSocketWrite, - extensions: Vec, - buffer_size: u32, - ) -> ( - MuxInnerResult, - mpsc::Receiver<(ConnectPacket, MuxStream)>, - ) { - let (fut_tx, fut_rx) = mpsc::bounded::>(256); - let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); - let ret_fut_tx = fut_tx.clone(); - let fut_exited = Arc::new(AtomicBool::new(false)); - - // 90% of the buffer size, not possible to overflow - #[expect(clippy::cast_possible_truncation)] - let target_buffer_size = ((u64::from(buffer_size) * 90) / 100) as u32; - - ( - MuxInnerResult { - mux: Self { - rx: Some(rx), - maybe_downgrade_packet, - tx, - - actor_rx: Some(fut_rx), - actor_tx: fut_tx, - fut_exited: fut_exited.clone(), - - tcp_extensions: Self::get_tcp_extensions(&extensions), - extensions: Some(extensions), - buffer_size, - target_buffer_size, - - role: Role::Server, - - stream_map: HashMap::default(), - - server_tx, - }, - actor_exited: fut_exited, - actor_tx: ret_fut_tx, - }, - server_rx, - ) - } - - pub fn new_client( - rx: R, - maybe_downgrade_packet: Option>, - tx: LockedWebSocketWrite, - extensions: Vec, - buffer_size: u32, - ) -> MuxInnerResult { - let (fut_tx, fut_rx) = mpsc::bounded::>(256); - let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); - let ret_fut_tx = fut_tx.clone(); - let fut_exited = Arc::new(AtomicBool::new(false)); + .collect(); MuxInnerResult { + actor_tx: actor_tx.clone(), mux: Self { - rx: Some(rx), - maybe_downgrade_packet, + start: Some(MuxStart { + rx, + downgrade, + extensions, + actor_rx, + }), tx, + flow_stream_types: flow_extensions, - actor_rx: Some(fut_rx), - actor_tx: fut_tx, - fut_exited: fut_exited.clone(), + mux, - tcp_extensions: Self::get_tcp_extensions(&extensions), - extensions: Some(extensions), + streams: StreamMap::default(), + current_id: 0, buffer_size, - target_buffer_size: 0, - role: Role::Client, - - stream_map: HashMap::default(), - - server_tx, + actor_tx, }, - actor_exited: fut_exited, - actor_tx: ret_fut_tx, } } pub async fn into_future(mut self) -> Result<(), WispError> { - let ret = self.stream_loop().await; + let ret = self.entry().await; - self.fut_exited.store(true, Ordering::Release); - - for stream in self.stream_map.values() { - Self::close_stream(stream, ClosePacket::new(CloseReason::Unknown)); + for stream in self.streams.drain() { + Self::close_stream( + stream.1, + ClosePacket { + reason: CloseReason::Unknown, + }, + ); } - self.stream_map.clear(); - let _ = self.tx.close().await; + self.tx.lock().await; + let _ = self.tx.get().close().await; + self.tx.unlock(); + ret } - fn create_new_stream( - &mut self, - stream_id: u32, - stream_type: StreamType, - ) -> (MuxMapValue, MuxStream) { - let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server { - self.buffer_size as usize - } else { - usize::MAX - 8 - }); + async fn entry(&mut self) -> Result<(), WispError> { + let MuxStart { + rx, + downgrade, + extensions, + actor_rx, + } = self.start.take().ok_or(WispError::MuxTaskStarted)?; - let should_flow_control = self.tcp_extensions.contains(&stream_type.into()); - let flow_control_event: Arc = Event::new().into(); - let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); - - let is_closed: Arc = AtomicBool::new(false).into(); - let close_reason: Arc = - AtomicCloseReason::new(CloseReason::Unknown).into(); - let is_closed_event: Arc = Event::new().into(); - - ( - MuxMapValue { - stream: ch_tx, - stream_type, - - should_flow_control, - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - - is_closed: is_closed.clone(), - close_reason: close_reason.clone(), - is_closed_event: is_closed_event.clone(), - }, - MuxStream::new( - stream_id, - self.role, - stream_type, - ch_rx, - self.actor_tx.clone(), - self.tx.clone(), - is_closed, - is_closed_event, - close_reason, - should_flow_control, - flow_control, - flow_control_event, - self.target_buffer_size, - ), - ) - } - - fn close_stream(stream: &MuxMapValue, close_packet: ClosePacket) { - stream - .close_reason - .store(close_packet.reason, Ordering::Release); - stream.is_closed.store(true, Ordering::Release); - stream.is_closed_event.notify(usize::MAX); - stream.flow_control.store(u32::MAX, Ordering::Release); - stream.flow_control_event.notify(usize::MAX); - } - - async fn process_wisp_message( - rx: &mut R, - tx: &LockedWebSocketWrite, - extensions: &mut [AnyProtocolExtension], - msg: (Frame<'static>, Option>), - ) -> Result>, WispError> { - let (mut frame, optional_frame) = msg; - if frame.opcode == OpCode::Close { - return Ok(None); - } else if frame.opcode == OpCode::Ping { - return Ok(Some(WsEvent::SendPong(frame.payload))); - } else if frame.opcode == OpCode::Pong { - return Ok(Some(WsEvent::Noop)); - } - - if let Some(ref extra_frame) = optional_frame { - if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() { - let mut payload = BytesMut::from(frame.payload); - payload.extend_from_slice(&extra_frame.payload); - frame.payload = Payload::Bytes(payload); - } - } - - let packet = Packet::maybe_handle_extension(frame, extensions, rx, tx).await?; - - Ok(Some(WsEvent::WispMessage(packet, optional_frame))) - } - - async fn stream_loop(&mut self) -> Result<(), WispError> { - let mut next_free_stream_id: u32 = 1; - - let rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?; - let maybe_downgrade_packet = self.maybe_downgrade_packet.take(); - - let tx = self.tx.clone(); - let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?; - - let extensions = self.extensions.take().ok_or(WispError::MuxTaskStarted)?; - - if let Some(downgrade_packet) = maybe_downgrade_packet { - if self.handle_packet(downgrade_packet, None).await? { + if let Some(packet) = downgrade { + if self.handle_packet(packet)? { return Ok(()); } } - let mut read_stream = Box::pin(unfold( - (rx, tx.clone(), extensions), - |(mut rx, tx, mut extensions)| async { - let ret = async { - let msg = rx.wisp_read_split(&tx).await?; - Self::process_wisp_message(&mut rx, &tx, &mut extensions, msg).await + let read_stream = pin!(unfold( + (rx, self.tx.clone(), extensions), + |(mut rx, mut tx, mut extensions)| async { + let ret: Result>, WispError> = async { + if let Some(msg) = rx.next().await { + match MaybeExtensionPacket::decode(msg?, &mut extensions, &mut rx, &mut tx) + .await? + { + MaybeExtensionPacket::Packet(x) => Ok(Some(WsEvent::WispMessage(x))), + MaybeExtensionPacket::ExtensionHandled => Ok(None), + } + } else { + Ok(None) + } } .await; ret.transpose().map(|x| (x, (rx, tx, extensions))) }, - )) - .fuse(); + )); - let mut recv_fut = fut_rx.recv_async().fuse(); - while let Some(msg) = select! { - x = recv_fut => { - drop(recv_fut); - recv_fut = fut_rx.recv_async().fuse(); - Ok(x.ok()) - }, - x = read_stream.next() => { - x.transpose() - } - }? { - match msg { - WsEvent::CreateStream(stream_type, host, port, channel) => { + let mut stream = select(read_stream, actor_rx.into_stream().map(Ok)); + + while let Some(msg) = stream.next().await { + match msg? { + WsEvent::CreateStream(connect, channel) => { let ret: Result, WispError> = async { - let stream_id = next_free_stream_id; - let next_stream_id = next_free_stream_id - .checked_add(1) - .ok_or(WispError::MaxStreamCountReached)?; - - let (map_value, stream) = self.create_new_stream(stream_id, stream_type); + let (stream, stream_id) = self.create_stream(connect.stream_type)?; + self.tx.lock().await; self.tx - .write_frame( - Packet::new_connect(stream_id, stream_type, port, host).into(), + .get() + .send( + Packet { + stream_id, + packet_type: PacketType::Connect(connect), + } + .encode(), ) .await?; - - self.stream_map.insert(stream_id, map_value); - - next_free_stream_id = next_stream_id; + self.tx.unlock(); Ok(stream) } .await; let _ = channel.send(ret); } - WsEvent::Close(packet, channel) => { - if let Some(stream) = self.stream_map.remove(&packet.stream_id) { - if let PacketType::Close(close) = packet.packet_type { - Self::close_stream(&stream, close); + WsEvent::Close(id, close, channel) => { + if let Some(stream) = self.streams.remove(&id) { + Self::close_stream(stream, close); + let pkt = Packet { + stream_id: id, + packet_type: PacketType::Close(close), } - let _ = channel.send(self.tx.write_frame(packet.into()).await); + .encode(); + + self.tx.lock().await; + let ret = self.tx.get().send(pkt).await; + self.tx.unlock(); + + let _ = channel.send(ret); } else { - let _ = channel.send(Err(WispError::InvalidStreamId)); + let _ = channel.send(Err(WispError::InvalidStreamId(id))); } } - WsEvent::SendPing(payload, channel) => { - let _ = channel.send( - self.tx - .write_frame(Frame::new(OpCode::Ping, payload, true)) - .await, - ); - } - WsEvent::SendPong(payload) => { - self.tx - .write_frame(Frame::new(OpCode::Pong, payload, true)) - .await?; - } WsEvent::EndFut(x) => { if let Some(reason) = x { + self.tx.lock().await; let _ = self .tx - .write_frame(Packet::new_close(0, reason).into()) + .get() + .send(Packet::new_close(0, reason).encode()) .await; + self.tx.unlock(); } break; } - WsEvent::WispMessage(packet, optional_frame) => { - if let Some(packet) = packet { - let should_break = self.handle_packet(packet, optional_frame).await?; - if should_break { - break; - } + WsEvent::WispMessage(packet) => { + let should_break = self.handle_packet(packet)?; + if should_break { + break; } } - WsEvent::Noop => {} } } Ok(()) } - fn handle_close_packet(&mut self, stream_id: u32, inner_packet: ClosePacket) -> bool { + fn create_stream(&mut self, ty: StreamType) -> Result<(MuxStream, u32), WispError> { + let id = self + .current_id + .checked_add(1) + .ok_or(WispError::MaxStreamCountReached)?; + self.current_id = id; + Ok((self.add_stream(id, ty), id)) + } + + fn add_stream(&mut self, id: u32, ty: StreamType) -> MuxStream { + let flow = M::get_flow_control(ty, &self.flow_stream_types); + let (data_tx, data_rx) = if flow == FlowControl::EnabledSendMessages { + flume::bounded(self.buffer_size as usize) + } else { + flume::unbounded() + }; + + let info = Arc::new(StreamInfo::new(id, flow, self.buffer_size)); + let val = StreamMapValue { + info: info.clone(), + stream: data_tx, + }; + self.streams.insert(id, val); + + MuxStream::new(data_rx, self.actor_tx.clone(), self.tx.clone(), info) + } + + fn close_stream(stream: StreamMapValue, close: ClosePacket) { + drop(stream.stream); + stream.info.set_reason(close.reason); + } + + fn handle_packet(&mut self, packet: Packet<'static>) -> Result { + use PacketType as P; + match packet.packet_type { + P::Connect(connect) => { + let stream = self.add_stream(packet.stream_id, connect.stream_type); + self.mux.handle_connect_packet(stream, connect)?; + Ok(false) + } + + P::Data(data) => { + self.mux.handle_data_packet( + packet.stream_id, + data.into_owned(), + &mut self.streams, + )?; + Ok(false) + } + + P::Continue(cont) => { + self.mux + .handle_continue_packet(packet.stream_id, cont, &mut self.streams)?; + Ok(false) + } + + P::Close(close) => Ok(self.handle_close_packet(packet.stream_id, close)), + } + } + + fn handle_close_packet(&mut self, stream_id: u32, close: ClosePacket) -> bool { if stream_id == 0 { return true; } - if let Some(stream) = self.stream_map.remove(&stream_id) { - Self::close_stream(&stream, inner_packet); + if let Some(stream) = self.streams.remove(&stream_id) { + Self::close_stream(stream, close); } false } - - fn handle_data_packet( - &mut self, - stream_id: u32, - optional_frame: Option>, - data: Payload<'static>, - ) -> bool { - let mut data = BytesMut::from(data); - - if let Some(stream) = self.stream_map.get(&stream_id) { - if let Some(extra_frame) = optional_frame { - if data.is_empty() { - data = extra_frame.payload.into(); - } else { - data.extend_from_slice(&extra_frame.payload); - } - } - let _ = stream.stream.try_send(Payload::Bytes(data)); - if self.role == Role::Server && stream.should_flow_control { - stream.flow_control.store( - stream - .flow_control - .load(Ordering::Acquire) - .saturating_sub(1), - Ordering::Release, - ); - } - } - - false - } - - async fn handle_packet( - &mut self, - packet: Packet<'static>, - optional_frame: Option>, - ) -> Result { - use PacketType as P; - match packet.packet_type { - P::Data(data) => Ok(self.handle_data_packet(packet.stream_id, optional_frame, data)), - P::Close(inner_packet) => Ok(self.handle_close_packet(packet.stream_id, inner_packet)), - - _ => match self.role { - Role::Server => self.server_handle_packet(packet, optional_frame).await, - Role::Client => self.client_handle_packet(&packet), - }, - } - } - - async fn server_handle_packet( - &mut self, - packet: Packet<'static>, - _optional_frame: Option>, - ) -> Result { - use PacketType as P; - match packet.packet_type { - P::Connect(inner_packet) => { - let (map_value, stream) = - self.create_new_stream(packet.stream_id, inner_packet.stream_type); - self.server_tx - .send_async((inner_packet, stream)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - self.stream_map.insert(packet.stream_id, map_value); - Ok(false) - } - - // Continue | Info => invalid packet type - // Data | Close => specialcased - _ => Err(WispError::InvalidPacketType), - } - } - - fn client_handle_packet(&mut self, packet: &Packet<'static>) -> Result { - use PacketType as P; - match packet.packet_type { - P::Continue(inner_packet) => { - if let Some(stream) = self.stream_map.get(&packet.stream_id) { - if stream.stream_type == StreamType::Tcp { - stream - .flow_control - .store(inner_packet.buffer_remaining, Ordering::Release); - let _ = stream.flow_control_event.notify(u32::MAX); - } - } - Ok(false) - } - - // Connect | Info => invalid packet type - // Data | Close => specialcased - _ => Err(WispError::InvalidPacketType), - } - } } diff --git a/wisp/src/mux/mod.rs b/wisp/src/mux/mod.rs index 9a26700..8bbb7e0 100644 --- a/wisp/src/mux/mod.rs +++ b/wisp/src/mux/mod.rs @@ -3,18 +3,26 @@ pub(crate) mod inner; mod server; use std::{future::Future, pin::Pin}; -pub use client::ClientMux; -pub use server::ServerMux; +use futures::SinkExt; +use inner::{MultiplexorActor, MuxInner, WsEvent}; + +pub use client::ClientImpl; +pub use server::ServerImpl; + +pub type ServerMux = Multiplexor, W>; +pub type ClientMux = Multiplexor; use crate::{ extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder}, - ws::{LockedWebSocketWrite, WebSocketWrite}, - CloseReason, Packet, PacketType, Role, WispError, + packet::{CloseReason, InfoPacket, Packet, PacketType}, + ws::{WebSocketRead, WebSocketWrite}, + LockedWebSocketWrite, LockedWebSocketWriteGuard, Role, WispError, WISP_VERSION, }; struct WispHandshakeResult { kind: WispHandshakeResultKind, downgraded: bool, + buffer_size: u32, } enum WispHandshakeResultKind { @@ -22,7 +30,7 @@ enum WispHandshakeResultKind { extensions: Vec, }, V1 { - frame: Option>, + packet: Option>, }, } @@ -30,35 +38,56 @@ impl WispHandshakeResultKind { pub fn into_parts(self) -> (Vec, Option>) { match self { Self::V2 { extensions } => (extensions, None), - Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame), + Self::V1 { packet } => (vec![UdpProtocolExtension.into()], packet), } } } -async fn send_info_packet( - write: &LockedWebSocketWrite, - builders: &mut [AnyProtocolExtensionBuilder], +async fn handle_handshake( + read: &mut R, + write: &mut LockedWebSocketWrite, + extensions: &mut [AnyProtocolExtension], ) -> Result<(), WispError> { - write - .write_frame( - Packet::new_info( - builders - .iter_mut() - .map(|x| x.build_to_extension(Role::Server)) - .collect::, _>>()?, - ) - .into(), - ) - .await + write.lock().await; + let mut handle = write.get_handle(); + for extension in extensions { + extension.handle_handshake(read, &mut handle).await?; + } + drop(handle); + + Ok(()) } -fn validate_continue_packet(packet: &Packet<'_>) -> Result { +async fn send_info_packet( + write: &mut LockedWebSocketWrite, + builders: &mut [AnyProtocolExtensionBuilder], + role: Role, +) -> Result<(), WispError> { + let extensions = builders + .iter_mut() + .map(|x| x.build_to_extension(role)) + .collect::, _>>()?; + + let packet = InfoPacket { + version: WISP_VERSION, + extensions, + } + .encode(); + + write.lock().await; + let ret = write.get().send(packet).await; + write.unlock(); + + ret +} + +fn validate_continue_packet(packet: &Packet) -> Result { if packet.stream_id != 0 { - return Err(WispError::InvalidStreamId); + return Err(WispError::InvalidStreamId(packet.stream_id)); } let PacketType::Continue(continue_packet) = packet.packet_type else { - return Err(WispError::InvalidPacketType); + return Err(WispError::InvalidPacketType(packet.packet_type.get_type())); }; Ok(continue_packet.buffer_remaining) @@ -75,35 +104,185 @@ fn get_supported_extensions( .collect() } -trait Multiplexor { - fn has_extension(&self, extension_id: u8) -> bool; - async fn exit(&self, reason: CloseReason) -> Result<(), WispError>; +trait MultiplexorImpl { + type Actor: MultiplexorActor + 'static; + + async fn handshake( + &mut self, + rx: &mut R, + tx: &mut LockedWebSocketWrite, + v2: Option, + ) -> Result; + + async fn handle_error( + &mut self, + err: WispError, + tx: &mut LockedWebSocketWrite, + ) -> Result; } +#[expect(private_bounds)] +pub struct Multiplexor, W: WebSocketWrite> { + mux: M, + + downgraded: bool, + supported_extensions: Vec, + + actor_tx: flume::Sender>, + tx: LockedWebSocketWrite, +} + +#[expect(private_bounds)] +impl, W: WebSocketWrite> Multiplexor { + async fn create( + mut rx: R, + tx: W, + wisp_v2: Option, + mut muxer: M, + actor: M::Actor, + ) -> Result, WispError> + where + R: WebSocketRead, + { + let mut tx = LockedWebSocketWrite::new(tx); + + let ret = async { + let handshake_result = muxer.handshake(&mut rx, &mut tx, wisp_v2).await?; + let (extensions, extra_packet) = handshake_result.kind.into_parts(); + + Ok(( + MuxInner::new( + rx, + tx.clone(), + actor, + extra_packet, + extensions.clone(), + handshake_result.buffer_size, + ), + handshake_result.downgraded, + extensions, + )) + } + .await; + + match ret { + Ok((mux_result, downgraded, extensions)) => Ok(MuxResult( + Self { + mux: muxer, + + downgraded, + supported_extensions: extensions, + + actor_tx: mux_result.actor_tx, + tx, + }, + Box::pin(mux_result.mux.into_future()), + )), + Err(x) => Err(muxer.handle_error(x, &mut tx).await?), + } + } + + /// Whether the connection was downgraded to Wisp v1. + pub fn was_downgraded(&self) -> bool { + self.downgraded + } + + /// Get a shared reference to the extensions that are supported by both sides. + pub fn get_extensions(&self) -> &[AnyProtocolExtension] { + &self.supported_extensions + } + + /// Get a mutable reference to the extensions that are supported by both sides. + pub fn get_extensions_mut(&mut self) -> &mut [AnyProtocolExtension] { + &mut self.supported_extensions + } + + /// Get a `Vec` of all extension IDs that are supported by both sides. + pub fn get_extension_ids(&self) -> Vec { + self.supported_extensions + .iter() + .map(|x| x.get_id()) + .collect() + } + + /// Get a locked guard to the write half of the websocket. + pub async fn lock_ws(&self) -> Result, WispError> { + if self.actor_tx.is_disconnected() { + Err(WispError::WsImplSocketClosed) + } else { + let mut cloned = self.tx.clone(); + cloned.lock().await; + Ok(cloned.get_guard()) + } + } + + async fn close_internal(&self, reason: Option) -> Result<(), WispError> { + self.actor_tx + .send_async(WsEvent::EndFut(reason)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) + } + + /// Close all streams. + /// + /// Also terminates the multiplexor future. + pub async fn close(&self) -> Result<(), WispError> { + self.close_internal(None).await + } + + /// Close all streams and send a close reason on stream ID 0. + /// + /// Also terminates the multiplexor future. + pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> { + self.close_internal(Some(reason)).await + } + + /* TODO + /// Get a protocol extension stream for sending packets with stream id 0. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + stream_id: 0, + tx: self.tx.clone(), + is_closed: self.actor_exited.clone(), + } + } + */ +} + +pub type MultiplexorActorFuture = Pin> + Send>>; + /// Result of creating a multiplexor. Helps require protocol extensions. #[expect(private_bounds)] -pub struct MuxResult(M, F) +pub struct MuxResult(Multiplexor, MultiplexorActorFuture) where - M: Multiplexor, - F: Future> + Send; + M: MultiplexorImpl, + W: WebSocketWrite; #[expect(private_bounds)] -impl MuxResult +impl MuxResult where - M: Multiplexor, - F: Future> + Send, + M: MultiplexorImpl, + W: WebSocketWrite, { /// Require no protocol extensions. - pub fn with_no_required_extensions(self) -> (M, F) { + pub fn with_no_required_extensions(self) -> (Multiplexor, MultiplexorActorFuture) { (self.0, self.1) } /// Require protocol extensions by their ID. Will close the multiplexor connection if /// extensions are not supported. - pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(M, F), WispError> { + pub async fn with_required_extensions( + self, + extensions: &[u8], + ) -> Result<(Multiplexor, MultiplexorActorFuture), WispError> { let mut unsupported_extensions = Vec::new(); + let supported_extensions = self.0.get_extensions(); + for extension in extensions { - if !self.0.has_extension(*extension) { + if !supported_extensions + .iter() + .any(|x| x.get_id() == *extension) + { unsupported_extensions.push(*extension); } } @@ -111,14 +290,18 @@ where if unsupported_extensions.is_empty() { Ok((self.0, self.1)) } else { - self.0.exit(CloseReason::ExtensionsIncompatible).await?; + self.0 + .close_with_reason(CloseReason::ExtensionsIncompatible) + .await?; self.1.await?; Err(WispError::ExtensionsNotSupported(unsupported_extensions)) } } /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` - pub async fn with_udp_extension_required(self) -> Result<(M, F), WispError> { + pub async fn with_udp_extension_required( + self, + ) -> Result<(Multiplexor, MultiplexorActorFuture), WispError> { self.with_required_extensions(&[UdpProtocolExtension::ID]) .await } diff --git a/wisp/src/mux/server.rs b/wisp/src/mux/server.rs index 8af6383..46399aa 100644 --- a/wisp/src/mux/server.rs +++ b/wisp/src/mux/server.rs @@ -1,241 +1,196 @@ -use std::{ - future::Future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; - -use flume as mpsc; -use futures::channel::oneshot; +use futures::SinkExt; use crate::{ - extensions::AnyProtocolExtension, - ws::{DynWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, - CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, - WispError, + locked_sink::LockedWebSocketWrite, + packet::{CloseReason, ConnectPacket, MaybeInfoPacket, Packet, StreamType}, + stream::MuxStream, + ws::{Payload, WebSocketRead, WebSocketReadExt, WebSocketWrite}, + Role, WispError, }; use super::{ - get_supported_extensions, - inner::{MuxInner, WsEvent}, - send_info_packet, Multiplexor, MuxResult, WispHandshakeResult, WispHandshakeResultKind, - WispV2Handshake, + get_supported_extensions, handle_handshake, + inner::{FlowControl, MultiplexorActor, StreamMap}, + send_info_packet, Multiplexor, MultiplexorImpl, MuxResult, WispHandshakeResult, + WispHandshakeResultKind, WispV2Handshake, }; -async fn handshake( - rx: &mut R, - tx: &LockedWebSocketWrite, - buffer_size: u32, - v2_info: Option, -) -> Result { - if let Some(WispV2Handshake { - mut builders, - closure, - }) = v2_info - { - send_info_packet(tx, &mut builders).await?; - tx.write_frame(Packet::new_continue(0, buffer_size).into()) - .await?; - - (closure)(&mut builders).await?; - - let packet = - Packet::maybe_parse_info(rx.wisp_read_frame(tx).await?, Role::Server, &mut builders)?; - - if let PacketType::Info(info) = packet.packet_type { - let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders); - - for extension in &mut supported_extensions { - extension - .handle_handshake(DynWebSocketRead::from_mut(rx), tx) - .await?; - } - - // v2 client - Ok(WispHandshakeResult { - kind: WispHandshakeResultKind::V2 { - extensions: supported_extensions, - }, - downgraded: false, - }) - } else { - // downgrade to v1 - Ok(WispHandshakeResult { - kind: WispHandshakeResultKind::V1 { - frame: Some(packet), - }, - downgraded: true, - }) - } - } else { - // user asked for v1 server - tx.write_frame(Packet::new_continue(0, buffer_size).into()) - .await?; - - Ok(WispHandshakeResult { - kind: WispHandshakeResultKind::V1 { frame: None }, - downgraded: false, - }) - } +pub(crate) struct ServerActor { + stream_tx: flume::Sender<(ConnectPacket, MuxStream)>, } -/// Server-side multiplexor. -pub struct ServerMux { - /// Whether the connection was downgraded to Wisp v1. - /// - /// If this variable is true you must assume no extensions are supported. - pub downgraded: bool, - /// Extensions that are supported by both sides. - pub supported_extensions: Vec, - actor_tx: mpsc::Sender>, - muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, - tx: LockedWebSocketWrite, - actor_exited: Arc, -} - -impl ServerMux { - /// Create a new server-side multiplexor. - /// - /// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. - /// **It is not guaranteed that all extensions you specify are available.** You must manually check - /// if the extensions you need are available after the multiplexor has been created. - pub async fn create( - mut rx: R, - tx: W, - buffer_size: u32, - wisp_v2: Option, - ) -> Result< - MuxResult, impl Future> + Send>, - WispError, - > - where - R: WebSocketRead + Send + 'static, - { - let tx = LockedWebSocketWrite::new(tx); - let ret_tx = tx.clone(); - let ret = async { - let handshake_result = handshake(&mut rx, &tx, buffer_size, wisp_v2).await?; - let (extensions, extra_packet) = handshake_result.kind.into_parts(); - - let (mux_result, muxstream_recv) = MuxInner::new_server( - rx, - extra_packet, - tx.clone(), - extensions.clone(), - buffer_size, - ); - - Ok(MuxResult( - Self { - actor_tx: mux_result.actor_tx, - actor_exited: mux_result.actor_exited, - muxstream_recv, - - tx, - - downgraded: handshake_result.downgraded, - supported_extensions: extensions, - }, - mux_result.mux.into_future(), - )) - } - .await; - - match ret { - Ok(x) => Ok(x), - Err(x) => match x { - WispError::PasswordExtensionCredsInvalid => { - ret_tx - .write_frame( - Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(), - ) - .await?; - ret_tx.close().await?; - Err(x) - } - WispError::CertAuthExtensionSigInvalid => { - ret_tx - .write_frame( - Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(), - ) - .await?; - ret_tx.close().await?; - Err(x) - } - x => Err(x), - }, - } - } - - /// Wait for a stream to be created. - pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { - if self.actor_exited.load(Ordering::Acquire) { - return None; - } - self.muxstream_recv.recv_async().await.ok() - } - - /// Send a ping to the client. - pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> { - if self.actor_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - let (tx, rx) = oneshot::channel(); - self.actor_tx - .send_async(WsEvent::SendPing(payload, tx)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? - } - - async fn close_internal(&self, reason: Option) -> Result<(), WispError> { - if self.actor_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - self.actor_tx - .send_async(WsEvent::EndFut(reason)) - .await +impl MultiplexorActor for ServerActor { + fn handle_connect_packet( + &mut self, + stream: MuxStream, + pkt: ConnectPacket, + ) -> Result<(), WispError> { + self.stream_tx + .send((pkt, stream)) .map_err(|_| WispError::MuxMessageFailedToSend) } - /// Close all streams. - /// - /// Also terminates the multiplexor future. - pub async fn close(&self) -> Result<(), WispError> { - self.close_internal(None).await + fn handle_data_packet( + &mut self, + id: u32, + pkt: Payload, + streams: &mut StreamMap, + ) -> Result<(), WispError> { + if let Some(stream) = streams.get(&id) { + if stream.stream.try_send(pkt).is_ok() { + stream.info.flow_dec(); + } + } + Ok(()) } - /// Close all streams and send a close reason on stream ID 0. - /// - /// Also terminates the multiplexor future. - pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> { - self.close_internal(Some(reason)).await + fn handle_continue_packet( + &mut self, + _: u32, + _: crate::packet::ContinuePacket, + _: &mut StreamMap, + ) -> Result<(), WispError> { + Err(WispError::InvalidPacketType(0x03)) } - /// Get a protocol extension stream for sending packets with stream id 0. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - MuxProtocolExtensionStream { - stream_id: 0, - tx: self.tx.clone(), - is_closed: self.actor_exited.clone(), + fn get_flow_control(ty: StreamType, flow_stream_types: &[u8]) -> FlowControl { + if flow_stream_types.contains(&ty.into()) { + FlowControl::EnabledSendMessages + } else { + FlowControl::Disabled } } } -impl Drop for ServerMux { - fn drop(&mut self) { - let _ = self.actor_tx.send(WsEvent::EndFut(None)); +pub struct ServerImpl { + buffer_size: u32, + stream_rx: flume::Receiver<(ConnectPacket, MuxStream)>, +} + +impl MultiplexorImpl for ServerImpl { + type Actor = ServerActor; + + async fn handshake( + &mut self, + rx: &mut R, + tx: &mut LockedWebSocketWrite, + v2: Option, + ) -> Result { + if let Some(WispV2Handshake { + mut builders, + closure, + }) = v2 + { + send_info_packet(tx, &mut builders, Role::Server).await?; + tx.lock().await; + tx.get() + .send(Packet::new_continue(0, self.buffer_size).encode()) + .await?; + tx.unlock(); + + (closure)(&mut builders).await?; + + let packet = + MaybeInfoPacket::decode(rx.next_erroring().await?, &mut builders, Role::Server)?; + + match packet { + MaybeInfoPacket::Info(info) => { + let mut supported_extensions = + get_supported_extensions(info.extensions, &mut builders); + + handle_handshake(rx, tx, &mut supported_extensions).await?; + + // v2 client + Ok(WispHandshakeResult { + kind: WispHandshakeResultKind::V2 { + extensions: supported_extensions, + }, + downgraded: false, + buffer_size: self.buffer_size, + }) + } + MaybeInfoPacket::Packet(packet) => { + // downgrade to v1 + Ok(WispHandshakeResult { + kind: WispHandshakeResultKind::V1 { + packet: Some(packet), + }, + downgraded: true, + buffer_size: self.buffer_size, + }) + } + } + } else { + // user asked for v1 server + tx.lock().await; + tx.get() + .send(Packet::new_continue(0, self.buffer_size).encode()) + .await?; + tx.unlock(); + + Ok(WispHandshakeResult { + kind: WispHandshakeResultKind::V1 { packet: None }, + downgraded: false, + buffer_size: self.buffer_size, + }) + } + } + + async fn handle_error( + &mut self, + err: WispError, + tx: &mut LockedWebSocketWrite, + ) -> Result { + match err { + WispError::PasswordExtensionCredsInvalid => { + tx.lock().await; + tx.get() + .send(Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).encode()) + .await?; + tx.get().close().await?; + tx.unlock(); + Ok(err) + } + WispError::CertAuthExtensionSigInvalid => { + tx.lock().await; + tx.get() + .send(Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).encode()) + .await?; + tx.get().close().await?; + tx.unlock(); + Ok(err) + } + x => Ok(x), + } } } -impl Multiplexor for ServerMux { - fn has_extension(&self, extension_id: u8) -> bool { - self.supported_extensions - .iter() - .any(|x| x.get_id() == extension_id) +impl Multiplexor, W> { + /// Create a new server-side multiplexor. + /// + /// If `wisp_v2` is None a Wisp v1 connection is created, otherwise a Wisp v2 connection is created. + /// **It is not guaranteed that all extensions you specify are available.** You must manually check + /// if the extensions you need are available after the multiplexor has been created. + #[expect(clippy::new_ret_no_self)] + pub async fn new( + rx: R, + tx: W, + buffer_size: u32, + wisp_v2: Option, + ) -> Result, W>, WispError> { + let (stream_tx, stream_rx) = flume::unbounded(); + + let mux = ServerImpl { + buffer_size, + stream_rx, + }; + let actor = ServerActor { stream_tx }; + + Self::create(rx, tx, wisp_v2, mux, actor).await } - async fn exit(&self, reason: CloseReason) -> Result<(), WispError> { - self.close_with_reason(reason).await + + /// Wait for a stream to be created. + pub async fn wait_for_stream(&self) -> Option<(ConnectPacket, MuxStream)> { + self.mux.stream_rx.recv_async().await.ok() } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index e7de23d..93650e5 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -1,293 +1,230 @@ use std::fmt::Display; +use bytes::{Buf, BufMut}; +use num_enum::{FromPrimitive, IntoPrimitive}; + use crate::{ extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder}, - ws::{ - self, DynWebSocketRead, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, - WebSocketWrite, - }, - Role, WispError, WISP_VERSION, + ws::{Payload, PayloadMut, PayloadRef, WebSocketRead, WebSocketWrite}, + LockedWebSocketWrite, Role, WispError, WISP_VERSION, }; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -/// Wisp stream type. -#[derive(Debug, PartialEq, Copy, Clone)] +trait PacketCodec: Sized { + fn size_hint(&self) -> usize; + + fn encode_into(&self, packet: &mut PayloadMut); + fn decode(packet: &mut Payload) -> Result; +} + +#[derive(FromPrimitive, IntoPrimitive, Debug, Copy, Clone, Eq, PartialEq)] +#[repr(u8)] pub enum StreamType { - /// TCP Wisp stream. - Tcp, - /// UDP Wisp stream. - Udp, - /// Unknown Wisp stream type used for custom streams by protocol extensions. - Unknown(u8), + Tcp = 0x01, + Udp = 0x02, + #[num_enum(catch_all)] + Other(u8), } -impl From for StreamType { - fn from(value: u8) -> Self { - use StreamType as S; - match value { - 0x01 => S::Tcp, - 0x02 => S::Udp, - x => S::Unknown(x), +impl PacketCodec for StreamType { + fn size_hint(&self) -> usize { + size_of::() + } + + fn encode_into(&self, packet: &mut PayloadMut) { + packet.put_u8((*self).into()); + } + + fn decode(packet: &mut Payload) -> Result { + if packet.remaining() < size_of::() { + return Err(WispError::PacketTooSmall); } + + Ok(Self::from(packet.get_u8())) } } -impl From for u8 { - fn from(value: StreamType) -> Self { - use StreamType as S; - match value { - S::Tcp => 0x01, - S::Udp => 0x02, - S::Unknown(x) => x, +#[derive(FromPrimitive, IntoPrimitive, Debug, Copy, Clone, Eq, PartialEq)] +#[repr(u8)] +pub enum CloseReason { + /// Reason unspecified or unknown. + Unknown = 0x01, + /// Voluntary stream closure. + Voluntary = 0x02, + /// Unexpected stream closure due to a network error. + Unexpected = 0x03, + /// Incompatible extensions. + ExtensionsIncompatible = 0x04, + + /// Stream creation failed due to invalid information. + ServerStreamInvalidInfo = 0x41, + /// Stream creation failed due to an unreachable destination host. + ServerStreamUnreachable = 0x42, + /// Stream creation timed out due to the destination server not responding. + ServerStreamConnectionTimedOut = 0x43, + /// Stream creation failed due to the destination server refusing the connection. + ServerStreamConnectionRefused = 0x44, + /// TCP data transfer timed out. + ServerStreamTimedOut = 0x47, + /// Stream destination address/domain is intentionally blocked by the proxy server. + ServerStreamBlockedAddress = 0x48, + /// Connection throttled by the server. + ServerStreamThrottled = 0x49, + + /// The client has encountered an unexpected error and is unable to recieve any more data. + ClientUnexpected = 0x81, + + /// Authentication failed due to invalid username/password. + ExtensionsPasswordAuthFailed = 0xc0, + /// Authentication failed due to invalid signature. + ExtensionsCertAuthFailed = 0xc1, + /// Authentication required but the client did not provide credentials. + ExtensionsAuthRequired = 0xc2, + + #[num_enum(catch_all)] + Other(u8), +} + +impl Display for CloseReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use CloseReason as C; + if let C::Other(x) = self { + return write!(f, "Other: {x}"); } - } -} -mod close { - use std::fmt::Display; + write!( + f, + "{}", + match self { + C::Unknown => "Unknown close reason", + C::Voluntary => "Voluntarily closed", + C::Unexpected => "Unexpectedly closed", + C::ExtensionsIncompatible => "Incompatible protocol extensions", - use atomic_enum::atomic_enum; + C::ServerStreamInvalidInfo => "Stream creation failed due to invalid information", + C::ServerStreamUnreachable => + "Stream creation failed due to an unreachable destination", + C::ServerStreamConnectionTimedOut => + "Stream creation failed due to destination not responding", + C::ServerStreamConnectionRefused => + "Stream creation failed due to destination refusing connection", + C::ServerStreamTimedOut => "TCP timed out", + C::ServerStreamBlockedAddress => "Destination address is blocked", + C::ServerStreamThrottled => "Throttled", - use crate::WispError; + C::ClientUnexpected => "Client encountered unexpected error", - /// Close reason. - /// - /// See [the - /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons) - #[derive(PartialEq)] - #[repr(u8)] - #[atomic_enum] - pub enum CloseReason { - /// Reason unspecified or unknown. - Unknown = 0x01, - /// Voluntary stream closure. - Voluntary = 0x02, - /// Unexpected stream closure due to a network error. - Unexpected = 0x03, - /// Incompatible extensions. - ExtensionsIncompatible = 0x04, - - /// Stream creation failed due to invalid information. - ServerStreamInvalidInfo = 0x41, - /// Stream creation failed due to an unreachable destination host. - ServerStreamUnreachable = 0x42, - /// Stream creation timed out due to the destination server not responding. - ServerStreamConnectionTimedOut = 0x43, - /// Stream creation failed due to the destination server refusing the connection. - ServerStreamConnectionRefused = 0x44, - /// TCP data transfer timed out. - ServerStreamTimedOut = 0x47, - /// Stream destination address/domain is intentionally blocked by the proxy server. - ServerStreamBlockedAddress = 0x48, - /// Connection throttled by the server. - ServerStreamThrottled = 0x49, - - /// The client has encountered an unexpected error and is unable to recieve any more data. - ClientUnexpected = 0x81, - - /// Authentication failed due to invalid username/password. - ExtensionsPasswordAuthFailed = 0xc0, - /// Authentication failed due to invalid signature. - ExtensionsCertAuthFailed = 0xc1, - /// Authentication required but the client did not provide credentials. - ExtensionsAuthRequired = 0xc2, - } - - impl TryFrom for CloseReason { - type Error = WispError; - fn try_from(close_reason: u8) -> Result { - match close_reason { - 0x01 => Ok(Self::Unknown), - 0x02 => Ok(Self::Voluntary), - 0x03 => Ok(Self::Unexpected), - 0x04 => Ok(Self::ExtensionsIncompatible), - - 0x41 => Ok(Self::ServerStreamInvalidInfo), - 0x42 => Ok(Self::ServerStreamUnreachable), - 0x43 => Ok(Self::ServerStreamConnectionTimedOut), - 0x44 => Ok(Self::ServerStreamConnectionRefused), - 0x47 => Ok(Self::ServerStreamTimedOut), - 0x48 => Ok(Self::ServerStreamBlockedAddress), - 0x49 => Ok(Self::ServerStreamThrottled), - - 0x81 => Ok(Self::ClientUnexpected), - - 0xc0 => Ok(Self::ExtensionsPasswordAuthFailed), - 0xc1 => Ok(Self::ExtensionsCertAuthFailed), - 0xc2 => Ok(Self::ExtensionsAuthRequired), - - _ => Err(Self::Error::InvalidCloseReason), + C::ExtensionsPasswordAuthFailed => "Invalid username/password", + C::ExtensionsCertAuthFailed => "Invalid signature", + C::ExtensionsAuthRequired => "Authentication required", + C::Other(_) => unreachable!(), } - } - } - - impl Display for CloseReason { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use CloseReason as C; - write!( - f, - "{}", - match self { - C::Unknown => "Unknown close reason", - C::Voluntary => "Voluntarily closed", - C::Unexpected => "Unexpectedly closed", - C::ExtensionsIncompatible => "Incompatible protocol extensions", - - C::ServerStreamInvalidInfo => - "Stream creation failed due to invalid information", - C::ServerStreamUnreachable => - "Stream creation failed due to an unreachable destination", - C::ServerStreamConnectionTimedOut => - "Stream creation failed due to destination not responding", - C::ServerStreamConnectionRefused => - "Stream creation failed due to destination refusing connection", - C::ServerStreamTimedOut => "TCP timed out", - C::ServerStreamBlockedAddress => "Destination address is blocked", - C::ServerStreamThrottled => "Throttled", - - C::ClientUnexpected => "Client encountered unexpected error", - - C::ExtensionsPasswordAuthFailed => "Invalid username/password", - C::ExtensionsCertAuthFailed => "Invalid signature", - C::ExtensionsAuthRequired => "Authentication required", - } - ) - } + ) } } -pub(crate) use close::AtomicCloseReason; -pub use close::CloseReason; +impl PacketCodec for CloseReason { + fn size_hint(&self) -> usize { + size_of::() + } -trait Encode { - fn encode(self, bytes: &mut BytesMut); + fn encode_into(&self, packet: &mut PayloadMut) { + packet.put_u8((*self).into()); + } + + fn decode(packet: &mut Payload) -> Result { + if packet.remaining() < size_of::() { + return Err(WispError::PacketTooSmall); + } + + Ok(Self::from(packet.get_u8())) + } } -/// Packet used to create a new stream. -/// -/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect). -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct ConnectPacket { - /// Whether the new stream should use a TCP or UDP socket. pub stream_type: StreamType, - /// Destination TCP/UDP port for the new stream. - pub destination_port: u16, - /// Destination hostname, in a UTF-8 string. - pub destination_hostname: String, + + pub host: String, + pub port: u16, } -impl ConnectPacket { - /// Create a new connect packet. - pub fn new( - stream_type: StreamType, - destination_port: u16, - destination_hostname: String, - ) -> Self { - Self { - stream_type, - destination_port, - destination_hostname, - } +impl PacketCodec for ConnectPacket { + fn size_hint(&self) -> usize { + self.stream_type.size_hint() + self.host.len() + size_of::() } -} -impl TryFrom> for ConnectPacket { - type Error = WispError; - fn try_from(mut bytes: Payload<'_>) -> Result { - if bytes.remaining() < (1 + 2) { - return Err(Self::Error::PacketTooSmall); + fn encode_into(&self, packet: &mut PayloadMut) { + self.stream_type.encode_into(packet); + packet.put_u16_le(self.port); + packet.extend_from_slice(self.host.as_bytes()); + } + + fn decode(packet: &mut Payload) -> Result { + if packet.remaining() < (size_of::() + size_of::()) { + return Err(WispError::PacketTooSmall); } + + let stream_type = StreamType::decode(packet)?; + let port = packet.get_u16_le(); + let host = String::from_utf8(packet.to_vec())?; + Ok(Self { - stream_type: bytes.get_u8().into(), - destination_port: bytes.get_u16_le(), - destination_hostname: std::str::from_utf8(&bytes)?.to_string(), + stream_type, + host, + port, }) } } -impl Encode for ConnectPacket { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u8(self.stream_type.into()); - bytes.put_u16_le(self.destination_port); - bytes.extend(self.destination_hostname.bytes()); - } -} - -/// Packet used for Wisp TCP stream flow control. -/// -/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x03---continue). -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct ContinuePacket { - /// Number of packets that the server can buffer for the current stream. pub buffer_remaining: u32, } -impl ContinuePacket { - /// Create a new continue packet. - pub fn new(buffer_remaining: u32) -> Self { - Self { buffer_remaining } +impl PacketCodec for ContinuePacket { + fn size_hint(&self) -> usize { + size_of::() } -} -impl TryFrom> for ContinuePacket { - type Error = WispError; - fn try_from(mut bytes: Payload<'_>) -> Result { - if bytes.remaining() < 4 { - return Err(Self::Error::PacketTooSmall); + fn encode_into(&self, packet: &mut PayloadMut) { + packet.put_u32_le(self.buffer_remaining); + } + + fn decode(packet: &mut Payload) -> Result { + if packet.remaining() < size_of::() { + return Err(WispError::PacketTooSmall); } - Ok(Self { - buffer_remaining: bytes.get_u32_le(), - }) + + let buffer_remaining = packet.get_u32_le(); + + Ok(Self { buffer_remaining }) } } -impl Encode for ContinuePacket { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u32_le(self.buffer_remaining); - } -} - -/// Packet used to close a stream. -/// -/// See [the -/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x04---close). -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct ClosePacket { - /// The close reason. pub reason: CloseReason, } -impl ClosePacket { - /// Create a new close packet. - pub fn new(reason: CloseReason) -> Self { - Self { reason } +impl PacketCodec for ClosePacket { + fn size_hint(&self) -> usize { + self.reason.size_hint() + } + + fn encode_into(&self, packet: &mut PayloadMut) { + self.reason.encode_into(packet); + } + + fn decode(packet: &mut Payload) -> Result { + let reason = CloseReason::decode(packet)?; + + Ok(Self { reason }) } } -impl TryFrom> for ClosePacket { - type Error = WispError; - fn try_from(mut bytes: Payload<'_>) -> Result { - if bytes.remaining() < 1 { - return Err(Self::Error::PacketTooSmall); - } - Ok(Self { - reason: bytes.get_u8().try_into()?, - }) - } -} - -impl Encode for ClosePacket { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u8(self.reason as u8); - } -} - -/// Wisp version sent in the handshake. -#[derive(Debug, Clone)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct WispVersion { - /// Major Wisp version according to semver. pub major: u8, - /// Minor Wisp version according to semver. pub minor: u8, } @@ -297,182 +234,47 @@ impl Display for WispVersion { } } -/// Packet used in the initial handshake. -/// -/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info) +impl PacketCodec for WispVersion { + fn size_hint(&self) -> usize { + size_of::() * 2 + } + + fn encode_into(&self, packet: &mut PayloadMut) { + packet.put_u8(self.major); + packet.put_u8(self.minor); + } + + fn decode(packet: &mut Payload) -> Result { + if packet.remaining() < 2 { + return Err(WispError::PacketTooSmall); + } + + Ok(Self { + major: packet.get_u8(), + minor: packet.get_u8(), + }) + } +} + #[derive(Debug, Clone)] pub struct InfoPacket { - /// Wisp version sent in the packet. pub version: WispVersion, - /// List of protocol extensions sent in the packet. pub extensions: Vec, } -impl Encode for InfoPacket { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u8(self.version.major); - bytes.put_u8(self.version.minor); - for extension in self.extensions { - bytes.extend_from_slice(&Bytes::from(extension)); - } - } -} - -#[derive(Debug, Clone)] -/// Type of packet recieved. -pub enum PacketType<'a> { - /// Connect packet. - Connect(ConnectPacket), - /// Data packet. - Data(Payload<'a>), - /// Continue packet. - Continue(ContinuePacket), - /// Close packet. - Close(ClosePacket), - /// Info packet. - Info(InfoPacket), -} - -impl PacketType<'_> { - /// Get the packet type used in the protocol. - pub fn as_u8(&self) -> u8 { - use PacketType as P; - match self { - P::Connect(_) => 0x01, - P::Data(_) => 0x02, - P::Continue(_) => 0x03, - P::Close(_) => 0x04, - P::Info(_) => 0x05, - } - } - - pub(crate) fn get_packet_size(&self) -> usize { - use PacketType as P; - match self { - P::Connect(p) => 1 + 2 + p.destination_hostname.len(), - P::Data(p) => p.len(), - P::Continue(_) => 4, - P::Close(_) => 1, - P::Info(_) => 2, - } - } -} - -impl Encode for PacketType<'_> { - fn encode(self, bytes: &mut BytesMut) { - use PacketType as P; - match self { - P::Connect(x) => x.encode(bytes), - P::Data(x) => bytes.extend_from_slice(&x), - P::Continue(x) => x.encode(bytes), - P::Close(x) => x.encode(bytes), - P::Info(x) => x.encode(bytes), - }; - } -} - -/// Wisp protocol packet. -#[derive(Debug, Clone)] -pub struct Packet<'a> { - /// Stream this packet is associated with. - pub stream_id: u32, - /// Packet type recieved. - pub packet_type: PacketType<'a>, -} - -impl<'a> Packet<'a> { - /// Create a new packet. - /// - /// The helper functions should be used for most use cases. - pub fn new(stream_id: u32, packet: PacketType<'a>) -> Self { - Self { - stream_id, - packet_type: packet, - } - } - - /// Create a new connect packet. - pub fn new_connect( - stream_id: u32, - stream_type: StreamType, - destination_port: u16, - destination_hostname: String, - ) -> Self { - Self { - stream_id, - packet_type: PacketType::Connect(ConnectPacket::new( - stream_type, - destination_port, - destination_hostname, - )), - } - } - - /// Create a new data packet. - pub fn new_data(stream_id: u32, data: Payload<'a>) -> Self { - Self { - stream_id, - packet_type: PacketType::Data(data), - } - } - - /// Create a new continue packet. - pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self { - Self { - stream_id, - packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)), - } - } - - /// Create a new close packet. - pub fn new_close(stream_id: u32, reason: CloseReason) -> Self { - Self { - stream_id, - packet_type: PacketType::Close(ClosePacket::new(reason)), - } - } - - pub(crate) fn new_info(extensions: Vec) -> Self { - Self { - stream_id: 0, - packet_type: PacketType::Info(InfoPacket { - version: WISP_VERSION, - extensions, - }), - } - } - - fn parse_packet(packet_type: u8, mut bytes: Payload<'a>) -> Result { - use PacketType as P; - Ok(Self { - stream_id: bytes.get_u32_le(), - packet_type: match packet_type { - 0x01 => P::Connect(ConnectPacket::try_from(bytes)?), - 0x02 => P::Data(bytes), - 0x03 => P::Continue(ContinuePacket::try_from(bytes)?), - 0x04 => P::Close(ClosePacket::try_from(bytes)?), - // 0x05 is handled seperately - _ => return Err(WispError::InvalidPacketType), - }, - }) - } - - fn parse_info( - mut bytes: Payload<'a>, +impl InfoPacket { + pub(crate) fn decode( + packet: &mut Payload, + builders: &mut [AnyProtocolExtensionBuilder], role: Role, - extension_builders: &mut [AnyProtocolExtensionBuilder], ) -> Result { - // packet type is already read by code that calls this - if bytes.remaining() < 4 + 2 { + if packet.remaining() < (size_of::() * 2) { return Err(WispError::PacketTooSmall); } - if bytes.get_u32_le() != 0 { - return Err(WispError::InvalidStreamId); - } let version = WispVersion { - major: bytes.get_u8(), - minor: bytes.get_u8(), + major: packet.get_u8(), + minor: packet.get_u8(), }; if version.major != WISP_VERSION.major { @@ -484,151 +286,213 @@ impl<'a> Packet<'a> { let mut extensions = Vec::new(); - while bytes.remaining() > 4 { + while packet.remaining() >= (size_of::() + size_of::()) { // We have some extensions - let id = bytes.get_u8(); - let length = usize::try_from(bytes.get_u32_le())?; - if bytes.remaining() < length { + let id = packet.get_u8(); + let length = usize::try_from(packet.get_u32_le())?; + + if packet.remaining() < length { return Err(WispError::PacketTooSmall); } - if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) { - extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?); + + if let Some(builder) = builders.iter_mut().find(|x| x.get_id() == id) { + extensions.push(builder.build_from_bytes(packet.split_to(length), role)?); } else { - bytes.advance(length); + packet.advance(length); } } Ok(Self { - stream_id: 0, - packet_type: PacketType::Info(InfoPacket { - version, - extensions, - }), + version, + extensions, }) } - pub(crate) fn maybe_parse_info( - frame: Frame<'a>, - role: Role, - extension_builders: &mut [AnyProtocolExtensionBuilder], - ) -> Result { - if !frame.finished { - return Err(WispError::WsFrameNotFinished); + pub(crate) fn encode(&self) -> Payload { + let mut packet = PayloadMut::with_capacity( + size_of::() + size_of::() + self.version.size_hint(), + ); + packet.put_u8(0x05); + packet.put_u32(0); + self.version.encode_into(&mut packet); + for extension in &self.extensions { + extension.encode_into(&mut packet); } - if frame.opcode != OpCode::Binary { - return Err(WispError::WsFrameInvalidType(frame.opcode)); - } - let mut bytes = frame.payload; - if bytes.remaining() < 1 { - return Err(WispError::PacketTooSmall); - } - let packet_type = bytes.get_u8(); - if packet_type == 0x05 { - Self::parse_info(bytes, role, extension_builders) - } else { - Self::parse_packet(packet_type, bytes) + packet.freeze() + } +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum PacketType<'a> { + Connect(ConnectPacket), + Data(PayloadRef<'a>), + Continue(ContinuePacket), + Close(ClosePacket), +} + +impl PacketType<'_> { + pub(crate) fn size_hint(&self) -> usize { + match self { + Self::Connect(x) => x.size_hint(), + Self::Data(x) => x.len(), + Self::Continue(x) => x.size_hint(), + Self::Close(x) => x.size_hint(), } } - pub(crate) async fn maybe_handle_extension( - frame: Frame<'a>, - extensions: &mut [AnyProtocolExtension], - read: &mut R, - write: &LockedWebSocketWrite, - ) -> Result, WispError> { - if !frame.finished { - return Err(WispError::WsFrameNotFinished); + pub(crate) fn get_type(&self) -> u8 { + match self { + Self::Connect(_) => 0x01, + Self::Data(_) => 0x02, + Self::Continue(_) => 0x03, + Self::Close(_) => 0x04, } - if frame.opcode != OpCode::Binary { - return Err(WispError::WsFrameInvalidType(frame.opcode)); + } + + pub(crate) fn encode(&self, packet: &mut PayloadMut) { + match self { + Self::Connect(x) => x.encode_into(packet), + Self::Data(x) => packet.extend_from_slice(x), + Self::Continue(x) => x.encode_into(packet), + Self::Close(x) => x.encode_into(packet), } - let mut bytes = frame.payload; - if bytes.remaining() < 5 { + } + + pub(crate) fn decode(mut packet: Payload, ty: u8) -> Result, WispError> { + Ok(match ty { + 0x01 => PacketType::Connect(ConnectPacket::decode(&mut packet)?), + 0x02 => PacketType::Data(packet.into()), + 0x03 => PacketType::Continue(ContinuePacket::decode(&mut packet)?), + 0x04 => PacketType::Close(ClosePacket::decode(&mut packet)?), + x => return Err(WispError::InvalidPacketType(x)), + }) + } +} + +pub(crate) enum MaybeInfoPacket<'a> { + Packet(Packet<'a>), + Info(InfoPacket), +} + +impl MaybeInfoPacket<'static> { + pub(crate) fn decode( + mut packet: Payload, + builders: &mut [AnyProtocolExtensionBuilder], + role: Role, + ) -> Result { + if packet.remaining() < size_of::() + size_of::() { return Err(WispError::PacketTooSmall); } - let packet_type = bytes.get_u8(); - match packet_type { - 0x01 => Ok(Some(Self { - stream_id: bytes.get_u32_le(), - packet_type: PacketType::Connect(bytes.try_into()?), - })), - 0x02 => Ok(Some(Self { - stream_id: bytes.get_u32_le(), - packet_type: PacketType::Data(bytes), - })), - 0x03 => Ok(Some(Self { - stream_id: bytes.get_u32_le(), - packet_type: PacketType::Continue(bytes.try_into()?), - })), - 0x04 => Ok(Some(Self { - stream_id: bytes.get_u32_le(), - packet_type: PacketType::Close(bytes.try_into()?), - })), - 0x05 => Ok(None), - packet_type => { - if let Some(extension) = extensions - .iter_mut() - .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) - { - extension - .handle_packet( - packet_type, - BytesMut::from(bytes).freeze(), - DynWebSocketRead::from_mut(read), - write, - ) - .await?; - Ok(None) - } else { - Err(WispError::InvalidPacketType) + + let ty = packet.get_u8(); + let stream_id = packet.get_u32_le(); + + if ty == 0x05 { + Ok(Self::Info(InfoPacket::decode(&mut packet, builders, role)?)) + } else { + Ok(Self::Packet(Packet { + stream_id, + packet_type: PacketType::decode(packet, ty)?, + })) + } + } +} + +pub(crate) enum MaybeExtensionPacket<'a> { + Packet(Packet<'a>), + ExtensionHandled, +} + +impl MaybeExtensionPacket<'static> { + pub(crate) async fn decode( + mut packet: Payload, + extensions: &mut [AnyProtocolExtension], + rx: &mut dyn WebSocketRead, + tx: &mut LockedWebSocketWrite, + ) -> Result { + if packet.remaining() < size_of::() + size_of::() { + return Err(WispError::PacketTooSmall); + } + + let ty = packet.get_u8(); + let stream_id = packet.get_u32_le(); + + if (0x01..=0x04).contains(&ty) { + Ok(Self::Packet(Packet { + stream_id, + packet_type: PacketType::decode(packet, ty)?, + })) + } else { + tx.lock().await; + let mut handle = tx.get_handle(); + for extension in extensions { + if extension.get_supported_packets().contains(&ty) { + extension.handle_packet(ty, packet, rx, &mut handle).await?; + return Ok(Self::ExtensionHandled); } } + drop(handle); + + Err(WispError::InvalidPacketType(ty)) } } } -impl Encode for Packet<'_> { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u8(self.packet_type.as_u8()); - bytes.put_u32_le(self.stream_id); - self.packet_type.encode(bytes); - } +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Packet<'a> { + pub stream_id: u32, + pub packet_type: PacketType<'a>, } -impl<'a> TryFrom> for Packet<'a> { - type Error = WispError; - fn try_from(mut bytes: Payload<'a>) -> Result { - if bytes.remaining() < 1 { - return Err(Self::Error::PacketTooSmall); +impl Packet<'_> { + fn size_hint(&self) -> usize { + size_of::() + size_of::() + self.packet_type.size_hint() + } + + fn encode_into(&self, packet: &mut PayloadMut) { + packet.put_u8(self.packet_type.get_type()); + packet.put_u32_le(self.stream_id); + self.packet_type.encode(packet); + } + + pub(crate) fn encode(&self) -> Payload { + let mut payload = PayloadMut::with_capacity(self.size_hint()); + self.encode_into(&mut payload); + payload.into() + } + + pub(crate) fn decode(mut packet: Payload) -> Result, WispError> { + if packet.remaining() < size_of::() + size_of::() { + return Err(WispError::PacketTooSmall); } - let packet_type = bytes.get_u8(); - Self::parse_packet(packet_type, bytes) - } -} -impl From> for BytesMut { - fn from(packet: Packet) -> Self { - let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size()); - packet.encode(&mut encoded); - encoded - } -} + let ty = packet.get_u8(); + let stream_id = packet.get_u32_le(); -impl<'a> TryFrom> for Packet<'a> { - type Error = WispError; - fn try_from(frame: ws::Frame<'a>) -> Result { - if !frame.finished { - return Err(Self::Error::WsFrameNotFinished); + Ok(Packet { + stream_id, + packet_type: PacketType::decode(packet, ty)?, + }) + } + + pub fn new_data<'a>(stream_id: u32, data: impl Into>) -> Packet<'a> { + Packet { + stream_id, + packet_type: PacketType::Data(data.into()), } - if frame.opcode != ws::OpCode::Binary { - return Err(Self::Error::WsFrameInvalidType(frame.opcode)); - } - Packet::try_from(frame.payload) } -} -impl From> for ws::Frame<'static> { - fn from(packet: Packet) -> Self { - Self::binary(Payload::Bytes(BytesMut::from(packet))) + pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self { + Self { + stream_id, + packet_type: PacketType::Continue(ContinuePacket { buffer_remaining }), + } + } + + pub fn new_close(stream_id: u32, reason: CloseReason) -> Self { + Self { + stream_id, + packet_type: PacketType::Close(ClosePacket { reason }), + } } } diff --git a/wisp/src/stream/compat.rs b/wisp/src/stream/compat.rs index 21f54b3..e9d188f 100644 --- a/wisp/src/stream/compat.rs +++ b/wisp/src/stream/compat.rs @@ -1,338 +1,240 @@ use std::{ + io, pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll}, + sync::Arc, + task::{ready, Context, Poll}, }; -use bytes::BytesMut; use futures::{ - ready, stream::IntoAsyncRead, task::noop_waker_ref, AsyncBufRead, AsyncRead, AsyncWrite, Sink, - Stream, TryStreamExt, + channel::oneshot, stream::IntoAsyncRead, AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, + SinkExt, Stream, StreamExt, TryStreamExt, }; -use pin_project_lite::pin_project; +use pin_project::pin_project; -use crate::{ws::Payload, AtomicCloseReason, CloseReason, WispError}; +use crate::{ + locked_sink::LockedWebSocketWrite, + packet::{ClosePacket, CloseReason, Packet}, + ws::{Payload, WebSocketWrite}, + WispError, +}; -pin_project! { - /// Multiplexor stream that implements futures `Stream + Sink`. - pub struct MuxStreamIo { - #[pin] - pub(crate) rx: MuxStreamIoStream, - #[pin] - pub(crate) tx: MuxStreamIoSink, +use super::{MuxStream, MuxStreamRead, MuxStreamWrite, StreamInfo, WsEvent}; + +struct MapToIo(MuxStreamRead); + +impl Stream for MapToIo { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_next_unpin(cx).map_err(std::io::Error::other) } } -impl MuxStreamIo { - /// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. - pub fn into_asyncrw(self) -> MuxStreamAsyncRW { - MuxStreamAsyncRW { - rx: self.rx.into_asyncread(), - tx: self.tx.into_asyncwrite(), - } - } - - /// Get the stream's close reason, if it was closed. - pub fn get_close_reason(&self) -> Option { - self.rx.get_close_reason() - } - - /// Split the stream into read and write parts, consuming it. - pub fn into_split(self) -> (MuxStreamIoStream, MuxStreamIoSink) { - (self.rx, self.tx) - } +// TODO: don't use `futures` for this so get_close_reason etc can be implemented +#[pin_project] +pub struct MuxStreamAsyncRead { + #[pin] + inner: IntoAsyncRead>, } -impl Stream for MuxStreamIo { - type Item = Result, std::io::Error>; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_next(cx) - } -} - -impl Sink for MuxStreamIo { - type Error = std::io::Error; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_ready(cx) - } - fn start_send(self: Pin<&mut Self>, item: BytesMut) -> Result<(), Self::Error> { - self.project().tx.start_send(item) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_flush(cx) - } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_close(cx) - } -} - -pin_project! { - /// Read side of a multiplexor stream that implements futures `Stream`. - pub struct MuxStreamIoStream { - #[pin] - pub(crate) rx: Pin, WispError>> + Send>>, - pub(crate) is_closed: Arc, - pub(crate) close_reason: Arc, - } -} - -impl MuxStreamIoStream { - /// Turn the stream into one that implements futures `AsyncRead + AsyncBufRead`. - pub fn into_asyncread(self) -> MuxStreamAsyncRead { - MuxStreamAsyncRead::new(self) - } - - /// Get the stream's close reason, if it was closed. - pub fn get_close_reason(&self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - Some(self.close_reason.load(Ordering::Acquire)) - } else { - None +impl MuxStreamAsyncRead { + pub(crate) fn new(inner: MuxStreamRead) -> Self { + Self { + inner: MapToIo(inner).into_async_read(), } } } -impl Stream for MuxStreamIoStream { - type Item = Result, std::io::Error>; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .rx - .poll_next(cx) - .map_err(std::io::Error::other) - } -} - -pin_project! { - /// Write side of a multiplexor stream that implements futures `Sink`. - pub struct MuxStreamIoSink { - #[pin] - pub(crate) tx: Pin, Error = WispError> + Send>>, - pub(crate) is_closed: Arc, - pub(crate) close_reason: Arc, - } -} - -impl MuxStreamIoSink { - /// Turn the sink into one that implements futures `AsyncWrite`. - pub fn into_asyncwrite(self) -> MuxStreamAsyncWrite { - MuxStreamAsyncWrite::new(self) - } - - /// Get the stream's close reason, if it was closed. - pub fn get_close_reason(&self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - Some(self.close_reason.load(Ordering::Acquire)) - } else { - None - } - } -} - -impl Sink for MuxStreamIoSink { - type Error = std::io::Error; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_ready(cx) - .map_err(std::io::Error::other) - } - fn start_send(self: Pin<&mut Self>, item: BytesMut) -> Result<(), Self::Error> { - self.project() - .tx - .start_send(Payload::Bytes(item)) - .map_err(std::io::Error::other) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_flush(cx) - .map_err(std::io::Error::other) - } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .tx - .poll_close(cx) - .map_err(std::io::Error::other) - } -} - -pin_project! { - /// Multiplexor stream that implements futures `AsyncRead + AsyncBufRead + AsyncWrite`. - pub struct MuxStreamAsyncRW { - #[pin] - rx: MuxStreamAsyncRead, - #[pin] - tx: MuxStreamAsyncWrite, - } -} - -impl MuxStreamAsyncRW { - /// Get the stream's close reason, if it was closed. - pub fn get_close_reason(&self) -> Option { - self.rx.get_close_reason() - } - - /// Split the stream into read and write parts, consuming it. - pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) { - (self.rx, self.tx) - } -} - -impl AsyncRead for MuxStreamAsyncRW { +impl AsyncRead for MuxStreamAsyncRead { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], - ) -> Poll> { - self.project().rx.poll_read(cx, buf) + ) -> Poll> { + self.project().inner.poll_read(cx, buf) } fn poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, - bufs: &mut [std::io::IoSliceMut<'_>], - ) -> Poll> { - self.project().rx.poll_read_vectored(cx, bufs) + bufs: &mut [io::IoSliceMut<'_>], + ) -> Poll> { + self.project().inner.poll_read_vectored(cx, bufs) } } -impl AsyncBufRead for MuxStreamAsyncRW { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_fill_buf(cx) +impl AsyncBufRead for MuxStreamAsyncRead { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_fill_buf(cx) } fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().rx.consume(amt); + self.project().inner.consume(amt); } } -impl AsyncWrite for MuxStreamAsyncRW { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().tx.poll_write(cx, buf) - } +pub struct MuxStreamAsyncWrite { + inner: flume::r#async::SendSink<'static, WsEvent>, + write: LockedWebSocketWrite, + info: Arc, - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_close(cx) - } + oneshot: Option>>, } -pin_project! { - /// Read side of a multiplexor stream that implements futures `AsyncRead + AsyncBufRead`. - pub struct MuxStreamAsyncRead { - #[pin] - rx: IntoAsyncRead, - is_closed: Arc, - close_reason: Arc, - } -} - -impl MuxStreamAsyncRead { - pub(crate) fn new(stream: MuxStreamIoStream) -> Self { +impl MuxStreamAsyncWrite { + pub(crate) fn new(inner: MuxStreamWrite) -> Self { Self { - is_closed: stream.is_closed.clone(), - close_reason: stream.close_reason.clone(), - rx: stream.into_async_read(), + inner: inner.inner, + write: inner.write, + info: inner.info, + + oneshot: None, } } /// Get the stream's close reason, if it was closed. pub fn get_close_reason(&self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - Some(self.close_reason.load(Ordering::Acquire)) - } else { - None - } + self.inner.is_disconnected().then(|| self.info.get_reason()) } } -impl AsyncRead for MuxStreamAsyncRead { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.project().rx.poll_read(cx, buf) - } -} -impl AsyncBufRead for MuxStreamAsyncRead { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_fill_buf(cx) - } - fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().rx.consume(amt); - } -} - -pin_project! { - /// Write side of a multiplexor stream that implements futures `AsyncWrite`. - pub struct MuxStreamAsyncWrite { - #[pin] - tx: MuxStreamIoSink, - error: Option - } -} - -impl MuxStreamAsyncWrite { - pub(crate) fn new(sink: MuxStreamIoSink) -> Self { - Self { - tx: sink, - error: None, - } - } - - /// Get the stream's close reason, if it was closed. - pub fn get_close_reason(&self) -> Option { - self.tx.get_close_reason() - } -} - -impl AsyncWrite for MuxStreamAsyncWrite { +impl AsyncWrite for MuxStreamAsyncWrite { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], - ) -> Poll> { - if let Some(err) = self.error.take() { - return Poll::Ready(Err(err)); - } + ) -> Poll> { + ready!(self.write.poll_lock(cx)); + ready!(self.write.get().poll_flush(cx)).map_err(io::Error::other)?; + ready!(self.write.get().poll_ready(cx)).map_err(io::Error::other)?; - let mut this = self.as_mut().project(); + let packet = Packet::new_data(self.info.id, buf); + self.write + .get() + .start_send(packet.encode()) + .map_err(io::Error::other)?; - ready!(this.tx.as_mut().poll_ready(cx))?; - match this.tx.as_mut().start_send(buf.into()) { - Ok(()) => { - let mut cx = Context::from_waker(noop_waker_ref()); - let cx = &mut cx; - - match this.tx.poll_flush(cx) { - Poll::Ready(Err(err)) => { - self.error = Some(err); - } - Poll::Ready(Ok(())) | Poll::Pending => {} - } - - Poll::Ready(Ok(buf.len())) - } - Err(e) => Poll::Ready(Err(e)), - } + self.write.unlock(); + Poll::Ready(Ok(buf.len())) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_flush(cx) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.write.poll_lock(cx)); + ready!(self.write.get().poll_flush(cx)).map_err(io::Error::other)?; + self.write.unlock(); + Poll::Ready(Ok(())) } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().tx.poll_close(cx) + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(oneshot) = &mut self.oneshot { + let ret = ready!(oneshot.poll_unpin(cx)); + self.oneshot.take(); + Poll::Ready( + ret.map_err(|_| io::Error::other(WispError::MuxMessageFailedToSend))? + .map_err(io::Error::other), + ) + } else { + ready!(self.as_mut().poll_flush(cx))?; + + ready!(self.inner.poll_ready_unpin(cx)) + .map_err(|_| io::Error::other(WispError::MuxMessageFailedToSend))?; + + let (tx, rx) = oneshot::channel(); + self.oneshot = Some(rx); + + let pkt = WsEvent::Close( + self.info.id, + ClosePacket { + reason: CloseReason::Unknown, + }, + tx, + ); + + self.inner + .start_send_unpin(pkt) + .map_err(|_| io::Error::other(WispError::MuxMessageFailedToSend))?; + + Poll::Pending + } + } +} + +#[pin_project] +pub struct MuxStreamAsyncRW { + #[pin] + read: MuxStreamAsyncRead, + #[pin] + write: MuxStreamAsyncWrite, +} + +impl MuxStreamAsyncRW { + pub(crate) fn new(old: MuxStream) -> Self { + Self { + read: MuxStreamAsyncRead::new(old.read), + write: MuxStreamAsyncWrite::new(old.write), + } + } + + pub fn into_split(self) -> (MuxStreamAsyncRead, MuxStreamAsyncWrite) { + (self.read, self.write) + } + + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + self.write.get_close_reason() + } +} + +impl AsyncRead for MuxStreamAsyncRW { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().read.poll_read(cx, buf) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [io::IoSliceMut<'_>], + ) -> Poll> { + self.project().read.poll_read_vectored(cx, bufs) + } +} + +impl AsyncBufRead for MuxStreamAsyncRW { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().read.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().read.consume(amt); + } +} + +impl AsyncWrite for MuxStreamAsyncRW { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().write.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.project().write.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_close(cx) } } diff --git a/wisp/src/stream/handles.rs b/wisp/src/stream/handles.rs new file mode 100644 index 0000000..9a8b4b7 --- /dev/null +++ b/wisp/src/stream/handles.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; + +use futures::channel::oneshot; + +use crate::{ + packet::{ClosePacket, CloseReason}, + ws::WebSocketWrite, + WispError, +}; + +use super::{StreamInfo, WsEvent}; + +/// Close handle for a multiplexor stream. +#[derive(Clone)] +pub struct MuxStreamCloser { + pub(crate) info: Arc, + pub(crate) inner: flume::Sender>, +} + +impl MuxStreamCloser { + /// Close the stream. You will no longer be able to write or read after this has been called. + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { + if self.inner.is_disconnected() { + return Err(WispError::StreamAlreadyClosed); + } + + let (tx, rx) = oneshot::channel::>(); + let evt = WsEvent::Close(self.info.id, ClosePacket { reason }, tx); + + self.inner + .send_async(evt) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; + + Ok(()) + } + + /// Get the stream's close reason, if it was closed. + pub fn get_close_reason(&self) -> Option { + self.inner.is_disconnected().then(|| self.info.get_reason()) + } +} diff --git a/wisp/src/stream/mod.rs b/wisp/src/stream/mod.rs index 8f5c704..9c26a68 100644 --- a/wisp/src/stream/mod.rs +++ b/wisp/src/stream/mod.rs @@ -1,407 +1,183 @@ -mod compat; -mod sink_unfold; -pub use compat::*; - -use crate::{ - inner::WsEvent, - ws::{Frame, LockedWebSocketWrite, Payload, WebSocketWrite}, - AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError, -}; - -use bytes::{BufMut, Bytes, BytesMut}; -use event_listener::Event; -use flume as mpsc; -use futures::{channel::oneshot, select, stream, FutureExt, Sink, Stream}; use std::{ pin::Pin, - sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, - }, + sync::Arc, + task::{ready, Context, Poll}, }; -/// Read side of a multiplexor stream. -pub struct MuxStreamRead { - /// ID of the stream. - pub stream_id: u32, - /// Type of the stream. - pub stream_type: StreamType, +use futures::{channel::oneshot, FutureExt, Sink, SinkExt, Stream, StreamExt}; - role: Role, +use crate::{ + mux::inner::{FlowControl, StreamInfo, WsEvent}, + packet::{ClosePacket, CloseReason, Packet}, + ws::{Payload, WebSocketWrite}, + LockedWebSocketWrite, WispError, +}; - tx: LockedWebSocketWrite, - rx: mpsc::Receiver>, +mod compat; +mod handles; +pub use compat::*; +pub use handles::*; - is_closed: Arc, - is_closed_event: Arc, - close_reason: Arc, - - should_flow_control: bool, - flow_control: Arc, - flow_control_read: AtomicU32, - target_flow_control: u32, +macro_rules! unlock_some { + ($unlock:expr, $x:expr) => { + if let Err(err) = $x { + $unlock.unlock(); + return Poll::Ready(Some(Err(err))); + } + }; +} +macro_rules! unlock { + ($unlock:expr, $x:expr) => { + if let Err(err) = $x { + $unlock.unlock(); + return Poll::Ready(Err(err)); + } + }; } -impl MuxStreamRead { - /// Read an event from the stream. - pub async fn read(&self) -> Result>, WispError> { - if self.rx.is_empty() && self.is_closed.load(Ordering::Acquire) { - return Ok(None); - } - let bytes = select! { - x = self.rx.recv_async() => x.map_err(|_| WispError::MuxMessageFailedToRecv)?, - () = self.is_closed_event.listen().fuse() => return Ok(None) - }; - if self.role == Role::Server && self.should_flow_control { - let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; - if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) { - self.tx - .write_frame( - Packet::new_continue( - self.stream_id, - self.flow_control.fetch_add(val, Ordering::AcqRel) + val, - ) - .into(), - ) - .await?; - self.flow_control_read.store(0, Ordering::Release); - } - } - Ok(Some(bytes)) - } +pub struct MuxStreamRead { + inner: flume::r#async::RecvStream<'static, Payload>, + write: LockedWebSocketWrite, + info: Arc, - pub(crate) fn into_inner_stream( - self, - ) -> Pin, WispError>> + Send>> { - Box::pin(stream::unfold(self, |rx| async move { - Some((rx.read().await.transpose()?, rx)) - })) - } - - /// Turn the read half into one that implements futures `Stream`, consuming it. - pub fn into_stream(self) -> MuxStreamIoStream { - MuxStreamIoStream { - close_reason: self.close_reason.clone(), - is_closed: self.is_closed.clone(), - rx: self.into_inner_stream(), - } - } - - /// Get the stream's close reason, if it was closed. - pub fn get_close_reason(&self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - Some(self.close_reason.load(Ordering::Acquire)) - } else { - None - } - } + read_cnt: u32, + chunk: Option, } -/// Write side of a multiplexor stream. -pub struct MuxStreamWrite { - /// ID of the stream. - pub stream_id: u32, - /// Type of the stream. - pub stream_type: StreamType, - - role: Role, - mux_tx: mpsc::Sender>, - tx: LockedWebSocketWrite, - - is_closed: Arc, - close_reason: Arc, - - continue_recieved: Arc, - should_flow_control: bool, - flow_control: Arc, -} - -impl MuxStreamWrite { - pub(crate) async fn write_payload_internal<'a>( - &self, - header: Frame<'static>, - body: Frame<'a>, - ) -> Result<(), WispError> { - if self.role == Role::Client - && self.should_flow_control - && self.flow_control.load(Ordering::Acquire) == 0 - { - self.continue_recieved.listen().await; - } - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - - self.tx.write_split(header, body).await?; - - if self.role == Role::Client && self.stream_type == StreamType::Tcp { - self.flow_control.store( - self.flow_control.load(Ordering::Acquire).saturating_sub(1), - Ordering::Release, - ); - } - Ok(()) - } - - /// Write a payload to the stream. - pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> { - let frame: Frame<'static> = Frame::from(Packet::new_data( - self.stream_id, - Payload::Bytes(BytesMut::new()), - )); - self.write_payload_internal(frame, Frame::binary(data)) - .await - } - - /// Write data to the stream. - pub async fn write>(&self, data: D) -> Result<(), WispError> { - self.write_payload(Payload::Borrowed(data.as_ref())).await - } - - /// Get a handle to close the connection. - /// - /// Useful to close the connection without having access to the stream. - /// - /// # Example - /// ``` - /// let handle = stream.get_close_handle(); - /// if let Err(error) = handle_stream(stream) { - /// handle.close(0x01); - /// } - /// ``` - pub fn get_close_handle(&self) -> MuxStreamCloser { - MuxStreamCloser { - stream_id: self.stream_id, - close_channel: self.mux_tx.clone(), - is_closed: self.is_closed.clone(), - close_reason: self.close_reason.clone(), - } - } - - /// Get a protocol extension stream to send protocol extension packets. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - MuxProtocolExtensionStream { - stream_id: self.stream_id, - tx: self.tx.clone(), - is_closed: self.is_closed.clone(), - } - } - - /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } - self.is_closed.store(true, Ordering::Release); - - let (tx, rx) = oneshot::channel::>(); - self.mux_tx - .send_async(WsEvent::Close( - Packet::new_close(self.stream_id, reason), - tx, - )) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; - - Ok(()) - } - - /// Get the stream's close reason, if it was closed. - pub fn get_close_reason(&self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - Some(self.close_reason.load(Ordering::Acquire)) - } else { - None - } - } - - pub(crate) fn into_inner_sink( - self, - ) -> Pin, Error = WispError> + Send>> { - let handle = self.get_close_handle(); - Box::pin(sink_unfold::unfold( - self, - |tx, data| async move { - tx.write_payload(data).await?; - Ok(tx) - }, - handle, - |handle| async move { - handle.close(CloseReason::Unknown).await?; - Ok(handle) - }, - )) - } - - /// Turn the write half into one that implements futures `Sink`, consuming it. - pub fn into_sink(self) -> MuxStreamIoSink { - MuxStreamIoSink { - close_reason: self.close_reason.clone(), - is_closed: self.is_closed.clone(), - tx: self.into_inner_sink(), - } - } -} - -impl Drop for MuxStreamWrite { - fn drop(&mut self) { - if !self.is_closed.load(Ordering::Acquire) { - self.is_closed.store(true, Ordering::Release); - let (tx, _) = oneshot::channel(); - let _ = self.mux_tx.send(WsEvent::Close( - Packet::new_close(self.stream_id, CloseReason::Unknown), - tx, - )); - } - } -} - -/// Multiplexor stream. -pub struct MuxStream { - /// ID of the stream. - pub stream_id: u32, - rx: MuxStreamRead, - tx: MuxStreamWrite, -} - -impl MuxStream { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - stream_id: u32, - role: Role, - stream_type: StreamType, - rx: mpsc::Receiver>, - mux_tx: mpsc::Sender>, - tx: LockedWebSocketWrite, - is_closed: Arc, - is_closed_event: Arc, - close_reason: Arc, - should_flow_control: bool, - flow_control: Arc, - continue_recieved: Arc, - target_flow_control: u32, +impl MuxStreamRead { + fn new( + inner: flume::Receiver, + write: LockedWebSocketWrite, + info: Arc, ) -> Self { Self { - stream_id, - rx: MuxStreamRead { - stream_id, - stream_type, - role, + inner: inner.into_stream(), + write, + info, - tx: tx.clone(), - rx, - - is_closed: is_closed.clone(), - is_closed_event, - close_reason: close_reason.clone(), - - should_flow_control, - flow_control: flow_control.clone(), - flow_control_read: AtomicU32::new(0), - target_flow_control, - }, - tx: MuxStreamWrite { - stream_id, - stream_type, - role, - - mux_tx, - tx, - - is_closed, - close_reason, - - continue_recieved, - should_flow_control, - flow_control, - }, + chunk: None, + read_cnt: 0, } } - /// Read an event from the stream. - pub async fn read(&self) -> Result>, WispError> { - self.rx.read().await + pub fn get_stream_id(&self) -> u32 { + self.info.id } - /// Write a payload to the stream. - pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> { - self.tx.write_payload(data).await - } - - /// Write data to the stream. - pub async fn write>(&self, data: D) -> Result<(), WispError> { - self.tx.write(data).await - } - - /// Get a handle to close the connection. - /// - /// Useful to close the connection without having access to the stream. - /// - /// # Example - /// ``` - /// let handle = stream.get_close_handle(); - /// if let Err(error) = handle_stream(stream) { - /// handle.close(0x01); - /// } - /// ``` - pub fn get_close_handle(&self) -> MuxStreamCloser { - self.tx.get_close_handle() - } - - /// Get a protocol extension stream to send protocol extension packets. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - self.tx.get_protocol_extension_stream() - } - - /// Get the stream's close reason, if it was closed. pub fn get_close_reason(&self) -> Option { - self.rx.get_close_reason() + self.inner.is_disconnected().then(|| self.info.get_reason()) } - /// Close the stream. You will no longer be able to write or read after this has been called. - pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { - self.tx.close(reason).await + pub fn into_async_read(self) -> MuxStreamAsyncRead { + MuxStreamAsyncRead::new(self) } +} - /// Split the stream into read and write parts, consuming it. - pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { - (self.rx, self.tx) +impl Stream for MuxStreamRead { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.inner.is_disconnected() { + return Poll::Ready(None); + } + + let was_reading = self.chunk.is_some(); + let chunk = if let Some(chunk) = self.chunk.take() { + chunk + } else { + let Some(chunk) = ready!(self.inner.poll_next_unpin(cx)) else { + return Poll::Ready(None); + }; + chunk + }; + + macro_rules! ready { + ($x:expr) => { + match $x { + Poll::Ready(x) => x, + Poll::Pending => { + self.chunk = Some(chunk); + return Poll::Pending; + } + } + }; + } + + if self.info.flow_status == FlowControl::EnabledSendMessages { + if !was_reading { + self.read_cnt += 1; + } + + if self.read_cnt > self.info.target_flow_control { + ready!(self.write.poll_lock(cx)); + unlock_some!(self.write, ready!(self.write.get().poll_ready(cx))); + let pkt = + Packet::new_continue(self.info.id, self.info.flow_add(self.read_cnt)).encode(); + unlock_some!(self.write, self.write.get().start_send(pkt)); + self.write.unlock(); + + self.read_cnt = 0; + } + } + + Poll::Ready(Some(Ok(chunk))) } +} - /// Turn the stream into one that implements futures `Stream + Sink`, consuming it. - pub fn into_io(self) -> MuxStreamIo { - MuxStreamIo { - rx: self.rx.into_stream(), - tx: self.tx.into_sink(), +pub struct MuxStreamWrite { + inner: flume::r#async::SendSink<'static, WsEvent>, + write: LockedWebSocketWrite, + info: Arc, + + chunk: Option, + + oneshot: Option>>, +} + +impl MuxStreamWrite { + fn new( + inner: flume::Sender>, + write: LockedWebSocketWrite, + info: Arc, + ) -> Self { + Self { + inner: inner.into_sink(), + write, + info, + + chunk: None, + + oneshot: None, } } -} -/// Close handle for a multiplexor stream. -#[derive(Clone)] -pub struct MuxStreamCloser { - /// ID of the stream. - pub stream_id: u32, - close_channel: mpsc::Sender>, - is_closed: Arc, - close_reason: Arc, -} + pub fn get_stream_id(&self) -> u32 { + self.info.id + } + + pub fn get_close_reason(&self) -> Option { + self.inner.is_disconnected().then(|| self.info.get_reason()) + } + + pub fn get_close_handle(&self) -> MuxStreamCloser { + MuxStreamCloser { + info: self.info.clone(), + inner: self.inner.sender().clone(), + } + } -impl MuxStreamCloser { /// Close the stream. You will no longer be able to write or read after this has been called. pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { + if self.inner.is_disconnected() { return Err(WispError::StreamAlreadyClosed); } - self.is_closed.store(true, Ordering::Release); let (tx, rx) = oneshot::channel::>(); - self.close_channel - .send_async(WsEvent::Close( - Packet::new_close(self.stream_id, reason), - tx, - )) + let evt = WsEvent::Close(self.info.id, ClosePacket { reason }, tx); + + self.inner + .sender() + .send_async(evt) .await .map_err(|_| WispError::MuxMessageFailedToSend)?; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; @@ -409,36 +185,170 @@ impl MuxStreamCloser { Ok(()) } - /// Get the stream's close reason, if it was closed. - pub fn get_close_reason(&self) -> Option { - if self.is_closed.load(Ordering::Acquire) { - Some(self.close_reason.load(Ordering::Acquire)) + pub fn into_async_write(self) -> MuxStreamAsyncWrite { + MuxStreamAsyncWrite::new(self) + } + + fn maybe_write(&mut self) -> Result<(), WispError> { + if let Some(chunk) = self.chunk.take() { + let packet = Packet::new_data(self.info.id, chunk).encode(); + self.write.get().start_send(packet)?; + + if self.info.flow_status == FlowControl::EnabledTrackAmount { + self.info.flow_dec(); + } + } + + Ok(()) + } +} + +impl Sink for MuxStreamWrite { + type Error = WispError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.inner.is_disconnected() { + return Poll::Ready(Err(WispError::StreamAlreadyClosed)); + } + + if self.info.flow_status == FlowControl::EnabledTrackAmount && self.info.flow_empty() { + self.info.flow_register(cx); + return Poll::Pending; + } + + if self.chunk.is_some() { + ready!(self.write.poll_lock(cx)); + unlock!(self.write, ready!(self.write.get().poll_ready(cx))); + unlock!(self.write, self.maybe_write()); + self.write.unlock(); + } + + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Payload) -> Result<(), Self::Error> { + debug_assert!(self.chunk.is_none()); + self.chunk = Some(item); + + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.write.poll_lock(cx)); + + if self.chunk.is_some() { + unlock!(self.write, ready!(self.write.get().poll_ready(cx))); + unlock!(self.write, self.maybe_write()); + } + unlock!(self.write, ready!(self.write.get().poll_flush(cx))); + + self.write.unlock(); + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(oneshot) = &mut self.oneshot { + let ret = ready!(oneshot.poll_unpin(cx)); + self.oneshot.take(); + Poll::Ready(ret.map_err(|_| WispError::MuxMessageFailedToSend)?) } else { - None + ready!(self.as_mut().poll_flush(cx))?; + + ready!(self.inner.poll_ready_unpin(cx)) + .map_err(|_| WispError::MuxMessageFailedToSend)?; + + let (tx, rx) = oneshot::channel(); + self.oneshot = Some(rx); + + let pkt = WsEvent::Close( + self.info.id, + ClosePacket { + reason: CloseReason::Unknown, + }, + tx, + ); + + self.inner + .start_send_unpin(pkt) + .map_err(|_| WispError::MuxMessageFailedToSend)?; + + Poll::Pending } } } -/// Stream for sending arbitrary protocol extension packets. -pub struct MuxProtocolExtensionStream { - /// ID of the stream. - pub stream_id: u32, - pub(crate) tx: LockedWebSocketWrite, - pub(crate) is_closed: Arc, +pub struct MuxStream { + read: MuxStreamRead, + write: MuxStreamWrite, } -impl MuxProtocolExtensionStream { - /// Send a protocol extension packet with this stream's ID. - pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); +impl MuxStream { + pub(crate) fn new( + rx: flume::Receiver, + tx: flume::Sender>, + ws: LockedWebSocketWrite, + info: Arc, + ) -> Self { + Self { + read: MuxStreamRead::new(rx, ws.clone(), info.clone()), + write: MuxStreamWrite::new(tx, ws, info), } - let mut encoded = BytesMut::with_capacity(1 + 4 + data.len()); - encoded.put_u8(packet_type); - encoded.put_u32_le(self.stream_id); - encoded.extend(data); - self.tx - .write_frame(Frame::binary(Payload::Bytes(encoded))) - .await + } + + pub fn get_stream_id(&self) -> u32 { + self.read.get_stream_id() + } + + pub fn get_close_reason(&self) -> Option { + self.read.get_close_reason() + } + + pub fn get_close_handle(&self) -> MuxStreamCloser { + self.write.get_close_handle() + } + + /// Close the stream. You will no longer be able to write or read after this has been called. + pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { + self.write.close(reason).await + } + + pub fn into_async_rw(self) -> MuxStreamAsyncRW { + MuxStreamAsyncRW::new(self) + } + + pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { + (self.read, self.write) + } +} + +impl Stream for MuxStream { + type Item = as Stream>::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.read.poll_next_unpin(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.read.size_hint() + } +} + +impl Sink for MuxStream { + type Error = as Sink>::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.write.poll_ready_unpin(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Payload) -> Result<(), Self::Error> { + self.write.start_send_unpin(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.write.poll_flush_unpin(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.write.poll_close_unpin(cx) } } diff --git a/wisp/src/stream/sink_unfold.rs b/wisp/src/stream/sink_unfold.rs deleted file mode 100644 index 852abce..0000000 --- a/wisp/src/stream/sink_unfold.rs +++ /dev/null @@ -1,146 +0,0 @@ -//! futures sink unfold with a close function -use core::{future::Future, pin::Pin}; -use futures::{ - ready, - task::{Context, Poll}, - Sink, -}; -use pin_project_lite::pin_project; - -pin_project! { - /// UnfoldState used for stream and sink unfolds - #[project = UnfoldStateProj] - #[project_replace = UnfoldStateProjReplace] - #[derive(Debug)] - pub(crate) enum UnfoldState { - Value { - value: T, - }, - Future { - #[pin] - future: Fut, - }, - Empty, - } -} - -impl UnfoldState { - pub(crate) fn project_future(self: Pin<&mut Self>) -> Option> { - match self.project() { - UnfoldStateProj::Future { future } => Some(future), - _ => None, - } - } - - pub(crate) fn take_value(self: Pin<&mut Self>) -> Option { - match &*self { - Self::Value { .. } => match self.project_replace(Self::Empty) { - UnfoldStateProjReplace::Value { value } => Some(value), - _ => unreachable!(), - }, - _ => None, - } - } -} - -pin_project! { - /// Sink for the [`unfold`] function. - #[derive(Debug)] - #[must_use = "sinks do nothing unless polled"] - pub struct Unfold { - function: F, - close_function: CF, - #[pin] - state: UnfoldState, - #[pin] - close_state: UnfoldState - } -} - -pub(crate) fn unfold( - init: T, - function: F, - close_init: CT, - close_function: CF, -) -> Unfold -where - F: FnMut(T, Item) -> R, - R: Future>, - CF: FnMut(CT) -> CR, - CR: Future>, -{ - Unfold { - function, - close_function, - state: UnfoldState::Value { value: init }, - close_state: UnfoldState::Value { value: close_init }, - } -} - -impl Sink for Unfold -where - F: FnMut(T, Item) -> R, - R: Future>, - CF: FnMut(CT) -> CR, - CR: Future>, -{ - type Error = E; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush(cx) - } - - fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { - let mut this = self.project(); - let future = match this.state.as_mut().take_value() { - Some(value) => (this.function)(value, item), - None => panic!("start_send called without poll_ready being called first"), - }; - this.state.set(UnfoldState::Future { future }); - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - Poll::Ready(if let Some(future) = this.state.as_mut().project_future() { - match ready!(future.poll(cx)) { - Ok(state) => { - this.state.set(UnfoldState::Value { value: state }); - Ok(()) - } - Err(err) => { - this.state.set(UnfoldState::Empty); - Err(err) - } - } - } else { - Ok(()) - }) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().poll_flush(cx))?; - let mut this = self.project(); - Poll::Ready( - if let Some(future) = this.close_state.as_mut().project_future() { - match ready!(future.poll(cx)) { - Ok(state) => { - this.close_state.set(UnfoldState::Value { value: state }); - Ok(()) - } - Err(err) => { - this.close_state.set(UnfoldState::Empty); - Err(err) - } - } - } else { - let future = match this.close_state.as_mut().take_value() { - Some(value) => (this.close_function)(value), - None => panic!("start_send called without poll_ready being called first"), - }; - this.close_state.set(UnfoldState::Future { future }); - return Poll::Pending; - }, - ) - } -} diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs deleted file mode 100644 index 002567e..0000000 --- a/wisp/src/ws.rs +++ /dev/null @@ -1,557 +0,0 @@ -//! Abstraction over WebSocket implementations. -//! -//! Use the [`fastwebsockets`] implementation of these traits as an example for implementing them -//! for other WebSocket implementations. -//! -//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs -use std::{future::Future, ops::Deref, pin::Pin, sync::Arc}; - -use crate::WispError; -use bytes::{Buf, Bytes, BytesMut}; -use futures::{lock::Mutex, TryFutureExt}; - -/// Payload of the websocket frame. -#[derive(Debug)] -pub enum Payload<'a> { - /// Borrowed payload. Currently used when writing data. - Borrowed(&'a [u8]), - /// `BytesMut` payload. Currently used when reading data. - Bytes(BytesMut), -} - -impl From for Payload<'static> { - fn from(value: BytesMut) -> Self { - Self::Bytes(value) - } -} - -impl<'a> From<&'a [u8]> for Payload<'a> { - fn from(value: &'a [u8]) -> Self { - Self::Borrowed(value) - } -} - -impl Payload<'_> { - /// Turn a Payload<'a> into a Payload<'static> by copying the data. - #[must_use] - pub fn into_owned(self) -> Self { - match self { - Self::Bytes(x) => Self::Bytes(x), - Self::Borrowed(x) => Self::Bytes(BytesMut::from(x)), - } - } -} - -impl From> for BytesMut { - fn from(value: Payload<'_>) -> Self { - match value { - Payload::Bytes(x) => x, - Payload::Borrowed(x) => x.into(), - } - } -} - -impl From> for Bytes { - fn from(value: Payload<'static>) -> Self { - match value { - Payload::Bytes(x) => x.freeze(), - Payload::Borrowed(x) => x.into(), - } - } -} - -impl Deref for Payload<'_> { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - match self { - Self::Bytes(x) => x, - Self::Borrowed(x) => x, - } - } -} - -impl AsRef<[u8]> for Payload<'_> { - fn as_ref(&self) -> &[u8] { - self - } -} - -impl Clone for Payload<'_> { - fn clone(&self) -> Self { - match self { - Self::Bytes(x) => Self::Bytes(x.clone()), - Self::Borrowed(x) => Self::Bytes(BytesMut::from(*x)), - } - } -} - -impl Buf for Payload<'_> { - fn remaining(&self) -> usize { - match self { - Self::Bytes(x) => x.remaining(), - Self::Borrowed(x) => x.remaining(), - } - } - - fn chunk(&self) -> &[u8] { - match self { - Self::Bytes(x) => x.chunk(), - Self::Borrowed(x) => x.chunk(), - } - } - - fn advance(&mut self, cnt: usize) { - match self { - Self::Bytes(x) => x.advance(cnt), - Self::Borrowed(x) => x.advance(cnt), - } - } -} - -/// Opcode of the WebSocket frame. -#[derive(Debug, PartialEq, Clone, Copy)] -pub enum OpCode { - /// Text frame. - Text, - /// Binary frame. - Binary, - /// Close frame. - Close, - /// Ping frame. - Ping, - /// Pong frame. - Pong, -} - -/// WebSocket frame. -#[derive(Debug, Clone)] -pub struct Frame<'a> { - /// Whether the frame is finished or not. - pub finished: bool, - /// Opcode of the WebSocket frame. - pub opcode: OpCode, - /// Payload of the WebSocket frame. - pub payload: Payload<'a>, -} - -impl<'a> Frame<'a> { - /// Create a new frame. - pub fn new(opcode: OpCode, payload: Payload<'a>, finished: bool) -> Self { - Self { - finished, - opcode, - payload, - } - } - - /// Create a new text frame. - pub fn text(payload: Payload<'a>) -> Self { - Self { - finished: true, - opcode: OpCode::Text, - payload, - } - } - - /// Create a new binary frame. - pub fn binary(payload: Payload<'a>) -> Self { - Self { - finished: true, - opcode: OpCode::Binary, - payload, - } - } - - /// Create a new close frame. - pub fn close(payload: Payload<'a>) -> Self { - Self { - finished: true, - opcode: OpCode::Close, - payload, - } - } -} - -/// Generic WebSocket read trait. -pub trait WebSocketRead: Send { - /// Read a frame from the socket. - fn wisp_read_frame( - &mut self, - tx: &dyn LockingWebSocketWrite, - ) -> impl Future, WispError>> + Send; - - /// Read a split frame from the socket. - fn wisp_read_split( - &mut self, - tx: &dyn LockingWebSocketWrite, - ) -> impl Future, Option>), WispError>> + Send { - self.wisp_read_frame(tx).map_ok(|x| (x, None)) - } -} - -// similar to what dynosaur does -mod wsr_inner { - use std::{future::Future, pin::Pin, ptr}; - - use crate::WispError; - - use super::{Frame, LockingWebSocketWrite, WebSocketRead}; - - trait ErasedWebSocketRead: Send { - fn wisp_read_frame<'a>( - &'a mut self, - tx: &'a dyn LockingWebSocketWrite, - ) -> Pin, WispError>> + Send + 'a>>; - - #[expect(clippy::type_complexity)] - fn wisp_read_split<'a>( - &'a mut self, - tx: &'a dyn LockingWebSocketWrite, - ) -> Pin< - Box< - dyn Future, Option>), WispError>> - + Send - + 'a, - >, - >; - } - - impl ErasedWebSocketRead for T { - fn wisp_read_frame<'a>( - &'a mut self, - tx: &'a dyn LockingWebSocketWrite, - ) -> Pin, WispError>> + Send + 'a>> { - Box::pin(self.wisp_read_frame(tx)) - } - - fn wisp_read_split<'a>( - &'a mut self, - tx: &'a dyn LockingWebSocketWrite, - ) -> Pin< - Box< - dyn Future, Option>), WispError>> - + Send - + 'a, - >, - > { - Box::pin(self.wisp_read_split(tx)) - } - } - - /// `WebSocketRead` trait object. - #[repr(transparent)] - pub struct DynWebSocketRead { - ptr: dyn ErasedWebSocketRead + 'static, - } - impl WebSocketRead for DynWebSocketRead { - async fn wisp_read_frame( - &mut self, - tx: &dyn LockingWebSocketWrite, - ) -> Result, WispError> { - self.ptr.wisp_read_frame(tx).await - } - - async fn wisp_read_split( - &mut self, - tx: &dyn LockingWebSocketWrite, - ) -> Result<(Frame<'static>, Option>), WispError> { - self.ptr.wisp_read_split(tx).await - } - } - impl DynWebSocketRead { - /// Create a `WebSocketRead` trait object from a boxed `WebSocketRead`. - pub fn new(val: Box) -> Box { - let val: Box = val; - unsafe { std::mem::transmute(val) } - } - /// Create a `WebSocketRead` trait object from a `WebSocketRead`. - pub fn boxed(val: impl WebSocketRead + 'static) -> Box { - Self::new(Box::new(val)) - } - /// Create a `WebSocketRead` trait object from a `WebSocketRead` reference. - pub fn from_ref(val: &(impl WebSocketRead + 'static)) -> &Self { - let val: &(dyn ErasedWebSocketRead + 'static) = val; - unsafe { &*(ptr::from_ref::(val) as *const DynWebSocketRead) } - } - /// Create a `WebSocketRead` trait object from a mutable `WebSocketRead` reference. - pub fn from_mut(val: &mut (impl WebSocketRead + 'static)) -> &mut Self { - let val: &mut (dyn ErasedWebSocketRead + 'static) = &mut *val; - unsafe { - &mut *(ptr::from_mut::(val) as *mut DynWebSocketRead) - } - } - } -} -pub use wsr_inner::DynWebSocketRead; - -/// Generic WebSocket write trait. -pub trait WebSocketWrite: Send { - /// Write a frame to the socket. - fn wisp_write_frame( - &mut self, - frame: Frame<'_>, - ) -> impl Future> + Send; - - /// Write a split frame to the socket. - fn wisp_write_split( - &mut self, - header: Frame<'_>, - body: Frame<'_>, - ) -> impl Future> + Send { - async move { - let mut payload = BytesMut::from(header.payload); - payload.extend_from_slice(&body.payload); - self.wisp_write_frame(Frame::binary(Payload::Bytes(payload))) - .await - } - } - - /// Close the socket. - fn wisp_close(&mut self) -> impl Future> + Send; -} - -// similar to what dynosaur does -mod wsw_inner { - use std::{future::Future, pin::Pin, ptr}; - - use crate::WispError; - - use super::{Frame, WebSocketWrite}; - - trait ErasedWebSocketWrite: Send { - fn wisp_write_frame<'a>( - &'a mut self, - frame: Frame<'a>, - ) -> Pin> + Send + 'a>>; - - fn wisp_write_split<'a>( - &'a mut self, - header: Frame<'a>, - body: Frame<'a>, - ) -> Pin> + Send + 'a>>; - - fn wisp_close<'a>( - &'a mut self, - ) -> Pin> + Send + 'a>>; - } - - impl ErasedWebSocketWrite for T { - fn wisp_write_frame<'a>( - &'a mut self, - frame: Frame<'a>, - ) -> Pin> + Send + 'a>> { - Box::pin(self.wisp_write_frame(frame)) - } - - fn wisp_write_split<'a>( - &'a mut self, - header: Frame<'a>, - body: Frame<'a>, - ) -> Pin> + Send + 'a>> { - Box::pin(self.wisp_write_split(header, body)) - } - - fn wisp_close<'a>( - &'a mut self, - ) -> Pin> + Send + 'a>> { - Box::pin(self.wisp_close()) - } - } - - /// `WebSocketWrite` trait object. - #[repr(transparent)] - pub struct DynWebSocketWrite { - ptr: dyn ErasedWebSocketWrite + 'static, - } - impl WebSocketWrite for DynWebSocketWrite { - async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { - self.ptr.wisp_write_frame(frame).await - } - - async fn wisp_write_split( - &mut self, - header: Frame<'_>, - body: Frame<'_>, - ) -> Result<(), WispError> { - self.ptr.wisp_write_split(header, body).await - } - - async fn wisp_close(&mut self) -> Result<(), WispError> { - self.ptr.wisp_close().await - } - } - impl DynWebSocketWrite { - /// Create a new `WebSocketWrite` trait object from a boxed `WebSocketWrite`. - pub fn new(val: Box) -> Box { - let val: Box = val; - unsafe { std::mem::transmute(val) } - } - /// Create a new `WebSocketWrite` trait object from a `WebSocketWrite`. - pub fn boxed(val: impl WebSocketWrite + 'static) -> Box { - Self::new(Box::new(val)) - } - /// Create a new `WebSocketWrite` trait object from a `WebSocketWrite` reference. - pub fn from_ref(val: &(impl WebSocketWrite + 'static)) -> &Self { - let val: &(dyn ErasedWebSocketWrite + 'static) = val; - unsafe { - &*(ptr::from_ref::(val) as *const DynWebSocketWrite) - } - } - /// Create a new `WebSocketWrite` trait object from a mutable `WebSocketWrite` reference. - pub fn from_mut(val: &mut (impl WebSocketWrite + 'static)) -> &mut Self { - let val: &mut (dyn ErasedWebSocketWrite + 'static) = &mut *val; - unsafe { - &mut *(ptr::from_mut::(val) as *mut DynWebSocketWrite) - } - } - } -} -pub use wsw_inner::DynWebSocketWrite; - -mod private { - pub trait Sealed {} -} - -/// Helper trait object for `LockedWebSocketWrite`. -pub trait LockingWebSocketWrite: private::Sealed + Sync { - /// Write a frame to the websocket. - fn wisp_write_frame<'a>( - &'a self, - frame: Frame<'a>, - ) -> Pin> + Send + 'a>>; - - /// Write a split frame to the websocket. - fn wisp_write_split<'a>( - &'a self, - header: Frame<'a>, - body: Frame<'a>, - ) -> Pin> + Send + 'a>>; - - /// Close the websocket. - fn wisp_close<'a>(&'a self) - -> Pin> + Send + 'a>>; -} - -/// Locked WebSocket. -pub struct LockedWebSocketWrite(Arc>); - -impl Clone for LockedWebSocketWrite { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -impl LockedWebSocketWrite { - /// Create a new locked websocket. - pub fn new(ws: T) -> Self { - Self(Mutex::new(ws).into()) - } - - /// Create a new locked websocket from an existing mutex. - pub fn from_locked(locked: Arc>) -> Self { - Self(locked) - } - - /// Write a frame to the websocket. - pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> { - self.0.lock().await.wisp_write_frame(frame).await - } - - /// Write a split frame to the websocket. - pub async fn write_split(&self, header: Frame<'_>, body: Frame<'_>) -> Result<(), WispError> { - self.0.lock().await.wisp_write_split(header, body).await - } - - /// Close the websocket. - pub async fn close(&self) -> Result<(), WispError> { - self.0.lock().await.wisp_close().await - } -} - -impl private::Sealed for LockedWebSocketWrite {} - -impl LockingWebSocketWrite for LockedWebSocketWrite { - fn wisp_write_frame<'a>( - &'a self, - frame: Frame<'a>, - ) -> Pin> + Send + 'a>> { - Box::pin(self.write_frame(frame)) - } - - fn wisp_write_split<'a>( - &'a self, - header: Frame<'a>, - body: Frame<'a>, - ) -> Pin> + Send + 'a>> { - Box::pin(self.write_split(header, body)) - } - - fn wisp_close<'a>( - &'a self, - ) -> Pin> + Send + 'a>> { - Box::pin(self.close()) - } -} - -/// Combines two different `WebSocketRead`s together. -pub enum EitherWebSocketRead { - /// First `WebSocketRead` variant. - Left(A), - /// Second `WebSocketRead` variant. - Right(B), -} -impl WebSocketRead for EitherWebSocketRead { - async fn wisp_read_frame( - &mut self, - tx: &dyn LockingWebSocketWrite, - ) -> Result, WispError> { - match self { - Self::Left(x) => x.wisp_read_frame(tx).await, - Self::Right(x) => x.wisp_read_frame(tx).await, - } - } - - async fn wisp_read_split( - &mut self, - tx: &dyn LockingWebSocketWrite, - ) -> Result<(Frame<'static>, Option>), WispError> { - match self { - Self::Left(x) => x.wisp_read_split(tx).await, - Self::Right(x) => x.wisp_read_split(tx).await, - } - } -} - -/// Combines two different `WebSocketWrite`s together. -pub enum EitherWebSocketWrite { - /// First `WebSocketWrite` variant. - Left(A), - /// Second `WebSocketWrite` variant. - Right(B), -} -impl WebSocketWrite for EitherWebSocketWrite { - async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { - match self { - Self::Left(x) => x.wisp_write_frame(frame).await, - Self::Right(x) => x.wisp_write_frame(frame).await, - } - } - - async fn wisp_write_split( - &mut self, - header: Frame<'_>, - body: Frame<'_>, - ) -> Result<(), WispError> { - match self { - Self::Left(x) => x.wisp_write_split(header, body).await, - Self::Right(x) => x.wisp_write_split(header, body).await, - } - } - - async fn wisp_close(&mut self) -> Result<(), WispError> { - match self { - Self::Left(x) => x.wisp_close().await, - Self::Right(x) => x.wisp_close().await, - } - } -} diff --git a/wisp/src/ws/mod.rs b/wisp/src/ws/mod.rs new file mode 100644 index 0000000..8152de6 --- /dev/null +++ b/wisp/src/ws/mod.rs @@ -0,0 +1,83 @@ +use std::ops::Deref; + +use bytes::{Bytes, BytesMut}; +use futures::{Sink, Stream, StreamExt}; + +use crate::WispError; + +mod split; +pub use split::*; +mod unfold; +pub use unfold::*; + +#[cfg(feature = "tokio-websockets")] +mod tokio_websockets; +#[cfg(feature = "tokio-websockets")] +pub use self::tokio_websockets::*; + +#[cfg(feature = "tokio-tungstenite")] +mod tokio_tungstenite; +#[cfg(feature = "tokio-tungstenite")] +pub use self::tokio_tungstenite::*; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum PayloadRef<'a> { + Owned(Payload), + Borrowed(&'a [u8]), +} + +impl PayloadRef<'_> { + pub fn into_owned(self) -> Payload { + match self { + Self::Owned(x) => x, + Self::Borrowed(x) => BytesMut::from(x).freeze(), + } + } +} + +impl From for PayloadRef<'static> { + fn from(value: Payload) -> Self { + Self::Owned(value) + } +} +impl<'a> From<&'a [u8]> for PayloadRef<'a> { + fn from(value: &'a [u8]) -> Self { + Self::Borrowed(value) + } +} + +impl Deref for PayloadRef<'_> { + type Target = [u8]; + fn deref(&self) -> &Self::Target { + match self { + Self::Owned(x) => x, + Self::Borrowed(x) => x, + } + } +} + +pub type Payload = Bytes; +pub type PayloadMut = BytesMut; + +pub trait WebSocketRead: + Stream> + Send + Unpin + 'static +{ +} +impl> + Send + Unpin + 'static> WebSocketRead for S {} + +pub(crate) trait WebSocketReadExt: WebSocketRead { + async fn next_erroring(&mut self) -> Result { + self.next().await.ok_or(WispError::WsImplSocketClosed)? + } +} +impl WebSocketReadExt for S {} + +pub trait WebSocketWrite: Sink + Send + Unpin + 'static {} +impl + Send + Unpin + 'static> WebSocketWrite for S {} + +pub trait WebSocketExt: WebSocketRead + WebSocketWrite + Sized { + fn split_fast(self) -> (WebSocketSplitRead, WebSocketSplitWrite) { + split::split(self) + } +} +impl WebSocketExt for S {} diff --git a/wisp/src/ws/split.rs b/wisp/src/ws/split.rs new file mode 100644 index 0000000..0bdbbc3 --- /dev/null +++ b/wisp/src/ws/split.rs @@ -0,0 +1,64 @@ +use std::sync::{Arc, Mutex, MutexGuard}; + +use futures::{Sink, SinkExt, Stream, StreamExt}; + +use super::{WebSocketRead, WebSocketWrite}; + +fn lock(mutex: &Mutex) -> MutexGuard<'_, T> { + mutex.lock().expect("WebSocketSplit mutex was poisoned") +} + +pub(crate) fn split( + s: S, +) -> (WebSocketSplitRead, WebSocketSplitWrite) { + let inner = Arc::new(Mutex::new(s)); + + ( + WebSocketSplitRead(inner.clone()), + WebSocketSplitWrite(inner), + ) +} + +pub struct WebSocketSplitRead(Arc>); + +impl Stream for WebSocketSplitRead { + type Item = S::Item; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + lock(&self.0).poll_next_unpin(cx) + } +} + +pub struct WebSocketSplitWrite(Arc>); + +impl, T> Sink for WebSocketSplitWrite { + type Error = >::Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + >::poll_ready_unpin(&mut *lock(&self.0), cx) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + >::start_send_unpin(&mut *lock(&self.0), item) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + >::poll_flush_unpin(&mut *lock(&self.0), cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + >::poll_close_unpin(&mut *lock(&self.0), cx) + } +} diff --git a/wisp/src/ws/tokio_tungstenite.rs b/wisp/src/ws/tokio_tungstenite.rs new file mode 100644 index 0000000..a4fee0f --- /dev/null +++ b/wisp/src/ws/tokio_tungstenite.rs @@ -0,0 +1,79 @@ +use std::task::Poll; + +use futures::{Sink, Stream}; +use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; + +use crate::WispError; + +use super::Payload; + +#[pin_project] +pub struct TokioTungsteniteTransport( + #[pin] pub WebSocketStream, +); + +fn map_err(x: tokio_tungstenite::tungstenite::Error) -> WispError { + if matches!(x, tokio_tungstenite::tungstenite::Error::AlreadyClosed) { + WispError::WsImplSocketClosed + } else { + WispError::WsImplError(Box::new(x)) + } +} + +impl Stream for TokioTungsteniteTransport { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.as_mut().project().0.poll_next(cx) { + Poll::Ready(Some(Ok(x))) => { + if x.is_binary() { + Poll::Ready(Some(Ok(x.into_data()))) + } else if x.is_close() { + Poll::Ready(None) + } else { + self.poll_next(cx) + } + } + Poll::Ready(Some(Err(x))) => Poll::Ready(Some(Err(map_err(x)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl Sink for TokioTungsteniteTransport { + type Error = WispError; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().0.poll_ready(cx).map_err(map_err) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: Payload) -> Result<(), Self::Error> { + self.project() + .0 + .start_send(Message::binary(item)) + .map_err(map_err) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().0.poll_flush(cx).map_err(map_err) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().0.poll_flush(cx).map_err(map_err) + } +} diff --git a/wisp/src/ws/tokio_websockets.rs b/wisp/src/ws/tokio_websockets.rs new file mode 100644 index 0000000..e51edde --- /dev/null +++ b/wisp/src/ws/tokio_websockets.rs @@ -0,0 +1,107 @@ +use std::{pin::Pin, task::Poll}; + +use futures::{Sink, SinkExt, Stream, StreamExt}; +use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_websockets::{Message, WebSocketStream}; + +use crate::WispError; + +use super::Payload; + +#[pin_project] +pub struct TokioWebsocketsTransport( + #[pin] pub WebSocketStream, +); + +fn map_err(x: tokio_websockets::Error) -> WispError { + if matches!(x, tokio_websockets::Error::AlreadyClosed) { + WispError::WsImplSocketClosed + } else { + WispError::WsImplError(Box::new(x)) + } +} + +impl Stream for TokioWebsocketsTransport { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.0.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(x))) => { + if x.is_binary() { + Poll::Ready(Some(Ok(x.into_payload().into()))) + } else if x.is_close() { + Poll::Ready(None) + } else { + self.poll_next(cx) + } + } + Poll::Ready(Some(Err(x))) => Poll::Ready(Some(Err(map_err(x)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl Sink for TokioWebsocketsTransport { + type Error = WispError; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.0.poll_ready_unpin(cx).map_err(map_err) + } + + fn start_send(mut self: Pin<&mut Self>, item: Payload) -> Result<(), Self::Error> { + self.0 + .start_send_unpin(Message::binary(item)) + .map_err(map_err) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.0.poll_flush_unpin(cx).map_err(map_err) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.0.poll_close_unpin(cx).map_err(map_err) + } +} + +impl Sink for TokioWebsocketsTransport { + type Error = tokio_websockets::Error; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.0.poll_ready_unpin(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + self.0.start_send_unpin(item) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.0.poll_flush_unpin(cx) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.0.poll_close_unpin(cx) + } +} diff --git a/wisp/src/ws/unfold.rs b/wisp/src/ws/unfold.rs new file mode 100644 index 0000000..02d265f --- /dev/null +++ b/wisp/src/ws/unfold.rs @@ -0,0 +1,198 @@ +// Similar to `futures-util` `StreamExt::unfold` and `SinkExt::unfold` + +use std::{ + future::Future, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use futures::{Sink, Stream}; +use pin_project::pin_project; + +use crate::WispError; + +use super::Payload; + +pub fn async_iterator_transport_read( + init: State, + func: Func, +) -> AsyncIteratorTransportRead +where + Func: FnMut(State) -> Fut, + Fut: Future, WispError>>, +{ + AsyncIteratorTransportRead { + func, + state: IteratorState::Value(init), + } +} + +pub fn async_iterator_transport_write( + init: State, + func: Func, + close_init: CloseState, + close_func: CloseFunc, +) -> AsyncIteratorTransportWrite +where + Func: FnMut(State, Payload) -> Fut, + Fut: Future>, + CloseFunc: FnMut(CloseState) -> CloseFut, + CloseFut: Future>, +{ + AsyncIteratorTransportWrite { + func, + state: IteratorState::Value(init), + + close: close_func, + close_state: IteratorState::Value(close_init), + } +} + +#[pin_project(project = IteratorStateProj, project_replace = IteratorStateProjReplace)] +enum IteratorState { + Value(S), + Future(#[pin] Fut), + Empty, +} + +impl IteratorState { + pub fn take_state(self: Pin<&mut Self>) -> Option { + match &*self { + Self::Value { .. } => match self.project_replace(Self::Empty) { + IteratorStateProjReplace::Value(value) => Some(value), + _ => unreachable!(), + }, + _ => None, + } + } + + pub fn get_future(self: Pin<&mut Self>) -> Option> { + match self.project() { + IteratorStateProj::Future(future) => Some(future), + _ => None, + } + } +} + +#[pin_project] +pub struct AsyncIteratorTransportRead +where + Func: FnMut(State) -> Fut, + Fut: Future, WispError>>, +{ + func: Func, + #[pin] + state: IteratorState, +} + +impl Stream for AsyncIteratorTransportRead +where + Func: FnMut(State) -> Fut, + Fut: Future, WispError>>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if let Some(state) = this.state.as_mut().take_state() { + this.state.set(IteratorState::Future((this.func)(state))); + } + + let ret = match this.state.as_mut().get_future() { + Some(fut) => ready!(fut.poll(cx)), + None => panic!("AsyncIteratorTransportRead was polled after completion"), + }; + + match ret { + Ok(Some((ret, state))) => { + this.state.set(IteratorState::Value(state)); + Poll::Ready(Some(Ok(ret))) + } + Ok(None) => Poll::Ready(None), + Err(err) => Poll::Ready(Some(Err(err))), + } + } +} + +#[pin_project] +pub struct AsyncIteratorTransportWrite +where + Func: FnMut(State, Payload) -> Fut, + Fut: Future>, + CloseFunc: FnMut(CloseState) -> CloseFut, + CloseFut: Future>, +{ + func: Func, + #[pin] + state: IteratorState, + + close: CloseFunc, + #[pin] + close_state: IteratorState, +} + +impl Sink + for AsyncIteratorTransportWrite +where + Func: FnMut(State, Payload) -> Fut, + Fut: Future>, + CloseFunc: FnMut(CloseState) -> CloseFut, + CloseFut: Future>, +{ + type Error = WispError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Payload) -> Result<(), Self::Error> { + let mut this = self.project(); + let fut = match this.state.as_mut().take_state() { + Some(state) => (this.func)(state, item), + None => panic!("start_send called on AsyncIteratorTransportWrite without poll_ready being called first"), + }; + this.state.set(IteratorState::Future(fut)); + + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + Poll::Ready(if let Some(future) = this.state.as_mut().get_future() { + match ready!(future.poll(cx)) { + Ok(state) => { + this.state.set(IteratorState::Value(state)); + Ok(()) + } + Err(err) => { + this.state.set(IteratorState::Empty); + Err(err) + } + } + } else { + Ok(()) + }) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.project(); + + if let Some(future) = this.close_state.as_mut().get_future() { + let ret = ready!(future.poll(cx)); + this.close_state.set(IteratorState::Empty); + Poll::Ready(ret) + } else { + let future = match this.close_state.as_mut().take_state() { + Some(value) => (this.close)(value), + None => { + panic!("poll_close called on AsyncIteratorTransportWrite after it finished") + } + }; + + this.close_state.set(IteratorState::Future(future)); + Poll::Pending + } + } +}