diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index e8274b1965..e2c1499838 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -131,25 +131,70 @@ pub struct Request { pub temperature: f32, pub model: Model, pub messages: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, } -impl Request { - pub fn new(model: Model, messages: Vec) -> Self { - Self { - intent: true, - n: 1, - stream: model.uses_streaming(), - temperature: 0.1, - model, - messages, - } - } +#[derive(Serialize, Deserialize)] +pub struct Function { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Tool { + Function { function: Function }, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolChoice { + Auto, + Any, + Tool { name: String }, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ChatMessage { - pub role: Role, - pub content: String, +#[serde(tag = "role", rename_all = "lowercase")] +pub enum ChatMessage { + Assistant { + content: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + tool_calls: Vec, + }, + User { + content: String, + }, + System { + content: String, + }, + Tool { + content: String, + tool_call_id: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCall { + pub id: String, + #[serde(flatten)] + pub content: ToolCallContent, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolCallContent { + Function { function: FunctionContent }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionContent { + pub name: String, + pub arguments: String, } #[derive(Deserialize, Debug)] @@ -172,6 +217,21 @@ pub struct ResponseChoice { pub struct ResponseDelta { pub content: Option, pub role: Option, + #[serde(default)] + pub tool_calls: Vec, +} + +#[derive(Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCallChunk { + pub index: usize, + pub id: Option, + pub function: Option, +} + +#[derive(Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionChunk { + pub name: Option, + pub arguments: Option, } #[derive(Deserialize)] @@ -385,7 +445,8 @@ async fn stream_completion( let is_streaming = request.stream; - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let json = serde_json::to_string(&request)?; + let request = request_builder.body(AsyncBody::from(json))?; let mut response = client.send(request).await?; if !response.status().is_success() { @@ -413,9 +474,7 @@ async fn stream_completion( match serde_json::from_str::(line) { Ok(response) => { - if response.choices.is_empty() - || response.choices.first().unwrap().finish_reason.is_some() - { + if response.choices.is_empty() { None } else { Some(Ok(response)) diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 280be56174..a6f7cf9e29 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -1,14 +1,17 @@ +use std::pin::Pin; +use std::str::FromStr as _; use std::sync::Arc; use anyhow::{Result, anyhow}; +use collections::HashMap; use copilot::copilot_chat::{ ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest, - Role as CopilotChatRole, + ResponseEvent, Tool, ToolCall, }; use copilot::{Copilot, Status}; use futures::future::BoxFuture; use futures::stream::BoxStream; -use futures::{FutureExt, StreamExt}; +use futures::{FutureExt, Stream, StreamExt}; use gpui::{ Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, Transformation, percentage, svg, @@ -16,12 +19,14 @@ use gpui::{ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, + LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, }; use settings::SettingsStore; use std::time::Duration; use strum::IntoEnumIterator; use ui::prelude::*; +use util::maybe; use super::anthropic::count_anthropic_tokens; use super::google::count_google_tokens; @@ -180,7 +185,12 @@ impl LanguageModel for CopilotChatLanguageModel { } fn supports_tools(&self) -> bool { - false + match self.model { + CopilotChatModel::Claude3_5Sonnet + | CopilotChatModel::Claude3_7Sonnet + | CopilotChatModel::Claude3_7SonnetThinking => true, + _ => false, + } } fn telemetry_id(&self) -> String { @@ -240,77 +250,241 @@ impl LanguageModel for CopilotChatLanguageModel { } } - let copilot_request = self.to_copilot_chat_request(request); - let is_streaming = copilot_request.stream; + let copilot_request = match self.to_copilot_chat_request(request) { + Ok(request) => request, + Err(err) => return futures::future::ready(Err(err)).boxed(), + }; let request_limiter = self.request_limiter.clone(); let future = cx.spawn(async move |cx| { - let response = CopilotChat::stream_completion(copilot_request, cx.clone()); - request_limiter.stream(async move { - let response = response.await?; - let stream = response - .filter_map(move |response| async move { - match response { - Ok(result) => { - let choice = result.choices.first(); - match choice { - Some(choice) if !is_streaming => { - match &choice.message { - Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())), - None => Some(Err(anyhow::anyhow!( - "The Copilot Chat API returned a response with no message content" - ))), - } - }, - Some(choice) => { - match &choice.delta { - Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())), - None => Some(Err(anyhow::anyhow!( - "The Copilot Chat API returned a response with no delta content" - ))), - } - }, - None => Some(Err(anyhow::anyhow!( - "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again." - ))), - } - } - Err(err) => Some(Err(err)), - } - }) - .boxed(); - - Ok(stream) - }).await + let request = CopilotChat::stream_completion(copilot_request, cx.clone()); + request_limiter + .stream(async move { + let response = request.await?; + Ok(map_to_language_model_completion_events(response)) + }) + .await }); - - async move { - Ok(future - .await? - .map(|result| result.map(LanguageModelCompletionEvent::Text)) - .boxed()) - } - .boxed() + async move { Ok(future.await?.boxed()) }.boxed() } } +pub fn map_to_language_model_completion_events( + events: Pin>>>, +) -> impl Stream> { + #[derive(Default)] + struct RawToolCall { + id: String, + name: String, + arguments: String, + } + + struct State { + events: Pin>>>, + tool_calls_by_index: HashMap, + } + + futures::stream::unfold( + State { + events, + tool_calls_by_index: HashMap::default(), + }, + |mut state| async move { + if let Some(event) = state.events.next().await { + match event { + Ok(event) => { + let Some(choice) = event.choices.first() else { + return Some(( + vec![Err(anyhow!("Response contained no choices"))], + state, + )); + }; + + let Some(delta) = choice.delta.as_ref() else { + return Some(( + vec![Err(anyhow!("Response contained no delta"))], + state, + )); + }; + + let mut events = Vec::new(); + if let Some(content) = delta.content.clone() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + + for tool_call in &delta.tool_calls { + let entry = state + .tool_calls_by_index + .entry(tool_call.index) + .or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + } + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop( + StopReason::EndTurn, + ))); + } + Some("tool_calls") => { + events.extend(state.tool_calls_by_index.drain().map( + |(_, tool_call)| { + maybe!({ + Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.into(), + name: tool_call.name.as_str().into(), + input: serde_json::Value::from_str( + &tool_call.arguments, + )?, + }, + )) + }) + }, + )); + + events.push(Ok(LanguageModelCompletionEvent::Stop( + StopReason::ToolUse, + ))); + } + Some(stop_reason) => { + log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop( + StopReason::EndTurn, + ))); + } + None => {} + } + + return Some((events, state)); + } + Err(err) => return Some((vec![Err(err)], state)), + } + } + + None + }, + ) + .flat_map(futures::stream::iter) +} + impl CopilotChatLanguageModel { - pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest { - CopilotChatRequest::new( - self.model.clone(), - request - .messages - .into_iter() - .map(|msg| ChatMessage { - role: match msg.role { - Role::User => CopilotChatRole::User, - Role::Assistant => CopilotChatRole::Assistant, - Role::System => CopilotChatRole::System, - }, - content: msg.string_contents(), - }) - .collect(), - ) + pub fn to_copilot_chat_request( + &self, + request: LanguageModelRequest, + ) -> Result { + let model = self.model.clone(); + + let mut request_messages: Vec = Vec::new(); + for message in request.messages { + if let Some(last_message) = request_messages.last_mut() { + if last_message.role == message.role { + last_message.content.extend(message.content); + } else { + request_messages.push(message); + } + } else { + request_messages.push(message); + } + } + + let mut messages: Vec = Vec::new(); + for message in request_messages { + let text_content = { + let mut buffer = String::new(); + for string in message.content.iter().filter_map(|content| match content { + MessageContent::Text(text) => Some(text.as_str()), + MessageContent::ToolUse(_) + | MessageContent::ToolResult(_) + | MessageContent::Image(_) => None, + }) { + buffer.push_str(string); + } + + buffer + }; + + match message.role { + Role::User => { + for content in &message.content { + if let MessageContent::ToolResult(tool_result) = content { + messages.push(ChatMessage::Tool { + tool_call_id: tool_result.tool_use_id.to_string(), + content: tool_result.content.to_string(), + }); + } + } + + messages.push(ChatMessage::User { + content: text_content, + }); + } + Role::Assistant => { + let mut tool_calls = Vec::new(); + for content in &message.content { + if let MessageContent::ToolUse(tool_use) = content { + tool_calls.push(ToolCall { + id: tool_use.id.to_string(), + content: copilot::copilot_chat::ToolCallContent::Function { + function: copilot::copilot_chat::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input)?, + }, + }, + }); + } + } + + messages.push(ChatMessage::Assistant { + content: if text_content.is_empty() { + None + } else { + Some(text_content) + }, + tool_calls, + }); + } + Role::System => messages.push(ChatMessage::System { + content: message.string_contents(), + }), + } + } + + let tools = request + .tools + .iter() + .map(|tool| Tool::Function { + function: copilot::copilot_chat::Function { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.input_schema.clone(), + }, + }) + .collect(); + + Ok(CopilotChatRequest { + intent: true, + n: 1, + stream: model.uses_streaming(), + temperature: 0.1, + model, + messages, + tools, + tool_choice: None, + }) } }