Add support for interacting with Claude in the assistant panel (#11798)
Release Notes: - Added support for interacting with Claude in the assistant panel. You can enable it by adding the following to your `settings.json`: ```json "assistant": { "version": "1", "provider": { "name": "anthropic" } } ```
This commit is contained in:
parent
019d98898e
commit
5944caaa90
12 changed files with 446 additions and 21 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -225,6 +225,8 @@ dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"futures 0.3.28",
|
"futures 0.3.28",
|
||||||
"http 0.1.0",
|
"http 0.1.0",
|
||||||
|
"isahc",
|
||||||
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
@ -332,6 +334,7 @@ dependencies = [
|
||||||
name = "assistant"
|
name = "assistant"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anthropic",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"chrono",
|
"chrono",
|
||||||
"client",
|
"client",
|
||||||
|
|
|
@ -5,6 +5,10 @@ edition = "2021"
|
||||||
publish = false
|
publish = false
|
||||||
license = "AGPL-3.0-or-later"
|
license = "AGPL-3.0-or-later"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
schemars = ["dep:schemars"]
|
||||||
|
|
||||||
[lints]
|
[lints]
|
||||||
workspace = true
|
workspace = true
|
||||||
|
|
||||||
|
@ -15,6 +19,8 @@ path = "src/anthropic.rs"
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
http.workspace = true
|
http.workspace = true
|
||||||
|
isahc.workspace = true
|
||||||
|
schemars = { workspace = true, optional = true }
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,21 @@
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||||
use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
|
use isahc::config::Configurable;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{convert::TryFrom, sync::Arc};
|
use std::{convert::TryFrom, time::Duration};
|
||||||
|
|
||||||
|
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
#[default]
|
#[default]
|
||||||
#[serde(rename = "claude-3-opus-20240229")]
|
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
|
||||||
Claude3Opus,
|
Claude3Opus,
|
||||||
#[serde(rename = "claude-3-sonnet-20240229")]
|
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
|
||||||
Claude3Sonnet,
|
Claude3Sonnet,
|
||||||
#[serde(rename = "claude-3-haiku-20240307")]
|
#[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
|
||||||
Claude3Haiku,
|
Claude3Haiku,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,6 +32,14 @@ impl Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Model::Claude3Opus => "claude-3-opus-20240229",
|
||||||
|
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
||||||
|
Model::Claude3Haiku => "claude-3-opus-20240307",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn display_name(&self) -> &'static str {
|
pub fn display_name(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::Claude3Opus => "Claude 3 Opus",
|
Self::Claude3Opus => "Claude 3 Opus",
|
||||||
|
@ -141,20 +153,24 @@ pub enum TextDelta {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_completion(
|
pub async fn stream_completion(
|
||||||
client: Arc<dyn HttpClient>,
|
client: &dyn HttpClient,
|
||||||
api_url: &str,
|
api_url: &str,
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
low_speed_timeout: Option<Duration>,
|
||||||
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
||||||
let uri = format!("{api_url}/v1/messages");
|
let uri = format!("{api_url}/v1/messages");
|
||||||
let request = HttpRequest::builder()
|
let mut request_builder = HttpRequest::builder()
|
||||||
.method(Method::POST)
|
.method(Method::POST)
|
||||||
.uri(uri)
|
.uri(uri)
|
||||||
.header("Anthropic-Version", "2023-06-01")
|
.header("Anthropic-Version", "2023-06-01")
|
||||||
.header("Anthropic-Beta", "messages-2023-12-15")
|
.header("Anthropic-Beta", "tools-2024-04-04")
|
||||||
.header("X-Api-Key", api_key)
|
.header("X-Api-Key", api_key)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json");
|
||||||
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||||
|
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||||
|
}
|
||||||
|
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||||
let mut response = client.send(request).await?;
|
let mut response = client.send(request).await?;
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
let reader = BufReader::new(response.into_body());
|
let reader = BufReader::new(response.into_body());
|
||||||
|
|
|
@ -11,6 +11,7 @@ doctest = false
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
anthropic = { workspace = true, features = ["schemars"] }
|
||||||
chrono.workspace = true
|
chrono.workspace = true
|
||||||
client.workspace = true
|
client.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
|
|
|
@ -7,7 +7,7 @@ mod saved_conversation;
|
||||||
mod streaming_diff;
|
mod streaming_diff;
|
||||||
|
|
||||||
pub use assistant_panel::AssistantPanel;
|
pub use assistant_panel::AssistantPanel;
|
||||||
use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
|
use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
|
||||||
use client::{proto, Client};
|
use client::{proto, Client};
|
||||||
use command_palette_hooks::CommandPaletteFilter;
|
use command_palette_hooks::CommandPaletteFilter;
|
||||||
pub(crate) use completion_provider::*;
|
pub(crate) use completion_provider::*;
|
||||||
|
@ -72,6 +72,7 @@ impl Display for Role {
|
||||||
pub enum LanguageModel {
|
pub enum LanguageModel {
|
||||||
ZedDotDev(ZedDotDevModel),
|
ZedDotDev(ZedDotDevModel),
|
||||||
OpenAi(OpenAiModel),
|
OpenAi(OpenAiModel),
|
||||||
|
Anthropic(AnthropicModel),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for LanguageModel {
|
impl Default for LanguageModel {
|
||||||
|
@ -84,6 +85,7 @@ impl LanguageModel {
|
||||||
pub fn telemetry_id(&self) -> String {
|
pub fn telemetry_id(&self) -> String {
|
||||||
match self {
|
match self {
|
||||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
||||||
|
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
|
||||||
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
|
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -91,6 +93,7 @@ impl LanguageModel {
|
||||||
pub fn display_name(&self) -> String {
|
pub fn display_name(&self) -> String {
|
||||||
match self {
|
match self {
|
||||||
LanguageModel::OpenAi(model) => model.display_name().into(),
|
LanguageModel::OpenAi(model) => model.display_name().into(),
|
||||||
|
LanguageModel::Anthropic(model) => model.display_name().into(),
|
||||||
LanguageModel::ZedDotDev(model) => model.display_name().into(),
|
LanguageModel::ZedDotDev(model) => model.display_name().into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -98,6 +101,7 @@ impl LanguageModel {
|
||||||
pub fn max_token_count(&self) -> usize {
|
pub fn max_token_count(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
LanguageModel::OpenAi(model) => model.max_token_count(),
|
LanguageModel::OpenAi(model) => model.max_token_count(),
|
||||||
|
LanguageModel::Anthropic(model) => model.max_token_count(),
|
||||||
LanguageModel::ZedDotDev(model) => model.max_token_count(),
|
LanguageModel::ZedDotDev(model) => model.max_token_count(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,6 +109,7 @@ impl LanguageModel {
|
||||||
pub fn id(&self) -> &str {
|
pub fn id(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
LanguageModel::OpenAi(model) => model.id(),
|
LanguageModel::OpenAi(model) => model.id(),
|
||||||
|
LanguageModel::Anthropic(model) => model.id(),
|
||||||
LanguageModel::ZedDotDev(model) => model.id(),
|
LanguageModel::ZedDotDev(model) => model.id(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -800,6 +800,11 @@ impl AssistantPanel {
|
||||||
open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
|
open_ai::Model::FourTurbo => open_ai::Model::FourOmni,
|
||||||
open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
|
open_ai::Model::FourOmni => open_ai::Model::ThreePointFiveTurbo,
|
||||||
}),
|
}),
|
||||||
|
LanguageModel::Anthropic(model) => LanguageModel::Anthropic(match &model {
|
||||||
|
anthropic::Model::Claude3Opus => anthropic::Model::Claude3Sonnet,
|
||||||
|
anthropic::Model::Claude3Sonnet => anthropic::Model::Claude3Haiku,
|
||||||
|
anthropic::Model::Claude3Haiku => anthropic::Model::Claude3Opus,
|
||||||
|
}),
|
||||||
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
|
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
|
||||||
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
|
ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
|
||||||
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
|
ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
pub use anthropic::Model as AnthropicModel;
|
||||||
use gpui::Pixels;
|
use gpui::Pixels;
|
||||||
pub use open_ai::Model as OpenAiModel;
|
pub use open_ai::Model as OpenAiModel;
|
||||||
use schemars::{
|
use schemars::{
|
||||||
|
@ -161,6 +162,15 @@ pub enum AssistantProvider {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
low_speed_timeout_in_seconds: Option<u64>,
|
low_speed_timeout_in_seconds: Option<u64>,
|
||||||
},
|
},
|
||||||
|
#[serde(rename = "anthropic")]
|
||||||
|
Anthropic {
|
||||||
|
#[serde(default)]
|
||||||
|
default_model: AnthropicModel,
|
||||||
|
#[serde(default = "anthropic_api_url")]
|
||||||
|
api_url: String,
|
||||||
|
#[serde(default)]
|
||||||
|
low_speed_timeout_in_seconds: Option<u64>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AssistantProvider {
|
impl Default for AssistantProvider {
|
||||||
|
@ -172,7 +182,11 @@ impl Default for AssistantProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn open_ai_url() -> String {
|
fn open_ai_url() -> String {
|
||||||
"https://api.openai.com/v1".into()
|
open_ai::OPEN_AI_API_URL.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn anthropic_api_url() -> String {
|
||||||
|
anthropic::ANTHROPIC_API_URL.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Debug, Deserialize, Serialize)]
|
#[derive(Default, Debug, Deserialize, Serialize)]
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
mod anthropic;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod fake;
|
mod fake;
|
||||||
mod open_ai;
|
mod open_ai;
|
||||||
mod zed;
|
mod zed;
|
||||||
|
|
||||||
|
pub use anthropic::*;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub use fake::*;
|
pub use fake::*;
|
||||||
pub use open_ai::*;
|
pub use open_ai::*;
|
||||||
|
@ -42,6 +44,17 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
settings_version,
|
settings_version,
|
||||||
)),
|
)),
|
||||||
|
AssistantProvider::Anthropic {
|
||||||
|
default_model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
||||||
|
default_model.clone(),
|
||||||
|
api_url.clone(),
|
||||||
|
client.http_client(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
settings_version,
|
||||||
|
)),
|
||||||
};
|
};
|
||||||
cx.set_global(provider);
|
cx.set_global(provider);
|
||||||
|
|
||||||
|
@ -64,13 +77,28 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
settings_version,
|
settings_version,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
(
|
||||||
|
CompletionProvider::Anthropic(provider),
|
||||||
|
AssistantProvider::Anthropic {
|
||||||
|
default_model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
provider.update(
|
||||||
|
default_model.clone(),
|
||||||
|
api_url.clone(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
settings_version,
|
||||||
|
);
|
||||||
|
}
|
||||||
(
|
(
|
||||||
CompletionProvider::ZedDotDev(provider),
|
CompletionProvider::ZedDotDev(provider),
|
||||||
AssistantProvider::ZedDotDev { default_model },
|
AssistantProvider::ZedDotDev { default_model },
|
||||||
) => {
|
) => {
|
||||||
provider.update(default_model.clone(), settings_version);
|
provider.update(default_model.clone(), settings_version);
|
||||||
}
|
}
|
||||||
(CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
|
(_, AssistantProvider::ZedDotDev { default_model }) => {
|
||||||
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
||||||
default_model.clone(),
|
default_model.clone(),
|
||||||
client.clone(),
|
client.clone(),
|
||||||
|
@ -79,7 +107,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
(
|
(
|
||||||
CompletionProvider::ZedDotDev(_),
|
_,
|
||||||
AssistantProvider::OpenAi {
|
AssistantProvider::OpenAi {
|
||||||
default_model,
|
default_model,
|
||||||
api_url,
|
api_url,
|
||||||
|
@ -94,8 +122,22 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
settings_version,
|
settings_version,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
#[cfg(test)]
|
(
|
||||||
(CompletionProvider::Fake(_), _) => unimplemented!(),
|
_,
|
||||||
|
AssistantProvider::Anthropic {
|
||||||
|
default_model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
||||||
|
default_model.clone(),
|
||||||
|
api_url.clone(),
|
||||||
|
client.http_client(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
settings_version,
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -104,6 +146,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
|
|
||||||
pub enum CompletionProvider {
|
pub enum CompletionProvider {
|
||||||
OpenAi(OpenAiCompletionProvider),
|
OpenAi(OpenAiCompletionProvider),
|
||||||
|
Anthropic(AnthropicCompletionProvider),
|
||||||
ZedDotDev(ZedDotDevCompletionProvider),
|
ZedDotDev(ZedDotDevCompletionProvider),
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
Fake(FakeCompletionProvider),
|
Fake(FakeCompletionProvider),
|
||||||
|
@ -119,6 +162,7 @@ impl CompletionProvider {
|
||||||
pub fn settings_version(&self) -> usize {
|
pub fn settings_version(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => provider.settings_version(),
|
CompletionProvider::OpenAi(provider) => provider.settings_version(),
|
||||||
|
CompletionProvider::Anthropic(provider) => provider.settings_version(),
|
||||||
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
|
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
CompletionProvider::Fake(_) => unimplemented!(),
|
CompletionProvider::Fake(_) => unimplemented!(),
|
||||||
|
@ -128,6 +172,7 @@ impl CompletionProvider {
|
||||||
pub fn is_authenticated(&self) -> bool {
|
pub fn is_authenticated(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
|
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
|
||||||
|
CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
|
||||||
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
|
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
CompletionProvider::Fake(_) => true,
|
CompletionProvider::Fake(_) => true,
|
||||||
|
@ -137,6 +182,7 @@ impl CompletionProvider {
|
||||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
|
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
|
||||||
|
CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
|
||||||
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
|
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||||
|
@ -146,6 +192,7 @@ impl CompletionProvider {
|
||||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
|
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
|
||||||
|
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
|
||||||
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
|
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
CompletionProvider::Fake(_) => unimplemented!(),
|
CompletionProvider::Fake(_) => unimplemented!(),
|
||||||
|
@ -155,6 +202,7 @@ impl CompletionProvider {
|
||||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
|
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
|
||||||
|
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
|
||||||
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
|
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||||
|
@ -164,6 +212,9 @@ impl CompletionProvider {
|
||||||
pub fn default_model(&self) -> LanguageModel {
|
pub fn default_model(&self) -> LanguageModel {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
|
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
|
||||||
|
CompletionProvider::Anthropic(provider) => {
|
||||||
|
LanguageModel::Anthropic(provider.default_model())
|
||||||
|
}
|
||||||
CompletionProvider::ZedDotDev(provider) => {
|
CompletionProvider::ZedDotDev(provider) => {
|
||||||
LanguageModel::ZedDotDev(provider.default_model())
|
LanguageModel::ZedDotDev(provider.default_model())
|
||||||
}
|
}
|
||||||
|
@ -179,6 +230,7 @@ impl CompletionProvider {
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
|
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
|
||||||
|
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
|
||||||
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
|
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
CompletionProvider::Fake(_) => unimplemented!(),
|
CompletionProvider::Fake(_) => unimplemented!(),
|
||||||
|
@ -191,6 +243,7 @@ impl CompletionProvider {
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
match self {
|
match self {
|
||||||
CompletionProvider::OpenAi(provider) => provider.complete(request),
|
CompletionProvider::OpenAi(provider) => provider.complete(request),
|
||||||
|
CompletionProvider::Anthropic(provider) => provider.complete(request),
|
||||||
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
|
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
CompletionProvider::Fake(provider) => provider.complete(),
|
CompletionProvider::Fake(provider) => provider.complete(),
|
||||||
|
|
317
crates/assistant/src/completion_provider/anthropic.rs
Normal file
317
crates/assistant/src/completion_provider/anthropic.rs
Normal file
|
@ -0,0 +1,317 @@
|
||||||
|
use crate::count_open_ai_tokens;
|
||||||
|
use crate::{
|
||||||
|
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
|
||||||
|
Role,
|
||||||
|
};
|
||||||
|
use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
|
||||||
|
use http::HttpClient;
|
||||||
|
use settings::Settings;
|
||||||
|
use std::time::Duration;
|
||||||
|
use std::{env, sync::Arc};
|
||||||
|
use theme::ThemeSettings;
|
||||||
|
use ui::prelude::*;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
pub struct AnthropicCompletionProvider {
|
||||||
|
api_key: Option<String>,
|
||||||
|
api_url: String,
|
||||||
|
default_model: AnthropicModel,
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
low_speed_timeout: Option<Duration>,
|
||||||
|
settings_version: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnthropicCompletionProvider {
|
||||||
|
pub fn new(
|
||||||
|
default_model: AnthropicModel,
|
||||||
|
api_url: String,
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
low_speed_timeout: Option<Duration>,
|
||||||
|
settings_version: usize,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
api_key: None,
|
||||||
|
api_url,
|
||||||
|
default_model,
|
||||||
|
http_client,
|
||||||
|
low_speed_timeout,
|
||||||
|
settings_version,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(
|
||||||
|
&mut self,
|
||||||
|
default_model: AnthropicModel,
|
||||||
|
api_url: String,
|
||||||
|
low_speed_timeout: Option<Duration>,
|
||||||
|
settings_version: usize,
|
||||||
|
) {
|
||||||
|
self.default_model = default_model;
|
||||||
|
self.api_url = api_url;
|
||||||
|
self.low_speed_timeout = low_speed_timeout;
|
||||||
|
self.settings_version = settings_version;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn settings_version(&self) -> usize {
|
||||||
|
self.settings_version
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_authenticated(&self) -> bool {
|
||||||
|
self.api_key.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
if self.is_authenticated() {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
let api_url = self.api_url.clone();
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
|
||||||
|
api_key
|
||||||
|
} else {
|
||||||
|
let (_, api_key) = cx
|
||||||
|
.update(|cx| cx.read_credentials(&api_url))?
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||||
|
String::from_utf8(api_key)?
|
||||||
|
};
|
||||||
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
|
if let CompletionProvider::Anthropic(provider) = provider {
|
||||||
|
provider.api_key = Some(api_key);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
delete_credentials.await.log_err();
|
||||||
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
|
if let CompletionProvider::Anthropic(provider) = provider {
|
||||||
|
provider.api_key = None;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
|
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_model(&self) -> AnthropicModel {
|
||||||
|
self.default_model.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn count_tokens(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
count_open_ai_tokens(request, cx.background_executor())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn complete(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
let request = self.to_anthropic_request(request);
|
||||||
|
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
let api_key = self.api_key.clone();
|
||||||
|
let api_url = self.api_url.clone();
|
||||||
|
let low_speed_timeout = self.low_speed_timeout;
|
||||||
|
async move {
|
||||||
|
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||||
|
let request = stream_completion(
|
||||||
|
http_client.as_ref(),
|
||||||
|
&api_url,
|
||||||
|
&api_key,
|
||||||
|
request,
|
||||||
|
low_speed_timeout,
|
||||||
|
);
|
||||||
|
let response = request.await?;
|
||||||
|
let stream = response
|
||||||
|
.filter_map(|response| async move {
|
||||||
|
match response {
|
||||||
|
Ok(response) => match response {
|
||||||
|
anthropic::ResponseEvent::ContentBlockStart {
|
||||||
|
content_block, ..
|
||||||
|
} => match content_block {
|
||||||
|
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
||||||
|
},
|
||||||
|
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
||||||
|
match delta {
|
||||||
|
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
},
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
|
||||||
|
let model = match request.model {
|
||||||
|
LanguageModel::Anthropic(model) => model,
|
||||||
|
_ => self.default_model(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut system_message = String::new();
|
||||||
|
let messages = request
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|message| {
|
||||||
|
match message.role {
|
||||||
|
Role::User => Some(RequestMessage {
|
||||||
|
role: AnthropicRole::User,
|
||||||
|
content: message.content,
|
||||||
|
}),
|
||||||
|
Role::Assistant => Some(RequestMessage {
|
||||||
|
role: AnthropicRole::Assistant,
|
||||||
|
content: message.content,
|
||||||
|
}),
|
||||||
|
// Anthropic's API breaks system instructions out as a separate field rather
|
||||||
|
// than having a system message role.
|
||||||
|
Role::System => {
|
||||||
|
if !system_message.is_empty() {
|
||||||
|
system_message.push_str("\n\n");
|
||||||
|
}
|
||||||
|
system_message.push_str(&message.content);
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Request {
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
stream: true,
|
||||||
|
system: system_message,
|
||||||
|
max_tokens: 4092,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AuthenticationPrompt {
|
||||||
|
api_key: View<Editor>,
|
||||||
|
api_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthenticationPrompt {
|
||||||
|
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||||
|
Self {
|
||||||
|
api_key: cx.new_view(|cx| {
|
||||||
|
let mut editor = Editor::single_line(cx);
|
||||||
|
editor.set_placeholder_text(
|
||||||
|
"sk-000000000000000000000000000000000000000000000000",
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
editor
|
||||||
|
}),
|
||||||
|
api_url,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||||
|
let api_key = self.api_key.read(cx).text(cx);
|
||||||
|
if api_key.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||||
|
cx.spawn(|_, mut cx| async move {
|
||||||
|
write_credentials.await?;
|
||||||
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
|
if let CompletionProvider::Anthropic(provider) = provider {
|
||||||
|
provider.api_key = Some(api_key);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
|
let settings = ThemeSettings::get_global(cx);
|
||||||
|
let text_style = TextStyle {
|
||||||
|
color: cx.theme().colors().text,
|
||||||
|
font_family: settings.ui_font.family.clone(),
|
||||||
|
font_features: settings.ui_font.features.clone(),
|
||||||
|
font_size: rems(0.875).into(),
|
||||||
|
font_weight: FontWeight::NORMAL,
|
||||||
|
font_style: FontStyle::Normal,
|
||||||
|
line_height: relative(1.3),
|
||||||
|
background_color: None,
|
||||||
|
underline: None,
|
||||||
|
strikethrough: None,
|
||||||
|
white_space: WhiteSpace::Normal,
|
||||||
|
};
|
||||||
|
EditorElement::new(
|
||||||
|
&self.api_key,
|
||||||
|
EditorStyle {
|
||||||
|
background: cx.theme().colors().editor_background,
|
||||||
|
local_player: cx.theme().players().local(),
|
||||||
|
text: text_style,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Render for AuthenticationPrompt {
|
||||||
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
|
const INSTRUCTIONS: [&str; 4] = [
|
||||||
|
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
|
||||||
|
"You can create an API key at: https://console.anthropic.com/settings/keys",
|
||||||
|
"",
|
||||||
|
"Paste your Anthropic API key below and hit enter to use the assistant:",
|
||||||
|
];
|
||||||
|
|
||||||
|
v_flex()
|
||||||
|
.p_4()
|
||||||
|
.size_full()
|
||||||
|
.on_action(cx.listener(Self::save_api_key))
|
||||||
|
.children(
|
||||||
|
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.w_full()
|
||||||
|
.my_2()
|
||||||
|
.px_2()
|
||||||
|
.py_1()
|
||||||
|
.bg(cx.theme().colors().editor_background)
|
||||||
|
.rounded_md()
|
||||||
|
.child(self.render_api_key_editor(cx)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
Label::new(
|
||||||
|
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
|
||||||
|
)
|
||||||
|
.size(LabelSize::Small),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.gap_2()
|
||||||
|
.child(Label::new("Click on").size(LabelSize::Small))
|
||||||
|
.child(Icon::new(IconName::Ai).size(IconSize::XSmall))
|
||||||
|
.child(
|
||||||
|
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_any()
|
||||||
|
}
|
||||||
|
}
|
|
@ -151,8 +151,8 @@ impl OpenAiCompletionProvider {
|
||||||
|
|
||||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||||
let model = match request.model {
|
let model = match request.model {
|
||||||
LanguageModel::ZedDotDev(_) => self.default_model(),
|
|
||||||
LanguageModel::OpenAi(model) => model,
|
LanguageModel::OpenAi(model) => model,
|
||||||
|
_ => self.default_model(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Request {
|
Request {
|
||||||
|
@ -205,8 +205,12 @@ pub fn count_open_ai_tokens(
|
||||||
|
|
||||||
match request.model {
|
match request.model {
|
||||||
LanguageModel::OpenAi(OpenAiModel::FourOmni)
|
LanguageModel::OpenAi(OpenAiModel::FourOmni)
|
||||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni) => {
|
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
|
||||||
// Tiktoken doesn't yet support gpt-4o, so we manually use the
|
| LanguageModel::Anthropic(_)
|
||||||
|
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus)
|
||||||
|
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet)
|
||||||
|
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => {
|
||||||
|
// Tiktoken doesn't yet support these models, so we manually use the
|
||||||
// same tokenizer as GPT-4.
|
// same tokenizer as GPT-4.
|
||||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,7 +78,6 @@ impl ZedDotDevCompletionProvider {
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
match request.model {
|
match request.model {
|
||||||
LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
|
||||||
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
|
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
|
||||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
|
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
|
||||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
|
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
|
||||||
|
@ -108,6 +107,7 @@ impl ZedDotDevCompletionProvider {
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
_ => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4489,8 +4489,8 @@ async fn complete_with_anthropic(
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut stream = anthropic::stream_completion(
|
let mut stream = anthropic::stream_completion(
|
||||||
session.http_client.clone(),
|
session.http_client.as_ref(),
|
||||||
"https://api.anthropic.com",
|
anthropic::ANTHROPIC_API_URL,
|
||||||
&api_key,
|
&api_key,
|
||||||
anthropic::Request {
|
anthropic::Request {
|
||||||
model,
|
model,
|
||||||
|
@ -4499,6 +4499,7 @@ async fn complete_with_anthropic(
|
||||||
system: system_message,
|
system: system_message,
|
||||||
max_tokens: 4092,
|
max_tokens: 4092,
|
||||||
},
|
},
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue