Add thinking budget for Gemini custom models (#31251)
Closes #31243 As described in my issue, the [thinking budget](https://ai.google.dev/gemini-api/docs/thinking) gets automatically chosen by Gemini unless it is specifically set to something. In order to have fast responses (inline assistant) I prefer to set it to 0. Release Notes: - ai: Added `thinking` mode for custom Google models with configurable token budget --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
b74477d12e
commit
cf931247d0
3 changed files with 82 additions and 5 deletions
|
@ -289,6 +289,22 @@ pub struct UsageMetadata {
|
|||
pub total_token_count: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ThinkingConfig {
|
||||
pub thinking_budget: u32,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum GoogleModelMode {
|
||||
#[default]
|
||||
Default,
|
||||
Thinking {
|
||||
budget_tokens: Option<u32>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
|
@ -304,6 +320,8 @@ pub struct GenerationConfig {
|
|||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thinking_config: Option<ThinkingConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -496,6 +514,8 @@ pub enum Model {
|
|||
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
|
||||
display_name: Option<String>,
|
||||
max_tokens: usize,
|
||||
#[serde(default)]
|
||||
mode: GoogleModelMode,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -552,6 +572,21 @@ impl Model {
|
|||
Model::Custom { max_tokens, .. } => *max_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mode(&self) -> GoogleModelMode {
|
||||
match self {
|
||||
Self::Gemini15Pro
|
||||
| Self::Gemini15Flash
|
||||
| Self::Gemini20Pro
|
||||
| Self::Gemini20Flash
|
||||
| Self::Gemini20FlashThinking
|
||||
| Self::Gemini20FlashLite
|
||||
| Self::Gemini25ProExp0325
|
||||
| Self::Gemini25ProPreview0325
|
||||
| Self::Gemini25FlashPreview0417 => GoogleModelMode::Default,
|
||||
Self::Custom { mode, .. } => *mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Model {
|
||||
|
|
|
@ -4,6 +4,7 @@ use client::{Client, UserStore, zed_urls};
|
|||
use futures::{
|
||||
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
|
||||
};
|
||||
use google_ai::GoogleModelMode;
|
||||
use gpui::{
|
||||
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
|
||||
};
|
||||
|
@ -750,7 +751,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let model_id = self.model.id.to_string();
|
||||
let generate_content_request = into_google(request, model_id.clone());
|
||||
let generate_content_request =
|
||||
into_google(request, model_id.clone(), GoogleModelMode::Default);
|
||||
async move {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
@ -922,7 +924,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
zed_llm_client::LanguageModelProvider::Google => {
|
||||
let client = self.client.clone();
|
||||
let request = into_google(request, self.model.id.to_string());
|
||||
let request =
|
||||
into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let PerformLlmCompletionResponse {
|
||||
|
|
|
@ -4,7 +4,8 @@ use credentials_provider::CredentialsProvider;
|
|||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
|
||||
use google_ai::{
|
||||
FunctionDeclaration, GenerateContentResponse, Part, SystemInstruction, UsageMetadata,
|
||||
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
|
||||
ThinkingConfig, UsageMetadata,
|
||||
};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
|
@ -45,11 +46,41 @@ pub struct GoogleSettings {
|
|||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ModelMode {
|
||||
#[default]
|
||||
Default,
|
||||
Thinking {
|
||||
/// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
|
||||
budget_tokens: Option<u32>,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<ModelMode> for GoogleModelMode {
|
||||
fn from(value: ModelMode) -> Self {
|
||||
match value {
|
||||
ModelMode::Default => GoogleModelMode::Default,
|
||||
ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GoogleModelMode> for ModelMode {
|
||||
fn from(value: GoogleModelMode) -> Self {
|
||||
match value {
|
||||
GoogleModelMode::Default => ModelMode::Default,
|
||||
GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
name: String,
|
||||
display_name: Option<String>,
|
||||
max_tokens: usize,
|
||||
mode: Option<ModelMode>,
|
||||
}
|
||||
|
||||
pub struct GoogleLanguageModelProvider {
|
||||
|
@ -216,6 +247,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
mode: model.mode.unwrap_or_default().into(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -343,7 +375,7 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
let model_id = self.model.id().to_string();
|
||||
let request = into_google(request, model_id.clone());
|
||||
let request = into_google(request, model_id.clone(), self.model.mode());
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.state.read(cx).api_key.clone();
|
||||
|
||||
|
@ -379,7 +411,7 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
>,
|
||||
>,
|
||||
> {
|
||||
let request = into_google(request, self.model.id().to_string());
|
||||
let request = into_google(request, self.model.id().to_string(), self.model.mode());
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request
|
||||
|
@ -394,6 +426,7 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
pub fn into_google(
|
||||
mut request: LanguageModelRequest,
|
||||
model_id: String,
|
||||
mode: GoogleModelMode,
|
||||
) -> google_ai::GenerateContentRequest {
|
||||
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
|
||||
content
|
||||
|
@ -504,6 +537,12 @@ pub fn into_google(
|
|||
stop_sequences: Some(request.stop),
|
||||
max_output_tokens: None,
|
||||
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
|
||||
thinking_config: match mode {
|
||||
GoogleModelMode::Thinking { budget_tokens } => {
|
||||
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
|
||||
}
|
||||
GoogleModelMode::Default => None,
|
||||
},
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
}),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue