Fix issues with Claude in Assistant2 (#12619)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Mikayla Maki 2024-06-03 16:30:09 -07:00 committed by GitHub
parent afc0650a49
commit 3cd6719b30
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 276 additions and 213 deletions

View file

@ -12,7 +12,7 @@ mod streaming_diff;
pub use assistant_panel::AssistantPanel;
use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
@ -87,14 +87,14 @@ impl Display for Role {
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
ZedDotDev(ZedDotDevModel),
Cloud(CloudModel),
OpenAi(OpenAiModel),
Anthropic(AnthropicModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::ZedDotDev(ZedDotDevModel::default())
LanguageModel::Cloud(CloudModel::default())
}
}
@ -103,7 +103,7 @@ impl LanguageModel {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
}
}
@ -111,7 +111,7 @@ impl LanguageModel {
match self {
LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::ZedDotDev(model) => model.display_name().into(),
LanguageModel::Cloud(model) => model.display_name().into(),
}
}
@ -119,7 +119,7 @@ impl LanguageModel {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::ZedDotDev(model) => model.max_token_count(),
LanguageModel::Cloud(model) => model.max_token_count(),
}
}
@ -127,7 +127,7 @@ impl LanguageModel {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(),
LanguageModel::ZedDotDev(model) => model.id(),
LanguageModel::Cloud(model) => model.id(),
}
}
}
@ -172,6 +172,20 @@ impl LanguageModelRequest {
tools: Vec::new(),
}
}
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
pub fn preprocess(&mut self) {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
preprocess_anthropic_request(self);
}
_ => {}
},
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]

View file

@ -17,7 +17,7 @@ use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommandOutput, SlashCommandOutputSection};
use client::telemetry::Telemetry;
use collections::{hash_map, BTreeSet, HashMap, HashSet, VecDeque};
use editor::actions::ShowCompletions;
use editor::{actions::ShowCompletions, GutterDimensions};
use editor::{
actions::{FoldAt, MoveDown, MoveToEndOfLine, MoveUp, Newline, UnfoldAt},
display_map::{
@ -469,7 +469,7 @@ impl AssistantPanel {
)
});
let measurements = Arc::new(Mutex::new(BlockMeasurements::default()));
let measurements = Arc::new(Mutex::new(GutterDimensions::default()));
let inline_assistant = cx.new_view(|cx| {
InlineAssistant::new(
inline_assist_id,
@ -492,10 +492,7 @@ impl AssistantPanel {
render: Box::new({
let inline_assistant = inline_assistant.clone();
move |cx: &mut BlockContext| {
*measurements.lock() = BlockMeasurements {
gutter_width: cx.gutter_dimensions.width,
gutter_margin: cx.gutter_dimensions.margin,
};
*measurements.lock() = *cx.gutter_dimensions;
inline_assistant.clone().into_any_element()
}
}),
@ -583,6 +580,7 @@ impl AssistantPanel {
],
},
);
self.pending_inline_assist_ids_by_editor
.entry(editor.downgrade())
.or_default()
@ -810,7 +808,7 @@ impl AssistantPanel {
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
anyhow::Ok(())
})
.detach();
.detach_and_log_err(cx);
}
fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut ViewContext<Self>) {
@ -1431,7 +1429,7 @@ impl Panel for AssistantPanel {
return None;
}
Some(IconName::Ai)
Some(IconName::ZedAssistant)
}
fn icon_tooltip(&self, _cx: &WindowContext) -> Option<&'static str> {
@ -3151,7 +3149,7 @@ impl ConversationEditor {
h_flex()
.id(("message_header", message_id.0))
.pl(cx.gutter_dimensions.width + cx.gutter_dimensions.margin)
.pl(cx.gutter_dimensions.full_width())
.h_11()
.w_full()
.relative()
@ -3551,7 +3549,7 @@ struct InlineAssistant {
prompt_editor: View<Editor>,
confirmed: bool,
include_conversation: bool,
measurements: Arc<Mutex<BlockMeasurements>>,
gutter_dimensions: Arc<Mutex<GutterDimensions>>,
prompt_history: VecDeque<String>,
prompt_history_ix: Option<usize>,
pending_prompt: String,
@ -3563,7 +3561,8 @@ impl EventEmitter<InlineAssistantEvent> for InlineAssistant {}
impl Render for InlineAssistant {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let measurements = *self.measurements.lock();
let gutter_dimensions = *self.gutter_dimensions.lock();
let icon_size = IconSize::default();
h_flex()
.w_full()
.py_2()
@ -3576,14 +3575,20 @@ impl Render for InlineAssistant {
.on_action(cx.listener(Self::move_down))
.child(
h_flex()
.w(measurements.gutter_width + measurements.gutter_margin)
.w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
.pr(gutter_dimensions.fold_area_width())
.justify_end()
.children(if let Some(error) = self.codegen.read(cx).error() {
let error_message = SharedString::from(error.to_string());
Some(
div()
.id("error")
.tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
.child(Icon::new(IconName::XCircle).color(Color::Error)),
.child(
Icon::new(IconName::XCircle)
.size(icon_size)
.color(Color::Error),
),
)
} else {
None
@ -3603,7 +3608,7 @@ impl InlineAssistant {
#[allow(clippy::too_many_arguments)]
fn new(
id: usize,
measurements: Arc<Mutex<BlockMeasurements>>,
gutter_dimensions: Arc<Mutex<GutterDimensions>>,
include_conversation: bool,
prompt_history: VecDeque<String>,
codegen: Model<Codegen>,
@ -3630,7 +3635,7 @@ impl InlineAssistant {
prompt_editor,
confirmed: false,
include_conversation,
measurements,
gutter_dimensions,
prompt_history,
prompt_history_ix: None,
pending_prompt: String::new(),
@ -3755,13 +3760,6 @@ impl InlineAssistant {
}
}
// This wouldn't need to exist if we could pass parameters when rendering child views.
#[derive(Copy, Clone, Default)]
struct BlockMeasurements {
gutter_width: Pixels,
gutter_margin: Pixels,
}
struct PendingInlineAssist {
editor: WeakView<Editor>,
inline_assistant: Option<(BlockId, View<InlineAssistant>)>,

View file

@ -14,10 +14,10 @@ use serde::{
use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator};
use crate::LanguageModel;
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum ZedDotDevModel {
pub enum CloudModel {
Gpt3Point5Turbo,
Gpt4,
Gpt4Turbo,
@ -29,7 +29,7 @@ pub enum ZedDotDevModel {
Custom(String),
}
impl Serialize for ZedDotDevModel {
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
@ -38,7 +38,7 @@ impl Serialize for ZedDotDevModel {
}
}
impl<'de> Deserialize<'de> for ZedDotDevModel {
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
@ -46,7 +46,7 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = ZedDotDevModel;
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
@ -56,9 +56,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
where
E: de::Error,
{
let model = ZedDotDevModel::iter()
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
@ -67,13 +67,13 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
}
}
impl JsonSchema for ZedDotDevModel {
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = ZedDotDevModel::iter()
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
@ -88,7 +88,7 @@ impl JsonSchema for ZedDotDevModel {
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(ZedDotDevModel::default().id().into()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
@ -97,7 +97,7 @@ impl JsonSchema for ZedDotDevModel {
}
}
impl ZedDotDevModel {
impl CloudModel {
pub fn id(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
@ -133,6 +133,15 @@ impl ZedDotDevModel {
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
preprocess_anthropic_request(request)
}
_ => {}
}
}
}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
@ -147,7 +156,7 @@ pub enum AssistantDockPosition {
#[derive(Debug, PartialEq)]
pub enum AssistantProvider {
ZedDotDev {
model: ZedDotDevModel,
model: CloudModel,
},
OpenAi {
model: OpenAiModel,
@ -175,9 +184,7 @@ impl Default for AssistantProvider {
#[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProviderContent {
#[serde(rename = "zed.dev")]
ZedDotDev {
default_model: Option<ZedDotDevModel>,
},
ZedDotDev { default_model: Option<CloudModel> },
#[serde(rename = "openai")]
OpenAi {
default_model: Option<OpenAiModel>,
@ -281,7 +288,7 @@ impl AssistantSettingsContent {
Some(AssistantProviderContent::ZedDotDev {
default_model: model,
}) => {
if let LanguageModel::ZedDotDev(new_model) = new_model {
if let LanguageModel::Cloud(new_model) = new_model {
*model = Some(new_model);
}
}
@ -302,7 +309,7 @@ impl AssistantSettingsContent {
}
}
provider => match new_model {
LanguageModel::ZedDotDev(model) => {
LanguageModel::Cloud(model) => {
*provider = Some(AssistantProviderContent::ZedDotDev {
default_model: Some(model),
})
@ -613,7 +620,7 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
model: ZedDotDevModel::Custom("custom".into())
model: CloudModel::Custom("custom".into())
}
);
}

View file

@ -11,6 +11,7 @@ use language::{Rope, TransactionId};
use multi_buffer::MultiBufferRow;
use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
#[derive(Debug)]
pub enum Event {
Finished,
Undone,
@ -120,91 +121,98 @@ impl Codegen {
let mut edit_start = range.start.to_offset(&snapshot);
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff = cx.background_executor().spawn(async move {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = strip_invalid_spans_from_codeblock(response.await?);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let diff: Task<anyhow::Result<()>> =
cx.background_executor().spawn(async move {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = strip_invalid_spans_from_codeblock(response.await?);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta =
line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(
selection_start.column as usize,
);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta =
line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(
selection_start.column as usize,
);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
if line_indent.is_some() {
hunks_tx.send(diff.push_new(&new_text)).await?;
new_text.clear();
}
if lines.peek().is_some() {
hunks_tx.send(diff.push_new("\n")).await?;
line_indent = None;
first_line = false;
}
}
if line_indent.is_some() {
hunks_tx.send(diff.push_new(&new_text)).await?;
new_text.clear();
}
if lines.peek().is_some() {
hunks_tx.send(diff.push_new("\n")).await?;
line_indent = None;
first_line = false;
}
}
hunks_tx.send(diff.push_new(&new_text)).await?;
hunks_tx.send(diff.finish()).await?;
anyhow::Ok(())
};
let result = diff.await;
let error_message =
result.as_ref().err().map(|error| error.to_string());
if let Some(telemetry) = telemetry {
telemetry.report_assistant_event(
None,
telemetry_events::AssistantKind::Inline,
model_telemetry_id,
response_latency,
error_message,
);
}
hunks_tx.send(diff.push_new(&new_text)).await?;
hunks_tx.send(diff.finish()).await?;
anyhow::Ok(())
};
let error_message = diff.await.err().map(|error| error.to_string());
if let Some(telemetry) = telemetry {
telemetry.report_assistant_event(
None,
telemetry_events::AssistantKind::Inline,
model_telemetry_id,
response_latency,
error_message,
);
}
});
result?;
Ok(())
});
while let Some(hunks) = hunks_rx.next().await {
this.update(&mut cx, |this, cx| {
@ -266,7 +274,7 @@ impl Codegen {
})?;
}
diff.await;
diff.await?;
anyhow::Ok(())
};

View file

@ -1,14 +1,14 @@
mod anthropic;
mod cloud;
#[cfg(test)]
mod fake;
mod open_ai;
mod zed;
pub use anthropic::*;
pub use cloud::*;
#[cfg(test)]
pub use fake::*;
pub use open_ai::*;
pub use zed::*;
use crate::{
assistant_settings::{AssistantProvider, AssistantSettings},
@ -25,8 +25,8 @@ use std::time::Duration;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
),
AssistantProvider::OpenAi {
model,
@ -87,14 +87,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings_version,
);
}
(
CompletionProvider::ZedDotDev(provider),
AssistantProvider::ZedDotDev { model },
) => {
(CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
provider.update(model.clone(), settings_version);
}
(_, AssistantProvider::ZedDotDev { model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
*provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
model.clone(),
client.clone(),
settings_version,
@ -142,7 +139,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
pub enum CompletionProvider {
OpenAi(OpenAiCompletionProvider),
Anthropic(AnthropicCompletionProvider),
ZedDotDev(ZedDotDevCompletionProvider),
Cloud(CloudCompletionProvider),
#[cfg(test)]
Fake(FakeCompletionProvider),
}
@ -164,9 +161,9 @@ impl CompletionProvider {
.available_models()
.map(LanguageModel::Anthropic)
.collect(),
CompletionProvider::ZedDotDev(provider) => provider
CompletionProvider::Cloud(provider) => provider
.available_models()
.map(LanguageModel::ZedDotDev)
.map(LanguageModel::Cloud)
.collect(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
@ -177,7 +174,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(),
CompletionProvider::Anthropic(provider) => provider.settings_version(),
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
CompletionProvider::Cloud(provider) => provider.settings_version(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
@ -187,7 +184,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
CompletionProvider::Cloud(provider) => provider.is_authenticated(),
#[cfg(test)]
CompletionProvider::Fake(_) => true,
}
@ -197,7 +194,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
CompletionProvider::Cloud(provider) => provider.authenticate(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
@ -207,7 +204,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
@ -217,7 +214,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
CompletionProvider::Cloud(_) => Task::ready(Ok(())),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
@ -227,7 +224,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
#[cfg(test)]
CompletionProvider::Fake(_) => LanguageModel::default(),
}
@ -241,7 +238,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
#[cfg(test)]
CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
}
@ -254,7 +251,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.complete(request),
CompletionProvider::Anthropic(provider) => provider.complete(request),
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
CompletionProvider::Cloud(provider) => provider.complete(request),
#[cfg(test)]
CompletionProvider::Fake(provider) => provider.complete(),
}

View file

@ -1,9 +1,9 @@
use crate::count_open_ai_tokens;
use crate::{
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
Role,
};
use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole};
use crate::{count_open_ai_tokens, LanguageModelRequestMessage};
use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -167,53 +167,37 @@ impl AnthropicCompletionProvider {
.boxed()
}
fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request);
let model = match request.model {
LanguageModel::Anthropic(model) => model,
_ => self.model(),
};
let mut system_message = String::new();
let mut messages: Vec<RequestMessage> = Vec::new();
for message in request.messages {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
let role = match message.role {
Role::User => AnthropicRole::User,
Role::Assistant => AnthropicRole::Assistant,
_ => unreachable!(),
};
if let Some(last_message) = messages.last_mut() {
if last_message.role == role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
messages.push(RequestMessage {
role,
content: message.content,
});
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
if request
.messages
.first()
.map_or(false, |message| message.role == Role::System)
{
system_message = request.messages.remove(0).content;
}
Request {
model,
messages,
messages: request
.messages
.iter()
.map(|msg| RequestMessage {
role: match msg.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("filtered out by preprocess_request"),
},
content: msg.content.clone(),
})
.collect(),
stream: true,
system: system_message,
max_tokens: 4092,
@ -221,6 +205,49 @@ impl AnthropicCompletionProvider {
}
}
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in request.messages.drain(..) {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
if !system_message.is_empty() {
request.messages.insert(
0,
LanguageModelRequestMessage {
role: Role::System,
content: system_message,
},
);
}
request.messages = new_messages;
}
struct AuthenticationPrompt {
api_key: View<Editor>,
api_url: String,

View file

@ -1,5 +1,5 @@
use crate::{
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
LanguageModelRequest,
};
use anyhow::{anyhow, Result};
@ -10,17 +10,17 @@ use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
pub struct ZedDotDevCompletionProvider {
pub struct CloudCompletionProvider {
client: Arc<Client>,
model: ZedDotDevModel,
model: CloudModel,
settings_version: usize,
status: client::Status,
_maintain_client_status: Task<()>,
}
impl ZedDotDevCompletionProvider {
impl CloudCompletionProvider {
pub fn new(
model: ZedDotDevModel,
model: CloudModel,
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
@ -30,7 +30,7 @@ impl ZedDotDevCompletionProvider {
let maintain_client_status = cx.spawn(|mut cx| async move {
while let Some(status) = status_rx.next().await {
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::ZedDotDev(provider) = provider {
if let CompletionProvider::Cloud(provider) = provider {
provider.status = status;
} else {
unreachable!()
@ -47,20 +47,20 @@ impl ZedDotDevCompletionProvider {
}
}
pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) {
pub fn update(&mut self, model: CloudModel, settings_version: usize) {
self.model = model;
self.settings_version = settings_version;
}
pub fn available_models(&self) -> impl Iterator<Item = ZedDotDevModel> {
let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() {
pub fn available_models(&self) -> impl Iterator<Item = CloudModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model)
} else {
None
};
ZedDotDevModel::iter().filter_map(move |model| {
if let ZedDotDevModel::Custom(_) = model {
Some(ZedDotDevModel::Custom(custom_model.take()?))
CloudModel::iter().filter_map(move |model| {
if let CloudModel::Custom(_) = model {
Some(CloudModel::Custom(custom_model.take()?))
} else {
Some(model)
}
@ -71,7 +71,7 @@ impl ZedDotDevCompletionProvider {
self.settings_version
}
pub fn model(&self) -> ZedDotDevModel {
pub fn model(&self) -> CloudModel {
self.model.clone()
}
@ -94,21 +94,19 @@ impl ZedDotDevCompletionProvider {
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match request.model {
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
LanguageModel::Cloud(CloudModel::Gpt4)
| LanguageModel::Cloud(CloudModel::Gpt4Turbo)
| LanguageModel::Cloud(CloudModel::Gpt4Omni)
| LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::ZedDotDev(
ZedDotDevModel::Claude3Opus
| ZedDotDevModel::Claude3Sonnet
| ZedDotDevModel::Claude3Haiku,
LanguageModel::Cloud(
CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku,
) => {
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
LanguageModel::Cloud(CloudModel::Custom(model)) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model,
messages: request
@ -129,8 +127,10 @@ impl ZedDotDevCompletionProvider {
pub fn complete(
&self,
request: LanguageModelRequest,
mut request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
request.preprocess();
let request = proto::CompleteWithLanguageModel {
model: request.model.id().to_string(),
messages: request

View file

@ -1,4 +1,4 @@
use crate::assistant_settings::ZedDotDevModel;
use crate::assistant_settings::CloudModel;
use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
@ -210,9 +210,9 @@ pub fn count_open_ai_tokens(
match request.model {
LanguageModel::Anthropic(_)
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus)
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet)
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => {
| LanguageModel::Cloud(CloudModel::Claude3Opus)
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Haiku) => {
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)

View file

@ -554,6 +554,20 @@ pub struct GutterDimensions {
pub git_blame_entries_width: Option<Pixels>,
}
impl GutterDimensions {
/// The full width of the space taken up by the gutter.
pub fn full_width(&self) -> Pixels {
self.margin + self.width
}
/// The width of the space reserved for the fold indicators,
/// use alongside 'justify_end' and `gutter_width` to
/// right align content with the line numbers
pub fn fold_area_width(&self) -> Pixels {
self.margin + self.right_padding
}
}
impl Default for GutterDimensions {
fn default() -> Self {
Self {

View file

@ -1125,9 +1125,7 @@ impl EditorElement {
ix as f32 * line_height - (scroll_pixel_position.y % line_height),
);
let centering_offset = point(
(gutter_dimensions.right_padding + gutter_dimensions.margin
- fold_indicator_size.width)
/ 2.,
(gutter_dimensions.fold_area_width() - fold_indicator_size.width) / 2.,
(line_height - fold_indicator_size.height) / 2.,
);
let origin = gutter_hitbox.origin + position + centering_offset;
@ -4629,7 +4627,7 @@ impl Element for EditorElement {
&mut scroll_width,
&gutter_dimensions,
em_width,
gutter_dimensions.width + gutter_dimensions.margin,
gutter_dimensions.full_width(),
line_height,
&line_layouts,
cx,

View file

@ -320,7 +320,7 @@ impl Editor {
div()
.bg(deleted_hunk_color)
.size_full()
.pl(gutter_dimensions.width + gutter_dimensions.margin)
.pl(gutter_dimensions.full_width())
.child(editor_with_deleted_text.clone())
.into_any_element()
}),