diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 4ed22717fc..6ef9d880fa 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -125,6 +125,7 @@ pub struct GenerateContentRequest { #[serde(default, skip_serializing_if = "String::is_empty")] pub model: String, pub contents: Vec, + pub system_instructions: Option, pub generation_config: Option, pub safety_settings: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -159,6 +160,12 @@ pub struct Content { pub role: Role, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SystemInstructions { + pub parts: Vec, +} + #[derive(Debug, PartialEq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub enum Role { diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 36a01a30c2..78e3dadcad 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -3,14 +3,16 @@ use collections::BTreeMap; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; -use google_ai::{FunctionDeclaration, GenerateContentResponse, Part, UsageMetadata}; +use google_ai::{ + FunctionDeclaration, GenerateContentResponse, Part, SystemInstructions, UsageMetadata, +}; use gpui::{ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, }; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, StopReason, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, }; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, @@ -359,48 +361,65 @@ impl LanguageModel for GoogleLanguageModel { } pub fn into_google( - request: LanguageModelRequest, + mut request: LanguageModelRequest, model: String, ) -> google_ai::GenerateContentRequest { + fn map_content(content: Vec) -> Vec { + content + .into_iter() + .filter_map(|content| match content { + language_model::MessageContent::Text(text) => { + if !text.is_empty() { + Some(Part::TextPart(google_ai::TextPart { text })) + } else { + None + } + } + language_model::MessageContent::Image(_) => None, + language_model::MessageContent::ToolUse(tool_use) => { + Some(Part::FunctionCallPart(google_ai::FunctionCallPart { + function_call: google_ai::FunctionCall { + name: tool_use.name.to_string(), + args: tool_use.input, + }, + })) + } + language_model::MessageContent::ToolResult(tool_result) => Some( + Part::FunctionResponsePart(google_ai::FunctionResponsePart { + function_response: google_ai::FunctionResponse { + name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object + response: serde_json::json!({ + "output": tool_result.content + }), + }, + }), + ), + }) + .collect() + } + + let system_instructions = if request + .messages + .first() + .map_or(false, |msg| matches!(msg.role, Role::System)) + { + let message = request.messages.remove(0); + Some(SystemInstructions { + parts: map_content(message.content), + }) + } else { + None + }; + google_ai::GenerateContentRequest { model, + system_instructions, contents: request .messages .into_iter() .map(|message| google_ai::Content { - parts: message - .content - .into_iter() - .filter_map(|content| match content { - language_model::MessageContent::Text(text) => { - if !text.is_empty() { - Some(Part::TextPart(google_ai::TextPart { text })) - } else { - None - } - } - language_model::MessageContent::Image(_) => None, - language_model::MessageContent::ToolUse(tool_use) => { - Some(Part::FunctionCallPart(google_ai::FunctionCallPart { - function_call: google_ai::FunctionCall { - name: tool_use.name.to_string(), - args: tool_use.input, - }, - })) - } - language_model::MessageContent::ToolResult(tool_result) => Some( - Part::FunctionResponsePart(google_ai::FunctionResponsePart { - function_response: google_ai::FunctionResponse { - name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object - response: serde_json::json!({ - "output": tool_result.content - }), - }, - }), - ), - }) - .collect(), + parts: map_content(message.content), role: match message.role { Role::User => google_ai::Role::User, Role::Assistant => google_ai::Role::Model,