Add tool calling support for Gemini models (#27772)

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-03-31 17:46:42 +02:00 committed by GitHub
parent f6d58f76e4
commit c8a9a74e6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 735 additions and 251 deletions

View file

@ -2,13 +2,17 @@ use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{future::BoxFuture, FutureExt, StreamExt};
use google_ai::stream_generate_content;
use google_ai::{FunctionDeclaration, GenerateContentResponse, Part, UsageMetadata};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
};
use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@ -17,7 +21,8 @@ use language_model::{
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc};
use std::pin::Pin;
use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, List, Tooltip};
@ -174,7 +179,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
request_limiter: RateLimiter::new(4),
}))
}
@ -211,7 +216,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -240,7 +245,39 @@ pub struct GoogleLanguageModel {
model: google_ai::Model,
state: gpui::Entity<State>,
http_client: Arc<dyn HttpClient>,
rate_limiter: RateLimiter,
request_limiter: RateLimiter,
}
impl GoogleLanguageModel {
fn stream_completion(
&self,
request: google_ai::GenerateContentRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
let request = google_ai::stream_generate_content(
http_client.as_ref(),
&api_url,
&api_key,
request,
);
request.await.context("failed to stream completion")
}
.boxed()
}
}
impl LanguageModel for GoogleLanguageModel {
@ -260,6 +297,10 @@ impl LanguageModel for GoogleLanguageModel {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
LanguageModelToolSchemaFormat::JsonSchemaSubset
}
fn telemetry_id(&self) -> String {
format!("google/{}", self.model.id())
}
@ -305,40 +346,67 @@ impl LanguageModel for GoogleLanguageModel {
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
let request = into_google(request, self.model.id().to_string());
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?;
let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
Ok(map_to_language_model_completion_events(response))
});
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncApp,
request: LanguageModelRequest,
name: String,
description: String,
schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
let mut request = into_google(request, self.model.id().to_string());
request.tools = Some(vec![google_ai::Tool {
function_declarations: vec![google_ai::FunctionDeclaration {
name: name.clone(),
description,
parameters: schema,
}],
}]);
request.tool_config = Some(google_ai::ToolConfig {
function_calling_config: google_ai::FunctionCallingConfig {
mode: google_ai::FunctionCallingMode::Any,
allowed_function_names: Some(vec![name]),
},
});
let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
let response = response.await?;
Ok(response
.filter_map(|event| async move {
match event {
Ok(response) => {
if let Some(candidates) = &response.candidates {
for candidate in candidates {
for part in &candidate.content.parts {
if let google_ai::Part::FunctionCallPart(
function_call_part,
) = part
{
return Some(Ok(serde_json::to_string(
&function_call_part.function_call.args,
)
.unwrap_or_default()));
}
}
}
}
None
}
Err(e) => Some(Err(e)),
}
})
.boxed())
})
.boxed()
}
}
@ -351,11 +419,41 @@ pub fn into_google(
contents: request
.messages
.into_iter()
.map(|msg| google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: msg.string_contents(),
})],
role: match msg.role {
.map(|message| google_ai::Content {
parts: message
.content
.into_iter()
.filter_map(|content| match content {
language_model::MessageContent::Text(text) => {
if !text.is_empty() {
Some(Part::TextPart(google_ai::TextPart { text }))
} else {
None
}
}
language_model::MessageContent::Image(_) => None,
language_model::MessageContent::ToolUse(tool_use) => {
Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
function_call: google_ai::FunctionCall {
name: tool_use.name.to_string(),
args: tool_use.input,
},
}))
}
language_model::MessageContent::ToolResult(tool_result) => Some(
Part::FunctionResponsePart(google_ai::FunctionResponsePart {
function_response: google_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": tool_result.content
}),
},
}),
),
})
.collect(),
role: match message.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
@ -371,9 +469,119 @@ pub fn into_google(
top_k: None,
}),
safety_settings: None,
tools: Some(
request
.tools
.into_iter()
.map(|tool| google_ai::Tool {
function_declarations: vec![FunctionDeclaration {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
}],
})
.collect(),
),
tool_config: None,
}
}
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
use std::sync::atomic::{AtomicU64, Ordering};
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
usage: UsageMetadata,
stop_reason: StopReason,
}
futures::stream::unfold(
State {
events,
usage: UsageMetadata::default(),
stop_reason: StopReason::EndTurn,
},
|mut state| async move {
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut state.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&state.usage),
)))
}
if let Some(candidates) = event.candidates {
for candidate in candidates {
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
state.stop_reason = match finish_reason {
"STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens,
_ => {
log::error!(
"Unexpected google finish_reason: {finish_reason}"
);
StopReason::EndTurn
}
};
}
candidate
.content
.parts
.into_iter()
.for_each(|part| match part {
Part::TextPart(text_part) => events.push(Ok(
LanguageModelCompletionEvent::Text(text_part.text),
)),
Part::InlineDataPart(_) => {}
Part::FunctionCallPart(function_call_part) => {
wants_to_use_tool = true;
let name: Arc<str> =
function_call_part.function_call.name.into();
let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into();
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id,
name,
input: function_call_part.function_call.args,
},
)));
}
Part::FunctionResponsePart(_) => {}
});
}
}
// Even when Gemini wants to use a Tool, the API
// responds with `finish_reason: STOP`
if wants_to_use_tool {
state.stop_reason = StopReason::ToolUse;
}
events.push(Ok(LanguageModelCompletionEvent::Stop(state.stop_reason)));
return Some((events, state));
}
Err(err) => {
return Some((vec![Err(anyhow!(err))], state));
}
}
}
None
},
)
.flat_map(futures::stream::iter)
}
pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,
@ -403,6 +611,36 @@ pub fn count_google_tokens(
.boxed()
}
fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
if let Some(prompt_token_count) = new.prompt_token_count {
usage.prompt_token_count = Some(prompt_token_count);
}
if let Some(cached_content_token_count) = new.cached_content_token_count {
usage.cached_content_token_count = Some(cached_content_token_count);
}
if let Some(candidates_token_count) = new.candidates_token_count {
usage.candidates_token_count = Some(candidates_token_count);
}
if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
}
if let Some(thoughts_token_count) = new.thoughts_token_count {
usage.thoughts_token_count = Some(thoughts_token_count);
}
if let Some(total_token_count) = new.total_token_count {
usage.total_token_count = Some(total_token_count);
}
}
fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
language_model::TokenUsage {
input_tokens: usage.prompt_token_count.unwrap_or(0) as u32,
output_tokens: usage.candidates_token_count.unwrap_or(0) as u32,
cache_read_input_tokens: usage.cached_content_token_count.unwrap_or(0) as u32,
cache_creation_input_tokens: 0,
}
}
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: gpui::Entity<State>,