diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index c6472dfd68..85a08d5afa 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -289,6 +289,22 @@ pub struct UsageMetadata { pub total_token_count: Option, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ThinkingConfig { + pub thinking_budget: u32, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] +pub enum GoogleModelMode { + #[default] + Default, + Thinking { + budget_tokens: Option, + }, +} + #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { @@ -304,6 +320,8 @@ pub struct GenerationConfig { pub top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking_config: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -496,6 +514,8 @@ pub enum Model { /// The name displayed in the UI, such as in the assistant panel model dropdown menu. display_name: Option, max_tokens: usize, + #[serde(default)] + mode: GoogleModelMode, }, } @@ -552,6 +572,21 @@ impl Model { Model::Custom { max_tokens, .. } => *max_tokens, } } + + pub fn mode(&self) -> GoogleModelMode { + match self { + Self::Gemini15Pro + | Self::Gemini15Flash + | Self::Gemini20Pro + | Self::Gemini20Flash + | Self::Gemini20FlashThinking + | Self::Gemini20FlashLite + | Self::Gemini25ProExp0325 + | Self::Gemini25ProPreview0325 + | Self::Gemini25FlashPreview0417 => GoogleModelMode::Default, + Self::Custom { mode, .. } => *mode, + } + } } impl std::fmt::Display for Model { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 6e53bbf0e8..ee6fe8d484 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -4,6 +4,7 @@ use client::{Client, UserStore, zed_urls}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; +use google_ai::GoogleModelMode; use gpui::{ AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task, }; @@ -750,7 +751,8 @@ impl LanguageModel for CloudLanguageModel { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let model_id = self.model.id.to_string(); - let generate_content_request = into_google(request, model_id.clone()); + let generate_content_request = + into_google(request, model_id.clone(), GoogleModelMode::Default); async move { let http_client = &client.http_client(); let token = llm_api_token.acquire(&client).await?; @@ -922,7 +924,8 @@ impl LanguageModel for CloudLanguageModel { } zed_llm_client::LanguageModelProvider::Google => { let client = self.client.clone(); - let request = into_google(request, self.model.id.to_string()); + let request = + into_google(request, self.model.id.to_string(), GoogleModelMode::Default); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { let PerformLlmCompletionResponse { diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 718a7ba7ea..6ff70a3a91 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -4,7 +4,8 @@ use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use google_ai::{ - FunctionDeclaration, GenerateContentResponse, Part, SystemInstruction, UsageMetadata, + FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction, + ThinkingConfig, UsageMetadata, }; use gpui::{ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, @@ -45,11 +46,41 @@ pub struct GoogleSettings { pub available_models: Vec, } +#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. + budget_tokens: Option, + }, +} + +impl From for GoogleModelMode { + fn from(value: ModelMode) -> Self { + match value { + ModelMode::Default => GoogleModelMode::Default, + ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens }, + } + } +} + +impl From for ModelMode { + fn from(value: GoogleModelMode) -> Self { + match value { + GoogleModelMode::Default => ModelMode::Default, + GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens }, + } + } +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct AvailableModel { name: String, display_name: Option, max_tokens: usize, + mode: Option, } pub struct GoogleLanguageModelProvider { @@ -216,6 +247,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { name: model.name.clone(), display_name: model.display_name.clone(), max_tokens: model.max_tokens, + mode: model.mode.unwrap_or_default().into(), }, ); } @@ -343,7 +375,7 @@ impl LanguageModel for GoogleLanguageModel { cx: &App, ) -> BoxFuture<'static, Result> { let model_id = self.model.id().to_string(); - let request = into_google(request, model_id.clone()); + let request = into_google(request, model_id.clone(), self.model.mode()); let http_client = self.http_client.clone(); let api_key = self.state.read(cx).api_key.clone(); @@ -379,7 +411,7 @@ impl LanguageModel for GoogleLanguageModel { >, >, > { - let request = into_google(request, self.model.id().to_string()); + let request = into_google(request, self.model.id().to_string(), self.model.mode()); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { let response = request @@ -394,6 +426,7 @@ impl LanguageModel for GoogleLanguageModel { pub fn into_google( mut request: LanguageModelRequest, model_id: String, + mode: GoogleModelMode, ) -> google_ai::GenerateContentRequest { fn map_content(content: Vec) -> Vec { content @@ -504,6 +537,12 @@ pub fn into_google( stop_sequences: Some(request.stop), max_output_tokens: None, temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), + thinking_config: match mode { + GoogleModelMode::Thinking { budget_tokens } => { + budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget }) + } + GoogleModelMode::Default => None, + }, top_p: None, top_k: None, }),