Add tool calling support for Gemini models (#27772)

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-03-31 17:46:42 +02:00 committed by GitHub
parent f6d58f76e4
commit c8a9a74e6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 735 additions and 251 deletions

240
Cargo.lock generated
View file

@ -86,7 +86,7 @@ version = "0.25.1-dev"
source = "git+https://github.com/zed-industries/alacritty.git?branch=add-hush-login-flag#828457c9ff1f7ea0a0469337cc8a37ee3a1b0590"
dependencies = [
"base64 0.22.1",
"bitflags 2.9.0",
"bitflags 2.8.0",
"home",
"libc",
"log",
@ -128,7 +128,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed7572b7ba83a31e20d1b48970ee402d2e3e0537dcfe0a3ff4d6eb7508617d43"
dependencies = [
"alsa-sys",
"bitflags 2.9.0",
"bitflags 2.8.0",
"cfg-if",
"libc",
]
@ -1888,7 +1888,7 @@ version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
@ -1911,7 +1911,7 @@ version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
@ -1969,9 +1969,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.9.0"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd"
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
dependencies = [
"serde",
]
@ -2001,9 +2001,9 @@ source = "git+https://github.com/kvark/blade?rev=b16f5c7bd873c7126f48c82c39e7ae6
dependencies = [
"ash",
"ash-window",
"bitflags 2.9.0",
"bitflags 2.8.0",
"bytemuck",
"codespan-reporting",
"codespan-reporting 0.11.1",
"glow",
"gpu-alloc",
"gpu-alloc-ash",
@ -2049,9 +2049,9 @@ dependencies = [
[[package]]
name = "blake3"
version = "1.7.0"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b17679a8d69b6d7fd9cd9801a536cec9fa5e5970b69f9d4747f70b39b031f5e7"
checksum = "675f87afced0413c9bb02843499dbbd3882a237645883f71a2b59644a6d2f753"
dependencies = [
"arrayref",
"arrayvec",
@ -2289,12 +2289,11 @@ dependencies = [
[[package]]
name = "bzip2-sys"
version = "0.1.11+1.0.8"
version = "0.1.13+1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc"
checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14"
dependencies = [
"cc",
"libc",
"pkg-config",
]
@ -2330,7 +2329,7 @@ version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b99da2f8558ca23c71f4fd15dc57c906239752dd27ff3c00a1d56b685b7cbfec"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"log",
"polling",
"rustix",
@ -2368,7 +2367,7 @@ dependencies = [
"cap-primitives",
"cap-std",
"io-lifetimes",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@ -2396,7 +2395,7 @@ dependencies = [
"ipnet",
"maybe-owned",
"rustix",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
"winx",
]
@ -2845,7 +2844,7 @@ version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f79398230a6e2c08f5c9760610eb6924b52aa9e7950a619602baba59dcbbdbb2"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block",
"cocoa-foundation 0.2.0",
"core-foundation 0.10.0",
@ -2875,7 +2874,7 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e14045fb83be07b5acf1c0884b2180461635b433455fa35d1cd6f17f1450679d"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block",
"core-foundation 0.10.0",
"core-graphics-types 0.2.0",
@ -2893,6 +2892,17 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "codespan-reporting"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81"
dependencies = [
"serde",
"termcolor",
"unicode-width",
]
[[package]]
name = "collab"
version = "0.44.0"
@ -3344,7 +3354,7 @@ version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"core-foundation 0.10.0",
"core-graphics-types 0.2.0",
"foreign-types 0.5.0",
@ -3368,7 +3378,7 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"core-foundation 0.10.0",
"libc",
]
@ -3379,7 +3389,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e4583956b9806b69f73fcb23aee05eb3620efc282972f08f6a6db7504f8334d"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block",
"cfg-if",
"core-foundation 0.10.0",
@ -3467,7 +3477,7 @@ version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e418dd4f5128c3e93eab12246391c54a20c496811131f85754dc8152ee207892"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"fontdb 0.16.2",
"log",
"rangemap",
@ -3839,9 +3849,9 @@ checksum = "96a6ac251f4a2aca6b3f91340350eab87ae57c3f127ffeb585e92bd336717991"
[[package]]
name = "cxx"
version = "1.0.134"
version = "1.0.151"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5a32d755fe20281b46118ee4b507233311fb7a48a0cfd42f554b93640521a2f"
checksum = "fdb3e596b379180315d2f934231e233a2fc745041f88231807774093d8de45f2"
dependencies = [
"cc",
"cxxbridge-cmd",
@ -3853,12 +3863,12 @@ dependencies = [
[[package]]
name = "cxx-build"
version = "1.0.134"
version = "1.0.151"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11645536ada5d1c8804312cbffc9ab950f2216154de431de930da47ca6955199"
checksum = "3743fae7f47620cd34ec23bab819db9ee52da93166a058f87ab0ad99d777dc9b"
dependencies = [
"cc",
"codespan-reporting",
"codespan-reporting 0.12.0",
"proc-macro2",
"quote",
"scratch",
@ -3867,12 +3877,12 @@ dependencies = [
[[package]]
name = "cxxbridge-cmd"
version = "1.0.134"
version = "1.0.151"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebcc9c78e3c7289665aab921a2b394eaffe8bdb369aa18d81ffc0f534fd49385"
checksum = "aaea0273c049b126a3918df88a1670c9c0168e0738df9370a988ff69070d4fff"
dependencies = [
"clap",
"codespan-reporting",
"codespan-reporting 0.12.0",
"proc-macro2",
"quote",
"syn 2.0.100",
@ -3880,15 +3890,15 @@ dependencies = [
[[package]]
name = "cxxbridge-flags"
version = "1.0.134"
version = "1.0.151"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a22a87bd9e78d7204d793261470a4c9d585154fddd251828d8aefbb5f74c3bf"
checksum = "020a9a3d6b792aab7f30f6e323893ad7f45052e572cde5d014c47fe67c89495f"
[[package]]
name = "cxxbridge-macro"
version = "1.0.134"
version = "1.0.151"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dfdb020ff8787c5daf6e0dca743005cc8782868faeadfbabb8824ede5cb1c72"
checksum = "ee54cd01f94db0328c4c73036d38bd8c3bb88927e953d05ffefe743edbf4eb68"
dependencies = [
"proc-macro2",
"quote",
@ -4633,7 +4643,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
dependencies = [
"libc",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@ -5121,7 +5131,7 @@ name = "font-kit"
version = "0.14.1"
source = "git+https://github.com/zed-industries/font-kit?rev=5474cfad4b719a72ec8ed2cb7327b2b01fd10568#5474cfad4b719a72ec8ed2cb7327b2b01fd10568"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"byteorder",
"core-foundation 0.10.0",
"core-graphics 0.24.0",
@ -5298,7 +5308,7 @@ checksum = "5e2e6123af26f0f2c51cc66869137080199406754903cc926a7690401ce09cb4"
dependencies = [
"io-lifetimes",
"rustix",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@ -5321,7 +5331,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
name = "fsevent"
version = "0.1.0"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"core-foundation 0.10.0",
"fsevent-sys 3.1.0",
"parking_lot",
@ -5649,7 +5659,7 @@ version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5220b8ba44c68a9a7f7a7659e864dd73692e417ef0211bea133c7b74e031eeb9"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"libc",
"libgit2-sys",
"log",
@ -5815,7 +5825,7 @@ version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbcd2dba93594b227a1f57ee09b8b9da8892c34d55aa332e034a228d0fe6a171"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"gpu-alloc-types",
]
@ -5836,7 +5846,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98ff03b468aa837d70984d55f5d3f846f6ec31fe34bbb97c4f85219caeee1ca4"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
]
[[package]]
@ -6149,7 +6159,7 @@ version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd54745cfacb7b97dee45e8fdb91814b62bccddb481debb7de0f9ee6b7bf5b43"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"byteorder",
"heed-traits",
"heed-types",
@ -6709,9 +6719,9 @@ dependencies = [
[[package]]
name = "image"
version = "0.25.6"
version = "0.25.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db35664ce6b9810857a38a906215e75a9c879f0696556a39f59c62829710251a"
checksum = "cd6f44aed642f18953a158afeb30206f4d50da59fbc66ecb53c66488de73563b"
dependencies = [
"bytemuck",
"byteorder-lite",
@ -6886,7 +6896,7 @@ version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"inotify-sys",
"libc",
]
@ -6956,7 +6966,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2285ddfe3054097ef4b2fe909ef8c3bcd1ea52a8f0d274416caebeef39f04a65"
dependencies = [
"io-lifetimes",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@ -7630,7 +7640,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.52.6",
"windows-targets 0.48.5",
]
[[package]]
@ -7655,7 +7665,7 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"libc",
"redox_syscall 0.5.8",
]
@ -7708,9 +7718,9 @@ dependencies = [
[[package]]
name = "link-cplusplus"
version = "1.0.9"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9"
checksum = "4a6f6da007f968f9def0d65a05b187e2960183de70c160204ecfccf0ee330212"
dependencies = [
"cc",
]
@ -7915,9 +7925,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.27"
version = "0.4.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
dependencies = [
"serde",
"value-bag",
@ -8276,7 +8286,7 @@ version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block",
"core-graphics-types 0.1.3",
"foreign-types 0.5.0",
@ -8427,12 +8437,6 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "multimap"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "naga"
version = "23.1.0"
@ -8441,9 +8445,9 @@ checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f"
dependencies = [
"arrayvec",
"bit-set 0.8.0",
"bitflags 2.9.0",
"bitflags 2.8.0",
"cfg_aliases 0.1.1",
"codespan-reporting",
"codespan-reporting 0.11.1",
"hexf-parse",
"indexmap",
"log",
@ -8510,7 +8514,7 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"jni-sys",
"log",
"ndk-sys",
@ -8545,7 +8549,7 @@ version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"cfg-if",
"cfg_aliases 0.2.1",
"libc",
@ -8629,7 +8633,7 @@ version = "8.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fee8403b3d66ac7b26aee6e40a897d85dc5ce26f44da36b8b73e987cc52e943"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"filetime",
"fsevent-sys 4.1.0",
"inotify",
@ -8647,7 +8651,7 @@ name = "notify"
version = "8.0.0"
source = "git+https://github.com/zed-industries/notify.git?rev=bbb9ea5ae52b253e095737847e367c30653a2e96#bbb9ea5ae52b253e095737847e367c30653a2e96"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"filetime",
"fsevent-sys 4.1.0",
"inotify",
@ -8923,7 +8927,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block2",
"libc",
"objc2",
@ -8939,7 +8943,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74dd3b56391c7a0596a295029734d3c1c5e7e510a4cb30245f8221ccea96b009"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block2",
"objc2",
"objc2-core-location",
@ -8963,7 +8967,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block2",
"objc2",
"objc2-foundation",
@ -9005,7 +9009,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block2",
"libc",
"objc2",
@ -9029,7 +9033,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block2",
"objc2",
"objc2-foundation",
@ -9041,7 +9045,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block2",
"objc2",
"objc2-foundation",
@ -9064,7 +9068,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8bb46798b20cd6b91cbd113524c490f1686f4c4e8f49502431415f3512e2b6f"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block2",
"objc2",
"objc2-cloud-kit",
@ -9096,7 +9100,7 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76cfcbf642358e8689af64cee815d139339f3ed8ad05103ed5eaf73db8d84cb3"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"block2",
"objc2",
"objc2-core-location",
@ -9158,9 +9162,9 @@ checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
[[package]]
name = "oo7"
version = "0.4.3"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6cb23d3ec3527d65a83be1c1795cb883c52cfa57147d42acc797127df56fc489"
checksum = "72c84df357c7049f98c8b157abe71ee751531166c14ba09366e08bc6ab1ea2c9"
dependencies = [
"aes",
"ashpd",
@ -9239,7 +9243,7 @@ version = "0.10.70"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
@ -10751,7 +10755,7 @@ dependencies = [
"itertools 0.10.5",
"lazy_static",
"log",
"multimap 0.8.3",
"multimap",
"petgraph",
"prost 0.9.0",
"prost-types 0.9.0",
@ -10767,10 +10771,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes 1.10.1",
"heck 0.5.0",
"heck 0.4.1",
"itertools 0.12.1",
"log",
"multimap 0.10.0",
"multimap",
"once_cell",
"petgraph",
"prettyplease",
@ -10878,7 +10882,7 @@ version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76979bea66e7875e7509c4ec5300112b316af87fa7a252ca91c448b32dfe3993"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"memchr",
"pulldown-cmark-escape",
"unicase",
@ -10890,7 +10894,7 @@ version = "0.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f86ba2052aebccc42cbbb3ed234b8b13ce76f75c3551a303cb2bcffcff12bb14"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"memchr",
"unicase",
]
@ -10995,7 +10999,7 @@ dependencies = [
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@ -11271,7 +11275,7 @@ version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
]
[[package]]
@ -11915,13 +11919,13 @@ version = "0.38.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"errno 0.3.10",
"itoa",
"libc",
"linux-raw-sys",
"once_cell",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@ -12032,7 +12036,7 @@ dependencies = [
"security-framework 3.0.1",
"security-framework-sys",
"webpki-root-certs",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@ -12075,7 +12079,7 @@ version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfb9cf8877777222e4a3bc7eb247e398b56baba500c38c1c46842431adc8b55c"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"bytemuck",
"libm",
"smallvec",
@ -12092,7 +12096,7 @@ version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd3c7c96f8a08ee34eff8857b11b49b07d71d1c3f4e88f8a88d4c9e9f90b1702"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"bytemuck",
"core_maths",
"log",
@ -12189,9 +12193,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scratch"
version = "1.0.7"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
checksum = "9f6280af86e5f559536da57a45ebc84948833b3bee313a7dd25232e09c878a52"
[[package]]
name = "scrypt"
@ -12315,7 +12319,7 @@ version = "0.1.0"
dependencies = [
"any_vec",
"anyhow",
"bitflags 2.9.0",
"bitflags 2.8.0",
"client",
"collections",
"editor",
@ -12357,7 +12361,7 @@ version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
@ -12370,7 +12374,7 @@ version = "3.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"core-foundation 0.10.0",
"core-foundation-sys",
"libc",
@ -12970,7 +12974,7 @@ version = "0.3.0+sdk-1.3.268.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
]
[[package]]
@ -13137,7 +13141,7 @@ dependencies = [
"atoi",
"base64 0.22.1",
"bigdecimal",
"bitflags 2.9.0",
"bitflags 2.8.0",
"byteorder",
"bytes 1.10.1",
"chrono",
@ -13184,7 +13188,7 @@ dependencies = [
"atoi",
"base64 0.22.1",
"bigdecimal",
"bitflags 2.9.0",
"bitflags 2.8.0",
"byteorder",
"chrono",
"crc",
@ -13647,7 +13651,7 @@ version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"core-foundation 0.9.4",
"system-configuration-sys 0.6.0",
]
@ -13691,13 +13695,13 @@ version = "0.27.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc4592f674ce18521c2a81483873a49596655b179f71c5e05d10c1fe66c78745"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"cap-fs-ext",
"cap-std",
"fd-lock",
"io-lifetimes",
"rustix",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
"winx",
]
@ -13839,7 +13843,7 @@ dependencies = [
"getrandom 0.3.1",
"once_cell",
"rustix",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@ -14521,7 +14525,7 @@ version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c5bb1d698276a2443e5ecfabc1008bf15a36c12e6a7176e7bf089ea9131140"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"bytes 1.10.1",
"futures-core",
"futures-util",
@ -15452,7 +15456,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5924018406ce0063cd67f8e008104968b74b563ee1b85dde3ed1f7cb87d3dbd"
dependencies = [
"arrayvec",
"bitflags 2.9.0",
"bitflags 2.8.0",
"cursor-icon",
"log",
"memchr",
@ -15676,7 +15680,7 @@ version = "0.201.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84e5df6dba6c0d7fafc63a450f1738451ed7a0b52295d83e868218fa286bf708"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"indexmap",
"semver",
]
@ -15687,7 +15691,7 @@ version = "0.221.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d06bfa36ab3ac2be0dee563380147a5b81ba10dd8885d7fbbc9eb574be67d185"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"hashbrown 0.15.2",
"indexmap",
"semver",
@ -15713,7 +15717,7 @@ checksum = "11976a250672556d1c4c04c6d5d7656ac9192ac9edc42a4587d6c21460010e69"
dependencies = [
"anyhow",
"async-trait",
"bitflags 2.9.0",
"bitflags 2.8.0",
"bumpalo",
"cc",
"cfg-if",
@ -15919,7 +15923,7 @@ checksum = "8d1be69bfcab1bdac74daa7a1f9695ab992b9c8e21b9b061e7d66434097e0ca4"
dependencies = [
"anyhow",
"async-trait",
"bitflags 2.9.0",
"bitflags 2.8.0",
"bytes 1.10.1",
"cap-fs-ext",
"cap-net-ext",
@ -16000,7 +16004,7 @@ version = "0.31.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2120de3d33638aaef5b9f4472bff75f07c56379cf76ea320bd3a3d65ecaf73f"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"rustix",
"wayland-backend",
"wayland-scanner",
@ -16023,7 +16027,7 @@ version = "0.31.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f81f365b8b4a97f422ac0e8737c438024b5951734506b0e1d775c73030561f4"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"wayland-backend",
"wayland-client",
"wayland-scanner",
@ -16035,7 +16039,7 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23803551115ff9ea9bce586860c5c5a971e360825a0309264102a9495a5ff479"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"wayland-backend",
"wayland-client",
"wayland-protocols",
@ -16202,7 +16206,7 @@ checksum = "4b9af35bc9629c52c261465320a9a07959164928b4241980ba1cf923b9e6751d"
dependencies = [
"anyhow",
"async-trait",
"bitflags 2.9.0",
"bitflags 2.8.0",
"thiserror 1.0.69",
"tracing",
"wasmtime",
@ -16258,7 +16262,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.48.0",
]
[[package]]
@ -16817,8 +16821,8 @@ version = "0.36.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f3fd376f71958b862e7afb20cfe5a22830e1963462f3a17f49d82a6c1d1f42d"
dependencies = [
"bitflags 2.9.0",
"windows-sys 0.59.0",
"bitflags 2.8.0",
"windows-sys 0.52.0",
]
[[package]]
@ -16836,7 +16840,7 @@ version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "288f992ea30e6b5c531b52cdd5f3be81c148554b09ea416f058d16556ba92c27"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
"wit-bindgen-rt 0.22.0",
"wit-bindgen-rust-macro",
]
@ -16863,7 +16867,7 @@ version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
]
[[package]]
@ -16901,7 +16905,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "421c0c848a0660a8c22e2fd217929a0191f14476b68962afd2af89fd22e39825"
dependencies = [
"anyhow",
"bitflags 2.9.0",
"bitflags 2.8.0",
"indexmap",
"log",
"serde",
@ -16920,7 +16924,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66c55ca8772d2b270e28066caed50ce4e53a28c3ac10e01efbd90e5be31e448b"
dependencies = [
"anyhow",
"bitflags 2.9.0",
"bitflags 2.8.0",
"indexmap",
"log",
"serde",
@ -17164,7 +17168,7 @@ name = "xim-parser"
version = "0.2.1"
source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd"
dependencies = [
"bitflags 2.9.0",
"bitflags 2.8.0",
]
[[package]]

View file

@ -1787,12 +1787,13 @@ impl ActiveThread {
fn handle_deny_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
_: &ClickEvent,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.thread.update(cx, |thread, cx| {
thread.deny_tool_use(tool_use_id, cx);
thread.deny_tool_use(tool_use_id, tool_name, cx);
});
}
@ -1865,10 +1866,12 @@ impl ActiveThread {
})
.child({
let tool_id = tool.id.clone();
let tool_name = tool.name.clone();
Button::new("deny-tool", "Deny").on_click(cx.listener(
move |this, event, window, cx| {
this.handle_deny_tool(
tool_id.clone(),
tool_name.clone(),
event,
window,
cx,

View file

@ -777,7 +777,7 @@ impl Thread {
LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool.input_schema(),
input_schema: tool.input_schema(model.tool_input_format()),
}
}));
@ -1030,17 +1030,23 @@ impl Thread {
}
}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
if let Some(last_assistant_message) = thread
let last_assistant_message_id = thread
.messages
.iter()
.rfind(|message| message.role == Role::Assistant)
{
thread.tool_use.request_tool_use(
last_assistant_message.id,
tool_use,
cx,
);
}
.map(|message| message.id)
.unwrap_or_else(|| {
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text("Using tool...".to_string())],
cx,
)
});
thread.tool_use.request_tool_use(
last_assistant_message_id,
tool_use,
cx,
);
}
}
@ -1257,6 +1263,7 @@ impl Thread {
tool: Arc<dyn Tool>,
cx: &mut Context<Thread>,
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
let run_tool = tool.run(
input,
messages,
@ -1271,9 +1278,11 @@ impl Thread {
thread
.update(cx, |thread, cx| {
let pending_tool_use = thread
.tool_use
.insert_tool_output(tool_use_id.clone(), output);
let pending_tool_use = thread.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
output,
);
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
@ -1561,12 +1570,18 @@ impl Thread {
self.cumulative_token_usage.clone()
}
pub fn deny_tool_use(&mut self, tool_use_id: LanguageModelToolUseId, cx: &mut Context<Self>) {
pub fn deny_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
cx: &mut Context<Self>,
) {
let err = Err(anyhow::anyhow!(
"Permission to run tool action denied by user"
));
self.tool_use.insert_tool_output(tool_use_id.clone(), err);
self.tool_use
.insert_tool_output(tool_use_id.clone(), tool_name, err);
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,

View file

@ -113,6 +113,7 @@ impl ToolUseState {
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id,
tool_name: tool_use.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
},
@ -134,6 +135,7 @@ impl ToolUseState {
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id,
tool_name: tool_use.name.clone(),
content: "Tool canceled by user".into(),
is_error: true,
},
@ -313,6 +315,7 @@ impl ToolUseState {
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
output: Result<String>,
) -> Option<PendingToolUse> {
match output {
@ -321,6 +324,7 @@ impl ToolUseState {
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: tool_result.into(),
is_error: false,
},
@ -332,6 +336,7 @@ impl ToolUseState {
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: err.to_string().into(),
is_error: true,
},
@ -379,6 +384,7 @@ impl ToolUseState {
request_message.content.push(MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_result.tool_name.clone(),
is_error: tool_result.is_error,
content: if tool_result.content.is_empty() {
// Surprisingly, the API fails if we return an empty string here.

View file

@ -11,6 +11,7 @@ use anyhow::Result;
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
pub use crate::action_log::*;
@ -50,7 +51,7 @@ pub trait Tool: 'static + Send + Sync {
fn needs_confirmation(&self) -> bool;
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self) -> serde_json::Value {
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::default())
}

View file

@ -18,6 +18,7 @@ mod path_search_tool;
mod read_file_tool;
mod regex_search_tool;
mod replace;
mod schema;
mod symbol_info_tool;
mod thinking_tool;

View file

@ -1,7 +1,8 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -38,9 +39,8 @@ impl Tool for BashTool {
IconName::Terminal
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(BashToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<BashToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,8 +1,9 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use futures::future::join_all;
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -162,9 +163,8 @@ impl Tool for BatchTool {
IconName::Cog
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(BatchToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<BatchToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -7,7 +7,7 @@ use assistant_tool::{ActionLog, Tool};
use collections::IndexMap;
use gpui::{App, AsyncApp, Entity, Task};
use language::{CodeLabel, Language, LanguageRegistry};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use lsp::SymbolKind;
use project::{DocumentSymbol, Project, Symbol};
use regex::{Regex, RegexBuilder};
@ -17,6 +17,7 @@ use ui::IconName;
use util::markdown::MarkdownString;
use crate::code_symbol_iter::{CodeSymbolIterator, Entry};
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct CodeSymbolsInput {
@ -93,9 +94,8 @@ impl Tool for CodeSymbolsTool {
IconName::Eye
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(CodeSymbolsInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<CodeSymbolsInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,7 +1,9 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -53,9 +55,8 @@ impl Tool for CopyPathTool {
IconName::Clipboard
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(CopyPathToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<CopyPathToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,7 +1,9 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -43,9 +45,8 @@ impl Tool for CreateDirectoryTool {
IconName::Folder
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(CreateDirectoryToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<CreateDirectoryToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,7 +1,9 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -50,9 +52,8 @@ impl Tool for CreateFileTool {
IconName::FileCreate
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(CreateFileToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<CreateFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,8 +1,9 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use futures::{channel::mpsc, SinkExt, StreamExt};
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, ProjectPath};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -44,9 +45,8 @@ impl Tool for DeletePathTool {
IconName::FileDelete
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(DeletePathToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<DeletePathToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,8 +1,9 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -57,9 +58,8 @@ impl Tool for DiagnosticsTool {
IconName::Warning
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(DiagnosticsToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<DiagnosticsToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -2,12 +2,14 @@ mod edit_action;
pub mod log;
use crate::replace::{replace_exact, replace_with_flexible_indent};
use crate::schema::json_schema_for;
use anyhow::{anyhow, Context, Result};
use assistant_tool::{ActionLog, Tool};
use collections::HashSet;
use edit_action::{EditAction, EditActionParser};
use futures::{channel::mpsc, SinkExt, StreamExt};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::LanguageModelToolSchemaFormat;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
@ -91,9 +93,8 @@ impl Tool for EditFilesTool {
IconName::Pencil
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(EditFilesToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<EditFilesToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -2,13 +2,14 @@ use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{anyhow, bail, Context as _, Result};
use assistant_tool::{ActionLog, Tool};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext as _, Entity, Task};
use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -127,9 +128,8 @@ impl Tool for FetchTool {
IconName::Globe
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(FetchToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<FetchToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,7 +1,8 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -140,9 +141,8 @@ impl Tool for FindReplaceFileTool {
IconName::Pencil
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(FindReplaceFileToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<FindReplaceFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,7 +1,8 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -55,9 +56,8 @@ impl Tool for ListDirectoryTool {
IconName::Folder
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(ListDirectoryToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<ListDirectoryToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,7 +1,8 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -53,9 +54,8 @@ impl Tool for MovePathTool {
IconName::ArrowRightLeft
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(MovePathToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<MovePathToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,10 +1,11 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use chrono::{Local, Utc};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -44,9 +45,8 @@ impl Tool for NowTool {
IconName::Info
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(NowToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<NowToolInput>(format)
}
fn ui_text(&self, _input: &serde_json::Value) -> String {

View file

@ -1,7 +1,8 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -34,9 +35,8 @@ impl Tool for OpenTool {
IconName::ExternalLink
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(OpenToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<OpenToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,7 +1,8 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -52,9 +53,8 @@ impl Tool for PathSearchTool {
IconName::SearchCode
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(PathSearchToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<PathSearchToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,11 +1,12 @@
use std::path::Path;
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use itertools::Itertools;
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -58,9 +59,8 @@ impl Tool for ReadFileTool {
IconName::Eye
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(ReadFileToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<ReadFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,9 +1,10 @@
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use futures::StreamExt;
use gpui::{App, Entity, Task};
use language::OffsetRangeExt;
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{
search::{SearchQuery, SearchResult},
Project,
@ -55,9 +56,8 @@ impl Tool for RegexSearchTool {
IconName::Regex
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(RegexSearchToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<RegexSearchToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -0,0 +1,120 @@
use anyhow::Result;
use language_model::LanguageModelToolSchemaFormat;
use schemars::{
schema::{RootSchema, Schema, SchemaObject},
JsonSchema,
};
pub fn json_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> serde_json::Value {
let schema = root_schema_for::<T>(format);
schema_to_json(&schema, format).expect("Failed to convert tool calling schema to JSON")
}
pub fn schema_to_json(
schema: &RootSchema,
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut value = serde_json::to_value(schema)?;
match format {
LanguageModelToolSchemaFormat::JsonSchema => Ok(value),
LanguageModelToolSchemaFormat::JsonSchemaSubset => {
transform_fields_to_json_schema_subset(&mut value);
Ok(value)
}
}
}
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
let mut generator = match format {
LanguageModelToolSchemaFormat::JsonSchema => schemars::SchemaGenerator::default(),
LanguageModelToolSchemaFormat::JsonSchemaSubset => {
schemars::r#gen::SchemaSettings::default()
.with(|settings| {
settings.meta_schema = None;
settings.inline_subschemas = true;
settings
.visitors
.push(Box::new(TransformToJsonSchemaSubsetVisitor));
})
.into_generator()
}
};
generator.root_schema_for::<T>()
}
#[derive(Debug, Clone)]
struct TransformToJsonSchemaSubsetVisitor;
impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
fn visit_root_schema(&mut self, root: &mut RootSchema) {
schemars::visit::visit_root_schema(self, root)
}
fn visit_schema(&mut self, schema: &mut Schema) {
schemars::visit::visit_schema(self, schema)
}
fn visit_schema_object(&mut self, schema: &mut SchemaObject) {
// Ensure that the type field is not an array, this happens when we use
// Option<T>, the type will be [T, "null"].
if let Some(instance_type) = schema.instance_type.take() {
schema.instance_type = match instance_type {
schemars::schema::SingleOrVec::Single(t) => {
Some(schemars::schema::SingleOrVec::Single(t))
}
schemars::schema::SingleOrVec::Vec(items) => items
.into_iter()
.next()
.map(schemars::schema::SingleOrVec::from),
};
}
// One of is not supported, use anyOf instead.
if let Some(subschema) = schema.subschemas.as_mut() {
if let Some(one_of) = subschema.one_of.take() {
subschema.any_of = Some(one_of);
}
}
schemars::visit::visit_schema_object(self, schema)
}
}
fn transform_fields_to_json_schema_subset(json: &mut serde_json::Value) {
if let serde_json::Value::Object(obj) = json {
if let Some(default) = obj.get("default") {
let is_null = default.is_null();
//Default is not supported, so we need to remove it.
obj.remove("default");
if is_null {
obj.insert("nullable".to_string(), serde_json::Value::Bool(true));
}
}
// If a type is not specified for an input parameter we need to add it.
if obj.contains_key("description")
&& !obj.contains_key("type")
&& !(obj.contains_key("anyOf")
|| obj.contains_key("oneOf")
|| obj.contains_key("allOf"))
{
obj.insert(
"type".to_string(),
serde_json::Value::String("string".to_string()),
);
}
//Format field is only partially supported (e.g. not uint compatibility)
obj.remove("format");
for (_, value) in obj.iter_mut() {
if let serde_json::Value::Object(_) | serde_json::Value::Array(_) = value {
transform_fields_to_json_schema_subset(value);
}
}
} else if let serde_json::Value::Array(arr) = json {
for item in arr.iter_mut() {
transform_fields_to_json_schema_subset(item);
}
}
}

View file

@ -2,7 +2,7 @@ use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, AsyncApp, Entity, Task};
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -10,6 +10,8 @@ use std::{fmt::Write, ops::Range, sync::Arc};
use ui::IconName;
use util::markdown::MarkdownString;
use crate::schema::json_schema_for;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct SymbolInfoToolInput {
/// The relative path to the file containing the symbol.
@ -82,9 +84,8 @@ impl Tool for SymbolInfoTool {
IconName::Eye
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(SymbolInfoToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<SymbolInfoToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {

View file

@ -1,9 +1,10 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -35,9 +36,8 @@ impl Tool for ThinkingTool {
IconName::Brain
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(ThinkingToolInput);
serde_json::to_value(&schema).unwrap()
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<ThinkingToolInput>(format)
}
fn ui_text(&self, _input: &serde_json::Value) -> String {

View file

@ -4,7 +4,7 @@ use anyhow::{anyhow, bail, Result};
use assistant_tool::{ActionLog, Tool, ToolSource};
use gpui::{App, Entity, Task};
use icons::IconName;
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use crate::manager::ContextServerManager;
@ -53,7 +53,7 @@ impl Tool for ContextServerTool {
true
}
fn input_schema(&self) -> serde_json::Value {
fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value {
match &self.tool.input_schema {
serde_json::Value::Null => {
serde_json::json!({ "type": "object", "properties": [] })

View file

@ -127,6 +127,10 @@ pub struct GenerateContentRequest {
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_config: Option<ToolConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
@ -134,6 +138,7 @@ pub struct GenerateContentRequest {
pub struct GenerateContentResponse {
pub candidates: Option<Vec<GenerateContentCandidate>>,
pub prompt_feedback: Option<PromptFeedback>,
pub usage_metadata: Option<UsageMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
@ -166,6 +171,8 @@ pub enum Role {
pub enum Part {
TextPart(TextPart),
InlineDataPart(InlineDataPart),
FunctionCallPart(FunctionCallPart),
FunctionResponsePart(FunctionResponsePart),
}
#[derive(Debug, Serialize, Deserialize)]
@ -187,6 +194,18 @@ pub struct GenerativeContentBlob {
pub data: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallPart {
pub function_call: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponsePart {
pub function_response: FunctionResponse,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
@ -210,6 +229,17 @@ pub struct PromptFeedback {
pub block_reason_message: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
pub prompt_token_count: Option<usize>,
pub cached_content_token_count: Option<usize>,
pub candidates_token_count: Option<usize>,
pub tool_use_prompt_token_count: Option<usize>,
pub thoughts_token_count: Option<usize>,
pub total_token_count: Option<usize>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
@ -298,6 +328,53 @@ pub struct CountTokensResponse {
pub total_tokens: usize,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionResponse {
pub name: String,
pub response: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub function_declarations: Vec<FunctionDeclaration>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub function_calling_config: FunctionCallingConfig,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallingConfig {
pub mode: FunctionCallingMode,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_function_names: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FunctionCallingMode {
Auto,
Any,
None,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
pub enum Model {

View file

@ -66,6 +66,15 @@ pub enum LanguageModelCompletionEvent {
UsageUpdate(TokenUsage),
}
/// Indicates the format used to define the input schema for a language model tool.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum LanguageModelToolSchemaFormat {
/// A JSON schema, see https://json-schema.org
JsonSchema,
/// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
JsonSchemaSubset,
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
@ -176,6 +185,10 @@ pub trait LanguageModel: Send + Sync {
LanguageModelAvailability::Public
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
LanguageModelToolSchemaFormat::JsonSchema
}
fn max_token_count(&self) -> usize;
fn max_output_tokens(&self) -> Option<u32> {
None

View file

@ -167,6 +167,7 @@ impl LanguageModelImage {
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct LanguageModelToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub tool_name: Arc<str>,
pub is_error: bool,
pub content: Arc<str>,
}

View file

@ -2,13 +2,17 @@ use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{future::BoxFuture, FutureExt, StreamExt};
use google_ai::stream_generate_content;
use google_ai::{FunctionDeclaration, GenerateContentResponse, Part, UsageMetadata};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
};
use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -17,7 +21,8 @@ use language_model::{
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc};
use std::pin::Pin;
use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, List, Tooltip};
@ -174,7 +179,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
request_limiter: RateLimiter::new(4),
}))
}
@ -211,7 +216,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -240,7 +245,39 @@ pub struct GoogleLanguageModel {
model: google_ai::Model,
state: gpui::Entity<State>,
http_client: Arc<dyn HttpClient>,
rate_limiter: RateLimiter,
request_limiter: RateLimiter,
}
impl GoogleLanguageModel {
fn stream_completion(
&self,
request: google_ai::GenerateContentRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
let request = google_ai::stream_generate_content(
http_client.as_ref(),
&api_url,
&api_key,
request,
);
request.await.context("failed to stream completion")
}
.boxed()
}
}
impl LanguageModel for GoogleLanguageModel {
@ -260,6 +297,10 @@ impl LanguageModel for GoogleLanguageModel {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
LanguageModelToolSchemaFormat::JsonSchemaSubset
}
fn telemetry_id(&self) -> String {
format!("google/{}", self.model.id())
}
@ -305,40 +346,67 @@ impl LanguageModel for GoogleLanguageModel {
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
let request = into_google(request, self.model.id().to_string());
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?;
let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
Ok(map_to_language_model_completion_events(response))
});
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncApp,
request: LanguageModelRequest,
name: String,
description: String,
schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
let mut request = into_google(request, self.model.id().to_string());
request.tools = Some(vec![google_ai::Tool {
function_declarations: vec![google_ai::FunctionDeclaration {
name: name.clone(),
description,
parameters: schema,
}],
}]);
request.tool_config = Some(google_ai::ToolConfig {
function_calling_config: google_ai::FunctionCallingConfig {
mode: google_ai::FunctionCallingMode::Any,
allowed_function_names: Some(vec![name]),
},
});
let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
let response = response.await?;
Ok(response
.filter_map(|event| async move {
match event {
Ok(response) => {
if let Some(candidates) = &response.candidates {
for candidate in candidates {
for part in &candidate.content.parts {
if let google_ai::Part::FunctionCallPart(
function_call_part,
) = part
{
return Some(Ok(serde_json::to_string(
&function_call_part.function_call.args,
)
.unwrap_or_default()));
}
}
}
}
None
}
Err(e) => Some(Err(e)),
}
})
.boxed())
})
.boxed()
}
}
@ -351,11 +419,41 @@ pub fn into_google(
contents: request
.messages
.into_iter()
.map(|msg| google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: msg.string_contents(),
})],
role: match msg.role {
.map(|message| google_ai::Content {
parts: message
.content
.into_iter()
.filter_map(|content| match content {
language_model::MessageContent::Text(text) => {
if !text.is_empty() {
Some(Part::TextPart(google_ai::TextPart { text }))
} else {
None
}
}
language_model::MessageContent::Image(_) => None,
language_model::MessageContent::ToolUse(tool_use) => {
Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
function_call: google_ai::FunctionCall {
name: tool_use.name.to_string(),
args: tool_use.input,
},
}))
}
language_model::MessageContent::ToolResult(tool_result) => Some(
Part::FunctionResponsePart(google_ai::FunctionResponsePart {
function_response: google_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": tool_result.content
}),
},
}),
),
})
.collect(),
role: match message.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
@ -371,9 +469,119 @@ pub fn into_google(
top_k: None,
}),
safety_settings: None,
tools: Some(
request
.tools
.into_iter()
.map(|tool| google_ai::Tool {
function_declarations: vec![FunctionDeclaration {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
}],
})
.collect(),
),
tool_config: None,
}
}
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
use std::sync::atomic::{AtomicU64, Ordering};
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
usage: UsageMetadata,
stop_reason: StopReason,
}
futures::stream::unfold(
State {
events,
usage: UsageMetadata::default(),
stop_reason: StopReason::EndTurn,
},
|mut state| async move {
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut state.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&state.usage),
)))
}
if let Some(candidates) = event.candidates {
for candidate in candidates {
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
state.stop_reason = match finish_reason {
"STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens,
_ => {
log::error!(
"Unexpected google finish_reason: {finish_reason}"
);
StopReason::EndTurn
}
};
}
candidate
.content
.parts
.into_iter()
.for_each(|part| match part {
Part::TextPart(text_part) => events.push(Ok(
LanguageModelCompletionEvent::Text(text_part.text),
)),
Part::InlineDataPart(_) => {}
Part::FunctionCallPart(function_call_part) => {
wants_to_use_tool = true;
let name: Arc<str> =
function_call_part.function_call.name.into();
let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into();
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id,
name,
input: function_call_part.function_call.args,
},
)));
}
Part::FunctionResponsePart(_) => {}
});
}
}
// Even when Gemini wants to use a Tool, the API
// responds with `finish_reason: STOP`
if wants_to_use_tool {
state.stop_reason = StopReason::ToolUse;
}
events.push(Ok(LanguageModelCompletionEvent::Stop(state.stop_reason)));
return Some((events, state));
}
Err(err) => {
return Some((vec![Err(anyhow!(err))], state));
}
}
}
None
},
)
.flat_map(futures::stream::iter)
}
pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,
@ -403,6 +611,36 @@ pub fn count_google_tokens(
.boxed()
}
fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
if let Some(prompt_token_count) = new.prompt_token_count {
usage.prompt_token_count = Some(prompt_token_count);
}
if let Some(cached_content_token_count) = new.cached_content_token_count {
usage.cached_content_token_count = Some(cached_content_token_count);
}
if let Some(candidates_token_count) = new.candidates_token_count {
usage.candidates_token_count = Some(candidates_token_count);
}
if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
}
if let Some(thoughts_token_count) = new.thoughts_token_count {
usage.thoughts_token_count = Some(thoughts_token_count);
}
if let Some(total_token_count) = new.total_token_count {
usage.total_token_count = Some(total_token_count);
}
}
fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
language_model::TokenUsage {
input_tokens: usage.prompt_token_count.unwrap_or(0) as u32,
output_tokens: usage.candidates_token_count.unwrap_or(0) as u32,
cache_read_input_tokens: usage.cached_content_token_count.unwrap_or(0) as u32,
cache_creation_input_tokens: 0,
}
}
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: gpui::Entity<State>,