Add tool calling support for Gemini models (#27772)
Release Notes: - N/A
This commit is contained in:
parent
f6d58f76e4
commit
c8a9a74e6a
32 changed files with 735 additions and 251 deletions
|
@ -2,13 +2,17 @@ use anyhow::{anyhow, Context as _, Result};
|
|||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::Stream;
|
||||
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||
use google_ai::stream_generate_content;
|
||||
use google_ai::{FunctionDeclaration, GenerateContentResponse, Part, UsageMetadata};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, StopReason,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
|
@ -17,7 +21,8 @@ use language_model::{
|
|||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{future, sync::Arc};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{prelude::*, Icon, IconName, List, Tooltip};
|
||||
|
@ -174,7 +179,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
rate_limiter: RateLimiter::new(4),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -211,7 +216,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
rate_limiter: RateLimiter::new(4),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
|
@ -240,7 +245,39 @@ pub struct GoogleLanguageModel {
|
|||
model: google_ai::Model,
|
||||
state: gpui::Entity<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
rate_limiter: RateLimiter,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl GoogleLanguageModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: google_ai::GenerateContentRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
|
||||
> {
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).google;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
|
||||
let request = google_ai::stream_generate_content(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
);
|
||||
request.await.context("failed to stream completion")
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for GoogleLanguageModel {
|
||||
|
@ -260,6 +297,10 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("google/{}", self.model.id())
|
||||
}
|
||||
|
@ -305,40 +346,67 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
> {
|
||||
let request = into_google(request, self.model.id().to_string());
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).google;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.rate_limiter.stream(async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?;
|
||||
let response =
|
||||
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let events = response.await?;
|
||||
Ok(google_ai::extract_text_from_events(events).boxed())
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
||||
Ok(map_to_language_model_completion_events(response))
|
||||
});
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
_description: String,
|
||||
_schema: serde_json::Value,
|
||||
_cx: &AsyncApp,
|
||||
request: LanguageModelRequest,
|
||||
name: String,
|
||||
description: String,
|
||||
schema: serde_json::Value,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||
let mut request = into_google(request, self.model.id().to_string());
|
||||
request.tools = Some(vec![google_ai::Tool {
|
||||
function_declarations: vec![google_ai::FunctionDeclaration {
|
||||
name: name.clone(),
|
||||
description,
|
||||
parameters: schema,
|
||||
}],
|
||||
}]);
|
||||
request.tool_config = Some(google_ai::ToolConfig {
|
||||
function_calling_config: google_ai::FunctionCallingConfig {
|
||||
mode: google_ai::FunctionCallingMode::Any,
|
||||
allowed_function_names: Some(vec![name]),
|
||||
},
|
||||
});
|
||||
let response = self.stream_completion(request, cx);
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = response.await?;
|
||||
Ok(response
|
||||
.filter_map(|event| async move {
|
||||
match event {
|
||||
Ok(response) => {
|
||||
if let Some(candidates) = &response.candidates {
|
||||
for candidate in candidates {
|
||||
for part in &candidate.content.parts {
|
||||
if let google_ai::Part::FunctionCallPart(
|
||||
function_call_part,
|
||||
) = part
|
||||
{
|
||||
return Some(Ok(serde_json::to_string(
|
||||
&function_call_part.function_call.args,
|
||||
)
|
||||
.unwrap_or_default()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -351,11 +419,41 @@ pub fn into_google(
|
|||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| google_ai::Content {
|
||||
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
|
||||
text: msg.string_contents(),
|
||||
})],
|
||||
role: match msg.role {
|
||||
.map(|message| google_ai::Content {
|
||||
parts: message
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
language_model::MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
Some(Part::TextPart(google_ai::TextPart { text }))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
language_model::MessageContent::Image(_) => None,
|
||||
language_model::MessageContent::ToolUse(tool_use) => {
|
||||
Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
|
||||
function_call: google_ai::FunctionCall {
|
||||
name: tool_use.name.to_string(),
|
||||
args: tool_use.input,
|
||||
},
|
||||
}))
|
||||
}
|
||||
language_model::MessageContent::ToolResult(tool_result) => Some(
|
||||
Part::FunctionResponsePart(google_ai::FunctionResponsePart {
|
||||
function_response: google_ai::FunctionResponse {
|
||||
name: tool_result.tool_name.to_string(),
|
||||
// The API expects a valid JSON object
|
||||
response: serde_json::json!({
|
||||
"output": tool_result.content
|
||||
}),
|
||||
},
|
||||
}),
|
||||
),
|
||||
})
|
||||
.collect(),
|
||||
role: match message.role {
|
||||
Role::User => google_ai::Role::User,
|
||||
Role::Assistant => google_ai::Role::Model,
|
||||
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
|
||||
|
@ -371,9 +469,119 @@ pub fn into_google(
|
|||
top_k: None,
|
||||
}),
|
||||
safety_settings: None,
|
||||
tools: Some(
|
||||
request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| google_ai::Tool {
|
||||
function_declarations: vec![FunctionDeclaration {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema,
|
||||
}],
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
tool_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
|
||||
usage: UsageMetadata,
|
||||
stop_reason: StopReason,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
usage: UsageMetadata::default(),
|
||||
stop_reason: StopReason::EndTurn,
|
||||
},
|
||||
|mut state| async move {
|
||||
if let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new();
|
||||
let mut wants_to_use_tool = false;
|
||||
if let Some(usage_metadata) = event.usage_metadata {
|
||||
update_usage(&mut state.usage, &usage_metadata);
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&state.usage),
|
||||
)))
|
||||
}
|
||||
if let Some(candidates) = event.candidates {
|
||||
for candidate in candidates {
|
||||
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
|
||||
state.stop_reason = match finish_reason {
|
||||
"STOP" => StopReason::EndTurn,
|
||||
"MAX_TOKENS" => StopReason::MaxTokens,
|
||||
_ => {
|
||||
log::error!(
|
||||
"Unexpected google finish_reason: {finish_reason}"
|
||||
);
|
||||
StopReason::EndTurn
|
||||
}
|
||||
};
|
||||
}
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.for_each(|part| match part {
|
||||
Part::TextPart(text_part) => events.push(Ok(
|
||||
LanguageModelCompletionEvent::Text(text_part.text),
|
||||
)),
|
||||
Part::InlineDataPart(_) => {}
|
||||
Part::FunctionCallPart(function_call_part) => {
|
||||
wants_to_use_tool = true;
|
||||
let name: Arc<str> =
|
||||
function_call_part.function_call.name.into();
|
||||
let next_tool_id =
|
||||
TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let id: LanguageModelToolUseId =
|
||||
format!("{}-{}", name, next_tool_id).into();
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id,
|
||||
name,
|
||||
input: function_call_part.function_call.args,
|
||||
},
|
||||
)));
|
||||
}
|
||||
Part::FunctionResponsePart(_) => {}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Even when Gemini wants to use a Tool, the API
|
||||
// responds with `finish_reason: STOP`
|
||||
if wants_to_use_tool {
|
||||
state.stop_reason = StopReason::ToolUse;
|
||||
}
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(state.stop_reason)));
|
||||
return Some((events, state));
|
||||
}
|
||||
Err(err) => {
|
||||
return Some((vec![Err(anyhow!(err))], state));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
},
|
||||
)
|
||||
.flat_map(futures::stream::iter)
|
||||
}
|
||||
|
||||
pub fn count_google_tokens(
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
|
@ -403,6 +611,36 @@ pub fn count_google_tokens(
|
|||
.boxed()
|
||||
}
|
||||
|
||||
fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
|
||||
if let Some(prompt_token_count) = new.prompt_token_count {
|
||||
usage.prompt_token_count = Some(prompt_token_count);
|
||||
}
|
||||
if let Some(cached_content_token_count) = new.cached_content_token_count {
|
||||
usage.cached_content_token_count = Some(cached_content_token_count);
|
||||
}
|
||||
if let Some(candidates_token_count) = new.candidates_token_count {
|
||||
usage.candidates_token_count = Some(candidates_token_count);
|
||||
}
|
||||
if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
|
||||
usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
|
||||
}
|
||||
if let Some(thoughts_token_count) = new.thoughts_token_count {
|
||||
usage.thoughts_token_count = Some(thoughts_token_count);
|
||||
}
|
||||
if let Some(total_token_count) = new.total_token_count {
|
||||
usage.total_token_count = Some(total_token_count);
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
|
||||
language_model::TokenUsage {
|
||||
input_tokens: usage.prompt_token_count.unwrap_or(0) as u32,
|
||||
output_tokens: usage.candidates_token_count.unwrap_or(0) as u32,
|
||||
cache_read_input_tokens: usage.cached_content_token_count.unwrap_or(0) as u32,
|
||||
cache_creation_input_tokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<Editor>,
|
||||
state: gpui::Entity<State>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue