ZIm/crates/language_model/src/fake_provider.rs
Antonio Scandurra f517050548
Partially fix assistant onboarding (#25313)
While investigating #24896, I noticed two issues:

1. The default configuration for the `zed.dev` provider was using the
wrong string for Claude 3.5 Sonnet. This meant the provider would always
result as not configured until the user selected it from the model
picker, because we couldn't deserialize that string to a valid
`anthropic::Model` enum variant.
2. When clicking on `Open New Chat`/`Start New Thread` in the provider
configuration, we would select `Claude 3.5 Haiku` by default instead of
Claude 3.5 Sonnet.

Release Notes:

- Fixed some issues that caused AI providers to sometimes be
misconfigured.
2025-02-24 07:29:55 +00:00

205 lines
5.8 KiB
Rust

use crate::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result;
use parking_lot::Mutex;
use serde::Serialize;
use std::sync::Arc;
pub fn language_model_id() -> LanguageModelId {
LanguageModelId::from("fake".to_string())
}
pub fn language_model_name() -> LanguageModelName {
LanguageModelName::from("Fake".to_string())
}
pub fn provider_id() -> LanguageModelProviderId {
LanguageModelProviderId::from("fake".to_string())
}
pub fn provider_name() -> LanguageModelProviderName {
LanguageModelProviderName::from("Fake".to_string())
}
#[derive(Clone, Default)]
pub struct FakeLanguageModelProvider;
impl LanguageModelProviderState for FakeLanguageModelProvider {
type ObservableEntity = ();
fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
None
}
}
impl LanguageModelProvider for FakeLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
provider_id()
}
fn name(&self) -> LanguageModelProviderName {
provider_name()
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(Arc::new(FakeLanguageModel::default()))
}
fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
vec![Arc::new(FakeLanguageModel::default())]
}
fn is_authenticated(&self, _: &App) -> bool {
true
}
fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
Task::ready(Ok(()))
}
fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
unimplemented!()
}
fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
impl FakeLanguageModelProvider {
pub fn test_model(&self) -> FakeLanguageModel {
FakeLanguageModel::default()
}
}
#[derive(Debug, PartialEq)]
pub struct ToolUseRequest {
pub request: LanguageModelRequest,
pub name: String,
pub description: String,
pub schema: serde_json::Value,
}
#[derive(Default)]
pub struct FakeLanguageModel {
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
}
impl FakeLanguageModel {
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs
.lock()
.iter()
.map(|(request, _)| request.clone())
.collect()
}
pub fn completion_count(&self) -> usize {
self.current_completion_txs.lock().len()
}
pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
let current_completion_txs = self.current_completion_txs.lock();
let tx = current_completion_txs
.iter()
.find(|(req, _)| req == request)
.map(|(_, tx)| tx)
.unwrap();
tx.unbounded_send(chunk).unwrap();
}
pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
self.current_completion_txs
.lock()
.retain(|(req, _)| req != request);
}
pub fn stream_last_completion_response(&self, chunk: String) {
self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
}
pub fn end_last_completion_stream(&self) {
self.end_completion_stream(self.pending_completions().last().unwrap());
}
pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
let response = serde_json::to_string(&response).unwrap();
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
let (_, tx) = current_tool_call_txs.pop().unwrap();
tx.unbounded_send(response).unwrap();
}
}
impl LanguageModel for FakeLanguageModel {
fn id(&self) -> LanguageModelId {
language_model_id()
}
fn name(&self) -> LanguageModelName {
language_model_name()
}
fn provider_id(&self) -> LanguageModelProviderId {
provider_id()
}
fn provider_name(&self) -> LanguageModelProviderName {
provider_name()
}
fn telemetry_id(&self) -> String {
"fake".to_string()
}
fn max_token_count(&self) -> usize {
1000000
}
fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
futures::future::ready(Ok(0)).boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
_: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move {
Ok(rx
.map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
.boxed())
}
.boxed()
}
fn use_any_tool(
&self,
request: LanguageModelRequest,
name: String,
description: String,
schema: serde_json::Value,
_cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::unbounded();
let tool_call = ToolUseRequest {
request,
name,
description,
schema,
};
self.current_tool_use_txs.lock().push((tool_call, tx));
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn as_fake(&self) -> &Self {
self
}
}