gemini: Pass system prompt as system instructions (#28793)
https://ai.google.dev/gemini-api/docs/text-generation#system-instructions Release Notes: - agent: Improve performance of Gemini models
This commit is contained in:
parent
c381a500f8
commit
c7e80c80c6
2 changed files with 62 additions and 36 deletions
|
@ -125,6 +125,7 @@ pub struct GenerateContentRequest {
|
||||||
#[serde(default, skip_serializing_if = "String::is_empty")]
|
#[serde(default, skip_serializing_if = "String::is_empty")]
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub contents: Vec<Content>,
|
pub contents: Vec<Content>,
|
||||||
|
pub system_instructions: Option<SystemInstructions>,
|
||||||
pub generation_config: Option<GenerationConfig>,
|
pub generation_config: Option<GenerationConfig>,
|
||||||
pub safety_settings: Option<Vec<SafetySetting>>,
|
pub safety_settings: Option<Vec<SafetySetting>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
@ -159,6 +160,12 @@ pub struct Content {
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct SystemInstructions {
|
||||||
|
pub parts: Vec<Part>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Deserialize, Serialize)]
|
#[derive(Debug, PartialEq, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub enum Role {
|
pub enum Role {
|
||||||
|
|
|
@ -3,14 +3,16 @@ use collections::BTreeMap;
|
||||||
use credentials_provider::CredentialsProvider;
|
use credentials_provider::CredentialsProvider;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
|
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
|
||||||
use google_ai::{FunctionDeclaration, GenerateContentResponse, Part, UsageMetadata};
|
use google_ai::{
|
||||||
|
FunctionDeclaration, GenerateContentResponse, Part, SystemInstructions, UsageMetadata,
|
||||||
|
};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||||
};
|
};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
|
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
|
||||||
LanguageModelToolUse, LanguageModelToolUseId, StopReason,
|
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||||
|
@ -359,48 +361,65 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_google(
|
pub fn into_google(
|
||||||
request: LanguageModelRequest,
|
mut request: LanguageModelRequest,
|
||||||
model: String,
|
model: String,
|
||||||
) -> google_ai::GenerateContentRequest {
|
) -> google_ai::GenerateContentRequest {
|
||||||
|
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
|
||||||
|
content
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|content| match content {
|
||||||
|
language_model::MessageContent::Text(text) => {
|
||||||
|
if !text.is_empty() {
|
||||||
|
Some(Part::TextPart(google_ai::TextPart { text }))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
language_model::MessageContent::Image(_) => None,
|
||||||
|
language_model::MessageContent::ToolUse(tool_use) => {
|
||||||
|
Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
|
||||||
|
function_call: google_ai::FunctionCall {
|
||||||
|
name: tool_use.name.to_string(),
|
||||||
|
args: tool_use.input,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
language_model::MessageContent::ToolResult(tool_result) => Some(
|
||||||
|
Part::FunctionResponsePart(google_ai::FunctionResponsePart {
|
||||||
|
function_response: google_ai::FunctionResponse {
|
||||||
|
name: tool_result.tool_name.to_string(),
|
||||||
|
// The API expects a valid JSON object
|
||||||
|
response: serde_json::json!({
|
||||||
|
"output": tool_result.content
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
let system_instructions = if request
|
||||||
|
.messages
|
||||||
|
.first()
|
||||||
|
.map_or(false, |msg| matches!(msg.role, Role::System))
|
||||||
|
{
|
||||||
|
let message = request.messages.remove(0);
|
||||||
|
Some(SystemInstructions {
|
||||||
|
parts: map_content(message.content),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
google_ai::GenerateContentRequest {
|
google_ai::GenerateContentRequest {
|
||||||
model,
|
model,
|
||||||
|
system_instructions,
|
||||||
contents: request
|
contents: request
|
||||||
.messages
|
.messages
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| google_ai::Content {
|
.map(|message| google_ai::Content {
|
||||||
parts: message
|
parts: map_content(message.content),
|
||||||
.content
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|content| match content {
|
|
||||||
language_model::MessageContent::Text(text) => {
|
|
||||||
if !text.is_empty() {
|
|
||||||
Some(Part::TextPart(google_ai::TextPart { text }))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
language_model::MessageContent::Image(_) => None,
|
|
||||||
language_model::MessageContent::ToolUse(tool_use) => {
|
|
||||||
Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
|
|
||||||
function_call: google_ai::FunctionCall {
|
|
||||||
name: tool_use.name.to_string(),
|
|
||||||
args: tool_use.input,
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
language_model::MessageContent::ToolResult(tool_result) => Some(
|
|
||||||
Part::FunctionResponsePart(google_ai::FunctionResponsePart {
|
|
||||||
function_response: google_ai::FunctionResponse {
|
|
||||||
name: tool_result.tool_name.to_string(),
|
|
||||||
// The API expects a valid JSON object
|
|
||||||
response: serde_json::json!({
|
|
||||||
"output": tool_result.content
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
role: match message.role {
|
role: match message.role {
|
||||||
Role::User => google_ai::Role::User,
|
Role::User => google_ai::Role::User,
|
||||||
Role::Assistant => google_ai::Role::Model,
|
Role::Assistant => google_ai::Role::Model,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue