Add thinking budget for Gemini custom models (#31251)

Closes #31243

As described in my issue, the [thinking
budget](https://ai.google.dev/gemini-api/docs/thinking) gets
automatically chosen by Gemini unless it is specifically set to
something. In order to have fast responses (inline assistant) I prefer
to set it to 0.

Release Notes:

- ai: Added `thinking` mode for custom Google models with configurable
token budget

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
90aca 2025-06-03 13:40:20 +02:00 committed by GitHub
parent b74477d12e
commit cf931247d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 82 additions and 5 deletions

View file

@ -289,6 +289,22 @@ pub struct UsageMetadata {
pub total_token_count: Option<usize>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThinkingConfig {
pub thinking_budget: u32,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum GoogleModelMode {
#[default]
Default,
Thinking {
budget_tokens: Option<u32>,
},
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
@ -304,6 +320,8 @@ pub struct GenerationConfig {
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_config: Option<ThinkingConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
@ -496,6 +514,8 @@ pub enum Model {
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
display_name: Option<String>,
max_tokens: usize,
#[serde(default)]
mode: GoogleModelMode,
},
}
@ -552,6 +572,21 @@ impl Model {
Model::Custom { max_tokens, .. } => *max_tokens,
}
}
pub fn mode(&self) -> GoogleModelMode {
match self {
Self::Gemini15Pro
| Self::Gemini15Flash
| Self::Gemini20Pro
| Self::Gemini20Flash
| Self::Gemini20FlashThinking
| Self::Gemini20FlashLite
| Self::Gemini25ProExp0325
| Self::Gemini25ProPreview0325
| Self::Gemini25FlashPreview0417 => GoogleModelMode::Default,
Self::Custom { mode, .. } => *mode,
}
}
}
impl std::fmt::Display for Model {

View file

@ -4,6 +4,7 @@ use client::{Client, UserStore, zed_urls};
use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
};
use google_ai::GoogleModelMode;
use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
};
@ -750,7 +751,8 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let model_id = self.model.id.to_string();
let generate_content_request = into_google(request, model_id.clone());
let generate_content_request =
into_google(request, model_id.clone(), GoogleModelMode::Default);
async move {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
@ -922,7 +924,8 @@ impl LanguageModel for CloudLanguageModel {
}
zed_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
let request = into_google(request, self.model.id.to_string());
let request =
into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {

View file

@ -4,7 +4,8 @@ use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, Part, SystemInstruction, UsageMetadata,
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
ThinkingConfig, UsageMetadata,
};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
@ -45,11 +46,41 @@ pub struct GoogleSettings {
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ModelMode {
#[default]
Default,
Thinking {
/// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
budget_tokens: Option<u32>,
},
}
impl From<ModelMode> for GoogleModelMode {
fn from(value: ModelMode) -> Self {
match value {
ModelMode::Default => GoogleModelMode::Default,
ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens },
}
}
}
impl From<GoogleModelMode> for ModelMode {
fn from(value: GoogleModelMode) -> Self {
match value {
GoogleModelMode::Default => ModelMode::Default,
GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
name: String,
display_name: Option<String>,
max_tokens: usize,
mode: Option<ModelMode>,
}
pub struct GoogleLanguageModelProvider {
@ -216,6 +247,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
mode: model.mode.unwrap_or_default().into(),
},
);
}
@ -343,7 +375,7 @@ impl LanguageModel for GoogleLanguageModel {
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
let model_id = self.model.id().to_string();
let request = into_google(request, model_id.clone());
let request = into_google(request, model_id.clone(), self.model.mode());
let http_client = self.http_client.clone();
let api_key = self.state.read(cx).api_key.clone();
@ -379,7 +411,7 @@ impl LanguageModel for GoogleLanguageModel {
>,
>,
> {
let request = into_google(request, self.model.id().to_string());
let request = into_google(request, self.model.id().to_string(), self.model.mode());
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request
@ -394,6 +426,7 @@ impl LanguageModel for GoogleLanguageModel {
pub fn into_google(
mut request: LanguageModelRequest,
model_id: String,
mode: GoogleModelMode,
) -> google_ai::GenerateContentRequest {
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
content
@ -504,6 +537,12 @@ pub fn into_google(
stop_sequences: Some(request.stop),
max_output_tokens: None,
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
thinking_config: match mode {
GoogleModelMode::Thinking { budget_tokens } => {
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
}
GoogleModelMode::Default => None,
},
top_p: None,
top_k: None,
}),