Add tool calling support for GitHub Copilot Chat (#28035)

This PR adds tool calling support for GitHub Copilot Chat models.

Currently only supports the Claude family of models.

Release Notes:

- agent: Added tool calling support for Claude models in GitHub Copilot
Chat.

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Bennet Bo Fenner 2025-04-04 23:41:07 +02:00 committed by GitHub
parent c2afc2271b
commit 02e4267bc6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 318 additions and 85 deletions

View file

@ -131,25 +131,70 @@ pub struct Request {
pub temperature: f32, pub temperature: f32,
pub model: Model, pub model: Model,
pub messages: Vec<ChatMessage>, pub messages: Vec<ChatMessage>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
} }
impl Request { #[derive(Serialize, Deserialize)]
pub fn new(model: Model, messages: Vec<ChatMessage>) -> Self { pub struct Function {
Self { pub name: String,
intent: true, pub description: String,
n: 1, pub parameters: serde_json::Value,
stream: model.uses_streaming(), }
temperature: 0.1,
model, #[derive(Serialize, Deserialize)]
messages, #[serde(tag = "type", rename_all = "snake_case")]
} pub enum Tool {
} Function { function: Function },
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolChoice {
Auto,
Any,
Tool { name: String },
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ChatMessage { #[serde(tag = "role", rename_all = "lowercase")]
pub role: Role, pub enum ChatMessage {
pub content: String, Assistant {
content: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
User {
content: String,
},
System {
content: String,
},
Tool {
content: String,
tool_call_id: String,
},
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(flatten)]
pub content: ToolCallContent,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolCallContent {
Function { function: FunctionContent },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionContent {
pub name: String,
pub arguments: String,
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
@ -172,6 +217,21 @@ pub struct ResponseChoice {
pub struct ResponseDelta { pub struct ResponseDelta {
pub content: Option<String>, pub content: Option<String>,
pub role: Option<Role>, pub role: Option<Role>,
#[serde(default)]
pub tool_calls: Vec<ToolCallChunk>,
}
#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCallChunk {
pub index: usize,
pub id: Option<String>,
pub function: Option<FunctionChunk>,
}
#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionChunk {
pub name: Option<String>,
pub arguments: Option<String>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -385,7 +445,8 @@ async fn stream_completion(
let is_streaming = request.stream; let is_streaming = request.stream;
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let json = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(json))?;
let mut response = client.send(request).await?; let mut response = client.send(request).await?;
if !response.status().is_success() { if !response.status().is_success() {
@ -413,9 +474,7 @@ async fn stream_completion(
match serde_json::from_str::<ResponseEvent>(line) { match serde_json::from_str::<ResponseEvent>(line) {
Ok(response) => { Ok(response) => {
if response.choices.is_empty() if response.choices.is_empty() {
|| response.choices.first().unwrap().finish_reason.is_some()
{
None None
} else { } else {
Some(Ok(response)) Some(Ok(response))

View file

@ -1,14 +1,17 @@
use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use collections::HashMap;
use copilot::copilot_chat::{ use copilot::copilot_chat::{
ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest, ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
Role as CopilotChatRole, ResponseEvent, Tool, ToolCall,
}; };
use copilot::{Copilot, Status}; use copilot::{Copilot, Status};
use futures::future::BoxFuture; use futures::future::BoxFuture;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt}; use futures::{FutureExt, Stream, StreamExt};
use gpui::{ use gpui::{
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
Transformation, percentage, svg, Transformation, percentage, svg,
@ -16,12 +19,14 @@ use gpui::{
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
}; };
use settings::SettingsStore; use settings::SettingsStore;
use std::time::Duration; use std::time::Duration;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use ui::prelude::*; use ui::prelude::*;
use util::maybe;
use super::anthropic::count_anthropic_tokens; use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens; use super::google::count_google_tokens;
@ -180,7 +185,12 @@ impl LanguageModel for CopilotChatLanguageModel {
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
false match self.model {
CopilotChatModel::Claude3_5Sonnet
| CopilotChatModel::Claude3_7Sonnet
| CopilotChatModel::Claude3_7SonnetThinking => true,
_ => false,
}
} }
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
@ -240,77 +250,241 @@ impl LanguageModel for CopilotChatLanguageModel {
} }
} }
let copilot_request = self.to_copilot_chat_request(request); let copilot_request = match self.to_copilot_chat_request(request) {
let is_streaming = copilot_request.stream; Ok(request) => request,
Err(err) => return futures::future::ready(Err(err)).boxed(),
};
let request_limiter = self.request_limiter.clone(); let request_limiter = self.request_limiter.clone();
let future = cx.spawn(async move |cx| { let future = cx.spawn(async move |cx| {
let response = CopilotChat::stream_completion(copilot_request, cx.clone()); let request = CopilotChat::stream_completion(copilot_request, cx.clone());
request_limiter.stream(async move { request_limiter
let response = response.await?; .stream(async move {
let stream = response let response = request.await?;
.filter_map(move |response| async move { Ok(map_to_language_model_completion_events(response))
match response {
Ok(result) => {
let choice = result.choices.first();
match choice {
Some(choice) if !is_streaming => {
match &choice.message {
Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no message content"
))),
}
},
Some(choice) => {
match &choice.delta {
Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no delta content"
))),
}
},
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))),
}
}
Err(err) => Some(Err(err)),
}
}) })
.boxed(); .await
Ok(stream)
}).await
}); });
async move { Ok(future.await?.boxed()) }.boxed()
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
} }
} }
impl CopilotChatLanguageModel { pub fn map_to_language_model_completion_events(
pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest { events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
CopilotChatRequest::new( ) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
self.model.clone(), #[derive(Default)]
request struct RawToolCall {
.messages id: String,
.into_iter() name: String,
.map(|msg| ChatMessage { arguments: String,
role: match msg.role { }
Role::User => CopilotChatRole::User,
Role::Assistant => CopilotChatRole::Assistant, struct State {
Role::System => CopilotChatRole::System, events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
tool_calls_by_index: HashMap<usize, RawToolCall>,
}
futures::stream::unfold(
State {
events,
tool_calls_by_index: HashMap::default(),
}, },
content: msg.string_contents(), |mut state| async move {
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
vec![Err(anyhow!("Response contained no choices"))],
state,
));
};
let Some(delta) = choice.delta.as_ref() else {
return Some((
vec![Err(anyhow!("Response contained no delta"))],
state,
));
};
let mut events = Vec::new();
if let Some(content) = delta.content.clone() {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
for tool_call in &delta.tool_calls {
let entry = state
.tool_calls_by_index
.entry(tool_call.index)
.or_default();
if let Some(tool_id) = tool_call.id.clone() {
entry.id = tool_id;
}
if let Some(function) = tool_call.function.as_ref() {
if let Some(name) = function.name.clone() {
entry.name = name;
}
if let Some(arguments) = function.arguments.clone() {
entry.arguments.push_str(&arguments);
}
}
}
match choice.finish_reason.as_deref() {
Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::EndTurn,
)));
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| {
maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
},
))
}) })
.collect(), },
));
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::ToolUse,
)));
}
Some(stop_reason) => {
log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::EndTurn,
)));
}
None => {}
}
return Some((events, state));
}
Err(err) => return Some((vec![Err(err)], state)),
}
}
None
},
) )
.flat_map(futures::stream::iter)
}
impl CopilotChatLanguageModel {
pub fn to_copilot_chat_request(
&self,
request: LanguageModelRequest,
) -> Result<CopilotChatRequest> {
let model = self.model.clone();
let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
for message in request.messages {
if let Some(last_message) = request_messages.last_mut() {
if last_message.role == message.role {
last_message.content.extend(message.content);
} else {
request_messages.push(message);
}
} else {
request_messages.push(message);
}
}
let mut messages: Vec<ChatMessage> = Vec::new();
for message in request_messages {
let text_content = {
let mut buffer = String::new();
for string in message.content.iter().filter_map(|content| match content {
MessageContent::Text(text) => Some(text.as_str()),
MessageContent::ToolUse(_)
| MessageContent::ToolResult(_)
| MessageContent::Image(_) => None,
}) {
buffer.push_str(string);
}
buffer
};
match message.role {
Role::User => {
for content in &message.content {
if let MessageContent::ToolResult(tool_result) = content {
messages.push(ChatMessage::Tool {
tool_call_id: tool_result.tool_use_id.to_string(),
content: tool_result.content.to_string(),
});
}
}
messages.push(ChatMessage::User {
content: text_content,
});
}
Role::Assistant => {
let mut tool_calls = Vec::new();
for content in &message.content {
if let MessageContent::ToolUse(tool_use) = content {
tool_calls.push(ToolCall {
id: tool_use.id.to_string(),
content: copilot::copilot_chat::ToolCallContent::Function {
function: copilot::copilot_chat::FunctionContent {
name: tool_use.name.to_string(),
arguments: serde_json::to_string(&tool_use.input)?,
},
},
});
}
}
messages.push(ChatMessage::Assistant {
content: if text_content.is_empty() {
None
} else {
Some(text_content)
},
tool_calls,
});
}
Role::System => messages.push(ChatMessage::System {
content: message.string_contents(),
}),
}
}
let tools = request
.tools
.iter()
.map(|tool| Tool::Function {
function: copilot::copilot_chat::Function {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
})
.collect();
Ok(CopilotChatRequest {
intent: true,
n: 1,
stream: model.uses_streaming(),
temperature: 0.1,
model,
messages,
tools,
tool_choice: None,
})
} }
} }