ZIm/crates/language_models/src/provider/open_ai.rs
Antonio Scandurra f517050548
Partially fix assistant onboarding (#25313)
While investigating #24896, I noticed two issues:

1. The default configuration for the `zed.dev` provider was using the
wrong string for Claude 3.5 Sonnet. This meant the provider would always
result as not configured until the user selected it from the model
picker, because we couldn't deserialize that string to a valid
`anthropic::Model` enum variant.
2. When clicking on `Open New Chat`/`Start New Thread` in the provider
configuration, we would select `Claude 3.5 Haiku` by default instead of
Claude 3.5 Sonnet.

Release Notes:

- Fixed some issues that caused AI providers to sometimes be
misconfigured.
2025-02-24 07:29:55 +00:00

592 lines
20 KiB
Rust

use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use open_ai::{
stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
use crate::AllLanguageModelSettings;
const PROVIDER_ID: &str = "openai";
const PROVIDER_NAME: &str = "OpenAI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub max_output_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>,
}
pub struct OpenAiLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Entity<State>,
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
}
const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
credentials_provider
.delete_credentials(&api_url, &cx)
.await
.log_err();
this.update(&mut cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
.await
.log_err();
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, &cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
}
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for OpenAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn icon(&self) -> IconName {
IconName::AiOpenAi
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = open_ai::Model::default();
Some(Arc::new(OpenAiLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
// Add base models from open_ai::Model::iter()
for model in open_ai::Model::iter() {
if !matches!(model, open_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), model);
}
}
// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
.openai
.available_models
{
models.insert(
model.name.clone(),
open_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
},
);
}
models
.into_values()
.map(|model| {
Arc::new(OpenAiLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
}
}
pub struct OpenAiLanguageModel {
id: LanguageModelId,
model: open_ai::Model,
state: gpui::Entity<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl OpenAiLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenAI API Key"))?;
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn telemetry_id(&self) -> String {
format!("openai/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u32> {
self.model.max_output_tokens()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
count_open_ai_tokens(request, self.model.clone(), cx)
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move {
Ok(open_ai::extract_text_from_events(completions.await?)
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
fn use_any_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
function: FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
},
}));
request.tools = vec![ToolDefinition::Function {
function: FunctionDefinition {
name: tool_name.clone(),
description: Some(tool_description),
parameters: Some(schema),
},
}];
let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
let response = response.await?;
Ok(
open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
.await?
.boxed(),
)
})
.boxed()
}
}
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
model: open_ai::Model,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
cx.background_spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
match model {
open_ai::Model::Custom { .. }
| open_ai::Model::O1Mini
| open_ai::Model::O1
| open_ai::Model::O3Mini => tiktoken_rs::num_tokens_from_messages("gpt-4", &messages),
_ => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
}
})
.boxed()
}
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
}
impl ConfigurationView {
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let api_key_editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
editor
});
cx.observe(&state, |_, _, cx| {
cx.notify();
})
.detach();
let load_credentials_task = Some(cx.spawn_in(window, {
let state = state.clone();
|this, mut cx| async move {
if let Some(task) = state
.update(&mut cx, |state, cx| state.authenticate(cx))
.log_err()
{
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
this.update(&mut cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
api_key_editor,
state,
load_credentials_task,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self.api_key_editor.read(cx).text(cx);
if api_key.is_empty() {
return;
}
let state = self.state.clone();
cx.spawn_in(window, |_, mut cx| async move {
state
.update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, |_, mut cx| async move {
state
.update(&mut cx, |state, cx| state.reset_api_key(cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
white_space: WhiteSpace::Normal,
..Default::default()
};
EditorElement::new(
&self.api_key_editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
const OPENAI_CONSOLE_URL: &str = "https://platform.openai.com/api-keys";
const INSTRUCTIONS: [&str; 4] = [
"To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:",
" - Create one by visiting:",
" - Ensure your OpenAI account has credits",
" - Paste your API key below and hit enter to start using the assistant",
];
let env_var_set = self.state.read(cx).api_key_from_env;
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
} else if self.should_render_editor(cx) {
v_flex()
.size_full()
.on_action(cx.listener(Self::save_api_key))
.child(Label::new(INSTRUCTIONS[0]))
.child(h_flex().child(Label::new(INSTRUCTIONS[1])).child(
Button::new("openai_console", OPENAI_CONSOLE_URL)
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| cx.open_url(OPENAI_CONSOLE_URL))
)
)
.children(
(2..INSTRUCTIONS.len()).map(|n|
Label::new(INSTRUCTIONS[n])).collect::<Vec<_>>())
.child(
h_flex()
.w_full()
.my_2()
.px_2()
.py_1()
.bg(cx.theme().colors().editor_background)
.border_1()
.border_color(cx.theme().colors().border_variant)
.rounded_md()
.child(self.render_api_key_editor(cx)),
)
.child(
Label::new(
format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
)
.size(LabelSize::Small),
)
.child(
Label::new(
"Note that having a subscription for another service like GitHub Copilot won't work.".to_string(),
)
.size(LabelSize::Small),
)
.into_any()
} else {
h_flex()
.size_full()
.justify_between()
.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
} else {
"API key configured.".to_string()
})),
)
.child(
Button::new("reset-key", "Reset key")
.icon(Some(IconName::Trash))
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
.into_any()
}
}
}