diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 6c3189b660..d61dec9f72 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -9,15 +9,17 @@ use editor::{Anchor, Editor, ExcerptId, ExcerptRange, MultiBuffer}; use fs::Fs; use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use gpui::{ - actions, elements::*, executor::Background, Action, AppContext, AsyncAppContext, Entity, - ModelContext, ModelHandle, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, - WindowContext, + actions, + elements::*, + executor::Background, + platform::{CursorStyle, MouseButton}, + Action, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task, + View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, }; use isahc::{http::StatusCode, Request, RequestExt}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; use settings::SettingsStore; use std::{cell::RefCell, io, rc::Rc, sync::Arc, time::Duration}; -use tiktoken_rs::model::get_context_size; use util::{post_inc, ResultExt, TryFutureExt}; use workspace::{ dock::{DockPosition, Panel}, @@ -430,7 +432,7 @@ impl Assistant { pending_completions: Default::default(), languages: language_registry, token_count: None, - max_token_count: get_context_size(model), + max_token_count: tiktoken_rs::model::get_context_size(model), pending_token_count: Task::ready(None), model: model.into(), _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], @@ -483,6 +485,7 @@ impl Assistant { .await?; this.update(&mut cx, |this, cx| { + this.max_token_count = tiktoken_rs::model::get_context_size(&this.model); this.token_count = Some(token_count); cx.notify() }); @@ -496,6 +499,12 @@ impl Assistant { Some(self.max_token_count as isize - self.token_count? as isize) } + fn set_model(&mut self, model: String, cx: &mut ModelContext) { + self.model = model; + self.count_remaining_tokens(cx); + cx.notify(); + } + fn assist(&mut self, cx: &mut ModelContext) { let messages = self .messages @@ -825,6 +834,16 @@ impl AssistantEditor { }); } } + + fn cycle_model(&mut self, cx: &mut ViewContext) { + self.assistant.update(cx, |assistant, cx| { + let new_model = match assistant.model.as_str() { + "gpt-4" => "gpt-3.5-turbo", + _ => "gpt-4", + }; + assistant.set_model(new_model.into(), cx); + }); + } } impl Entity for AssistantEditor { @@ -837,27 +856,23 @@ impl View for AssistantEditor { } fn render(&mut self, cx: &mut ViewContext) -> AnyElement { + enum Model {} let theme = &theme::current(cx).assistant; - let remaining_tokens = self - .assistant - .read(cx) - .remaining_tokens() - .map(|remaining_tokens| { - let remaining_tokens_style = if remaining_tokens <= 0 { - &theme.no_remaining_tokens - } else { - &theme.remaining_tokens - }; - Label::new( - remaining_tokens.to_string(), - remaining_tokens_style.text.clone(), - ) - .contained() - .with_style(remaining_tokens_style.container) - .aligned() - .top() - .right() - }); + let assistant = &self.assistant.read(cx); + let model = assistant.model.clone(); + let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| { + let remaining_tokens_style = if remaining_tokens <= 0 { + &theme.no_remaining_tokens + } else { + &theme.remaining_tokens + }; + Label::new( + remaining_tokens.to_string(), + remaining_tokens_style.text.clone(), + ) + .contained() + .with_style(remaining_tokens_style.container) + }); Stack::new() .with_child( @@ -865,7 +880,25 @@ impl View for AssistantEditor { .contained() .with_style(theme.container), ) - .with_children(remaining_tokens) + .with_child( + Flex::row() + .with_child( + MouseEventHandler::::new(0, cx, |state, _| { + let style = theme.model.style_for(state, false); + Label::new(model, style.text.clone()) + .contained() + .with_style(style.container) + }) + .with_cursor_style(CursorStyle::PointingHand) + .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)), + ) + .with_children(remaining_tokens) + .contained() + .with_style(theme.model_info_container) + .aligned() + .top() + .right(), + ) .into_any() } diff --git a/crates/theme/src/theme.rs b/crates/theme/src/theme.rs index 97aac92afd..f746f90193 100644 --- a/crates/theme/src/theme.rs +++ b/crates/theme/src/theme.rs @@ -976,6 +976,8 @@ pub struct AssistantStyle { pub sent_at: ContainedText, pub user_sender: ContainedText, pub assistant_sender: ContainedText, + pub model_info_container: ContainerStyle, + pub model: Interactive, pub remaining_tokens: ContainedText, pub no_remaining_tokens: ContainedText, pub api_key_editor: FieldEditor, diff --git a/styles/src/styleTree/assistant.ts b/styles/src/styleTree/assistant.ts index 3d21ee8519..217476bc31 100644 --- a/styles/src/styleTree/assistant.ts +++ b/styles/src/styleTree/assistant.ts @@ -11,7 +11,8 @@ export default function assistant(colorScheme: ColorScheme) { }, header: { border: border(layer, "default", { bottom: true, top: true }), - margin: { bottom: 6, top: 6 } + margin: { bottom: 6, top: 6 }, + background: editor(colorScheme).background }, user_sender: { ...text(layer, "sans", "default", { size: "sm", weight: "bold" }), @@ -23,17 +24,32 @@ export default function assistant(colorScheme: ColorScheme) { margin: { top: 2, left: 8 }, ...text(layer, "sans", "default", { size: "2xs" }), }, - remaining_tokens: { - padding: 4, + model_info_container: { margin: { right: 16, top: 4 }, + }, + model: { background: background(layer, "on"), + border: border(layer, "on", { overlay: true }), + padding: 4, + cornerRadius: 4, + ...text(layer, "sans", "default", { size: "xs" }), + hover: { + background: background(layer, "on", "hovered"), + } + }, + remaining_tokens: { + background: background(layer, "on"), + border: border(layer, "on", { overlay: true }), + padding: 4, + margin: { left: 4 }, cornerRadius: 4, ...text(layer, "sans", "positive", { size: "xs" }), }, no_remaining_tokens: { - padding: 4, - margin: { right: 16, top: 4 }, background: background(layer, "on"), + border: border(layer, "on", { overlay: true }), + padding: 4, + margin: { left: 4 }, cornerRadius: 4, ...text(layer, "sans", "negative", { size: "xs" }), },