mod model; mod rate_limiter; mod registry; mod request; mod role; mod telemetry; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; use anyhow::{Context as _, Result}; use client::Client; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; use http_client::http::{HeaderMap, HeaderValue}; use icons::IconName; use parking_lot::Mutex; use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::fmt; use std::ops::{Add, Sub}; use std::str::FromStr as _; use std::sync::Arc; use thiserror::Error; use util::serde::is_default; use zed_llm_client::{ CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, }; pub use crate::model::*; pub use crate::rate_limiter::*; pub use crate::registry::*; pub use crate::request::*; pub use crate::role::*; pub use crate::telemetry::*; pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; pub fn init(client: Arc, cx: &mut App) { init_settings(cx); RefreshLlmTokenListener::register(client.clone(), cx); } pub fn init_settings(cx: &mut App) { registry::init(cx); } /// The availability of a [`LanguageModel`]. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum LanguageModelAvailability { /// The language model is available to the general public. Public, /// The language model is available to users on the indicated plan. RequiresPlan(Plan), } /// Configuration for caching language model messages. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct LanguageModelCacheConfiguration { pub max_cache_anchors: usize, pub should_speculate: bool, pub min_total_token: usize, } /// A completion event from a language model. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub enum LanguageModelCompletionEvent { StatusUpdate(CompletionRequestStatus), Stop(StopReason), Text(String), Thinking { text: String, signature: Option, }, ToolUse(LanguageModelToolUse), StartMessage { message_id: String, }, UsageUpdate(TokenUsage), } #[derive(Error, Debug)] pub enum LanguageModelCompletionError { #[error("received bad input JSON")] BadInputJson { id: LanguageModelToolUseId, tool_name: Arc, raw_input: Arc, json_parse_error: String, }, #[error(transparent)] Other(#[from] anyhow::Error), } /// Indicates the format used to define the input schema for a language model tool. #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum LanguageModelToolSchemaFormat { /// A JSON schema, see https://json-schema.org JsonSchema, /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema JsonSchemaSubset, } #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum StopReason { EndTurn, MaxTokens, ToolUse, } #[derive(Debug, Clone, Copy)] pub struct RequestUsage { pub limit: UsageLimit, pub amount: i32, } impl RequestUsage { pub fn from_headers(headers: &HeaderMap) -> Result { let limit = headers .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME) .with_context(|| { format!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header") })?; let limit = UsageLimit::from_str(limit.to_str()?)?; let amount = headers .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME) .with_context(|| { format!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header") })?; let amount = amount.to_str()?.parse::()?; Ok(Self { limit, amount }) } } #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] pub struct TokenUsage { #[serde(default, skip_serializing_if = "is_default")] pub input_tokens: u32, #[serde(default, skip_serializing_if = "is_default")] pub output_tokens: u32, #[serde(default, skip_serializing_if = "is_default")] pub cache_creation_input_tokens: u32, #[serde(default, skip_serializing_if = "is_default")] pub cache_read_input_tokens: u32, } impl TokenUsage { pub fn total_tokens(&self) -> u32 { self.input_tokens + self.output_tokens + self.cache_read_input_tokens + self.cache_creation_input_tokens } } impl Add for TokenUsage { type Output = Self; fn add(self, other: Self) -> Self { Self { input_tokens: self.input_tokens + other.input_tokens, output_tokens: self.output_tokens + other.output_tokens, cache_creation_input_tokens: self.cache_creation_input_tokens + other.cache_creation_input_tokens, cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens, } } } impl Sub for TokenUsage { type Output = Self; fn sub(self, other: Self) -> Self { Self { input_tokens: self.input_tokens - other.input_tokens, output_tokens: self.output_tokens - other.output_tokens, cache_creation_input_tokens: self.cache_creation_input_tokens - other.cache_creation_input_tokens, cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens, } } } #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] pub struct LanguageModelToolUseId(Arc); impl fmt::Display for LanguageModelToolUseId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } impl From for LanguageModelToolUseId where T: Into>, { fn from(value: T) -> Self { Self(value.into()) } } #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] pub struct LanguageModelToolUse { pub id: LanguageModelToolUseId, pub name: Arc, pub raw_input: String, pub input: serde_json::Value, pub is_input_complete: bool, } pub struct LanguageModelTextStream { pub message_id: Option, pub stream: BoxStream<'static, Result>, // Has complete token usage after the stream has finished pub last_token_usage: Arc>, } impl Default for LanguageModelTextStream { fn default() -> Self { Self { message_id: None, stream: Box::pin(futures::stream::empty()), last_token_usage: Arc::new(Mutex::new(TokenUsage::default())), } } } pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; fn provider_id(&self) -> LanguageModelProviderId; fn provider_name(&self) -> LanguageModelProviderName; fn telemetry_id(&self) -> String; fn api_key(&self, _cx: &App) -> Option { None } /// Returns the availability of this language model. fn availability(&self) -> LanguageModelAvailability { LanguageModelAvailability::Public } /// Whether this model supports images fn supports_images(&self) -> bool; /// Whether this model supports tools. fn supports_tools(&self) -> bool; /// Whether this model supports choosing which tool to use. fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool; /// Returns whether this model supports "max mode"; fn supports_max_mode(&self) -> bool { if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID { return false; } const MAX_MODE_CAPABLE_MODELS: &[CloudModel] = &[ CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet), CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking), ]; for model in MAX_MODE_CAPABLE_MODELS { if self.id().0 == model.id() { return true; } } false } fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { LanguageModelToolSchemaFormat::JsonSchema } fn max_token_count(&self) -> usize; fn max_output_tokens(&self) -> Option { None } fn count_tokens( &self, request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result>; fn stream_completion( &self, request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture< 'static, Result< BoxStream<'static, Result>, >, >; fn stream_completion_text( &self, request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture<'static, Result> { let future = self.stream_completion(request, cx); async move { let events = future.await?; let mut events = events.fuse(); let mut message_id = None; let mut first_item_text = None; let last_token_usage = Arc::new(Mutex::new(TokenUsage::default())); if let Some(first_event) = events.next().await { match first_event { Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => { message_id = Some(id.clone()); } Ok(LanguageModelCompletionEvent::Text(text)) => { first_item_text = Some(text); } _ => (), } } let stream = futures::stream::iter(first_item_text.map(Ok)) .chain(events.filter_map({ let last_token_usage = last_token_usage.clone(); move |result| { let last_token_usage = last_token_usage.clone(); async move { match result { Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Thinking { .. }) => None, Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { *last_token_usage.lock() = token_usage; None } Err(err) => Some(Err(err)), } } } })) .boxed(); Ok(LanguageModelTextStream { message_id, stream, last_token_usage, }) } .boxed() } fn cache_configuration(&self) -> Option { None } #[cfg(any(test, feature = "test-support"))] fn as_fake(&self) -> &fake_provider::FakeLanguageModel { unimplemented!() } } #[derive(Debug, Error)] pub enum LanguageModelKnownError { #[error("Context window limit exceeded ({tokens})")] ContextWindowLimitExceeded { tokens: usize }, } pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { fn name() -> String; fn description() -> String; } /// An error that occurred when trying to authenticate the language model provider. #[derive(Debug, Error)] pub enum AuthenticateError { #[error("credentials not found")] CredentialsNotFound, #[error(transparent)] Other(#[from] anyhow::Error), } pub trait LanguageModelProvider: 'static { fn id(&self) -> LanguageModelProviderId; fn name(&self) -> LanguageModelProviderName; fn icon(&self) -> IconName { IconName::ZedAssistant } fn default_model(&self, cx: &App) -> Option>; fn default_fast_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; fn recommended_models(&self, _cx: &App) -> Vec> { Vec::new() } fn load_model(&self, _model: Arc, _cx: &App) {} fn is_authenticated(&self, cx: &App) -> bool; fn authenticate(&self, cx: &mut App) -> Task>; fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView; fn must_accept_terms(&self, _cx: &App) -> bool { false } fn render_accept_terms( &self, _view: LanguageModelProviderTosView, _cx: &mut App, ) -> Option { None } fn reset_credentials(&self, cx: &mut App) -> Task>; } #[derive(PartialEq, Eq)] pub enum LanguageModelProviderTosView { /// When there are some past interactions in the Agent Panel. ThreadtEmptyState, /// When there are no past interactions in the Agent Panel. ThreadFreshStart, PromptEditorPopup, Configuration, } pub trait LanguageModelProviderState: 'static { type ObservableEntity; fn observable_entity(&self) -> Option>; fn subscribe( &self, cx: &mut gpui::Context, callback: impl Fn(&mut T, &mut gpui::Context) + 'static, ) -> Option { let entity = self.observable_entity()?; Some(cx.observe(&entity, move |this, _, cx| { callback(this, cx); })) } } #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)] pub struct LanguageModelId(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] pub struct LanguageModelName(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] pub struct LanguageModelProviderId(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] pub struct LanguageModelProviderName(pub SharedString); impl fmt::Display for LanguageModelProviderId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } impl From for LanguageModelId { fn from(value: String) -> Self { Self(SharedString::from(value)) } } impl From for LanguageModelName { fn from(value: String) -> Self { Self(SharedString::from(value)) } } impl From for LanguageModelProviderId { fn from(value: String) -> Self { Self(SharedString::from(value)) } } impl From for LanguageModelProviderName { fn from(value: String) -> Self { Self(SharedString::from(value)) } }