Extract completion provider crate (#14823)
We will soon need `semantic_index` to be able to use
`CompletionProvider`. This is currently impossible due to a cyclic crate
dependency, because `CompletionProvider` lives in the `assistant` crate,
which depends on `semantic_index`.
This PR breaks the dependency cycle by extracting two crates out of
`assistant`: `language_model` and `completion`.
Only one piece of logic changed: [this
code](922fcaf5a6 (diff-3857b3707687a4d585f1200eec4c34a7a079eae8d303b4ce5b4fce46234ace9fR61-R69)
).
* As of https://github.com/zed-industries/zed/pull/13276, whenever we
ask a given completion provider for its available models, OpenAI
providers would go and ask the global assistant settings whether the
user had configured an `available_models` setting, and if so, return
that.
* This PR changes it so that instead of eagerly asking the assistant
settings for this info (the new crate must not depend on `assistant`, or
else the dependency cycle would be back), OpenAI completion providers
now store the user-configured settings as part of their struct, and
whenever the settings change, we update the provider.
In theory, this change should not change user-visible behavior...but
since it's the only change in this large PR that's more than just moving
code around, I'm mentioning it here in case there's an unexpected
regression in practice! (cc @amtoaer in case you'd like to try out this
branch and verify that the feature is still working the way you expect.)
Release Notes:
- N/A
---------
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
b9a53ffa0b
commit
ec487d8f64
30 changed files with 820 additions and 610 deletions
|
@ -1,6 +1,5 @@
|
|||
pub mod assistant_panel;
|
||||
pub mod assistant_settings;
|
||||
mod completion_provider;
|
||||
mod context;
|
||||
pub mod context_store;
|
||||
mod inline_assistant;
|
||||
|
@ -12,17 +11,20 @@ mod streaming_diff;
|
|||
mod terminal_inline_assistant;
|
||||
|
||||
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
|
||||
use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_slash_command::SlashCommandRegistry;
|
||||
use client::{proto, Client};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
pub use completion_provider::*;
|
||||
use completion::CompletionProvider;
|
||||
pub use context::*;
|
||||
pub use context_store::*;
|
||||
use fs::Fs;
|
||||
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
|
||||
use gpui::{
|
||||
actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal,
|
||||
};
|
||||
use indexed_docs::IndexedDocsRegistry;
|
||||
pub(crate) use inline_assistant::*;
|
||||
use language_model::LanguageModelResponseMessage;
|
||||
pub(crate) use model_selector::*;
|
||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -32,10 +34,7 @@ use slash_command::{
|
|||
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
|
||||
tabs_command, term_command,
|
||||
};
|
||||
use std::{
|
||||
fmt::{self, Display},
|
||||
sync::Arc,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
pub(crate) use streaming_diff::*;
|
||||
|
||||
actions!(
|
||||
|
@ -73,166 +72,6 @@ impl MessageId {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn from_proto(role: i32) -> Role {
|
||||
match proto::LanguageModelRole::from_i32(role) {
|
||||
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
|
||||
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
|
||||
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
|
||||
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
|
||||
None => Role::User,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_proto(&self) -> proto::LanguageModelRole {
|
||||
match self {
|
||||
Role::User => proto::LanguageModelRole::LanguageModelUser,
|
||||
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
|
||||
Role::System => proto::LanguageModelRole::LanguageModelSystem,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cycle(self) -> Role {
|
||||
match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "user"),
|
||||
Role::Assistant => write!(f, "assistant"),
|
||||
Role::System => write!(f, "system"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub enum LanguageModel {
|
||||
Cloud(CloudModel),
|
||||
OpenAi(OpenAiModel),
|
||||
Anthropic(AnthropicModel),
|
||||
Ollama(OllamaModel),
|
||||
}
|
||||
|
||||
impl Default for LanguageModel {
|
||||
fn default() -> Self {
|
||||
LanguageModel::Cloud(CloudModel::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel {
|
||||
pub fn telemetry_id(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
||||
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
|
||||
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
|
||||
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.display_name().into(),
|
||||
LanguageModel::Anthropic(model) => model.display_name().into(),
|
||||
LanguageModel::Cloud(model) => model.display_name().into(),
|
||||
LanguageModel::Ollama(model) => model.display_name().into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.max_token_count(),
|
||||
LanguageModel::Anthropic(model) => model.max_token_count(),
|
||||
LanguageModel::Cloud(model) => model.max_token_count(),
|
||||
LanguageModel::Ollama(model) => model.max_token_count(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.id(),
|
||||
LanguageModel::Anthropic(model) => model.id(),
|
||||
LanguageModel::Cloud(model) => model.id(),
|
||||
LanguageModel::Ollama(model) => model.id(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl LanguageModelRequestMessage {
|
||||
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
|
||||
proto::LanguageModelRequestMessage {
|
||||
role: self.role.to_proto() as i32,
|
||||
content: self.content.clone(),
|
||||
tool_calls: Vec::new(),
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
pub struct LanguageModelRequest {
|
||||
pub model: LanguageModel,
|
||||
pub messages: Vec<LanguageModelRequestMessage>,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl LanguageModelRequest {
|
||||
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
|
||||
proto::CompleteWithLanguageModel {
|
||||
model: self.model.id().to_string(),
|
||||
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
|
||||
stop: self.stop.clone(),
|
||||
temperature: self.temperature,
|
||||
tool_choice: None,
|
||||
tools: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
|
||||
pub fn preprocess(&mut self) {
|
||||
match &self.model {
|
||||
LanguageModel::OpenAi(_) => {}
|
||||
LanguageModel::Anthropic(_) => {}
|
||||
LanguageModel::Ollama(_) => {}
|
||||
LanguageModel::Cloud(model) => match model {
|
||||
CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku
|
||||
| CloudModel::Claude3_5Sonnet => {
|
||||
preprocess_anthropic_request(self);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LanguageModelResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct LanguageModelUsage {
|
||||
pub prompt_tokens: u32,
|
||||
|
@ -343,7 +182,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||
|
||||
context_store::init(&client);
|
||||
prompt_library::init(cx);
|
||||
completion_provider::init(client.clone(), cx);
|
||||
init_completion_provider(Arc::clone(&client), cx);
|
||||
assistant_slash_command::init(cx);
|
||||
register_slash_commands(cx);
|
||||
assistant_panel::init(cx);
|
||||
|
@ -368,6 +207,20 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||
.detach();
|
||||
}
|
||||
|
||||
fn init_completion_provider(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx);
|
||||
cx.set_global(CompletionProvider::new(provider, Some(client)));
|
||||
|
||||
let mut settings_version = 0;
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
settings_version += 1;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
assistant_settings::update_completion_provider_settings(provider, settings_version, cx);
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn register_slash_commands(cx: &mut AppContext) {
|
||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||
slash_command_registry.register_command(file_command::FileSlashCommand, true);
|
||||
|
|
|
@ -8,18 +8,18 @@ use crate::{
|
|||
SlashCommandCompletionProvider, SlashCommandRegistry,
|
||||
},
|
||||
terminal_inline_assistant::TerminalInlineAssistant,
|
||||
Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore,
|
||||
CycleMessageRole, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep,
|
||||
EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant,
|
||||
InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus,
|
||||
QuoteSelection, RemoteContextMetadata, ResetKey, Role, SavedContextMetadata, Split,
|
||||
ToggleFocus, ToggleModelSelector,
|
||||
Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
|
||||
DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations,
|
||||
EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor,
|
||||
MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection,
|
||||
RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
||||
use breadcrumbs::Breadcrumbs;
|
||||
use client::proto;
|
||||
use collections::{BTreeSet, HashMap, HashSet};
|
||||
use completion::CompletionProvider;
|
||||
use editor::{
|
||||
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
|
||||
display_map::{
|
||||
|
@ -43,6 +43,7 @@ use language::{
|
|||
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
|
||||
ToOffset,
|
||||
};
|
||||
use language_model::Role;
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use project::{Project, ProjectLspAdapterDelegate};
|
||||
|
|
|
@ -1,166 +1,19 @@
|
|||
use std::fmt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
|
||||
pub use anthropic::Model as AnthropicModel;
|
||||
use gpui::Pixels;
|
||||
pub use ollama::Model as OllamaModel;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
use schemars::{
|
||||
schema::{InstanceType, Metadata, Schema, SchemaObject},
|
||||
JsonSchema,
|
||||
};
|
||||
use serde::{
|
||||
de::{self, Visitor},
|
||||
Deserialize, Deserializer, Serialize, Serializer,
|
||||
use anthropic::Model as AnthropicModel;
|
||||
use client::Client;
|
||||
use completion::{
|
||||
AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
|
||||
LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
|
||||
};
|
||||
use gpui::{AppContext, Pixels};
|
||||
use language_model::{CloudModel, LanguageModel};
|
||||
use ollama::Model as OllamaModel;
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use parking_lot::RwLock;
|
||||
use schemars::{schema::Schema, JsonSchema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
use strum::{EnumIter, IntoEnumIterator};
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
||||
pub enum CloudModel {
|
||||
Gpt3Point5Turbo,
|
||||
Gpt4,
|
||||
Gpt4Turbo,
|
||||
#[default]
|
||||
Gpt4Omni,
|
||||
Gpt4OmniMini,
|
||||
Claude3_5Sonnet,
|
||||
Claude3Opus,
|
||||
Claude3Sonnet,
|
||||
Claude3Haiku,
|
||||
Gemini15Pro,
|
||||
Gemini15Flash,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl Serialize for CloudModel {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.id())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for CloudModel {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct ZedDotDevModelVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
|
||||
type Value = CloudModel;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
let model = CloudModel::iter()
|
||||
.find(|model| model.id() == value)
|
||||
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(ZedDotDevModelVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonSchema for CloudModel {
|
||||
fn schema_name() -> String {
|
||||
"ZedDotDevModel".to_owned()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||
let variants = CloudModel::iter()
|
||||
.filter_map(|model| {
|
||||
let id = model.id();
|
||||
if id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(id.to_string())
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
|
||||
metadata: Some(Box::new(Metadata {
|
||||
title: Some("ZedDotDevModel".to_owned()),
|
||||
default: Some(CloudModel::default().id().into()),
|
||||
examples: variants.into_iter().map(Into::into).collect(),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl CloudModel {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||
Self::Gpt4 => "gpt-4",
|
||||
Self::Gpt4Turbo => "gpt-4-turbo-preview",
|
||||
Self::Gpt4Omni => "gpt-4o",
|
||||
Self::Gpt4OmniMini => "gpt-4o-mini",
|
||||
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
|
||||
Self::Claude3Opus => "claude-3-opus",
|
||||
Self::Claude3Sonnet => "claude-3-sonnet",
|
||||
Self::Claude3Haiku => "claude-3-haiku",
|
||||
Self::Gemini15Pro => "gemini-1.5-pro",
|
||||
Self::Gemini15Flash => "gemini-1.5-flash",
|
||||
Self::Custom(id) => id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
|
||||
Self::Gpt4 => "GPT 4",
|
||||
Self::Gpt4Turbo => "GPT 4 Turbo",
|
||||
Self::Gpt4Omni => "GPT 4 Omni",
|
||||
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
|
||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Gemini15Pro => "Gemini 1.5 Pro",
|
||||
Self::Gemini15Flash => "Gemini 1.5 Flash",
|
||||
Self::Custom(id) => id.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => 2048,
|
||||
Self::Gpt4 => 4096,
|
||||
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
|
||||
Self::Gpt4OmniMini => 128000,
|
||||
Self::Claude3_5Sonnet
|
||||
| Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3Haiku => 200000,
|
||||
Self::Gemini15Pro => 128000,
|
||||
Self::Gemini15Flash => 32000,
|
||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
|
||||
match self {
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
|
||||
preprocess_anthropic_request(request)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
|
@ -620,6 +473,124 @@ fn merge<T>(target: &mut T, value: Option<T>) {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn update_completion_provider_settings(
|
||||
provider: &mut CompletionProvider,
|
||||
version: usize,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let updated = match &AssistantSettings::get_global(cx).provider {
|
||||
AssistantProvider::ZedDotDev { model } => provider
|
||||
.update_current_as::<_, CloudCompletionProvider>(|provider| {
|
||||
provider.update(model.clone(), version);
|
||||
}),
|
||||
AssistantProvider::OpenAi {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
} => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
||||
provider.update(
|
||||
choose_openai_model(&model, &available_models),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
version,
|
||||
);
|
||||
}),
|
||||
AssistantProvider::Anthropic {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.update(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
version,
|
||||
);
|
||||
}),
|
||||
AssistantProvider::Ollama {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||
provider.update(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
version,
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
};
|
||||
|
||||
// Previously configured provider was changed to another one
|
||||
if updated.is_none() {
|
||||
provider.update_provider(|client| create_provider_from_settings(client, version, cx));
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn create_provider_from_settings(
|
||||
client: Arc<Client>,
|
||||
settings_version: usize,
|
||||
cx: &mut AppContext,
|
||||
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
|
||||
match &AssistantSettings::get_global(cx).provider {
|
||||
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
|
||||
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
|
||||
)),
|
||||
AssistantProvider::OpenAi {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
|
||||
choose_openai_model(&model, &available_models),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
available_models.clone(),
|
||||
))),
|
||||
AssistantProvider::Anthropic {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
))),
|
||||
AssistantProvider::Ollama {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
cx,
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Choose which model to use for openai provider.
|
||||
/// If the model is not available, try to use the first available model, or fallback to the original model.
|
||||
fn choose_openai_model(
|
||||
model: &::open_ai::Model,
|
||||
available_models: &[::open_ai::Model],
|
||||
) -> ::open_ai::Model {
|
||||
available_models
|
||||
.iter()
|
||||
.find(|&m| m == model)
|
||||
.or_else(|| available_models.first())
|
||||
.unwrap_or_else(|| model)
|
||||
.clone()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::{AppContext, UpdateGlobal};
|
||||
|
|
|
@ -1,396 +0,0 @@
|
|||
mod anthropic;
|
||||
mod cloud;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
mod fake;
|
||||
mod ollama;
|
||||
mod open_ai;
|
||||
|
||||
pub use anthropic::*;
|
||||
pub use cloud::*;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub use fake::*;
|
||||
pub use ollama::*;
|
||||
pub use open_ai::*;
|
||||
use parking_lot::RwLock;
|
||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||
|
||||
use crate::{
|
||||
assistant_settings::{AssistantProvider, AssistantSettings},
|
||||
LanguageModel, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
|
||||
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration};
|
||||
|
||||
/// Choose which model to use for openai provider.
|
||||
/// If the model is not available, try to use the first available model, or fallback to the original model.
|
||||
fn choose_openai_model(
|
||||
model: &::open_ai::Model,
|
||||
available_models: &[::open_ai::Model],
|
||||
) -> ::open_ai::Model {
|
||||
available_models
|
||||
.iter()
|
||||
.find(|&m| m == model)
|
||||
.or_else(|| available_models.first())
|
||||
.unwrap_or_else(|| model)
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let provider = create_provider_from_settings(client.clone(), 0, cx);
|
||||
cx.set_global(CompletionProvider::new(provider, Some(client)));
|
||||
|
||||
let mut settings_version = 0;
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
settings_version += 1;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
provider.update_settings(settings_version, cx);
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub struct CompletionResponse {
|
||||
inner: BoxStream<'static, Result<String>>,
|
||||
_lock: SemaphoreGuardArc,
|
||||
}
|
||||
|
||||
impl futures::Stream for CompletionResponse {
|
||||
type Item = Result<String>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.inner).poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelCompletionProvider: Send + Sync {
|
||||
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
|
||||
fn settings_version(&self) -> usize;
|
||||
fn is_authenticated(&self) -> bool;
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||
fn model(&self) -> LanguageModel;
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>>;
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn Any;
|
||||
}
|
||||
|
||||
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
|
||||
|
||||
pub struct CompletionProvider {
|
||||
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||
client: Option<Arc<Client>>,
|
||||
request_limiter: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl CompletionProvider {
|
||||
pub fn new(
|
||||
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||
client: Option<Arc<Client>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
client,
|
||||
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
|
||||
self.provider.read().available_models(cx)
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
self.provider.read().settings_version()
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.provider.read().is_authenticated()
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.provider.read().authenticate(cx)
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
self.provider.read().authentication_prompt(cx)
|
||||
}
|
||||
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.provider.read().reset_credentials(cx)
|
||||
}
|
||||
|
||||
pub fn model(&self) -> LanguageModel {
|
||||
self.provider.read().model()
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
self.provider.read().count_tokens(request, cx)
|
||||
}
|
||||
|
||||
pub fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> Task<Result<CompletionResponse>> {
|
||||
let rate_limiter = self.request_limiter.clone();
|
||||
let provider = self.provider.clone();
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let lock = rate_limiter.acquire_arc().await;
|
||||
let response = provider.read().stream_completion(request);
|
||||
let response = response.await?;
|
||||
Ok(CompletionResponse {
|
||||
inner: response,
|
||||
_lock: lock,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
|
||||
let response = self.stream_completion(request, cx);
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let mut chunks = response.await?;
|
||||
let mut completion = String::new();
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
let chunk = chunk?;
|
||||
completion.push_str(&chunk);
|
||||
}
|
||||
Ok(completion)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl gpui::Global for CompletionProvider {}
|
||||
|
||||
impl CompletionProvider {
|
||||
pub fn global(cx: &AppContext) -> &Self {
|
||||
cx.global::<Self>()
|
||||
}
|
||||
|
||||
pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
|
||||
&mut self,
|
||||
update: impl FnOnce(&mut T) -> R,
|
||||
) -> Option<R> {
|
||||
let mut provider = self.provider.write();
|
||||
if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
|
||||
Some(update(provider))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
|
||||
let updated = match &AssistantSettings::get_global(cx).provider {
|
||||
AssistantProvider::ZedDotDev { model } => self
|
||||
.update_current_as::<_, CloudCompletionProvider>(|provider| {
|
||||
provider.update(model.clone(), version);
|
||||
}),
|
||||
AssistantProvider::OpenAi {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
} => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
||||
provider.update(
|
||||
choose_openai_model(&model, &available_models),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
version,
|
||||
);
|
||||
}),
|
||||
AssistantProvider::Anthropic {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.update(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
version,
|
||||
);
|
||||
}),
|
||||
AssistantProvider::Ollama {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||
provider.update(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
version,
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
};
|
||||
|
||||
// Previously configured provider was changed to another one
|
||||
if updated.is_none() {
|
||||
if let Some(client) = self.client.clone() {
|
||||
self.provider = create_provider_from_settings(client, version, cx);
|
||||
} else {
|
||||
log::warn!("completion provider cannot be created because client is not set");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider_from_settings(
|
||||
client: Arc<Client>,
|
||||
settings_version: usize,
|
||||
cx: &mut AppContext,
|
||||
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
|
||||
match &AssistantSettings::get_global(cx).provider {
|
||||
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
|
||||
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
|
||||
)),
|
||||
AssistantProvider::OpenAi {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
|
||||
choose_openai_model(&model, &available_models),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
))),
|
||||
AssistantProvider::Anthropic {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
))),
|
||||
AssistantProvider::Ollama {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
cx,
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use gpui::AppContext;
|
||||
use parking_lot::RwLock;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt;
|
||||
|
||||
use crate::{
|
||||
completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
|
||||
FakeCompletionProvider, LanguageModelRequest,
|
||||
};
|
||||
|
||||
#[gpui::test]
|
||||
fn test_rate_limiting(cx: &mut AppContext) {
|
||||
SettingsStore::test(cx);
|
||||
let fake_provider = FakeCompletionProvider::setup_test(cx);
|
||||
|
||||
let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
|
||||
|
||||
// Enqueue some requests
|
||||
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
|
||||
let response = provider.stream_completion(
|
||||
LanguageModelRequest {
|
||||
temperature: i as f32 / 10.0,
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
);
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let mut stream = response.await.unwrap();
|
||||
while let Some(message) = stream.next().await {
|
||||
message.unwrap();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||
);
|
||||
|
||||
// Get the first completion request that is in flight and mark it as completed.
|
||||
let completion = fake_provider
|
||||
.pending_completions()
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
fake_provider.finish_completion(&completion);
|
||||
|
||||
// Ensure that the number of in-flight completion requests is reduced.
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||
);
|
||||
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
// Ensure that another completion request was allowed to acquire the lock.
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||
);
|
||||
|
||||
// Mark all completion requests as finished that are in flight.
|
||||
for request in fake_provider.pending_completions() {
|
||||
fake_provider.finish_completion(&request);
|
||||
}
|
||||
|
||||
assert_eq!(fake_provider.completion_count(), 0);
|
||||
|
||||
// Wait until the background tasks acquire the lock again.
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
fake_provider.completion_count(),
|
||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||
);
|
||||
|
||||
// Finish all remaining completion requests.
|
||||
for request in fake_provider.pending_completions() {
|
||||
fake_provider.finish_completion(&request);
|
||||
}
|
||||
|
||||
cx.background_executor().run_until_parked();
|
||||
|
||||
assert_eq!(fake_provider.completion_count(), 0);
|
||||
}
|
||||
}
|
|
@ -1,367 +0,0 @@
|
|||
use crate::{
|
||||
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
|
||||
Role,
|
||||
};
|
||||
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
|
||||
use anthropic::{stream_completion, Request, RequestMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
|
||||
use http::HttpClient;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct AnthropicCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
model: AnthropicModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
|
||||
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
||||
AnthropicModel::iter()
|
||||
.map(LanguageModel::Anthropic)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.api_key = None;
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::Anthropic(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_anthropic_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
anthropic::ResponseEvent::ContentBlockStart {
|
||||
content_block, ..
|
||||
} => match content_block {
|
||||
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
||||
},
|
||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicCompletionProvider {
|
||||
pub fn new(
|
||||
model: AnthropicModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
model: AnthropicModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
self.model = model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
||||
preprocess_anthropic_request(&mut request);
|
||||
|
||||
let model = match request.model {
|
||||
LanguageModel::Anthropic(model) => model,
|
||||
_ => self.model.clone(),
|
||||
};
|
||||
|
||||
let mut system_message = String::new();
|
||||
if request
|
||||
.messages
|
||||
.first()
|
||||
.map_or(false, |message| message.role == Role::System)
|
||||
{
|
||||
system_message = request.messages.remove(0).content;
|
||||
}
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| RequestMessage {
|
||||
role: match msg.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => unreachable!("filtered out by preprocess_request"),
|
||||
},
|
||||
content: msg.content.clone(),
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
|
||||
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
||||
for message in request.messages.drain(..) {
|
||||
if message.content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
if let Some(last_message) = new_messages.last_mut() {
|
||||
if last_message.role == message.role {
|
||||
last_message.content.push_str("\n\n");
|
||||
last_message.content.push_str(&message.content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
new_messages.push(message);
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !system_message.is_empty() {
|
||||
new_messages.insert(
|
||||
0,
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: system_message,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
request.messages = new_messages;
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 4] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
|
||||
"You can create an API key at: https://console.anthropic.com/settings/keys",
|
||||
"",
|
||||
"Paste your Anthropic API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
|
@ -1,208 +0,0 @@
|
|||
use crate::{
|
||||
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
|
||||
LanguageModelCompletionProvider, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{proto, Client};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use std::{future, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
|
||||
pub struct CloudCompletionProvider {
|
||||
client: Arc<Client>,
|
||||
model: CloudModel,
|
||||
settings_version: usize,
|
||||
status: client::Status,
|
||||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
|
||||
impl CloudCompletionProvider {
|
||||
pub fn new(
|
||||
model: CloudModel,
|
||||
client: Arc<Client>,
|
||||
settings_version: usize,
|
||||
cx: &mut AppContext,
|
||||
) -> Self {
|
||||
let mut status_rx = client.status();
|
||||
let status = *status_rx.borrow();
|
||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||
while let Some(status) = status_rx.next().await {
|
||||
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, Self>(|provider| {
|
||||
provider.status = status;
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
Self {
|
||||
client,
|
||||
model,
|
||||
settings_version,
|
||||
status,
|
||||
_maintain_client_status: maintain_client_status,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, model: CloudModel, settings_version: usize) {
|
||||
self.model = model;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
||||
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
||||
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
|
||||
Some(custom_model)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
CloudModel::iter()
|
||||
.filter_map(move |model| {
|
||||
if let CloudModel::Custom(_) = model {
|
||||
Some(CloudModel::Custom(custom_model.take()?))
|
||||
} else {
|
||||
Some(model)
|
||||
}
|
||||
})
|
||||
.map(LanguageModel::Cloud)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.status.is_connected()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|_cx| AuthenticationPrompt).into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::Cloud(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match request.model {
|
||||
LanguageModel::Cloud(CloudModel::Gpt4)
|
||||
| LanguageModel::Cloud(CloudModel::Gpt4Turbo)
|
||||
| LanguageModel::Cloud(CloudModel::Gpt4Omni)
|
||||
| LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::Cloud(
|
||||
CloudModel::Claude3_5Sonnet
|
||||
| CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku,
|
||||
) => {
|
||||
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::Cloud(CloudModel::Custom(model)) => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
});
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(response.token_count as usize)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
_ => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
mut request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
request.preprocess();
|
||||
|
||||
let request = proto::CompleteWithLanguageModel {
|
||||
model: request.model.id().to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
self.client
|
||||
.request_stream(request)
|
||||
.map_ok(|stream| {
|
||||
stream
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt;
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
|
||||
|
||||
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("sign_in", "Sign in")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
.style(ButtonStyle::Filled)
|
||||
.full_width()
|
||||
.on_click(|_, cx| {
|
||||
CompletionProvider::global(cx)
|
||||
.authenticate(cx)
|
||||
.detach_and_log_err(cx);
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
div().flex().w_full().items_center().child(
|
||||
Label::new("Sign in to enable collaboration.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,115 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use std::sync::Arc;
|
||||
use ui::WindowContext;
|
||||
|
||||
use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct FakeCompletionProvider {
|
||||
current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn setup_test(cx: &mut AppContext) -> Self {
|
||||
use crate::CompletionProvider;
|
||||
use parking_lot::RwLock;
|
||||
|
||||
let this = Self::default();
|
||||
let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
|
||||
cx.set_global(provider);
|
||||
this
|
||||
}
|
||||
|
||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.keys()
|
||||
.map(|k| serde_json::from_str(k).unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn completion_count(&self) -> usize {
|
||||
self.current_completion_txs.lock().len()
|
||||
}
|
||||
|
||||
pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
|
||||
let json = serde_json::to_string(request).unwrap();
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.get(&json)
|
||||
.unwrap()
|
||||
.unbounded_send(chunk)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn send_last_completion_chunk(&self, chunk: String) {
|
||||
self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self, request: &LanguageModelRequest) {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.remove(&serde_json::to_string(request).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_last_completion(&self) {
|
||||
self.finish_completion(self.pending_completions().last().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for FakeCompletionProvider {
|
||||
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
||||
vec![LanguageModel::default()]
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::default()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
futures::future::ready(Ok(0)).boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
.insert(serde_json::to_string(&_request).unwrap(), tx);
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
|
@ -1,358 +0,0 @@
|
|||
use crate::LanguageModelCompletionProvider;
|
||||
use crate::{
|
||||
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt as _;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use http::HttpClient;
|
||||
use ollama::{
|
||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||
Role as OllamaRole,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
||||
|
||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
||||
|
||||
pub struct OllamaCompletionProvider {
|
||||
api_url: String,
|
||||
model: OllamaModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
available_models: Vec<OllamaModel>,
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
||||
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
||||
self.available_models
|
||||
.iter()
|
||||
.map(|m| LanguageModel::Ollama(m.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
!self.available_models.is_empty()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
self.fetch_models(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
let fetch_models = Box::new(move |cx: &mut WindowContext| {
|
||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
provider
|
||||
.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||
provider.fetch_models(cx)
|
||||
})
|
||||
.unwrap_or_else(|| Task::ready(Ok(())))
|
||||
})
|
||||
});
|
||||
|
||||
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.fetch_models(cx)
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::Ollama(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
// There is no endpoint for this _yet_ in Ollama
|
||||
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
|
||||
let token_count = request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| msg.content.chars().count())
|
||||
.sum::<usize>()
|
||||
/ 4;
|
||||
|
||||
async move { Ok(token_count) }.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_ollama_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let request =
|
||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(delta) => {
|
||||
let content = match delta.message {
|
||||
ChatMessage::User { content } => content,
|
||||
ChatMessage::Assistant { content } => content,
|
||||
ChatMessage::System { content } => content,
|
||||
};
|
||||
Some(Ok(content))
|
||||
}
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl OllamaCompletionProvider {
|
||||
pub fn new(
|
||||
model: OllamaModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
cx: &AppContext,
|
||||
) -> Self {
|
||||
cx.spawn({
|
||||
let api_url = api_url.clone();
|
||||
let client = http_client.clone();
|
||||
let model = model.name.clone();
|
||||
|
||||
|_| async move {
|
||||
if model.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
preload_model(client.as_ref(), &api_url, &model).await
|
||||
}
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
Self {
|
||||
api_url,
|
||||
model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
available_models: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
model: OllamaModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
cx: &AppContext,
|
||||
) {
|
||||
cx.spawn({
|
||||
let api_url = api_url.clone();
|
||||
let client = self.http_client.clone();
|
||||
let model = model.name.clone();
|
||||
|
||||
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
if model.name.is_empty() {
|
||||
self.select_first_available_model()
|
||||
} else {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
pub fn select_first_available_model(&mut self) {
|
||||
if let Some(model) = self.available_models.first() {
|
||||
self.model = model.clone();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
cx.spawn(|mut cx| async move {
|
||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||
|
||||
let mut models: Vec<OllamaModel> = models
|
||||
.into_iter()
|
||||
// Since there is no metadata from the Ollama API
|
||||
// indicating which models are embedding models,
|
||||
// simply filter out models with "-embed" in their name
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
.map(|model| OllamaModel::new(&model.name))
|
||||
.collect();
|
||||
|
||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||
provider.available_models = models;
|
||||
|
||||
if !provider.available_models.is_empty() && provider.model.name.is_empty() {
|
||||
provider.select_first_available_model()
|
||||
}
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
||||
let model = match request.model {
|
||||
LanguageModel::Ollama(model) => model,
|
||||
_ => self.model.clone(),
|
||||
};
|
||||
|
||||
ChatRequest {
|
||||
model: model.name,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => ChatMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => ChatMessage::Assistant {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::System => ChatMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
keep_alive: model.keep_alive.unwrap_or_default(),
|
||||
stream: true,
|
||||
options: Some(ChatOptions {
|
||||
num_ctx: Some(model.max_tokens),
|
||||
stop: Some(request.stop),
|
||||
temperature: Some(request.temperature),
|
||||
..Default::default()
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for ollama::Role {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => OllamaRole::User,
|
||||
Role::Assistant => OllamaRole::Assistant,
|
||||
Role::System => OllamaRole::System,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct DownloadOllamaMessage {
|
||||
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
||||
}
|
||||
|
||||
impl DownloadOllamaMessage {
|
||||
pub fn new(
|
||||
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
|
||||
_cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
Self { retry_connection }
|
||||
}
|
||||
|
||||
fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
ButtonLike::new("download_ollama_button")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("Get Ollama"))
|
||||
.on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
|
||||
}
|
||||
|
||||
fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
ButtonLike::new("retry_ollama_models")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("Retry"))
|
||||
.on_click(cx.listener(move |this, _, cx| {
|
||||
let connected = (this.retry_connection)(cx);
|
||||
|
||||
cx.spawn(|_this, _cx| async move {
|
||||
connected.await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Once Ollama is on your machine, make sure to download a model or two.")
|
||||
.size(LabelSize::Large),
|
||||
)
|
||||
.child(
|
||||
h_flex().w_full().p_4().justify_center().gap_2().child(
|
||||
ButtonLike::new("view-models")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.child(Label::new("View Available Models"))
|
||||
.on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for DownloadOllamaMessage {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.p_4()
|
||||
.justify_center()
|
||||
.gap_2()
|
||||
.child(
|
||||
self.render_download_button(cx)
|
||||
)
|
||||
.child(
|
||||
self.render_retry_button(cx)
|
||||
)
|
||||
)
|
||||
.child(self.render_next_steps(cx))
|
||||
.into_any()
|
||||
}
|
||||
}
|
|
@ -1,378 +0,0 @@
|
|||
use crate::assistant_settings::CloudModel;
|
||||
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
|
||||
use crate::LanguageModelCompletionProvider;
|
||||
use crate::{
|
||||
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
|
||||
use http::HttpClient;
|
||||
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct OpenAiCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
model: OpenAiModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProvider {
|
||||
pub fn new(
|
||||
model: OpenAiModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
model: OpenAiModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
self.model = model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||
let model = match request.model {
|
||||
LanguageModel::OpenAi(model) => model,
|
||||
_ => self.model.clone(),
|
||||
};
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => RequestMessage::User {
|
||||
content: msg.content,
|
||||
},
|
||||
Role::Assistant => RequestMessage::Assistant {
|
||||
content: Some(msg.content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => RequestMessage::System {
|
||||
content: msg.content,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
|
||||
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
|
||||
if let AssistantProvider::OpenAi {
|
||||
available_models, ..
|
||||
} = &AssistantSettings::get_global(cx).provider
|
||||
{
|
||||
if !available_models.is_empty() {
|
||||
return available_models
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(LanguageModel::OpenAi)
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
|
||||
vec![self.model.clone()]
|
||||
} else {
|
||||
OpenAiModel::iter()
|
||||
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
|
||||
.collect()
|
||||
};
|
||||
available_models
|
||||
.into_iter()
|
||||
.map(LanguageModel::OpenAi)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, Self>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, Self>(|provider| {
|
||||
provider.api_key = None;
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn model(&self) -> LanguageModel {
|
||||
LanguageModel::OpenAi(self.model.clone())
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_open_ai_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
background_executor: &gpui::BackgroundExecutor,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
background_executor
|
||||
.spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.content),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match request.model {
|
||||
LanguageModel::Anthropic(_)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Opus)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
|
||||
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
|
||||
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
}
|
||||
_ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
impl From<Role> for open_ai::Role {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => OpenAiRole::User,
|
||||
Role::Assistant => OpenAiRole::Assistant,
|
||||
Role::System => OpenAiRole::System,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
||||
provider.api_key = Some(api_key);
|
||||
});
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 6] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
|
||||
" - You can create an API key at: platform.openai.com/api-keys",
|
||||
" - Make sure your OpenAI account has credits",
|
||||
" - Having a subscription for another service like GitHub Copilot won't work.",
|
||||
"",
|
||||
"Paste your OpenAI API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
|
@ -1,12 +1,12 @@
|
|||
use crate::{
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageStatus, Role,
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
|
||||
MessageStatus,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
|
||||
};
|
||||
use client::{proto, telemetry::Telemetry};
|
||||
use client::{self, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::Fs;
|
||||
|
@ -18,6 +18,8 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
|
|||
use language::{
|
||||
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
||||
};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::{LanguageModelRequest, Role};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::contexts_dir;
|
||||
use project::Project;
|
||||
|
@ -2477,9 +2479,10 @@ mod tests {
|
|||
use crate::{
|
||||
assistant_panel, prompt_library,
|
||||
slash_command::{active_command, file_command},
|
||||
FakeCompletionProvider, MessageId,
|
||||
MessageId,
|
||||
};
|
||||
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
|
||||
use completion::FakeCompletionProvider;
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, TestAppContext, WeakView};
|
||||
use indoc::indoc;
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use crate::{
|
||||
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
|
||||
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, Role, StreamingDiff,
|
||||
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
|
@ -28,6 +27,7 @@ use gpui::{
|
|||
WhiteSpace, WindowContext,
|
||||
};
|
||||
use language::{Buffer, Point, Selection, TransactionId};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use rope::Rope;
|
||||
|
@ -1432,8 +1432,7 @@ impl Render for PromptEditor {
|
|||
PopoverMenu::new("model-switcher")
|
||||
.menu(move |cx| {
|
||||
ContextMenu::build(cx, |mut menu, cx| {
|
||||
for model in CompletionProvider::global(cx).available_models(cx)
|
||||
{
|
||||
for model in CompletionProvider::global(cx).available_models() {
|
||||
menu = menu.custom_entry(
|
||||
{
|
||||
let model = model.clone();
|
||||
|
@ -2606,7 +2605,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::FakeCompletionProvider;
|
||||
use completion::FakeCompletionProvider;
|
||||
use futures::stream::{self};
|
||||
use gpui::{Context, TestAppContext};
|
||||
use indoc::indoc;
|
||||
|
|
|
@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector {
|
|||
.with_handle(self.handle)
|
||||
.menu(move |cx| {
|
||||
ContextMenu::build(cx, |mut menu, cx| {
|
||||
for model in CompletionProvider::global(cx).available_models(cx) {
|
||||
for model in CompletionProvider::global(cx).available_models() {
|
||||
menu = menu.custom_entry(
|
||||
{
|
||||
let model = model.clone();
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
|
||||
InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
InlineAssist, InlineAssistant,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use assets::Assets;
|
||||
|
@ -19,6 +19,7 @@ use gpui::{
|
|||
};
|
||||
use heed::{types::SerdeBincode, Database, RoTxn};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||
use parking_lot::RwLock;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use rope::Rope;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
assistant_settings::AssistantSettings, humanize_token_count,
|
||||
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
|
||||
CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
CompletionProvider,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
|
@ -17,6 +17,7 @@ use gpui::{
|
|||
Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace,
|
||||
};
|
||||
use language::Buffer;
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||
use settings::{update_settings_file, Settings};
|
||||
use std::{
|
||||
cmp,
|
||||
|
@ -558,8 +559,7 @@ impl Render for PromptEditor {
|
|||
PopoverMenu::new("model-switcher")
|
||||
.menu(move |cx| {
|
||||
ContextMenu::build(cx, |mut menu, cx| {
|
||||
for model in CompletionProvider::global(cx).available_models(cx)
|
||||
{
|
||||
for model in CompletionProvider::global(cx).available_models() {
|
||||
menu = menu.custom_entry(
|
||||
{
|
||||
let model = model.clone();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue