
This expands our deserialization of JSON from models to be more tolerant of different variations that the model may send, including capitalization, wrapping things in objects vs. being plain strings, etc. Also when deserialization fails, it reports the entire error in the JSON so we can see what failed to deserialize. (Previously these errors were very unhelpful at diagnosing the problem.) Finally, also removes the `WrappedText` variant since the custom deserializer just turns that style of JSON into a normal `Text` variant. Release Notes: - N/A
1070 lines
39 KiB
Rust
1070 lines
39 KiB
Rust
use crate::AllLanguageModelSettings;
|
|
use crate::ui::InstructionListItem;
|
|
use anthropic::{
|
|
AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent,
|
|
ToolResultPart, Usage,
|
|
};
|
|
use anyhow::{Context as _, Result, anyhow};
|
|
use collections::{BTreeMap, HashMap};
|
|
use credentials_provider::CredentialsProvider;
|
|
use editor::{Editor, EditorElement, EditorStyle};
|
|
use futures::Stream;
|
|
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
|
use gpui::{
|
|
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
|
};
|
|
use http_client::HttpClient;
|
|
use language_model::{
|
|
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
|
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
|
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
|
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
|
|
LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
|
|
};
|
|
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
|
use schemars::JsonSchema;
|
|
use serde::{Deserialize, Serialize};
|
|
use settings::{Settings, SettingsStore};
|
|
use std::pin::Pin;
|
|
use std::str::FromStr;
|
|
use std::sync::Arc;
|
|
use strum::IntoEnumIterator;
|
|
use theme::ThemeSettings;
|
|
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
|
use util::ResultExt;
|
|
|
|
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
|
|
const PROVIDER_NAME: &str = "Anthropic";
|
|
|
|
#[derive(Default, Clone, Debug, PartialEq)]
|
|
pub struct AnthropicSettings {
|
|
pub api_url: String,
|
|
/// Extend Zed's list of Anthropic models.
|
|
pub available_models: Vec<AvailableModel>,
|
|
pub needs_setting_migration: bool,
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
|
pub struct AvailableModel {
|
|
/// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-latest, claude-3-opus-20240229, etc
|
|
pub name: String,
|
|
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
|
|
pub display_name: Option<String>,
|
|
/// The model's context window size.
|
|
pub max_tokens: usize,
|
|
/// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
|
|
pub tool_override: Option<String>,
|
|
/// Configuration of Anthropic's caching API.
|
|
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
|
|
pub max_output_tokens: Option<u32>,
|
|
pub default_temperature: Option<f32>,
|
|
#[serde(default)]
|
|
pub extra_beta_headers: Vec<String>,
|
|
/// The model's mode (e.g. thinking)
|
|
pub mode: Option<ModelMode>,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
|
|
#[serde(tag = "type", rename_all = "lowercase")]
|
|
pub enum ModelMode {
|
|
#[default]
|
|
Default,
|
|
Thinking {
|
|
/// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
|
|
budget_tokens: Option<u32>,
|
|
},
|
|
}
|
|
|
|
impl From<ModelMode> for AnthropicModelMode {
|
|
fn from(value: ModelMode) -> Self {
|
|
match value {
|
|
ModelMode::Default => AnthropicModelMode::Default,
|
|
ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<AnthropicModelMode> for ModelMode {
|
|
fn from(value: AnthropicModelMode) -> Self {
|
|
match value {
|
|
AnthropicModelMode::Default => ModelMode::Default,
|
|
AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct AnthropicLanguageModelProvider {
|
|
http_client: Arc<dyn HttpClient>,
|
|
state: gpui::Entity<State>,
|
|
}
|
|
|
|
const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY";
|
|
|
|
pub struct State {
|
|
api_key: Option<String>,
|
|
api_key_from_env: bool,
|
|
_subscription: Subscription,
|
|
}
|
|
|
|
impl State {
|
|
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
|
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
|
let api_url = AllLanguageModelSettings::get_global(cx)
|
|
.anthropic
|
|
.api_url
|
|
.clone();
|
|
cx.spawn(async move |this, cx| {
|
|
credentials_provider
|
|
.delete_credentials(&api_url, &cx)
|
|
.await
|
|
.ok();
|
|
this.update(cx, |this, cx| {
|
|
this.api_key = None;
|
|
this.api_key_from_env = false;
|
|
cx.notify();
|
|
})
|
|
})
|
|
}
|
|
|
|
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
|
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
|
let api_url = AllLanguageModelSettings::get_global(cx)
|
|
.anthropic
|
|
.api_url
|
|
.clone();
|
|
cx.spawn(async move |this, cx| {
|
|
credentials_provider
|
|
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
|
|
.await
|
|
.ok();
|
|
|
|
this.update(cx, |this, cx| {
|
|
this.api_key = Some(api_key);
|
|
cx.notify();
|
|
})
|
|
})
|
|
}
|
|
|
|
fn is_authenticated(&self) -> bool {
|
|
self.api_key.is_some()
|
|
}
|
|
|
|
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
|
if self.is_authenticated() {
|
|
return Task::ready(Ok(()));
|
|
}
|
|
|
|
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
|
let api_url = AllLanguageModelSettings::get_global(cx)
|
|
.anthropic
|
|
.api_url
|
|
.clone();
|
|
|
|
cx.spawn(async move |this, cx| {
|
|
let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
|
|
(api_key, true)
|
|
} else {
|
|
let (_, api_key) = credentials_provider
|
|
.read_credentials(&api_url, &cx)
|
|
.await?
|
|
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
|
(
|
|
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
|
false,
|
|
)
|
|
};
|
|
|
|
this.update(cx, |this, cx| {
|
|
this.api_key = Some(api_key);
|
|
this.api_key_from_env = from_env;
|
|
cx.notify();
|
|
})?;
|
|
|
|
Ok(())
|
|
})
|
|
}
|
|
}
|
|
|
|
impl AnthropicLanguageModelProvider {
|
|
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
|
let state = cx.new(|cx| State {
|
|
api_key: None,
|
|
api_key_from_env: false,
|
|
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
|
cx.notify();
|
|
}),
|
|
});
|
|
|
|
Self { http_client, state }
|
|
}
|
|
|
|
fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
|
|
Arc::new(AnthropicModel {
|
|
id: LanguageModelId::from(model.id().to_string()),
|
|
model,
|
|
state: self.state.clone(),
|
|
http_client: self.http_client.clone(),
|
|
request_limiter: RateLimiter::new(4),
|
|
})
|
|
}
|
|
}
|
|
|
|
impl LanguageModelProviderState for AnthropicLanguageModelProvider {
|
|
type ObservableEntity = State;
|
|
|
|
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
|
|
Some(self.state.clone())
|
|
}
|
|
}
|
|
|
|
impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|
fn id(&self) -> LanguageModelProviderId {
|
|
LanguageModelProviderId(PROVIDER_ID.into())
|
|
}
|
|
|
|
fn name(&self) -> LanguageModelProviderName {
|
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
|
}
|
|
|
|
fn icon(&self) -> IconName {
|
|
IconName::AiAnthropic
|
|
}
|
|
|
|
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
|
Some(self.create_language_model(anthropic::Model::default()))
|
|
}
|
|
|
|
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
|
Some(self.create_language_model(anthropic::Model::default_fast()))
|
|
}
|
|
|
|
fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
|
[
|
|
anthropic::Model::ClaudeSonnet4,
|
|
anthropic::Model::ClaudeSonnet4Thinking,
|
|
]
|
|
.into_iter()
|
|
.map(|model| self.create_language_model(model))
|
|
.collect()
|
|
}
|
|
|
|
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
|
let mut models = BTreeMap::default();
|
|
|
|
// Add base models from anthropic::Model::iter()
|
|
for model in anthropic::Model::iter() {
|
|
if !matches!(model, anthropic::Model::Custom { .. }) {
|
|
models.insert(model.id().to_string(), model);
|
|
}
|
|
}
|
|
|
|
// Override with available models from settings
|
|
for model in AllLanguageModelSettings::get_global(cx)
|
|
.anthropic
|
|
.available_models
|
|
.iter()
|
|
{
|
|
models.insert(
|
|
model.name.clone(),
|
|
anthropic::Model::Custom {
|
|
name: model.name.clone(),
|
|
display_name: model.display_name.clone(),
|
|
max_tokens: model.max_tokens,
|
|
tool_override: model.tool_override.clone(),
|
|
cache_configuration: model.cache_configuration.as_ref().map(|config| {
|
|
anthropic::AnthropicModelCacheConfiguration {
|
|
max_cache_anchors: config.max_cache_anchors,
|
|
should_speculate: config.should_speculate,
|
|
min_total_token: config.min_total_token,
|
|
}
|
|
}),
|
|
max_output_tokens: model.max_output_tokens,
|
|
default_temperature: model.default_temperature,
|
|
extra_beta_headers: model.extra_beta_headers.clone(),
|
|
mode: model.mode.clone().unwrap_or_default().into(),
|
|
},
|
|
);
|
|
}
|
|
|
|
models
|
|
.into_values()
|
|
.map(|model| self.create_language_model(model))
|
|
.collect()
|
|
}
|
|
|
|
fn is_authenticated(&self, cx: &App) -> bool {
|
|
self.state.read(cx).is_authenticated()
|
|
}
|
|
|
|
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
|
self.state.update(cx, |state, cx| state.authenticate(cx))
|
|
}
|
|
|
|
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
|
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
|
.into()
|
|
}
|
|
|
|
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
|
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
|
}
|
|
}
|
|
|
|
pub struct AnthropicModel {
|
|
id: LanguageModelId,
|
|
model: anthropic::Model,
|
|
state: gpui::Entity<State>,
|
|
http_client: Arc<dyn HttpClient>,
|
|
request_limiter: RateLimiter,
|
|
}
|
|
|
|
pub fn count_anthropic_tokens(
|
|
request: LanguageModelRequest,
|
|
cx: &App,
|
|
) -> BoxFuture<'static, Result<usize>> {
|
|
cx.background_spawn(async move {
|
|
let messages = request.messages;
|
|
let mut tokens_from_images = 0;
|
|
let mut string_messages = Vec::with_capacity(messages.len());
|
|
|
|
for message in messages {
|
|
use language_model::MessageContent;
|
|
|
|
let mut string_contents = String::new();
|
|
|
|
for content in message.content {
|
|
match content {
|
|
MessageContent::Text(text) => {
|
|
string_contents.push_str(&text);
|
|
}
|
|
MessageContent::Thinking { .. } => {
|
|
// Thinking blocks are not included in the input token count.
|
|
}
|
|
MessageContent::RedactedThinking(_) => {
|
|
// Thinking blocks are not included in the input token count.
|
|
}
|
|
MessageContent::Image(image) => {
|
|
tokens_from_images += image.estimate_tokens();
|
|
}
|
|
MessageContent::ToolUse(_tool_use) => {
|
|
// TODO: Estimate token usage from tool uses.
|
|
}
|
|
MessageContent::ToolResult(tool_result) => match &tool_result.content {
|
|
LanguageModelToolResultContent::Text(text) => {
|
|
string_contents.push_str(text);
|
|
}
|
|
LanguageModelToolResultContent::Image(image) => {
|
|
tokens_from_images += image.estimate_tokens();
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
if !string_contents.is_empty() {
|
|
string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
|
|
role: match message.role {
|
|
Role::User => "user".into(),
|
|
Role::Assistant => "assistant".into(),
|
|
Role::System => "system".into(),
|
|
},
|
|
content: Some(string_contents),
|
|
name: None,
|
|
function_call: None,
|
|
});
|
|
}
|
|
}
|
|
|
|
// Tiktoken doesn't yet support these models, so we manually use the
|
|
// same tokenizer as GPT-4.
|
|
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
|
|
.map(|tokens| tokens + tokens_from_images)
|
|
})
|
|
.boxed()
|
|
}
|
|
|
|
impl AnthropicModel {
|
|
fn stream_completion(
|
|
&self,
|
|
request: anthropic::Request,
|
|
cx: &AsyncApp,
|
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event, AnthropicError>>>>
|
|
{
|
|
let http_client = self.http_client.clone();
|
|
|
|
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
|
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
|
(state.api_key.clone(), settings.api_url.clone())
|
|
}) else {
|
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
|
};
|
|
|
|
async move {
|
|
let api_key = api_key.context("Missing Anthropic API Key")?;
|
|
let request =
|
|
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
|
request.await.context("failed to stream completion")
|
|
}
|
|
.boxed()
|
|
}
|
|
}
|
|
|
|
impl LanguageModel for AnthropicModel {
|
|
fn id(&self) -> LanguageModelId {
|
|
self.id.clone()
|
|
}
|
|
|
|
fn name(&self) -> LanguageModelName {
|
|
LanguageModelName::from(self.model.display_name().to_string())
|
|
}
|
|
|
|
fn provider_id(&self) -> LanguageModelProviderId {
|
|
LanguageModelProviderId(PROVIDER_ID.into())
|
|
}
|
|
|
|
fn provider_name(&self) -> LanguageModelProviderName {
|
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
|
}
|
|
|
|
fn supports_tools(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
fn supports_images(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
|
match choice {
|
|
LanguageModelToolChoice::Auto
|
|
| LanguageModelToolChoice::Any
|
|
| LanguageModelToolChoice::None => true,
|
|
}
|
|
}
|
|
|
|
fn telemetry_id(&self) -> String {
|
|
format!("anthropic/{}", self.model.id())
|
|
}
|
|
|
|
fn api_key(&self, cx: &App) -> Option<String> {
|
|
self.state.read(cx).api_key.clone()
|
|
}
|
|
|
|
fn max_token_count(&self) -> usize {
|
|
self.model.max_token_count()
|
|
}
|
|
|
|
fn max_output_tokens(&self) -> Option<u32> {
|
|
Some(self.model.max_output_tokens())
|
|
}
|
|
|
|
fn count_tokens(
|
|
&self,
|
|
request: LanguageModelRequest,
|
|
cx: &App,
|
|
) -> BoxFuture<'static, Result<usize>> {
|
|
count_anthropic_tokens(request, cx)
|
|
}
|
|
|
|
fn stream_completion(
|
|
&self,
|
|
request: LanguageModelRequest,
|
|
cx: &AsyncApp,
|
|
) -> BoxFuture<
|
|
'static,
|
|
Result<
|
|
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
|
>,
|
|
> {
|
|
let request = into_anthropic(
|
|
request,
|
|
self.model.request_id().into(),
|
|
self.model.default_temperature(),
|
|
self.model.max_output_tokens(),
|
|
self.model.mode(),
|
|
);
|
|
let request = self.stream_completion(request, cx);
|
|
let future = self.request_limiter.stream(async move {
|
|
let response = request
|
|
.await
|
|
.map_err(|err| match err.downcast::<AnthropicError>() {
|
|
Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
|
|
Err(err) => anyhow!(err),
|
|
})?;
|
|
Ok(AnthropicEventMapper::new().map_stream(response))
|
|
});
|
|
async move { Ok(future.await?.boxed()) }.boxed()
|
|
}
|
|
|
|
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
|
self.model
|
|
.cache_configuration()
|
|
.map(|config| LanguageModelCacheConfiguration {
|
|
max_cache_anchors: config.max_cache_anchors,
|
|
should_speculate: config.should_speculate,
|
|
min_total_token: config.min_total_token,
|
|
})
|
|
}
|
|
}
|
|
|
|
pub fn into_anthropic(
|
|
request: LanguageModelRequest,
|
|
model: String,
|
|
default_temperature: f32,
|
|
max_output_tokens: u32,
|
|
mode: AnthropicModelMode,
|
|
) -> anthropic::Request {
|
|
let mut new_messages: Vec<anthropic::Message> = Vec::new();
|
|
let mut system_message = String::new();
|
|
|
|
for message in request.messages {
|
|
if message.contents_empty() {
|
|
continue;
|
|
}
|
|
|
|
match message.role {
|
|
Role::User | Role::Assistant => {
|
|
let cache_control = if message.cache {
|
|
Some(anthropic::CacheControl {
|
|
cache_type: anthropic::CacheControlType::Ephemeral,
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
let anthropic_message_content: Vec<anthropic::RequestContent> = message
|
|
.content
|
|
.into_iter()
|
|
.filter_map(|content| match content {
|
|
MessageContent::Text(text) => {
|
|
if !text.is_empty() {
|
|
Some(anthropic::RequestContent::Text {
|
|
text,
|
|
cache_control,
|
|
})
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
MessageContent::Thinking {
|
|
text: thinking,
|
|
signature,
|
|
} => {
|
|
if !thinking.is_empty() {
|
|
Some(anthropic::RequestContent::Thinking {
|
|
thinking,
|
|
signature: signature.unwrap_or_default(),
|
|
cache_control,
|
|
})
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
MessageContent::RedactedThinking(data) => {
|
|
if !data.is_empty() {
|
|
Some(anthropic::RequestContent::RedactedThinking {
|
|
data: String::from_utf8(data).ok()?,
|
|
})
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
|
|
source: anthropic::ImageSource {
|
|
source_type: "base64".to_string(),
|
|
media_type: "image/png".to_string(),
|
|
data: image.source.to_string(),
|
|
},
|
|
cache_control,
|
|
}),
|
|
MessageContent::ToolUse(tool_use) => {
|
|
Some(anthropic::RequestContent::ToolUse {
|
|
id: tool_use.id.to_string(),
|
|
name: tool_use.name.to_string(),
|
|
input: tool_use.input,
|
|
cache_control,
|
|
})
|
|
}
|
|
MessageContent::ToolResult(tool_result) => {
|
|
Some(anthropic::RequestContent::ToolResult {
|
|
tool_use_id: tool_result.tool_use_id.to_string(),
|
|
is_error: tool_result.is_error,
|
|
content: match tool_result.content {
|
|
LanguageModelToolResultContent::Text(text) => {
|
|
ToolResultContent::Plain(text.to_string())
|
|
}
|
|
LanguageModelToolResultContent::Image(image) => {
|
|
ToolResultContent::Multipart(vec![ToolResultPart::Image {
|
|
source: anthropic::ImageSource {
|
|
source_type: "base64".to_string(),
|
|
media_type: "image/png".to_string(),
|
|
data: image.source.to_string(),
|
|
},
|
|
}])
|
|
}
|
|
},
|
|
cache_control,
|
|
})
|
|
}
|
|
})
|
|
.collect();
|
|
let anthropic_role = match message.role {
|
|
Role::User => anthropic::Role::User,
|
|
Role::Assistant => anthropic::Role::Assistant,
|
|
Role::System => unreachable!("System role should never occur here"),
|
|
};
|
|
if let Some(last_message) = new_messages.last_mut() {
|
|
if last_message.role == anthropic_role {
|
|
last_message.content.extend(anthropic_message_content);
|
|
continue;
|
|
}
|
|
}
|
|
new_messages.push(anthropic::Message {
|
|
role: anthropic_role,
|
|
content: anthropic_message_content,
|
|
});
|
|
}
|
|
Role::System => {
|
|
if !system_message.is_empty() {
|
|
system_message.push_str("\n\n");
|
|
}
|
|
system_message.push_str(&message.string_contents());
|
|
}
|
|
}
|
|
}
|
|
|
|
anthropic::Request {
|
|
model,
|
|
messages: new_messages,
|
|
max_tokens: max_output_tokens,
|
|
system: if system_message.is_empty() {
|
|
None
|
|
} else {
|
|
Some(anthropic::StringOrContents::String(system_message))
|
|
},
|
|
thinking: if let AnthropicModelMode::Thinking { budget_tokens } = mode {
|
|
Some(anthropic::Thinking::Enabled { budget_tokens })
|
|
} else {
|
|
None
|
|
},
|
|
tools: request
|
|
.tools
|
|
.into_iter()
|
|
.map(|tool| anthropic::Tool {
|
|
name: tool.name,
|
|
description: tool.description,
|
|
input_schema: tool.input_schema,
|
|
})
|
|
.collect(),
|
|
tool_choice: request.tool_choice.map(|choice| match choice {
|
|
LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
|
|
LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
|
|
LanguageModelToolChoice::None => anthropic::ToolChoice::None,
|
|
}),
|
|
metadata: None,
|
|
stop_sequences: Vec::new(),
|
|
temperature: request.temperature.or(Some(default_temperature)),
|
|
top_k: None,
|
|
top_p: None,
|
|
}
|
|
}
|
|
|
|
pub struct AnthropicEventMapper {
|
|
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
|
usage: Usage,
|
|
stop_reason: StopReason,
|
|
}
|
|
|
|
impl AnthropicEventMapper {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
tool_uses_by_index: HashMap::default(),
|
|
usage: Usage::default(),
|
|
stop_reason: StopReason::EndTurn,
|
|
}
|
|
}
|
|
|
|
pub fn map_stream(
|
|
mut self,
|
|
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
|
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
|
{
|
|
events.flat_map(move |event| {
|
|
futures::stream::iter(match event {
|
|
Ok(event) => self.map_event(event),
|
|
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
|
})
|
|
})
|
|
}
|
|
|
|
pub fn map_event(
|
|
&mut self,
|
|
event: Event,
|
|
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
|
match event {
|
|
Event::ContentBlockStart {
|
|
index,
|
|
content_block,
|
|
} => match content_block {
|
|
ResponseContent::Text { text } => {
|
|
vec![Ok(LanguageModelCompletionEvent::Text(text))]
|
|
}
|
|
ResponseContent::Thinking { thinking } => {
|
|
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
|
text: thinking,
|
|
signature: None,
|
|
})]
|
|
}
|
|
ResponseContent::RedactedThinking { .. } => {
|
|
// Redacted thinking is encrypted and not accessible to the user, see:
|
|
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production
|
|
Vec::new()
|
|
}
|
|
ResponseContent::ToolUse { id, name, .. } => {
|
|
self.tool_uses_by_index.insert(
|
|
index,
|
|
RawToolUse {
|
|
id,
|
|
name,
|
|
input_json: String::new(),
|
|
},
|
|
);
|
|
Vec::new()
|
|
}
|
|
},
|
|
Event::ContentBlockDelta { index, delta } => match delta {
|
|
ContentDelta::TextDelta { text } => {
|
|
vec![Ok(LanguageModelCompletionEvent::Text(text))]
|
|
}
|
|
ContentDelta::ThinkingDelta { thinking } => {
|
|
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
|
text: thinking,
|
|
signature: None,
|
|
})]
|
|
}
|
|
ContentDelta::SignatureDelta { signature } => {
|
|
vec![Ok(LanguageModelCompletionEvent::Thinking {
|
|
text: "".to_string(),
|
|
signature: Some(signature),
|
|
})]
|
|
}
|
|
ContentDelta::InputJsonDelta { partial_json } => {
|
|
if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
|
|
tool_use.input_json.push_str(&partial_json);
|
|
|
|
// Try to convert invalid (incomplete) JSON into
|
|
// valid JSON that serde can accept, e.g. by closing
|
|
// unclosed delimiters. This way, we can update the
|
|
// UI with whatever has been streamed back so far.
|
|
if let Ok(input) = serde_json::Value::from_str(
|
|
&partial_json_fixer::fix_json(&tool_use.input_json),
|
|
) {
|
|
return vec![Ok(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: tool_use.id.clone().into(),
|
|
name: tool_use.name.clone().into(),
|
|
is_input_complete: false,
|
|
raw_input: tool_use.input_json.clone(),
|
|
input,
|
|
},
|
|
))];
|
|
}
|
|
}
|
|
return vec![];
|
|
}
|
|
},
|
|
Event::ContentBlockStop { index } => {
|
|
if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
|
|
let input_json = tool_use.input_json.trim();
|
|
let input_value = if input_json.is_empty() {
|
|
Ok(serde_json::Value::Object(serde_json::Map::default()))
|
|
} else {
|
|
serde_json::Value::from_str(input_json)
|
|
};
|
|
let event_result = match input_value {
|
|
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
|
LanguageModelToolUse {
|
|
id: tool_use.id.into(),
|
|
name: tool_use.name.into(),
|
|
is_input_complete: true,
|
|
input,
|
|
raw_input: tool_use.input_json.clone(),
|
|
},
|
|
)),
|
|
Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
|
|
id: tool_use.id.into(),
|
|
tool_name: tool_use.name.into(),
|
|
raw_input: input_json.into(),
|
|
json_parse_error: json_parse_err.to_string(),
|
|
}),
|
|
};
|
|
|
|
vec![event_result]
|
|
} else {
|
|
Vec::new()
|
|
}
|
|
}
|
|
Event::MessageStart { message } => {
|
|
update_usage(&mut self.usage, &message.usage);
|
|
vec![
|
|
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
|
|
&self.usage,
|
|
))),
|
|
Ok(LanguageModelCompletionEvent::StartMessage {
|
|
message_id: message.id,
|
|
}),
|
|
]
|
|
}
|
|
Event::MessageDelta { delta, usage } => {
|
|
update_usage(&mut self.usage, &usage);
|
|
if let Some(stop_reason) = delta.stop_reason.as_deref() {
|
|
self.stop_reason = match stop_reason {
|
|
"end_turn" => StopReason::EndTurn,
|
|
"max_tokens" => StopReason::MaxTokens,
|
|
"tool_use" => StopReason::ToolUse,
|
|
"refusal" => StopReason::Refusal,
|
|
_ => {
|
|
log::error!("Unexpected anthropic stop_reason: {stop_reason}");
|
|
StopReason::EndTurn
|
|
}
|
|
};
|
|
}
|
|
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
|
|
convert_usage(&self.usage),
|
|
))]
|
|
}
|
|
Event::MessageStop => {
|
|
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
|
}
|
|
Event::Error { error } => {
|
|
vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
|
AnthropicError::ApiError(error)
|
|
)))]
|
|
}
|
|
_ => Vec::new(),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct RawToolUse {
|
|
id: String,
|
|
name: String,
|
|
input_json: String,
|
|
}
|
|
|
|
pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
|
|
if let AnthropicError::ApiError(api_err) = &err {
|
|
if let Some(tokens) = api_err.match_window_exceeded() {
|
|
return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens });
|
|
}
|
|
}
|
|
|
|
anyhow!(err)
|
|
}
|
|
|
|
/// Updates usage data by preferring counts from `new`.
|
|
fn update_usage(usage: &mut Usage, new: &Usage) {
|
|
if let Some(input_tokens) = new.input_tokens {
|
|
usage.input_tokens = Some(input_tokens);
|
|
}
|
|
if let Some(output_tokens) = new.output_tokens {
|
|
usage.output_tokens = Some(output_tokens);
|
|
}
|
|
if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
|
|
usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
|
|
}
|
|
if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
|
|
usage.cache_read_input_tokens = Some(cache_read_input_tokens);
|
|
}
|
|
}
|
|
|
|
fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
|
|
language_model::TokenUsage {
|
|
input_tokens: usage.input_tokens.unwrap_or(0),
|
|
output_tokens: usage.output_tokens.unwrap_or(0),
|
|
cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
|
|
cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
|
|
}
|
|
}
|
|
|
|
struct ConfigurationView {
|
|
api_key_editor: Entity<Editor>,
|
|
state: gpui::Entity<State>,
|
|
load_credentials_task: Option<Task<()>>,
|
|
}
|
|
|
|
impl ConfigurationView {
|
|
const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
|
|
|
|
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
|
cx.observe(&state, |_, _, cx| {
|
|
cx.notify();
|
|
})
|
|
.detach();
|
|
|
|
let load_credentials_task = Some(cx.spawn({
|
|
let state = state.clone();
|
|
async move |this, cx| {
|
|
if let Some(task) = state
|
|
.update(cx, |state, cx| state.authenticate(cx))
|
|
.log_err()
|
|
{
|
|
// We don't log an error, because "not signed in" is also an error.
|
|
let _ = task.await;
|
|
}
|
|
this.update(cx, |this, cx| {
|
|
this.load_credentials_task = None;
|
|
cx.notify();
|
|
})
|
|
.log_err();
|
|
}
|
|
}));
|
|
|
|
Self {
|
|
api_key_editor: cx.new(|cx| {
|
|
let mut editor = Editor::single_line(window, cx);
|
|
editor.set_placeholder_text(Self::PLACEHOLDER_TEXT, cx);
|
|
editor
|
|
}),
|
|
state,
|
|
load_credentials_task,
|
|
}
|
|
}
|
|
|
|
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
|
let api_key = self.api_key_editor.read(cx).text(cx);
|
|
if api_key.is_empty() {
|
|
return;
|
|
}
|
|
|
|
let state = self.state.clone();
|
|
cx.spawn_in(window, async move |_, cx| {
|
|
state
|
|
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
|
.await
|
|
})
|
|
.detach_and_log_err(cx);
|
|
|
|
cx.notify();
|
|
}
|
|
|
|
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
|
self.api_key_editor
|
|
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
|
|
|
let state = self.state.clone();
|
|
cx.spawn_in(window, async move |_, cx| {
|
|
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
|
})
|
|
.detach_and_log_err(cx);
|
|
|
|
cx.notify();
|
|
}
|
|
|
|
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
|
let settings = ThemeSettings::get_global(cx);
|
|
let text_style = TextStyle {
|
|
color: cx.theme().colors().text,
|
|
font_family: settings.ui_font.family.clone(),
|
|
font_features: settings.ui_font.features.clone(),
|
|
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
|
font_size: rems(0.875).into(),
|
|
font_weight: settings.ui_font.weight,
|
|
font_style: FontStyle::Normal,
|
|
line_height: relative(1.3),
|
|
white_space: WhiteSpace::Normal,
|
|
..Default::default()
|
|
};
|
|
EditorElement::new(
|
|
&self.api_key_editor,
|
|
EditorStyle {
|
|
background: cx.theme().colors().editor_background,
|
|
local_player: cx.theme().players().local(),
|
|
text: text_style,
|
|
..Default::default()
|
|
},
|
|
)
|
|
}
|
|
|
|
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
|
!self.state.read(cx).is_authenticated()
|
|
}
|
|
}
|
|
|
|
impl Render for ConfigurationView {
|
|
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
|
let env_var_set = self.state.read(cx).api_key_from_env;
|
|
|
|
if self.load_credentials_task.is_some() {
|
|
div().child(Label::new("Loading credentials...")).into_any()
|
|
} else if self.should_render_editor(cx) {
|
|
v_flex()
|
|
.size_full()
|
|
.on_action(cx.listener(Self::save_api_key))
|
|
.child(Label::new("To use Zed's assistant with Anthropic, you need to add an API key. Follow these steps:"))
|
|
.child(
|
|
List::new()
|
|
.child(
|
|
InstructionListItem::new(
|
|
"Create one by visiting",
|
|
Some("Anthropic's settings"),
|
|
Some("https://console.anthropic.com/settings/keys")
|
|
)
|
|
)
|
|
.child(
|
|
InstructionListItem::text_only("Paste your API key below and hit enter to start using the assistant")
|
|
)
|
|
)
|
|
.child(
|
|
h_flex()
|
|
.w_full()
|
|
.my_2()
|
|
.px_2()
|
|
.py_1()
|
|
.bg(cx.theme().colors().editor_background)
|
|
.border_1()
|
|
.border_color(cx.theme().colors().border)
|
|
.rounded_sm()
|
|
.child(self.render_api_key_editor(cx)),
|
|
)
|
|
.child(
|
|
Label::new(
|
|
format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."),
|
|
)
|
|
.size(LabelSize::Small)
|
|
.color(Color::Muted),
|
|
)
|
|
.into_any()
|
|
} else {
|
|
h_flex()
|
|
.mt_1()
|
|
.p_1()
|
|
.justify_between()
|
|
.rounded_md()
|
|
.border_1()
|
|
.border_color(cx.theme().colors().border)
|
|
.bg(cx.theme().colors().background)
|
|
.child(
|
|
h_flex()
|
|
.gap_1()
|
|
.child(Icon::new(IconName::Check).color(Color::Success))
|
|
.child(Label::new(if env_var_set {
|
|
format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
|
|
} else {
|
|
"API key configured.".to_string()
|
|
})),
|
|
)
|
|
.child(
|
|
Button::new("reset-key", "Reset Key")
|
|
.label_size(LabelSize::Small)
|
|
.icon(Some(IconName::Trash))
|
|
.icon_size(IconSize::Small)
|
|
.icon_position(IconPosition::Start)
|
|
.disabled(env_var_set)
|
|
.when(env_var_set, |this| {
|
|
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable.")))
|
|
})
|
|
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
|
)
|
|
.into_any()
|
|
}
|
|
}
|
|
}
|