Add tool calling support for GitHub Copilot Chat (#28035)
This PR adds tool calling support for GitHub Copilot Chat models. Currently only supports the Claude family of models. Release Notes: - agent: Added tool calling support for Claude models in GitHub Copilot Chat. --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
c2afc2271b
commit
02e4267bc6
2 changed files with 318 additions and 85 deletions
|
@ -131,25 +131,70 @@ pub struct Request {
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
pub model: Model,
|
pub model: Model,
|
||||||
pub messages: Vec<ChatMessage>,
|
pub messages: Vec<ChatMessage>,
|
||||||
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub tools: Vec<Tool>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<ToolChoice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Request {
|
#[derive(Serialize, Deserialize)]
|
||||||
pub fn new(model: Model, messages: Vec<ChatMessage>) -> Self {
|
pub struct Function {
|
||||||
Self {
|
pub name: String,
|
||||||
intent: true,
|
pub description: String,
|
||||||
n: 1,
|
pub parameters: serde_json::Value,
|
||||||
stream: model.uses_streaming(),
|
}
|
||||||
temperature: 0.1,
|
|
||||||
model,
|
#[derive(Serialize, Deserialize)]
|
||||||
messages,
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
}
|
pub enum Tool {
|
||||||
}
|
Function { function: Function },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "lowercase")]
|
||||||
|
pub enum ToolChoice {
|
||||||
|
Auto,
|
||||||
|
Any,
|
||||||
|
Tool { name: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
pub struct ChatMessage {
|
#[serde(tag = "role", rename_all = "lowercase")]
|
||||||
pub role: Role,
|
pub enum ChatMessage {
|
||||||
pub content: String,
|
Assistant {
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
|
},
|
||||||
|
User {
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
|
System {
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
|
Tool {
|
||||||
|
content: String,
|
||||||
|
tool_call_id: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub content: ToolCallContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(tag = "type", rename_all = "lowercase")]
|
||||||
|
pub enum ToolCallContent {
|
||||||
|
Function { function: FunctionContent },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct FunctionContent {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
|
@ -172,6 +217,21 @@ pub struct ResponseChoice {
|
||||||
pub struct ResponseDelta {
|
pub struct ResponseDelta {
|
||||||
pub content: Option<String>,
|
pub content: Option<String>,
|
||||||
pub role: Option<Role>,
|
pub role: Option<Role>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tool_calls: Vec<ToolCallChunk>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct ToolCallChunk {
|
||||||
|
pub index: usize,
|
||||||
|
pub id: Option<String>,
|
||||||
|
pub function: Option<FunctionChunk>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct FunctionChunk {
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub arguments: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
@ -385,7 +445,8 @@ async fn stream_completion(
|
||||||
|
|
||||||
let is_streaming = request.stream;
|
let is_streaming = request.stream;
|
||||||
|
|
||||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
let json = serde_json::to_string(&request)?;
|
||||||
|
let request = request_builder.body(AsyncBody::from(json))?;
|
||||||
let mut response = client.send(request).await?;
|
let mut response = client.send(request).await?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
|
@ -413,9 +474,7 @@ async fn stream_completion(
|
||||||
|
|
||||||
match serde_json::from_str::<ResponseEvent>(line) {
|
match serde_json::from_str::<ResponseEvent>(line) {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
if response.choices.is_empty()
|
if response.choices.is_empty() {
|
||||||
|| response.choices.first().unwrap().finish_reason.is_some()
|
|
||||||
{
|
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(Ok(response))
|
Some(Ok(response))
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::str::FromStr as _;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
|
use collections::HashMap;
|
||||||
use copilot::copilot_chat::{
|
use copilot::copilot_chat::{
|
||||||
ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
|
ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
|
||||||
Role as CopilotChatRole,
|
ResponseEvent, Tool, ToolCall,
|
||||||
};
|
};
|
||||||
use copilot::{Copilot, Status};
|
use copilot::{Copilot, Status};
|
||||||
use futures::future::BoxFuture;
|
use futures::future::BoxFuture;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
use futures::{FutureExt, StreamExt};
|
use futures::{FutureExt, Stream, StreamExt};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
|
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
|
||||||
Transformation, percentage, svg,
|
Transformation, percentage, svg,
|
||||||
|
@ -16,12 +19,14 @@ use gpui::{
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
|
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
|
||||||
};
|
};
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
use util::maybe;
|
||||||
|
|
||||||
use super::anthropic::count_anthropic_tokens;
|
use super::anthropic::count_anthropic_tokens;
|
||||||
use super::google::count_google_tokens;
|
use super::google::count_google_tokens;
|
||||||
|
@ -180,7 +185,12 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
false
|
match self.model {
|
||||||
|
CopilotChatModel::Claude3_5Sonnet
|
||||||
|
| CopilotChatModel::Claude3_7Sonnet
|
||||||
|
| CopilotChatModel::Claude3_7SonnetThinking => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
|
@ -240,77 +250,241 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let copilot_request = self.to_copilot_chat_request(request);
|
let copilot_request = match self.to_copilot_chat_request(request) {
|
||||||
let is_streaming = copilot_request.stream;
|
Ok(request) => request,
|
||||||
|
Err(err) => return futures::future::ready(Err(err)).boxed(),
|
||||||
|
};
|
||||||
|
|
||||||
let request_limiter = self.request_limiter.clone();
|
let request_limiter = self.request_limiter.clone();
|
||||||
let future = cx.spawn(async move |cx| {
|
let future = cx.spawn(async move |cx| {
|
||||||
let response = CopilotChat::stream_completion(copilot_request, cx.clone());
|
let request = CopilotChat::stream_completion(copilot_request, cx.clone());
|
||||||
request_limiter.stream(async move {
|
request_limiter
|
||||||
let response = response.await?;
|
.stream(async move {
|
||||||
let stream = response
|
let response = request.await?;
|
||||||
.filter_map(move |response| async move {
|
Ok(map_to_language_model_completion_events(response))
|
||||||
match response {
|
|
||||||
Ok(result) => {
|
|
||||||
let choice = result.choices.first();
|
|
||||||
match choice {
|
|
||||||
Some(choice) if !is_streaming => {
|
|
||||||
match &choice.message {
|
|
||||||
Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
|
|
||||||
None => Some(Err(anyhow::anyhow!(
|
|
||||||
"The Copilot Chat API returned a response with no message content"
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Some(choice) => {
|
|
||||||
match &choice.delta {
|
|
||||||
Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
|
|
||||||
None => Some(Err(anyhow::anyhow!(
|
|
||||||
"The Copilot Chat API returned a response with no delta content"
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
None => Some(Err(anyhow::anyhow!(
|
|
||||||
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => Some(Err(err)),
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.boxed();
|
.await
|
||||||
|
|
||||||
Ok(stream)
|
|
||||||
}).await
|
|
||||||
});
|
});
|
||||||
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
async move {
|
|
||||||
Ok(future
|
|
||||||
.await?
|
|
||||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
|
||||||
.boxed())
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CopilotChatLanguageModel {
|
pub fn map_to_language_model_completion_events(
|
||||||
pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
|
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
|
||||||
CopilotChatRequest::new(
|
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||||
self.model.clone(),
|
#[derive(Default)]
|
||||||
request
|
struct RawToolCall {
|
||||||
.messages
|
id: String,
|
||||||
.into_iter()
|
name: String,
|
||||||
.map(|msg| ChatMessage {
|
arguments: String,
|
||||||
role: match msg.role {
|
}
|
||||||
Role::User => CopilotChatRole::User,
|
|
||||||
Role::Assistant => CopilotChatRole::Assistant,
|
struct State {
|
||||||
Role::System => CopilotChatRole::System,
|
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
|
||||||
|
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
futures::stream::unfold(
|
||||||
|
State {
|
||||||
|
events,
|
||||||
|
tool_calls_by_index: HashMap::default(),
|
||||||
},
|
},
|
||||||
content: msg.string_contents(),
|
|mut state| async move {
|
||||||
|
if let Some(event) = state.events.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(event) => {
|
||||||
|
let Some(choice) = event.choices.first() else {
|
||||||
|
return Some((
|
||||||
|
vec![Err(anyhow!("Response contained no choices"))],
|
||||||
|
state,
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(delta) = choice.delta.as_ref() else {
|
||||||
|
return Some((
|
||||||
|
vec![Err(anyhow!("Response contained no delta"))],
|
||||||
|
state,
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
if let Some(content) = delta.content.clone() {
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||||
|
}
|
||||||
|
|
||||||
|
for tool_call in &delta.tool_calls {
|
||||||
|
let entry = state
|
||||||
|
.tool_calls_by_index
|
||||||
|
.entry(tool_call.index)
|
||||||
|
.or_default();
|
||||||
|
|
||||||
|
if let Some(tool_id) = tool_call.id.clone() {
|
||||||
|
entry.id = tool_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(function) = tool_call.function.as_ref() {
|
||||||
|
if let Some(name) = function.name.clone() {
|
||||||
|
entry.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(arguments) = function.arguments.clone() {
|
||||||
|
entry.arguments.push_str(&arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match choice.finish_reason.as_deref() {
|
||||||
|
Some("stop") => {
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||||
|
StopReason::EndTurn,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Some("tool_calls") => {
|
||||||
|
events.extend(state.tool_calls_by_index.drain().map(
|
||||||
|
|(_, tool_call)| {
|
||||||
|
maybe!({
|
||||||
|
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: tool_call.id.into(),
|
||||||
|
name: tool_call.name.as_str().into(),
|
||||||
|
input: serde_json::Value::from_str(
|
||||||
|
&tool_call.arguments,
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
))
|
||||||
})
|
})
|
||||||
.collect(),
|
},
|
||||||
|
));
|
||||||
|
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||||
|
StopReason::ToolUse,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Some(stop_reason) => {
|
||||||
|
log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
|
||||||
|
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||||
|
StopReason::EndTurn,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Some((events, state));
|
||||||
|
}
|
||||||
|
Err(err) => return Some((vec![Err(err)], state)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
.flat_map(futures::stream::iter)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CopilotChatLanguageModel {
|
||||||
|
pub fn to_copilot_chat_request(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
) -> Result<CopilotChatRequest> {
|
||||||
|
let model = self.model.clone();
|
||||||
|
|
||||||
|
let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||||
|
for message in request.messages {
|
||||||
|
if let Some(last_message) = request_messages.last_mut() {
|
||||||
|
if last_message.role == message.role {
|
||||||
|
last_message.content.extend(message.content);
|
||||||
|
} else {
|
||||||
|
request_messages.push(message);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
request_messages.push(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut messages: Vec<ChatMessage> = Vec::new();
|
||||||
|
for message in request_messages {
|
||||||
|
let text_content = {
|
||||||
|
let mut buffer = String::new();
|
||||||
|
for string in message.content.iter().filter_map(|content| match content {
|
||||||
|
MessageContent::Text(text) => Some(text.as_str()),
|
||||||
|
MessageContent::ToolUse(_)
|
||||||
|
| MessageContent::ToolResult(_)
|
||||||
|
| MessageContent::Image(_) => None,
|
||||||
|
}) {
|
||||||
|
buffer.push_str(string);
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer
|
||||||
|
};
|
||||||
|
|
||||||
|
match message.role {
|
||||||
|
Role::User => {
|
||||||
|
for content in &message.content {
|
||||||
|
if let MessageContent::ToolResult(tool_result) = content {
|
||||||
|
messages.push(ChatMessage::Tool {
|
||||||
|
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||||
|
content: tool_result.content.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.push(ChatMessage::User {
|
||||||
|
content: text_content,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Role::Assistant => {
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
for content in &message.content {
|
||||||
|
if let MessageContent::ToolUse(tool_use) = content {
|
||||||
|
tool_calls.push(ToolCall {
|
||||||
|
id: tool_use.id.to_string(),
|
||||||
|
content: copilot::copilot_chat::ToolCallContent::Function {
|
||||||
|
function: copilot::copilot_chat::FunctionContent {
|
||||||
|
name: tool_use.name.to_string(),
|
||||||
|
arguments: serde_json::to_string(&tool_use.input)?,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.push(ChatMessage::Assistant {
|
||||||
|
content: if text_content.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(text_content)
|
||||||
|
},
|
||||||
|
tool_calls,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Role::System => messages.push(ChatMessage::System {
|
||||||
|
content: message.string_contents(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let tools = request
|
||||||
|
.tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| Tool::Function {
|
||||||
|
function: copilot::copilot_chat::Function {
|
||||||
|
name: tool.name.clone(),
|
||||||
|
description: tool.description.clone(),
|
||||||
|
parameters: tool.input_schema.clone(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(CopilotChatRequest {
|
||||||
|
intent: true,
|
||||||
|
n: 1,
|
||||||
|
stream: model.uses_streaming(),
|
||||||
|
temperature: 0.1,
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
tools,
|
||||||
|
tool_choice: None,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue