Add tool calling support for GitHub Copilot Chat (#28035)
This PR adds tool calling support for GitHub Copilot Chat models. Currently only supports the Claude family of models. Release Notes: - agent: Added tool calling support for Claude models in GitHub Copilot Chat. --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
c2afc2271b
commit
02e4267bc6
2 changed files with 318 additions and 85 deletions
|
@ -1,14 +1,17 @@
|
|||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use copilot::copilot_chat::{
|
||||
ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
|
||||
Role as CopilotChatRole,
|
||||
ResponseEvent, Tool, ToolCall,
|
||||
};
|
||||
use copilot::{Copilot, Status};
|
||||
use futures::future::BoxFuture;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use futures::{FutureExt, Stream, StreamExt};
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
|
||||
Transformation, percentage, svg,
|
||||
|
@ -16,12 +19,14 @@ use gpui::{
|
|||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
|
||||
};
|
||||
use settings::SettingsStore;
|
||||
use std::time::Duration;
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
use util::maybe;
|
||||
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
use super::google::count_google_tokens;
|
||||
|
@ -180,7 +185,12 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
match self.model {
|
||||
CopilotChatModel::Claude3_5Sonnet
|
||||
| CopilotChatModel::Claude3_7Sonnet
|
||||
| CopilotChatModel::Claude3_7SonnetThinking => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
|
@ -240,77 +250,241 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
let copilot_request = self.to_copilot_chat_request(request);
|
||||
let is_streaming = copilot_request.stream;
|
||||
let copilot_request = match self.to_copilot_chat_request(request) {
|
||||
Ok(request) => request,
|
||||
Err(err) => return futures::future::ready(Err(err)).boxed(),
|
||||
};
|
||||
|
||||
let request_limiter = self.request_limiter.clone();
|
||||
let future = cx.spawn(async move |cx| {
|
||||
let response = CopilotChat::stream_completion(copilot_request, cx.clone());
|
||||
request_limiter.stream(async move {
|
||||
let response = response.await?;
|
||||
let stream = response
|
||||
.filter_map(move |response| async move {
|
||||
match response {
|
||||
Ok(result) => {
|
||||
let choice = result.choices.first();
|
||||
match choice {
|
||||
Some(choice) if !is_streaming => {
|
||||
match &choice.message {
|
||||
Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
|
||||
None => Some(Err(anyhow::anyhow!(
|
||||
"The Copilot Chat API returned a response with no message content"
|
||||
))),
|
||||
}
|
||||
},
|
||||
Some(choice) => {
|
||||
match &choice.delta {
|
||||
Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
|
||||
None => Some(Err(anyhow::anyhow!(
|
||||
"The Copilot Chat API returned a response with no delta content"
|
||||
))),
|
||||
}
|
||||
},
|
||||
None => Some(Err(anyhow::anyhow!(
|
||||
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
|
||||
))),
|
||||
}
|
||||
}
|
||||
Err(err) => Some(Err(err)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
|
||||
Ok(stream)
|
||||
}).await
|
||||
let request = CopilotChat::stream_completion(copilot_request, cx.clone());
|
||||
request_limiter
|
||||
.stream(async move {
|
||||
let response = request.await?;
|
||||
Ok(map_to_language_model_completion_events(response))
|
||||
})
|
||||
.await
|
||||
});
|
||||
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
tool_calls_by_index: HashMap::default(),
|
||||
},
|
||||
|mut state| async move {
|
||||
if let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return Some((
|
||||
vec![Err(anyhow!("Response contained no choices"))],
|
||||
state,
|
||||
));
|
||||
};
|
||||
|
||||
let Some(delta) = choice.delta.as_ref() else {
|
||||
return Some((
|
||||
vec![Err(anyhow!("Response contained no delta"))],
|
||||
state,
|
||||
));
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
||||
for tool_call in &delta.tool_calls {
|
||||
let entry = state
|
||||
.tool_calls_by_index
|
||||
.entry(tool_call.index)
|
||||
.or_default();
|
||||
|
||||
if let Some(tool_id) = tool_call.id.clone() {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
|
||||
if let Some(function) = tool_call.function.as_ref() {
|
||||
if let Some(name) = function.name.clone() {
|
||||
entry.name = name;
|
||||
}
|
||||
|
||||
if let Some(arguments) = function.arguments.clone() {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::EndTurn,
|
||||
)));
|
||||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(state.tool_calls_by_index.drain().map(
|
||||
|(_, tool_call)| {
|
||||
maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
)?,
|
||||
},
|
||||
))
|
||||
})
|
||||
},
|
||||
));
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::ToolUse,
|
||||
)));
|
||||
}
|
||||
Some(stop_reason) => {
|
||||
log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::EndTurn,
|
||||
)));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
return Some((events, state));
|
||||
}
|
||||
Err(err) => return Some((vec![Err(err)], state)),
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
},
|
||||
)
|
||||
.flat_map(futures::stream::iter)
|
||||
}
|
||||
|
||||
impl CopilotChatLanguageModel {
|
||||
pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
|
||||
CopilotChatRequest::new(
|
||||
self.model.clone(),
|
||||
request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| ChatMessage {
|
||||
role: match msg.role {
|
||||
Role::User => CopilotChatRole::User,
|
||||
Role::Assistant => CopilotChatRole::Assistant,
|
||||
Role::System => CopilotChatRole::System,
|
||||
},
|
||||
content: msg.string_contents(),
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
pub fn to_copilot_chat_request(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> Result<CopilotChatRequest> {
|
||||
let model = self.model.clone();
|
||||
|
||||
let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||
for message in request.messages {
|
||||
if let Some(last_message) = request_messages.last_mut() {
|
||||
if last_message.role == message.role {
|
||||
last_message.content.extend(message.content);
|
||||
} else {
|
||||
request_messages.push(message);
|
||||
}
|
||||
} else {
|
||||
request_messages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
let mut messages: Vec<ChatMessage> = Vec::new();
|
||||
for message in request_messages {
|
||||
let text_content = {
|
||||
let mut buffer = String::new();
|
||||
for string in message.content.iter().filter_map(|content| match content {
|
||||
MessageContent::Text(text) => Some(text.as_str()),
|
||||
MessageContent::ToolUse(_)
|
||||
| MessageContent::ToolResult(_)
|
||||
| MessageContent::Image(_) => None,
|
||||
}) {
|
||||
buffer.push_str(string);
|
||||
}
|
||||
|
||||
buffer
|
||||
};
|
||||
|
||||
match message.role {
|
||||
Role::User => {
|
||||
for content in &message.content {
|
||||
if let MessageContent::ToolResult(tool_result) = content {
|
||||
messages.push(ChatMessage::Tool {
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
content: tool_result.content.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(ChatMessage::User {
|
||||
content: text_content,
|
||||
});
|
||||
}
|
||||
Role::Assistant => {
|
||||
let mut tool_calls = Vec::new();
|
||||
for content in &message.content {
|
||||
if let MessageContent::ToolUse(tool_use) = content {
|
||||
tool_calls.push(ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: copilot::copilot_chat::ToolCallContent::Function {
|
||||
function: copilot::copilot_chat::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)?,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(ChatMessage::Assistant {
|
||||
content: if text_content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text_content)
|
||||
},
|
||||
tool_calls,
|
||||
});
|
||||
}
|
||||
Role::System => messages.push(ChatMessage::System {
|
||||
content: message.string_contents(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
let tools = request
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| Tool::Function {
|
||||
function: copilot::copilot_chat::Function {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.input_schema.clone(),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(CopilotChatRequest {
|
||||
intent: true,
|
||||
n: 1,
|
||||
stream: model.uses_streaming(),
|
||||
temperature: 0.1,
|
||||
model,
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue