diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 421199114f..8d86988306 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -12,7 +12,7 @@ mod streaming_diff; pub use assistant_panel::AssistantPanel; -use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel}; +use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel}; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; @@ -87,14 +87,14 @@ impl Display for Role { #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub enum LanguageModel { - ZedDotDev(ZedDotDevModel), + Cloud(CloudModel), OpenAi(OpenAiModel), Anthropic(AnthropicModel), } impl Default for LanguageModel { fn default() -> Self { - LanguageModel::ZedDotDev(ZedDotDevModel::default()) + LanguageModel::Cloud(CloudModel::default()) } } @@ -103,7 +103,7 @@ impl LanguageModel { match self { LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), - LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()), + LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), } } @@ -111,7 +111,7 @@ impl LanguageModel { match self { LanguageModel::OpenAi(model) => model.display_name().into(), LanguageModel::Anthropic(model) => model.display_name().into(), - LanguageModel::ZedDotDev(model) => model.display_name().into(), + LanguageModel::Cloud(model) => model.display_name().into(), } } @@ -119,7 +119,7 @@ impl LanguageModel { match self { LanguageModel::OpenAi(model) => model.max_token_count(), LanguageModel::Anthropic(model) => model.max_token_count(), - LanguageModel::ZedDotDev(model) => model.max_token_count(), + LanguageModel::Cloud(model) => model.max_token_count(), } } @@ -127,7 +127,7 @@ impl LanguageModel { match self { LanguageModel::OpenAi(model) => model.id(), LanguageModel::Anthropic(model) => model.id(), - LanguageModel::ZedDotDev(model) => model.id(), + LanguageModel::Cloud(model) => model.id(), } } } @@ -172,6 +172,20 @@ impl LanguageModelRequest { tools: Vec::new(), } } + + /// Before we send the request to the server, we can perform fixups on it appropriate to the model. + pub fn preprocess(&mut self) { + match &self.model { + LanguageModel::OpenAi(_) => {} + LanguageModel::Anthropic(_) => {} + LanguageModel::Cloud(model) => match model { + CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => { + preprocess_anthropic_request(self); + } + _ => {} + }, + } + } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index e1bd2805da..70165a7d63 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -17,7 +17,7 @@ use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection}; use client::telemetry::Telemetry; use collections::{hash_map, BTreeSet, HashMap, HashSet, VecDeque}; -use editor::actions::ShowCompletions; +use editor::{actions::ShowCompletions, GutterDimensions}; use editor::{ actions::{FoldAt, MoveDown, MoveToEndOfLine, MoveUp, Newline, UnfoldAt}, display_map::{ @@ -469,7 +469,7 @@ impl AssistantPanel { ) }); - let measurements = Arc::new(Mutex::new(BlockMeasurements::default())); + let measurements = Arc::new(Mutex::new(GutterDimensions::default())); let inline_assistant = cx.new_view(|cx| { InlineAssistant::new( inline_assist_id, @@ -492,10 +492,7 @@ impl AssistantPanel { render: Box::new({ let inline_assistant = inline_assistant.clone(); move |cx: &mut BlockContext| { - *measurements.lock() = BlockMeasurements { - gutter_width: cx.gutter_dimensions.width, - gutter_margin: cx.gutter_dimensions.margin, - }; + *measurements.lock() = *cx.gutter_dimensions; inline_assistant.clone().into_any_element() } }), @@ -583,6 +580,7 @@ impl AssistantPanel { ], }, ); + self.pending_inline_assist_ids_by_editor .entry(editor.downgrade()) .or_default() @@ -810,7 +808,7 @@ impl AssistantPanel { codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?; anyhow::Ok(()) }) - .detach(); + .detach_and_log_err(cx); } fn update_highlights_for_editor(&self, editor: &View, cx: &mut ViewContext) { @@ -1431,7 +1429,7 @@ impl Panel for AssistantPanel { return None; } - Some(IconName::Ai) + Some(IconName::ZedAssistant) } fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> { @@ -3151,7 +3149,7 @@ impl ConversationEditor { h_flex() .id(("message_header", message_id.0)) - .pl(cx.gutter_dimensions.width + cx.gutter_dimensions.margin) + .pl(cx.gutter_dimensions.full_width()) .h_11() .w_full() .relative() @@ -3551,7 +3549,7 @@ struct InlineAssistant { prompt_editor: View, confirmed: bool, include_conversation: bool, - measurements: Arc>, + gutter_dimensions: Arc>, prompt_history: VecDeque, prompt_history_ix: Option, pending_prompt: String, @@ -3563,7 +3561,8 @@ impl EventEmitter for InlineAssistant {} impl Render for InlineAssistant { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let measurements = *self.measurements.lock(); + let gutter_dimensions = *self.gutter_dimensions.lock(); + let icon_size = IconSize::default(); h_flex() .w_full() .py_2() @@ -3576,14 +3575,20 @@ impl Render for InlineAssistant { .on_action(cx.listener(Self::move_down)) .child( h_flex() - .w(measurements.gutter_width + measurements.gutter_margin) + .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)) + .pr(gutter_dimensions.fold_area_width()) + .justify_end() .children(if let Some(error) = self.codegen.read(cx).error() { let error_message = SharedString::from(error.to_string()); Some( div() .id("error") .tooltip(move |cx| Tooltip::text(error_message.clone(), cx)) - .child(Icon::new(IconName::XCircle).color(Color::Error)), + .child( + Icon::new(IconName::XCircle) + .size(icon_size) + .color(Color::Error), + ), ) } else { None @@ -3603,7 +3608,7 @@ impl InlineAssistant { #[allow(clippy::too_many_arguments)] fn new( id: usize, - measurements: Arc>, + gutter_dimensions: Arc>, include_conversation: bool, prompt_history: VecDeque, codegen: Model, @@ -3630,7 +3635,7 @@ impl InlineAssistant { prompt_editor, confirmed: false, include_conversation, - measurements, + gutter_dimensions, prompt_history, prompt_history_ix: None, pending_prompt: String::new(), @@ -3755,13 +3760,6 @@ impl InlineAssistant { } } -// This wouldn't need to exist if we could pass parameters when rendering child views. -#[derive(Copy, Clone, Default)] -struct BlockMeasurements { - gutter_width: Pixels, - gutter_margin: Pixels, -} - struct PendingInlineAssist { editor: WeakView, inline_assistant: Option<(BlockId, View)>, diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 5d866b6efc..3efaff100d 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -14,10 +14,10 @@ use serde::{ use settings::{Settings, SettingsSources}; use strum::{EnumIter, IntoEnumIterator}; -use crate::LanguageModel; +use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest}; #[derive(Clone, Debug, Default, PartialEq, EnumIter)] -pub enum ZedDotDevModel { +pub enum CloudModel { Gpt3Point5Turbo, Gpt4, Gpt4Turbo, @@ -29,7 +29,7 @@ pub enum ZedDotDevModel { Custom(String), } -impl Serialize for ZedDotDevModel { +impl Serialize for CloudModel { fn serialize(&self, serializer: S) -> Result where S: Serializer, @@ -38,7 +38,7 @@ impl Serialize for ZedDotDevModel { } } -impl<'de> Deserialize<'de> for ZedDotDevModel { +impl<'de> Deserialize<'de> for CloudModel { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, @@ -46,7 +46,7 @@ impl<'de> Deserialize<'de> for ZedDotDevModel { struct ZedDotDevModelVisitor; impl<'de> Visitor<'de> for ZedDotDevModelVisitor { - type Value = ZedDotDevModel; + type Value = CloudModel; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a string for a ZedDotDevModel variant or a custom model") @@ -56,9 +56,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel { where E: de::Error, { - let model = ZedDotDevModel::iter() + let model = CloudModel::iter() .find(|model| model.id() == value) - .unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string())); + .unwrap_or_else(|| CloudModel::Custom(value.to_string())); Ok(model) } } @@ -67,13 +67,13 @@ impl<'de> Deserialize<'de> for ZedDotDevModel { } } -impl JsonSchema for ZedDotDevModel { +impl JsonSchema for CloudModel { fn schema_name() -> String { "ZedDotDevModel".to_owned() } fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { - let variants = ZedDotDevModel::iter() + let variants = CloudModel::iter() .filter_map(|model| { let id = model.id(); if id.is_empty() { @@ -88,7 +88,7 @@ impl JsonSchema for ZedDotDevModel { enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), metadata: Some(Box::new(Metadata { title: Some("ZedDotDevModel".to_owned()), - default: Some(ZedDotDevModel::default().id().into()), + default: Some(CloudModel::default().id().into()), examples: variants.into_iter().map(Into::into).collect(), ..Default::default() })), @@ -97,7 +97,7 @@ impl JsonSchema for ZedDotDevModel { } } -impl ZedDotDevModel { +impl CloudModel { pub fn id(&self) -> &str { match self { Self::Gpt3Point5Turbo => "gpt-3.5-turbo", @@ -133,6 +133,15 @@ impl ZedDotDevModel { Self::Custom(_) => 4096, // TODO: Make this configurable } } + + pub fn preprocess_request(&self, request: &mut LanguageModelRequest) { + match self { + Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => { + preprocess_anthropic_request(request) + } + _ => {} + } + } } #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] @@ -147,7 +156,7 @@ pub enum AssistantDockPosition { #[derive(Debug, PartialEq)] pub enum AssistantProvider { ZedDotDev { - model: ZedDotDevModel, + model: CloudModel, }, OpenAi { model: OpenAiModel, @@ -175,9 +184,7 @@ impl Default for AssistantProvider { #[serde(tag = "name", rename_all = "snake_case")] pub enum AssistantProviderContent { #[serde(rename = "zed.dev")] - ZedDotDev { - default_model: Option, - }, + ZedDotDev { default_model: Option }, #[serde(rename = "openai")] OpenAi { default_model: Option, @@ -281,7 +288,7 @@ impl AssistantSettingsContent { Some(AssistantProviderContent::ZedDotDev { default_model: model, }) => { - if let LanguageModel::ZedDotDev(new_model) = new_model { + if let LanguageModel::Cloud(new_model) = new_model { *model = Some(new_model); } } @@ -302,7 +309,7 @@ impl AssistantSettingsContent { } } provider => match new_model { - LanguageModel::ZedDotDev(model) => { + LanguageModel::Cloud(model) => { *provider = Some(AssistantProviderContent::ZedDotDev { default_model: Some(model), }) @@ -613,7 +620,7 @@ mod tests { assert_eq!( AssistantSettings::get_global(cx).provider, AssistantProvider::ZedDotDev { - model: ZedDotDevModel::Custom("custom".into()) + model: CloudModel::Custom("custom".into()) } ); } diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 8483a2ae14..2a725189eb 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -11,6 +11,7 @@ use language::{Rope, TransactionId}; use multi_buffer::MultiBufferRow; use std::{cmp, future, ops::Range, sync::Arc, time::Instant}; +#[derive(Debug)] pub enum Event { Finished, Undone, @@ -120,91 +121,98 @@ impl Codegen { let mut edit_start = range.start.to_offset(&snapshot); let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); - let diff = cx.background_executor().spawn(async move { - let mut response_latency = None; - let request_start = Instant::now(); - let diff = async { - let chunks = strip_invalid_spans_from_codeblock(response.await?); - futures::pin_mut!(chunks); - let mut diff = StreamingDiff::new(selected_text.to_string()); + let diff: Task> = + cx.background_executor().spawn(async move { + let mut response_latency = None; + let request_start = Instant::now(); + let diff = async { + let chunks = strip_invalid_spans_from_codeblock(response.await?); + futures::pin_mut!(chunks); + let mut diff = StreamingDiff::new(selected_text.to_string()); - let mut new_text = String::new(); - let mut base_indent = None; - let mut line_indent = None; - let mut first_line = true; + let mut new_text = String::new(); + let mut base_indent = None; + let mut line_indent = None; + let mut first_line = true; - while let Some(chunk) = chunks.next().await { - if response_latency.is_none() { - response_latency = Some(request_start.elapsed()); - } - let chunk = chunk?; + while let Some(chunk) = chunks.next().await { + if response_latency.is_none() { + response_latency = Some(request_start.elapsed()); + } + let chunk = chunk?; - let mut lines = chunk.split('\n').peekable(); - while let Some(line) = lines.next() { - new_text.push_str(line); - if line_indent.is_none() { - if let Some(non_whitespace_ch_ix) = - new_text.find(|ch: char| !ch.is_whitespace()) - { - line_indent = Some(non_whitespace_ch_ix); - base_indent = base_indent.or(line_indent); + let mut lines = chunk.split('\n').peekable(); + while let Some(line) = lines.next() { + new_text.push_str(line); + if line_indent.is_none() { + if let Some(non_whitespace_ch_ix) = + new_text.find(|ch: char| !ch.is_whitespace()) + { + line_indent = Some(non_whitespace_ch_ix); + base_indent = base_indent.or(line_indent); - let line_indent = line_indent.unwrap(); - let base_indent = base_indent.unwrap(); - let indent_delta = - line_indent as i32 - base_indent as i32; - let mut corrected_indent_len = cmp::max( - 0, - suggested_line_indent.len as i32 + indent_delta, - ) - as usize; - if first_line { - corrected_indent_len = corrected_indent_len - .saturating_sub( - selection_start.column as usize, - ); + let line_indent = line_indent.unwrap(); + let base_indent = base_indent.unwrap(); + let indent_delta = + line_indent as i32 - base_indent as i32; + let mut corrected_indent_len = cmp::max( + 0, + suggested_line_indent.len as i32 + indent_delta, + ) + as usize; + if first_line { + corrected_indent_len = corrected_indent_len + .saturating_sub( + selection_start.column as usize, + ); + } + + let indent_char = suggested_line_indent.char(); + let mut indent_buffer = [0; 4]; + let indent_str = + indent_char.encode_utf8(&mut indent_buffer); + new_text.replace_range( + ..line_indent, + &indent_str.repeat(corrected_indent_len), + ); } + } - let indent_char = suggested_line_indent.char(); - let mut indent_buffer = [0; 4]; - let indent_str = - indent_char.encode_utf8(&mut indent_buffer); - new_text.replace_range( - ..line_indent, - &indent_str.repeat(corrected_indent_len), - ); + if line_indent.is_some() { + hunks_tx.send(diff.push_new(&new_text)).await?; + new_text.clear(); + } + + if lines.peek().is_some() { + hunks_tx.send(diff.push_new("\n")).await?; + line_indent = None; + first_line = false; } } - - if line_indent.is_some() { - hunks_tx.send(diff.push_new(&new_text)).await?; - new_text.clear(); - } - - if lines.peek().is_some() { - hunks_tx.send(diff.push_new("\n")).await?; - line_indent = None; - first_line = false; - } } + hunks_tx.send(diff.push_new(&new_text)).await?; + hunks_tx.send(diff.finish()).await?; + + anyhow::Ok(()) + }; + + let result = diff.await; + + let error_message = + result.as_ref().err().map(|error| error.to_string()); + if let Some(telemetry) = telemetry { + telemetry.report_assistant_event( + None, + telemetry_events::AssistantKind::Inline, + model_telemetry_id, + response_latency, + error_message, + ); } - hunks_tx.send(diff.push_new(&new_text)).await?; - hunks_tx.send(diff.finish()).await?; - anyhow::Ok(()) - }; - - let error_message = diff.await.err().map(|error| error.to_string()); - if let Some(telemetry) = telemetry { - telemetry.report_assistant_event( - None, - telemetry_events::AssistantKind::Inline, - model_telemetry_id, - response_latency, - error_message, - ); - } - }); + result?; + Ok(()) + }); while let Some(hunks) = hunks_rx.next().await { this.update(&mut cx, |this, cx| { @@ -266,7 +274,7 @@ impl Codegen { })?; } - diff.await; + diff.await?; anyhow::Ok(()) }; diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 99b8b407fb..01ea6325ad 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -1,14 +1,14 @@ mod anthropic; +mod cloud; #[cfg(test)] mod fake; mod open_ai; -mod zed; pub use anthropic::*; +pub use cloud::*; #[cfg(test)] pub use fake::*; pub use open_ai::*; -pub use zed::*; use crate::{ assistant_settings::{AssistantProvider, AssistantSettings}, @@ -25,8 +25,8 @@ use std::time::Duration; pub fn init(client: Arc, cx: &mut AppContext) { let mut settings_version = 0; let provider = match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev( - ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), + AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud( + CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), ), AssistantProvider::OpenAi { model, @@ -87,14 +87,11 @@ pub fn init(client: Arc, cx: &mut AppContext) { settings_version, ); } - ( - CompletionProvider::ZedDotDev(provider), - AssistantProvider::ZedDotDev { model }, - ) => { + (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => { provider.update(model.clone(), settings_version); } (_, AssistantProvider::ZedDotDev { model }) => { - *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new( + *provider = CompletionProvider::Cloud(CloudCompletionProvider::new( model.clone(), client.clone(), settings_version, @@ -142,7 +139,7 @@ pub fn init(client: Arc, cx: &mut AppContext) { pub enum CompletionProvider { OpenAi(OpenAiCompletionProvider), Anthropic(AnthropicCompletionProvider), - ZedDotDev(ZedDotDevCompletionProvider), + Cloud(CloudCompletionProvider), #[cfg(test)] Fake(FakeCompletionProvider), } @@ -164,9 +161,9 @@ impl CompletionProvider { .available_models() .map(LanguageModel::Anthropic) .collect(), - CompletionProvider::ZedDotDev(provider) => provider + CompletionProvider::Cloud(provider) => provider .available_models() - .map(LanguageModel::ZedDotDev) + .map(LanguageModel::Cloud) .collect(), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), @@ -177,7 +174,7 @@ impl CompletionProvider { match self { CompletionProvider::OpenAi(provider) => provider.settings_version(), CompletionProvider::Anthropic(provider) => provider.settings_version(), - CompletionProvider::ZedDotDev(provider) => provider.settings_version(), + CompletionProvider::Cloud(provider) => provider.settings_version(), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -187,7 +184,7 @@ impl CompletionProvider { match self { CompletionProvider::OpenAi(provider) => provider.is_authenticated(), CompletionProvider::Anthropic(provider) => provider.is_authenticated(), - CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(), + CompletionProvider::Cloud(provider) => provider.is_authenticated(), #[cfg(test)] CompletionProvider::Fake(_) => true, } @@ -197,7 +194,7 @@ impl CompletionProvider { match self { CompletionProvider::OpenAi(provider) => provider.authenticate(cx), CompletionProvider::Anthropic(provider) => provider.authenticate(cx), - CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx), + CompletionProvider::Cloud(provider) => provider.authenticate(cx), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), } @@ -207,7 +204,7 @@ impl CompletionProvider { match self { CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx), CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx), - CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx), + CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx), #[cfg(test)] CompletionProvider::Fake(_) => unimplemented!(), } @@ -217,7 +214,7 @@ impl CompletionProvider { match self { CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx), CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx), - CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())), + CompletionProvider::Cloud(_) => Task::ready(Ok(())), #[cfg(test)] CompletionProvider::Fake(_) => Task::ready(Ok(())), } @@ -227,7 +224,7 @@ impl CompletionProvider { match self { CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()), CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()), - CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()), + CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()), #[cfg(test)] CompletionProvider::Fake(_) => LanguageModel::default(), } @@ -241,7 +238,7 @@ impl CompletionProvider { match self { CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx), CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx), - CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx), + CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx), #[cfg(test)] CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))), } @@ -254,7 +251,7 @@ impl CompletionProvider { match self { CompletionProvider::OpenAi(provider) => provider.complete(request), CompletionProvider::Anthropic(provider) => provider.complete(request), - CompletionProvider::ZedDotDev(provider) => provider.complete(request), + CompletionProvider::Cloud(provider) => provider.complete(request), #[cfg(test)] CompletionProvider::Fake(provider) => provider.complete(), } diff --git a/crates/assistant/src/completion_provider/anthropic.rs b/crates/assistant/src/completion_provider/anthropic.rs index 8ae40993bc..d17a601284 100644 --- a/crates/assistant/src/completion_provider/anthropic.rs +++ b/crates/assistant/src/completion_provider/anthropic.rs @@ -1,9 +1,9 @@ -use crate::count_open_ai_tokens; use crate::{ assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, }; -use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole}; +use crate::{count_open_ai_tokens, LanguageModelRequestMessage}; +use anthropic::{stream_completion, Request, RequestMessage}; use anyhow::{anyhow, Result}; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; @@ -167,53 +167,37 @@ impl AnthropicCompletionProvider { .boxed() } - fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request { + fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { + preprocess_anthropic_request(&mut request); + let model = match request.model { LanguageModel::Anthropic(model) => model, _ => self.model(), }; let mut system_message = String::new(); - - let mut messages: Vec = Vec::new(); - for message in request.messages { - if message.content.is_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let role = match message.role { - Role::User => AnthropicRole::User, - Role::Assistant => AnthropicRole::Assistant, - _ => unreachable!(), - }; - - if let Some(last_message) = messages.last_mut() { - if last_message.role == role { - last_message.content.push_str("\n\n"); - last_message.content.push_str(&message.content); - continue; - } - } - - messages.push(RequestMessage { - role, - content: message.content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - } - } + if request + .messages + .first() + .map_or(false, |message| message.role == Role::System) + { + system_message = request.messages.remove(0).content; } Request { model, - messages, + messages: request + .messages + .iter() + .map(|msg| RequestMessage { + role: match msg.role { + Role::User => anthropic::Role::User, + Role::Assistant => anthropic::Role::Assistant, + Role::System => unreachable!("filtered out by preprocess_request"), + }, + content: msg.content.clone(), + }) + .collect(), stream: true, system: system_message, max_tokens: 4092, @@ -221,6 +205,49 @@ impl AnthropicCompletionProvider { } } +pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages.drain(..) { + if message.content.is_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == message.role { + last_message.content.push_str("\n\n"); + last_message.content.push_str(&message.content); + continue; + } + } + + new_messages.push(message); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + } + } + } + + if !system_message.is_empty() { + request.messages.insert( + 0, + LanguageModelRequestMessage { + role: Role::System, + content: system_message, + }, + ); + } + + request.messages = new_messages; +} + struct AuthenticationPrompt { api_key: View, api_url: String, diff --git a/crates/assistant/src/completion_provider/zed.rs b/crates/assistant/src/completion_provider/cloud.rs similarity index 80% rename from crates/assistant/src/completion_provider/zed.rs rename to crates/assistant/src/completion_provider/cloud.rs index d300541a88..e2c157353c 100644 --- a/crates/assistant/src/completion_provider/zed.rs +++ b/crates/assistant/src/completion_provider/cloud.rs @@ -1,5 +1,5 @@ use crate::{ - assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel, + assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelRequest, }; use anyhow::{anyhow, Result}; @@ -10,17 +10,17 @@ use std::{future, sync::Arc}; use strum::IntoEnumIterator; use ui::prelude::*; -pub struct ZedDotDevCompletionProvider { +pub struct CloudCompletionProvider { client: Arc, - model: ZedDotDevModel, + model: CloudModel, settings_version: usize, status: client::Status, _maintain_client_status: Task<()>, } -impl ZedDotDevCompletionProvider { +impl CloudCompletionProvider { pub fn new( - model: ZedDotDevModel, + model: CloudModel, client: Arc, settings_version: usize, cx: &mut AppContext, @@ -30,7 +30,7 @@ impl ZedDotDevCompletionProvider { let maintain_client_status = cx.spawn(|mut cx| async move { while let Some(status) = status_rx.next().await { let _ = cx.update_global::(|provider, _cx| { - if let CompletionProvider::ZedDotDev(provider) = provider { + if let CompletionProvider::Cloud(provider) = provider { provider.status = status; } else { unreachable!() @@ -47,20 +47,20 @@ impl ZedDotDevCompletionProvider { } } - pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) { + pub fn update(&mut self, model: CloudModel, settings_version: usize) { self.model = model; self.settings_version = settings_version; } - pub fn available_models(&self) -> impl Iterator { - let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() { + pub fn available_models(&self) -> impl Iterator { + let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() { Some(custom_model) } else { None }; - ZedDotDevModel::iter().filter_map(move |model| { - if let ZedDotDevModel::Custom(_) = model { - Some(ZedDotDevModel::Custom(custom_model.take()?)) + CloudModel::iter().filter_map(move |model| { + if let CloudModel::Custom(_) = model { + Some(CloudModel::Custom(custom_model.take()?)) } else { Some(model) } @@ -71,7 +71,7 @@ impl ZedDotDevCompletionProvider { self.settings_version } - pub fn model(&self) -> ZedDotDevModel { + pub fn model(&self) -> CloudModel { self.model.clone() } @@ -94,21 +94,19 @@ impl ZedDotDevCompletionProvider { cx: &AppContext, ) -> BoxFuture<'static, Result> { match request.model { - LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4) - | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo) - | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) - | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => { + LanguageModel::Cloud(CloudModel::Gpt4) + | LanguageModel::Cloud(CloudModel::Gpt4Turbo) + | LanguageModel::Cloud(CloudModel::Gpt4Omni) + | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => { count_open_ai_tokens(request, cx.background_executor()) } - LanguageModel::ZedDotDev( - ZedDotDevModel::Claude3Opus - | ZedDotDevModel::Claude3Sonnet - | ZedDotDevModel::Claude3Haiku, + LanguageModel::Cloud( + CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku, ) => { // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation. count_open_ai_tokens(request, cx.background_executor()) } - LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => { + LanguageModel::Cloud(CloudModel::Custom(model)) => { let request = self.client.request(proto::CountTokensWithLanguageModel { model, messages: request @@ -129,8 +127,10 @@ impl ZedDotDevCompletionProvider { pub fn complete( &self, - request: LanguageModelRequest, + mut request: LanguageModelRequest, ) -> BoxFuture<'static, Result>>> { + request.preprocess(); + let request = proto::CompleteWithLanguageModel { model: request.model.id().to_string(), messages: request diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/assistant/src/completion_provider/open_ai.rs index 6ab43d773b..ecd127a418 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/assistant/src/completion_provider/open_ai.rs @@ -1,4 +1,4 @@ -use crate::assistant_settings::ZedDotDevModel; +use crate::assistant_settings::CloudModel; use crate::{ assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, }; @@ -210,9 +210,9 @@ pub fn count_open_ai_tokens( match request.model { LanguageModel::Anthropic(_) - | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus) - | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet) - | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => { + | LanguageModel::Cloud(CloudModel::Claude3Opus) + | LanguageModel::Cloud(CloudModel::Claude3Sonnet) + | LanguageModel::Cloud(CloudModel::Claude3Haiku) => { // Tiktoken doesn't yet support these models, so we manually use the // same tokenizer as GPT-4. tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 23272ae1c8..3c472540c0 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -554,6 +554,20 @@ pub struct GutterDimensions { pub git_blame_entries_width: Option, } +impl GutterDimensions { + /// The full width of the space taken up by the gutter. + pub fn full_width(&self) -> Pixels { + self.margin + self.width + } + + /// The width of the space reserved for the fold indicators, + /// use alongside 'justify_end' and `gutter_width` to + /// right align content with the line numbers + pub fn fold_area_width(&self) -> Pixels { + self.margin + self.right_padding + } +} + impl Default for GutterDimensions { fn default() -> Self { Self { diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 280be02523..3abb250efb 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -1125,9 +1125,7 @@ impl EditorElement { ix as f32 * line_height - (scroll_pixel_position.y % line_height), ); let centering_offset = point( - (gutter_dimensions.right_padding + gutter_dimensions.margin - - fold_indicator_size.width) - / 2., + (gutter_dimensions.fold_area_width() - fold_indicator_size.width) / 2., (line_height - fold_indicator_size.height) / 2., ); let origin = gutter_hitbox.origin + position + centering_offset; @@ -4629,7 +4627,7 @@ impl Element for EditorElement { &mut scroll_width, &gutter_dimensions, em_width, - gutter_dimensions.width + gutter_dimensions.margin, + gutter_dimensions.full_width(), line_height, &line_layouts, cx, diff --git a/crates/editor/src/hunk_diff.rs b/crates/editor/src/hunk_diff.rs index b927c05262..a2fa4521c6 100644 --- a/crates/editor/src/hunk_diff.rs +++ b/crates/editor/src/hunk_diff.rs @@ -320,7 +320,7 @@ impl Editor { div() .bg(deleted_hunk_color) .size_full() - .pl(gutter_dimensions.width + gutter_dimensions.margin) + .pl(gutter_dimensions.full_width()) .child(editor_with_deleted_text.clone()) .into_any_element() }),