mod model; pub mod provider; mod rate_limiter; mod registry; mod request; mod role; pub mod settings; use anyhow::Result; use client::{Client, UserStore}; use futures::{future::BoxFuture, stream::BoxStream}; use gpui::{ AnyView, AppContext, AsyncAppContext, FocusHandle, Model, SharedString, Task, WindowContext, }; pub use model::*; use project::Fs; use proto::Plan; pub(crate) use rate_limiter::*; pub use registry::*; pub use request::*; pub use role::*; use schemars::JsonSchema; use serde::de::DeserializeOwned; use std::{future::Future, sync::Arc}; use ui::IconName; pub fn init( user_store: Model, client: Arc, fs: Arc, cx: &mut AppContext, ) { settings::init(fs, cx); registry::init(user_store, client, cx); } /// The availability of a [`LanguageModel`]. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum LanguageModelAvailability { /// The language model is available to the general public. Public, /// The language model is available to users on the indicated plan. RequiresPlan(Plan), } pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; fn provider_id(&self) -> LanguageModelProviderId; fn provider_name(&self) -> LanguageModelProviderName; fn telemetry_id(&self) -> String; /// Returns the availability of this language model. fn availability(&self) -> LanguageModelAvailability { LanguageModelAvailability::Public } fn max_token_count(&self) -> usize; fn count_tokens( &self, request: LanguageModelRequest, cx: &AppContext, ) -> BoxFuture<'static, Result>; fn stream_completion( &self, request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>>; fn use_any_tool( &self, request: LanguageModelRequest, name: String, description: String, schema: serde_json::Value, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>; } impl dyn LanguageModel { pub fn use_tool( &self, request: LanguageModelRequest, cx: &AsyncAppContext, ) -> impl 'static + Future> { let schema = schemars::schema_for!(T); let schema_json = serde_json::to_value(&schema).unwrap(); let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx); async move { let response = request.await?; Ok(serde_json::from_value(response)?) } } } pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { fn name() -> String; fn description() -> String; } pub trait LanguageModelProvider: 'static { fn id(&self) -> LanguageModelProviderId; fn name(&self) -> LanguageModelProviderName; fn icon(&self) -> IconName { IconName::ZedAssistant } fn provided_models(&self, cx: &AppContext) -> Vec>; fn load_model(&self, _model: Arc, _cx: &AppContext) {} fn is_authenticated(&self, cx: &AppContext) -> bool; fn authenticate(&self, cx: &mut AppContext) -> Task>; fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option); fn reset_credentials(&self, cx: &mut AppContext) -> Task>; } pub trait LanguageModelProviderState: 'static { type ObservableEntity; fn observable_entity(&self) -> Option>; fn subscribe( &self, cx: &mut gpui::ModelContext, callback: impl Fn(&mut T, &mut gpui::ModelContext) + 'static, ) -> Option { let entity = self.observable_entity()?; Some(cx.observe(&entity, move |this, _, cx| { callback(this, cx); })) } } #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] pub struct LanguageModelId(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] pub struct LanguageModelName(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] pub struct LanguageModelProviderId(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] pub struct LanguageModelProviderName(pub SharedString); impl From for LanguageModelId { fn from(value: String) -> Self { Self(SharedString::from(value)) } } impl From for LanguageModelName { fn from(value: String) -> Self { Self(SharedString::from(value)) } } impl From for LanguageModelProviderId { fn from(value: String) -> Self { Self(SharedString::from(value)) } } impl From for LanguageModelProviderName { fn from(value: String) -> Self { Self(SharedString::from(value)) } }