ZIm/crates/language_model/src/fake_provider.rs
Michael Sloan fbf7caf93e
Default to fast model for thread summaries and titles + don't include system prompt / context / thinking segments (#29102)
* Adds a fast / cheaper model to providers and defaults thread
summarization to this model. Initial motivation for this was that
https://github.com/zed-industries/zed/pull/29099 would cause these
requests to fail when used with a thinking model. It doesn't seem
correct to use a thinking model for summarization.

* Skips system prompt, context, and thinking segments.

* If tool use is happening, allows 2 tool uses + one more agent response
before summarizing.

Downside of this is that there was potential for some prefix cache reuse
before, especially for title summarization (thread summarization omitted
tool results and so would not share a prefix for those). This seems fine
as these requests should typically be fairly small. Even for full thread
summarization, skipping all tool use / context should greatly reduce the
token use.

Release Notes:

- N/A
2025-04-19 23:26:29 +00:00

185 lines
5 KiB
Rust

use crate::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result;
use parking_lot::Mutex;
use std::sync::Arc;
pub fn language_model_id() -> LanguageModelId {
LanguageModelId::from("fake".to_string())
}
pub fn language_model_name() -> LanguageModelName {
LanguageModelName::from("Fake".to_string())
}
pub fn provider_id() -> LanguageModelProviderId {
LanguageModelProviderId::from("fake".to_string())
}
pub fn provider_name() -> LanguageModelProviderName {
LanguageModelProviderName::from("Fake".to_string())
}
#[derive(Clone, Default)]
pub struct FakeLanguageModelProvider;
impl LanguageModelProviderState for FakeLanguageModelProvider {
type ObservableEntity = ();
fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
None
}
}
impl LanguageModelProvider for FakeLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
provider_id()
}
fn name(&self) -> LanguageModelProviderName {
provider_name()
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(Arc::new(FakeLanguageModel::default()))
}
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(Arc::new(FakeLanguageModel::default()))
}
fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
vec![Arc::new(FakeLanguageModel::default())]
}
fn is_authenticated(&self, _: &App) -> bool {
true
}
fn authenticate(&self, _: &mut App) -> Task<Result<(), AuthenticateError>> {
Task::ready(Ok(()))
}
fn configuration_view(&self, _window: &mut Window, _: &mut App) -> AnyView {
unimplemented!()
}
fn reset_credentials(&self, _: &mut App) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
impl FakeLanguageModelProvider {
pub fn test_model(&self) -> FakeLanguageModel {
FakeLanguageModel::default()
}
}
#[derive(Debug, PartialEq)]
pub struct ToolUseRequest {
pub request: LanguageModelRequest,
pub name: String,
pub description: String,
pub schema: serde_json::Value,
}
#[derive(Default)]
pub struct FakeLanguageModel {
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
}
impl FakeLanguageModel {
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs
.lock()
.iter()
.map(|(request, _)| request.clone())
.collect()
}
pub fn completion_count(&self) -> usize {
self.current_completion_txs.lock().len()
}
pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) {
let current_completion_txs = self.current_completion_txs.lock();
let tx = current_completion_txs
.iter()
.find(|(req, _)| req == request)
.map(|(_, tx)| tx)
.unwrap();
tx.unbounded_send(chunk).unwrap();
}
pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
self.current_completion_txs
.lock()
.retain(|(req, _)| req != request);
}
pub fn stream_last_completion_response(&self, chunk: String) {
self.stream_completion_response(self.pending_completions().last().unwrap(), chunk);
}
pub fn end_last_completion_stream(&self) {
self.end_completion_stream(self.pending_completions().last().unwrap());
}
}
impl LanguageModel for FakeLanguageModel {
fn id(&self) -> LanguageModelId {
language_model_id()
}
fn name(&self) -> LanguageModelName {
language_model_name()
}
fn provider_id(&self) -> LanguageModelProviderId {
provider_id()
}
fn provider_name(&self) -> LanguageModelProviderName {
provider_name()
}
fn supports_tools(&self) -> bool {
false
}
fn telemetry_id(&self) -> String {
"fake".to_string()
}
fn max_token_count(&self) -> usize {
1000000
}
fn count_tokens(&self, _: LanguageModelRequest, _: &App) -> BoxFuture<'static, Result<usize>> {
futures::future::ready(Ok(0)).boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
_: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move {
Ok(rx
.map(|text| Ok(LanguageModelCompletionEvent::Text(text)))
.boxed())
}
.boxed()
}
fn as_fake(&self) -> &Self {
self
}
}