Fix issues with Claude in Assistant2 (#12619)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Mikayla Maki 2024-06-03 16:30:09 -07:00 committed by GitHub
parent afc0650a49
commit 3cd6719b30
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 276 additions and 213 deletions

View file

@ -12,7 +12,7 @@ mod streaming_diff;
pub use assistant_panel::AssistantPanel; pub use assistant_panel::AssistantPanel;
use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel}; use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
use assistant_slash_command::SlashCommandRegistry; use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client}; use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter; use command_palette_hooks::CommandPaletteFilter;
@ -87,14 +87,14 @@ impl Display for Role {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel { pub enum LanguageModel {
ZedDotDev(ZedDotDevModel), Cloud(CloudModel),
OpenAi(OpenAiModel), OpenAi(OpenAiModel),
Anthropic(AnthropicModel), Anthropic(AnthropicModel),
} }
impl Default for LanguageModel { impl Default for LanguageModel {
fn default() -> Self { fn default() -> Self {
LanguageModel::ZedDotDev(ZedDotDevModel::default()) LanguageModel::Cloud(CloudModel::default())
} }
} }
@ -103,7 +103,7 @@ impl LanguageModel {
match self { match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()), LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
} }
} }
@ -111,7 +111,7 @@ impl LanguageModel {
match self { match self {
LanguageModel::OpenAi(model) => model.display_name().into(), LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(), LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::ZedDotDev(model) => model.display_name().into(), LanguageModel::Cloud(model) => model.display_name().into(),
} }
} }
@ -119,7 +119,7 @@ impl LanguageModel {
match self { match self {
LanguageModel::OpenAi(model) => model.max_token_count(), LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(), LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::ZedDotDev(model) => model.max_token_count(), LanguageModel::Cloud(model) => model.max_token_count(),
} }
} }
@ -127,7 +127,7 @@ impl LanguageModel {
match self { match self {
LanguageModel::OpenAi(model) => model.id(), LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(), LanguageModel::Anthropic(model) => model.id(),
LanguageModel::ZedDotDev(model) => model.id(), LanguageModel::Cloud(model) => model.id(),
} }
} }
} }
@ -172,6 +172,20 @@ impl LanguageModelRequest {
tools: Vec::new(), tools: Vec::new(),
} }
} }
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
pub fn preprocess(&mut self) {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
preprocess_anthropic_request(self);
}
_ => {}
},
}
}
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]

View file

@ -17,7 +17,7 @@ use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection}; use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
use collections::{hash_map, BTreeSet, HashMap, HashSet, VecDeque}; use collections::{hash_map, BTreeSet, HashMap, HashSet, VecDeque};
use editor::actions::ShowCompletions; use editor::{actions::ShowCompletions, GutterDimensions};
use editor::{ use editor::{
actions::{FoldAt, MoveDown, MoveToEndOfLine, MoveUp, Newline, UnfoldAt}, actions::{FoldAt, MoveDown, MoveToEndOfLine, MoveUp, Newline, UnfoldAt},
display_map::{ display_map::{
@ -469,7 +469,7 @@ impl AssistantPanel {
) )
}); });
let measurements = Arc::new(Mutex::new(BlockMeasurements::default())); let measurements = Arc::new(Mutex::new(GutterDimensions::default()));
let inline_assistant = cx.new_view(|cx| { let inline_assistant = cx.new_view(|cx| {
InlineAssistant::new( InlineAssistant::new(
inline_assist_id, inline_assist_id,
@ -492,10 +492,7 @@ impl AssistantPanel {
render: Box::new({ render: Box::new({
let inline_assistant = inline_assistant.clone(); let inline_assistant = inline_assistant.clone();
move |cx: &mut BlockContext| { move |cx: &mut BlockContext| {
*measurements.lock() = BlockMeasurements { *measurements.lock() = *cx.gutter_dimensions;
gutter_width: cx.gutter_dimensions.width,
gutter_margin: cx.gutter_dimensions.margin,
};
inline_assistant.clone().into_any_element() inline_assistant.clone().into_any_element()
} }
}), }),
@ -583,6 +580,7 @@ impl AssistantPanel {
], ],
}, },
); );
self.pending_inline_assist_ids_by_editor self.pending_inline_assist_ids_by_editor
.entry(editor.downgrade()) .entry(editor.downgrade())
.or_default() .or_default()
@ -810,7 +808,7 @@ impl AssistantPanel {
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?; codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
anyhow::Ok(()) anyhow::Ok(())
}) })
.detach(); .detach_and_log_err(cx);
} }
fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut ViewContext<Self>) { fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut ViewContext<Self>) {
@ -1431,7 +1429,7 @@ impl Panel for AssistantPanel {
return None; return None;
} }
Some(IconName::Ai) Some(IconName::ZedAssistant)
} }
fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> { fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
@ -3151,7 +3149,7 @@ impl ConversationEditor {
h_flex() h_flex()
.id(("message_header", message_id.0)) .id(("message_header", message_id.0))
.pl(cx.gutter_dimensions.width + cx.gutter_dimensions.margin) .pl(cx.gutter_dimensions.full_width())
.h_11() .h_11()
.w_full() .w_full()
.relative() .relative()
@ -3551,7 +3549,7 @@ struct InlineAssistant {
prompt_editor: View<Editor>, prompt_editor: View<Editor>,
confirmed: bool, confirmed: bool,
include_conversation: bool, include_conversation: bool,
measurements: Arc<Mutex<BlockMeasurements>>, gutter_dimensions: Arc<Mutex<GutterDimensions>>,
prompt_history: VecDeque<String>, prompt_history: VecDeque<String>,
prompt_history_ix: Option<usize>, prompt_history_ix: Option<usize>,
pending_prompt: String, pending_prompt: String,
@ -3563,7 +3561,8 @@ impl EventEmitter<InlineAssistantEvent> for InlineAssistant {}
impl Render for InlineAssistant { impl Render for InlineAssistant {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement { fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let measurements = *self.measurements.lock(); let gutter_dimensions = *self.gutter_dimensions.lock();
let icon_size = IconSize::default();
h_flex() h_flex()
.w_full() .w_full()
.py_2() .py_2()
@ -3576,14 +3575,20 @@ impl Render for InlineAssistant {
.on_action(cx.listener(Self::move_down)) .on_action(cx.listener(Self::move_down))
.child( .child(
h_flex() h_flex()
.w(measurements.gutter_width + measurements.gutter_margin) .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
.pr(gutter_dimensions.fold_area_width())
.justify_end()
.children(if let Some(error) = self.codegen.read(cx).error() { .children(if let Some(error) = self.codegen.read(cx).error() {
let error_message = SharedString::from(error.to_string()); let error_message = SharedString::from(error.to_string());
Some( Some(
div() div()
.id("error") .id("error")
.tooltip(move |cx| Tooltip::text(error_message.clone(), cx)) .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
.child(Icon::new(IconName::XCircle).color(Color::Error)), .child(
Icon::new(IconName::XCircle)
.size(icon_size)
.color(Color::Error),
),
) )
} else { } else {
None None
@ -3603,7 +3608,7 @@ impl InlineAssistant {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
id: usize, id: usize,
measurements: Arc<Mutex<BlockMeasurements>>, gutter_dimensions: Arc<Mutex<GutterDimensions>>,
include_conversation: bool, include_conversation: bool,
prompt_history: VecDeque<String>, prompt_history: VecDeque<String>,
codegen: Model<Codegen>, codegen: Model<Codegen>,
@ -3630,7 +3635,7 @@ impl InlineAssistant {
prompt_editor, prompt_editor,
confirmed: false, confirmed: false,
include_conversation, include_conversation,
measurements, gutter_dimensions,
prompt_history, prompt_history,
prompt_history_ix: None, prompt_history_ix: None,
pending_prompt: String::new(), pending_prompt: String::new(),
@ -3755,13 +3760,6 @@ impl InlineAssistant {
} }
} }
// This wouldn't need to exist if we could pass parameters when rendering child views.
#[derive(Copy, Clone, Default)]
struct BlockMeasurements {
gutter_width: Pixels,
gutter_margin: Pixels,
}
struct PendingInlineAssist { struct PendingInlineAssist {
editor: WeakView<Editor>, editor: WeakView<Editor>,
inline_assistant: Option<(BlockId, View<InlineAssistant>)>, inline_assistant: Option<(BlockId, View<InlineAssistant>)>,

View file

@ -14,10 +14,10 @@ use serde::{
use settings::{Settings, SettingsSources}; use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator}; use strum::{EnumIter, IntoEnumIterator};
use crate::LanguageModel; use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)] #[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum ZedDotDevModel { pub enum CloudModel {
Gpt3Point5Turbo, Gpt3Point5Turbo,
Gpt4, Gpt4,
Gpt4Turbo, Gpt4Turbo,
@ -29,7 +29,7 @@ pub enum ZedDotDevModel {
Custom(String), Custom(String),
} }
impl Serialize for ZedDotDevModel { impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
S: Serializer, S: Serializer,
@ -38,7 +38,7 @@ impl Serialize for ZedDotDevModel {
} }
} }
impl<'de> Deserialize<'de> for ZedDotDevModel { impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
@ -46,7 +46,7 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
struct ZedDotDevModelVisitor; struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor { impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = ZedDotDevModel; type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model") formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
@ -56,9 +56,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
where where
E: de::Error, E: de::Error,
{ {
let model = ZedDotDevModel::iter() let model = CloudModel::iter()
.find(|model| model.id() == value) .find(|model| model.id() == value)
.unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string())); .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model) Ok(model)
} }
} }
@ -67,13 +67,13 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
} }
} }
impl JsonSchema for ZedDotDevModel { impl JsonSchema for CloudModel {
fn schema_name() -> String { fn schema_name() -> String {
"ZedDotDevModel".to_owned() "ZedDotDevModel".to_owned()
} }
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = ZedDotDevModel::iter() let variants = CloudModel::iter()
.filter_map(|model| { .filter_map(|model| {
let id = model.id(); let id = model.id();
if id.is_empty() { if id.is_empty() {
@ -88,7 +88,7 @@ impl JsonSchema for ZedDotDevModel {
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata { metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()), title: Some("ZedDotDevModel".to_owned()),
default: Some(ZedDotDevModel::default().id().into()), default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(), examples: variants.into_iter().map(Into::into).collect(),
..Default::default() ..Default::default()
})), })),
@ -97,7 +97,7 @@ impl JsonSchema for ZedDotDevModel {
} }
} }
impl ZedDotDevModel { impl CloudModel {
pub fn id(&self) -> &str { pub fn id(&self) -> &str {
match self { match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo", Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
@ -133,6 +133,15 @@ impl ZedDotDevModel {
Self::Custom(_) => 4096, // TODO: Make this configurable Self::Custom(_) => 4096, // TODO: Make this configurable
} }
} }
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
preprocess_anthropic_request(request)
}
_ => {}
}
}
} }
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
@ -147,7 +156,7 @@ pub enum AssistantDockPosition {
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum AssistantProvider { pub enum AssistantProvider {
ZedDotDev { ZedDotDev {
model: ZedDotDevModel, model: CloudModel,
}, },
OpenAi { OpenAi {
model: OpenAiModel, model: OpenAiModel,
@ -175,9 +184,7 @@ impl Default for AssistantProvider {
#[serde(tag = "name", rename_all = "snake_case")] #[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProviderContent { pub enum AssistantProviderContent {
#[serde(rename = "zed.dev")] #[serde(rename = "zed.dev")]
ZedDotDev { ZedDotDev { default_model: Option<CloudModel> },
default_model: Option<ZedDotDevModel>,
},
#[serde(rename = "openai")] #[serde(rename = "openai")]
OpenAi { OpenAi {
default_model: Option<OpenAiModel>, default_model: Option<OpenAiModel>,
@ -281,7 +288,7 @@ impl AssistantSettingsContent {
Some(AssistantProviderContent::ZedDotDev { Some(AssistantProviderContent::ZedDotDev {
default_model: model, default_model: model,
}) => { }) => {
if let LanguageModel::ZedDotDev(new_model) = new_model { if let LanguageModel::Cloud(new_model) = new_model {
*model = Some(new_model); *model = Some(new_model);
} }
} }
@ -302,7 +309,7 @@ impl AssistantSettingsContent {
} }
} }
provider => match new_model { provider => match new_model {
LanguageModel::ZedDotDev(model) => { LanguageModel::Cloud(model) => {
*provider = Some(AssistantProviderContent::ZedDotDev { *provider = Some(AssistantProviderContent::ZedDotDev {
default_model: Some(model), default_model: Some(model),
}) })
@ -613,7 +620,7 @@ mod tests {
assert_eq!( assert_eq!(
AssistantSettings::get_global(cx).provider, AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev { AssistantProvider::ZedDotDev {
model: ZedDotDevModel::Custom("custom".into()) model: CloudModel::Custom("custom".into())
} }
); );
} }

View file

@ -11,6 +11,7 @@ use language::{Rope, TransactionId};
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use std::{cmp, future, ops::Range, sync::Arc, time::Instant}; use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
#[derive(Debug)]
pub enum Event { pub enum Event {
Finished, Finished,
Undone, Undone,
@ -120,91 +121,98 @@ impl Codegen {
let mut edit_start = range.start.to_offset(&snapshot); let mut edit_start = range.start.to_offset(&snapshot);
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff = cx.background_executor().spawn(async move { let diff: Task<anyhow::Result<()>> =
let mut response_latency = None; cx.background_executor().spawn(async move {
let request_start = Instant::now(); let mut response_latency = None;
let diff = async { let request_start = Instant::now();
let chunks = strip_invalid_spans_from_codeblock(response.await?); let diff = async {
futures::pin_mut!(chunks); let chunks = strip_invalid_spans_from_codeblock(response.await?);
let mut diff = StreamingDiff::new(selected_text.to_string()); futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut new_text = String::new(); let mut new_text = String::new();
let mut base_indent = None; let mut base_indent = None;
let mut line_indent = None; let mut line_indent = None;
let mut first_line = true; let mut first_line = true;
while let Some(chunk) = chunks.next().await { while let Some(chunk) = chunks.next().await {
if response_latency.is_none() { if response_latency.is_none() {
response_latency = Some(request_start.elapsed()); response_latency = Some(request_start.elapsed());
} }
let chunk = chunk?; let chunk = chunk?;
let mut lines = chunk.split('\n').peekable(); let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() { while let Some(line) = lines.next() {
new_text.push_str(line); new_text.push_str(line);
if line_indent.is_none() { if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) = if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace()) new_text.find(|ch: char| !ch.is_whitespace())
{ {
line_indent = Some(non_whitespace_ch_ix); line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent); base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap(); let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap(); let base_indent = base_indent.unwrap();
let indent_delta = let indent_delta =
line_indent as i32 - base_indent as i32; line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max( let mut corrected_indent_len = cmp::max(
0, 0,
suggested_line_indent.len as i32 + indent_delta, suggested_line_indent.len as i32 + indent_delta,
) )
as usize; as usize;
if first_line { if first_line {
corrected_indent_len = corrected_indent_len corrected_indent_len = corrected_indent_len
.saturating_sub( .saturating_sub(
selection_start.column as usize, selection_start.column as usize,
); );
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
} }
}
let indent_char = suggested_line_indent.char(); if line_indent.is_some() {
let mut indent_buffer = [0; 4]; hunks_tx.send(diff.push_new(&new_text)).await?;
let indent_str = new_text.clear();
indent_char.encode_utf8(&mut indent_buffer); }
new_text.replace_range(
..line_indent, if lines.peek().is_some() {
&indent_str.repeat(corrected_indent_len), hunks_tx.send(diff.push_new("\n")).await?;
); line_indent = None;
first_line = false;
} }
} }
if line_indent.is_some() {
hunks_tx.send(diff.push_new(&new_text)).await?;
new_text.clear();
}
if lines.peek().is_some() {
hunks_tx.send(diff.push_new("\n")).await?;
line_indent = None;
first_line = false;
}
} }
hunks_tx.send(diff.push_new(&new_text)).await?;
hunks_tx.send(diff.finish()).await?;
anyhow::Ok(())
};
let result = diff.await;
let error_message =
result.as_ref().err().map(|error| error.to_string());
if let Some(telemetry) = telemetry {
telemetry.report_assistant_event(
None,
telemetry_events::AssistantKind::Inline,
model_telemetry_id,
response_latency,
error_message,
);
} }
hunks_tx.send(diff.push_new(&new_text)).await?;
hunks_tx.send(diff.finish()).await?;
anyhow::Ok(()) result?;
}; Ok(())
});
let error_message = diff.await.err().map(|error| error.to_string());
if let Some(telemetry) = telemetry {
telemetry.report_assistant_event(
None,
telemetry_events::AssistantKind::Inline,
model_telemetry_id,
response_latency,
error_message,
);
}
});
while let Some(hunks) = hunks_rx.next().await { while let Some(hunks) = hunks_rx.next().await {
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
@ -266,7 +274,7 @@ impl Codegen {
})?; })?;
} }
diff.await; diff.await?;
anyhow::Ok(()) anyhow::Ok(())
}; };

View file

@ -1,14 +1,14 @@
mod anthropic; mod anthropic;
mod cloud;
#[cfg(test)] #[cfg(test)]
mod fake; mod fake;
mod open_ai; mod open_ai;
mod zed;
pub use anthropic::*; pub use anthropic::*;
pub use cloud::*;
#[cfg(test)] #[cfg(test)]
pub use fake::*; pub use fake::*;
pub use open_ai::*; pub use open_ai::*;
pub use zed::*;
use crate::{ use crate::{
assistant_settings::{AssistantProvider, AssistantSettings}, assistant_settings::{AssistantProvider, AssistantSettings},
@ -25,8 +25,8 @@ use std::time::Duration;
pub fn init(client: Arc<Client>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0; let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider { let provider = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev( AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
), ),
AssistantProvider::OpenAi { AssistantProvider::OpenAi {
model, model,
@ -87,14 +87,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings_version, settings_version,
); );
} }
( (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
CompletionProvider::ZedDotDev(provider),
AssistantProvider::ZedDotDev { model },
) => {
provider.update(model.clone(), settings_version); provider.update(model.clone(), settings_version);
} }
(_, AssistantProvider::ZedDotDev { model }) => { (_, AssistantProvider::ZedDotDev { model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new( *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
model.clone(), model.clone(),
client.clone(), client.clone(),
settings_version, settings_version,
@ -142,7 +139,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
pub enum CompletionProvider { pub enum CompletionProvider {
OpenAi(OpenAiCompletionProvider), OpenAi(OpenAiCompletionProvider),
Anthropic(AnthropicCompletionProvider), Anthropic(AnthropicCompletionProvider),
ZedDotDev(ZedDotDevCompletionProvider), Cloud(CloudCompletionProvider),
#[cfg(test)] #[cfg(test)]
Fake(FakeCompletionProvider), Fake(FakeCompletionProvider),
} }
@ -164,9 +161,9 @@ impl CompletionProvider {
.available_models() .available_models()
.map(LanguageModel::Anthropic) .map(LanguageModel::Anthropic)
.collect(), .collect(),
CompletionProvider::ZedDotDev(provider) => provider CompletionProvider::Cloud(provider) => provider
.available_models() .available_models()
.map(LanguageModel::ZedDotDev) .map(LanguageModel::Cloud)
.collect(), .collect(),
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(), CompletionProvider::Fake(_) => unimplemented!(),
@ -177,7 +174,7 @@ impl CompletionProvider {
match self { match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(), CompletionProvider::OpenAi(provider) => provider.settings_version(),
CompletionProvider::Anthropic(provider) => provider.settings_version(), CompletionProvider::Anthropic(provider) => provider.settings_version(),
CompletionProvider::ZedDotDev(provider) => provider.settings_version(), CompletionProvider::Cloud(provider) => provider.settings_version(),
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(), CompletionProvider::Fake(_) => unimplemented!(),
} }
@ -187,7 +184,7 @@ impl CompletionProvider {
match self { match self {
CompletionProvider::OpenAi(provider) => provider.is_authenticated(), CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
CompletionProvider::Anthropic(provider) => provider.is_authenticated(), CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(), CompletionProvider::Cloud(provider) => provider.is_authenticated(),
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(_) => true, CompletionProvider::Fake(_) => true,
} }
@ -197,7 +194,7 @@ impl CompletionProvider {
match self { match self {
CompletionProvider::OpenAi(provider) => provider.authenticate(cx), CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
CompletionProvider::Anthropic(provider) => provider.authenticate(cx), CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx), CompletionProvider::Cloud(provider) => provider.authenticate(cx),
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())), CompletionProvider::Fake(_) => Task::ready(Ok(())),
} }
@ -207,7 +204,7 @@ impl CompletionProvider {
match self { match self {
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx), CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx), CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx), CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(), CompletionProvider::Fake(_) => unimplemented!(),
} }
@ -217,7 +214,7 @@ impl CompletionProvider {
match self { match self {
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx), CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx), CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())), CompletionProvider::Cloud(_) => Task::ready(Ok(())),
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())), CompletionProvider::Fake(_) => Task::ready(Ok(())),
} }
@ -227,7 +224,7 @@ impl CompletionProvider {
match self { match self {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()), CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()), CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()), CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(_) => LanguageModel::default(), CompletionProvider::Fake(_) => LanguageModel::default(),
} }
@ -241,7 +238,7 @@ impl CompletionProvider {
match self { match self {
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx), CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx), CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx), CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))), CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
} }
@ -254,7 +251,7 @@ impl CompletionProvider {
match self { match self {
CompletionProvider::OpenAi(provider) => provider.complete(request), CompletionProvider::OpenAi(provider) => provider.complete(request),
CompletionProvider::Anthropic(provider) => provider.complete(request), CompletionProvider::Anthropic(provider) => provider.complete(request),
CompletionProvider::ZedDotDev(provider) => provider.complete(request), CompletionProvider::Cloud(provider) => provider.complete(request),
#[cfg(test)] #[cfg(test)]
CompletionProvider::Fake(provider) => provider.complete(), CompletionProvider::Fake(provider) => provider.complete(),
} }

View file

@ -1,9 +1,9 @@
use crate::count_open_ai_tokens;
use crate::{ use crate::{
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest, assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
Role, Role,
}; };
use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole}; use crate::{count_open_ai_tokens, LanguageModelRequestMessage};
use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -167,53 +167,37 @@ impl AnthropicCompletionProvider {
.boxed() .boxed()
} }
fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request { fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request);
let model = match request.model { let model = match request.model {
LanguageModel::Anthropic(model) => model, LanguageModel::Anthropic(model) => model,
_ => self.model(), _ => self.model(),
}; };
let mut system_message = String::new(); let mut system_message = String::new();
if request
let mut messages: Vec<RequestMessage> = Vec::new(); .messages
for message in request.messages { .first()
if message.content.is_empty() { .map_or(false, |message| message.role == Role::System)
continue; {
} system_message = request.messages.remove(0).content;
match message.role {
Role::User | Role::Assistant => {
let role = match message.role {
Role::User => AnthropicRole::User,
Role::Assistant => AnthropicRole::Assistant,
_ => unreachable!(),
};
if let Some(last_message) = messages.last_mut() {
if last_message.role == role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
messages.push(RequestMessage {
role,
content: message.content,
});
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
} }
Request { Request {
model, model,
messages, messages: request
.messages
.iter()
.map(|msg| RequestMessage {
role: match msg.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("filtered out by preprocess_request"),
},
content: msg.content.clone(),
})
.collect(),
stream: true, stream: true,
system: system_message, system: system_message,
max_tokens: 4092, max_tokens: 4092,
@ -221,6 +205,49 @@ impl AnthropicCompletionProvider {
} }
} }
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in request.messages.drain(..) {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
if !system_message.is_empty() {
request.messages.insert(
0,
LanguageModelRequestMessage {
role: Role::System,
content: system_message,
},
);
}
request.messages = new_messages;
}
struct AuthenticationPrompt { struct AuthenticationPrompt {
api_key: View<Editor>, api_key: View<Editor>,
api_url: String, api_url: String,

View file

@ -1,5 +1,5 @@
use crate::{ use crate::{
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel, assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
LanguageModelRequest, LanguageModelRequest,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
@ -10,17 +10,17 @@ use std::{future, sync::Arc};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use ui::prelude::*; use ui::prelude::*;
pub struct ZedDotDevCompletionProvider { pub struct CloudCompletionProvider {
client: Arc<Client>, client: Arc<Client>,
model: ZedDotDevModel, model: CloudModel,
settings_version: usize, settings_version: usize,
status: client::Status, status: client::Status,
_maintain_client_status: Task<()>, _maintain_client_status: Task<()>,
} }
impl ZedDotDevCompletionProvider { impl CloudCompletionProvider {
pub fn new( pub fn new(
model: ZedDotDevModel, model: CloudModel,
client: Arc<Client>, client: Arc<Client>,
settings_version: usize, settings_version: usize,
cx: &mut AppContext, cx: &mut AppContext,
@ -30,7 +30,7 @@ impl ZedDotDevCompletionProvider {
let maintain_client_status = cx.spawn(|mut cx| async move { let maintain_client_status = cx.spawn(|mut cx| async move {
while let Some(status) = status_rx.next().await { while let Some(status) = status_rx.next().await {
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| { let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::ZedDotDev(provider) = provider { if let CompletionProvider::Cloud(provider) = provider {
provider.status = status; provider.status = status;
} else { } else {
unreachable!() unreachable!()
@ -47,20 +47,20 @@ impl ZedDotDevCompletionProvider {
} }
} }
pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) { pub fn update(&mut self, model: CloudModel, settings_version: usize) {
self.model = model; self.model = model;
self.settings_version = settings_version; self.settings_version = settings_version;
} }
pub fn available_models(&self) -> impl Iterator<Item = ZedDotDevModel> { pub fn available_models(&self) -> impl Iterator<Item = CloudModel> {
let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() { let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model) Some(custom_model)
} else { } else {
None None
}; };
ZedDotDevModel::iter().filter_map(move |model| { CloudModel::iter().filter_map(move |model| {
if let ZedDotDevModel::Custom(_) = model { if let CloudModel::Custom(_) = model {
Some(ZedDotDevModel::Custom(custom_model.take()?)) Some(CloudModel::Custom(custom_model.take()?))
} else { } else {
Some(model) Some(model)
} }
@ -71,7 +71,7 @@ impl ZedDotDevCompletionProvider {
self.settings_version self.settings_version
} }
pub fn model(&self) -> ZedDotDevModel { pub fn model(&self) -> CloudModel {
self.model.clone() self.model.clone()
} }
@ -94,21 +94,19 @@ impl ZedDotDevCompletionProvider {
cx: &AppContext, cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> { ) -> BoxFuture<'static, Result<usize>> {
match request.model { match request.model {
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4) LanguageModel::Cloud(CloudModel::Gpt4)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo) | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) | LanguageModel::Cloud(CloudModel::Gpt4Omni)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => { | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
count_open_ai_tokens(request, cx.background_executor()) count_open_ai_tokens(request, cx.background_executor())
} }
LanguageModel::ZedDotDev( LanguageModel::Cloud(
ZedDotDevModel::Claude3Opus CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku,
| ZedDotDevModel::Claude3Sonnet
| ZedDotDevModel::Claude3Haiku,
) => { ) => {
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation. // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
count_open_ai_tokens(request, cx.background_executor()) count_open_ai_tokens(request, cx.background_executor())
} }
LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => { LanguageModel::Cloud(CloudModel::Custom(model)) => {
let request = self.client.request(proto::CountTokensWithLanguageModel { let request = self.client.request(proto::CountTokensWithLanguageModel {
model, model,
messages: request messages: request
@ -129,8 +127,10 @@ impl ZedDotDevCompletionProvider {
pub fn complete( pub fn complete(
&self, &self,
request: LanguageModelRequest, mut request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
request.preprocess();
let request = proto::CompleteWithLanguageModel { let request = proto::CompleteWithLanguageModel {
model: request.model.id().to_string(), model: request.model.id().to_string(),
messages: request messages: request

View file

@ -1,4 +1,4 @@
use crate::assistant_settings::ZedDotDevModel; use crate::assistant_settings::CloudModel;
use crate::{ use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
}; };
@ -210,9 +210,9 @@ pub fn count_open_ai_tokens(
match request.model { match request.model {
LanguageModel::Anthropic(_) LanguageModel::Anthropic(_)
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus) | LanguageModel::Cloud(CloudModel::Claude3Opus)
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet) | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => { | LanguageModel::Cloud(CloudModel::Claude3Haiku) => {
// Tiktoken doesn't yet support these models, so we manually use the // Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4. // same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages) tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)

View file

@ -554,6 +554,20 @@ pub struct GutterDimensions {
pub git_blame_entries_width: Option<Pixels>, pub git_blame_entries_width: Option<Pixels>,
} }
impl GutterDimensions {
/// The full width of the space taken up by the gutter.
pub fn full_width(&self) -> Pixels {
self.margin + self.width
}
/// The width of the space reserved for the fold indicators,
/// use alongside 'justify_end' and `gutter_width` to
/// right align content with the line numbers
pub fn fold_area_width(&self) -> Pixels {
self.margin + self.right_padding
}
}
impl Default for GutterDimensions { impl Default for GutterDimensions {
fn default() -> Self { fn default() -> Self {
Self { Self {

View file

@ -1125,9 +1125,7 @@ impl EditorElement {
ix as f32 * line_height - (scroll_pixel_position.y % line_height), ix as f32 * line_height - (scroll_pixel_position.y % line_height),
); );
let centering_offset = point( let centering_offset = point(
(gutter_dimensions.right_padding + gutter_dimensions.margin (gutter_dimensions.fold_area_width() - fold_indicator_size.width) / 2.,
- fold_indicator_size.width)
/ 2.,
(line_height - fold_indicator_size.height) / 2., (line_height - fold_indicator_size.height) / 2.,
); );
let origin = gutter_hitbox.origin + position + centering_offset; let origin = gutter_hitbox.origin + position + centering_offset;
@ -4629,7 +4627,7 @@ impl Element for EditorElement {
&mut scroll_width, &mut scroll_width,
&gutter_dimensions, &gutter_dimensions,
em_width, em_width,
gutter_dimensions.width + gutter_dimensions.margin, gutter_dimensions.full_width(),
line_height, line_height,
&line_layouts, &line_layouts,
cx, cx,

View file

@ -320,7 +320,7 @@ impl Editor {
div() div()
.bg(deleted_hunk_color) .bg(deleted_hunk_color)
.size_full() .size_full()
.pl(gutter_dimensions.width + gutter_dimensions.margin) .pl(gutter_dimensions.full_width())
.child(editor_with_deleted_text.clone()) .child(editor_with_deleted_text.clone())
.into_any_element() .into_any_element()
}), }),