Add tool calling support for GitHub Copilot Chat (#28035)

This PR adds tool calling support for GitHub Copilot Chat models.

Currently only supports the Claude family of models.

Release Notes:

- agent: Added tool calling support for Claude models in GitHub Copilot
Chat.

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Bennet Bo Fenner 2025-04-04 23:41:07 +02:00 committed by GitHub
parent c2afc2271b
commit 02e4267bc6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 318 additions and 85 deletions

View file

@ -131,25 +131,70 @@ pub struct Request {
pub temperature: f32,
pub model: Model,
pub messages: Vec<ChatMessage>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
impl Request {
pub fn new(model: Model, messages: Vec<ChatMessage>) -> Self {
Self {
intent: true,
n: 1,
stream: model.uses_streaming(),
temperature: 0.1,
model,
messages,
}
}
#[derive(Serialize, Deserialize)]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Tool {
Function { function: Function },
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolChoice {
Auto,
Any,
Tool { name: String },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
Assistant {
content: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
User {
content: String,
},
System {
content: String,
},
Tool {
content: String,
tool_call_id: String,
},
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(flatten)]
pub content: ToolCallContent,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolCallContent {
Function { function: FunctionContent },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionContent {
pub name: String,
pub arguments: String,
}
#[derive(Deserialize, Debug)]
@ -172,6 +217,21 @@ pub struct ResponseChoice {
pub struct ResponseDelta {
pub content: Option<String>,
pub role: Option<Role>,
#[serde(default)]
pub tool_calls: Vec<ToolCallChunk>,
}
#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCallChunk {
pub index: usize,
pub id: Option<String>,
pub function: Option<FunctionChunk>,
}
#[derive(Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionChunk {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[derive(Deserialize)]
@ -385,7 +445,8 @@ async fn stream_completion(
let is_streaming = request.stream;
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let json = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(json))?;
let mut response = client.send(request).await?;
if !response.status().is_success() {
@ -413,9 +474,7 @@ async fn stream_completion(
match serde_json::from_str::<ResponseEvent>(line) {
Ok(response) => {
if response.choices.is_empty()
|| response.choices.first().unwrap().finish_reason.is_some()
{
if response.choices.is_empty() {
None
} else {
Some(Ok(response))

View file

@ -1,14 +1,17 @@
use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::Arc;
use anyhow::{Result, anyhow};
use collections::HashMap;
use copilot::copilot_chat::{
ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
Role as CopilotChatRole,
ResponseEvent, Tool, ToolCall,
};
use copilot::{Copilot, Status};
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt};
use futures::{FutureExt, Stream, StreamExt};
use gpui::{
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
Transformation, percentage, svg,
@ -16,12 +19,14 @@ use gpui::{
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
};
use settings::SettingsStore;
use std::time::Duration;
use strum::IntoEnumIterator;
use ui::prelude::*;
use util::maybe;
use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
@ -180,7 +185,12 @@ impl LanguageModel for CopilotChatLanguageModel {
}
fn supports_tools(&self) -> bool {
false
match self.model {
CopilotChatModel::Claude3_5Sonnet
| CopilotChatModel::Claude3_7Sonnet
| CopilotChatModel::Claude3_7SonnetThinking => true,
_ => false,
}
}
fn telemetry_id(&self) -> String {
@ -240,77 +250,241 @@ impl LanguageModel for CopilotChatLanguageModel {
}
}
let copilot_request = self.to_copilot_chat_request(request);
let is_streaming = copilot_request.stream;
let copilot_request = match self.to_copilot_chat_request(request) {
Ok(request) => request,
Err(err) => return futures::future::ready(Err(err)).boxed(),
};
let request_limiter = self.request_limiter.clone();
let future = cx.spawn(async move |cx| {
let response = CopilotChat::stream_completion(copilot_request, cx.clone());
request_limiter.stream(async move {
let response = response.await?;
let stream = response
.filter_map(move |response| async move {
match response {
Ok(result) => {
let choice = result.choices.first();
match choice {
Some(choice) if !is_streaming => {
match &choice.message {
Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no message content"
))),
}
},
Some(choice) => {
match &choice.delta {
Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no delta content"
))),
}
},
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))),
}
}
Err(err) => Some(Err(err)),
}
})
.boxed();
Ok(stream)
}).await
let request = CopilotChat::stream_completion(copilot_request, cx.clone());
request_limiter
.stream(async move {
let response = request.await?;
Ok(map_to_language_model_completion_events(response))
})
.await
});
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
}
}
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
#[derive(Default)]
struct RawToolCall {
id: String,
name: String,
arguments: String,
}
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
tool_calls_by_index: HashMap<usize, RawToolCall>,
}
futures::stream::unfold(
State {
events,
tool_calls_by_index: HashMap::default(),
},
|mut state| async move {
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
vec![Err(anyhow!("Response contained no choices"))],
state,
));
};
let Some(delta) = choice.delta.as_ref() else {
return Some((
vec![Err(anyhow!("Response contained no delta"))],
state,
));
};
let mut events = Vec::new();
if let Some(content) = delta.content.clone() {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
for tool_call in &delta.tool_calls {
let entry = state
.tool_calls_by_index
.entry(tool_call.index)
.or_default();
if let Some(tool_id) = tool_call.id.clone() {
entry.id = tool_id;
}
if let Some(function) = tool_call.function.as_ref() {
if let Some(name) = function.name.clone() {
entry.name = name;
}
if let Some(arguments) = function.arguments.clone() {
entry.arguments.push_str(&arguments);
}
}
}
match choice.finish_reason.as_deref() {
Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::EndTurn,
)));
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| {
maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
},
))
})
},
));
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::ToolUse,
)));
}
Some(stop_reason) => {
log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::EndTurn,
)));
}
None => {}
}
return Some((events, state));
}
Err(err) => return Some((vec![Err(err)], state)),
}
}
None
},
)
.flat_map(futures::stream::iter)
}
impl CopilotChatLanguageModel {
pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
CopilotChatRequest::new(
self.model.clone(),
request
.messages
.into_iter()
.map(|msg| ChatMessage {
role: match msg.role {
Role::User => CopilotChatRole::User,
Role::Assistant => CopilotChatRole::Assistant,
Role::System => CopilotChatRole::System,
},
content: msg.string_contents(),
})
.collect(),
)
pub fn to_copilot_chat_request(
&self,
request: LanguageModelRequest,
) -> Result<CopilotChatRequest> {
let model = self.model.clone();
let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
for message in request.messages {
if let Some(last_message) = request_messages.last_mut() {
if last_message.role == message.role {
last_message.content.extend(message.content);
} else {
request_messages.push(message);
}
} else {
request_messages.push(message);
}
}
let mut messages: Vec<ChatMessage> = Vec::new();
for message in request_messages {
let text_content = {
let mut buffer = String::new();
for string in message.content.iter().filter_map(|content| match content {
MessageContent::Text(text) => Some(text.as_str()),
MessageContent::ToolUse(_)
| MessageContent::ToolResult(_)
| MessageContent::Image(_) => None,
}) {
buffer.push_str(string);
}
buffer
};
match message.role {
Role::User => {
for content in &message.content {
if let MessageContent::ToolResult(tool_result) = content {
messages.push(ChatMessage::Tool {
tool_call_id: tool_result.tool_use_id.to_string(),
content: tool_result.content.to_string(),
});
}
}
messages.push(ChatMessage::User {
content: text_content,
});
}
Role::Assistant => {
let mut tool_calls = Vec::new();
for content in &message.content {
if let MessageContent::ToolUse(tool_use) = content {
tool_calls.push(ToolCall {
id: tool_use.id.to_string(),
content: copilot::copilot_chat::ToolCallContent::Function {
function: copilot::copilot_chat::FunctionContent {
name: tool_use.name.to_string(),
arguments: serde_json::to_string(&tool_use.input)?,
},
},
});
}
}
messages.push(ChatMessage::Assistant {
content: if text_content.is_empty() {
None
} else {
Some(text_content)
},
tool_calls,
});
}
Role::System => messages.push(ChatMessage::System {
content: message.string_contents(),
}),
}
}
let tools = request
.tools
.iter()
.map(|tool| Tool::Function {
function: copilot::copilot_chat::Function {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
})
.collect();
Ok(CopilotChatRequest {
intent: true,
n: 1,
stream: model.uses_streaming(),
temperature: 0.1,
model,
messages,
tools,
tool_choice: None,
})
}
}