agent2: Make model
of Thread
optional (#36395)
Related to #36394 Release Notes: - N/A
This commit is contained in:
parent
2075627d6c
commit
2eadd5a396
5 changed files with 195 additions and 138 deletions
|
@ -427,10 +427,12 @@ impl NativeAgent {
|
||||||
self.models.refresh_list(cx);
|
self.models.refresh_list(cx);
|
||||||
for session in self.sessions.values_mut() {
|
for session in self.sessions.values_mut() {
|
||||||
session.thread.update(cx, |thread, _| {
|
session.thread.update(cx, |thread, _| {
|
||||||
let model_id = LanguageModels::model_id(&thread.model());
|
if let Some(model) = thread.model() {
|
||||||
|
let model_id = LanguageModels::model_id(model);
|
||||||
if let Some(model) = self.models.model_from_id(&model_id) {
|
if let Some(model) = self.models.model_from_id(&model_id) {
|
||||||
thread.set_model(model.clone());
|
thread.set_model(model.clone());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -622,13 +624,15 @@ impl AgentModelSelector for NativeAgentConnection {
|
||||||
else {
|
else {
|
||||||
return Task::ready(Err(anyhow!("Session not found")));
|
return Task::ready(Err(anyhow!("Session not found")));
|
||||||
};
|
};
|
||||||
let model = thread.read(cx).model().clone();
|
let Some(model) = thread.read(cx).model() else {
|
||||||
|
return Task::ready(Err(anyhow!("Model not found")));
|
||||||
|
};
|
||||||
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
||||||
else {
|
else {
|
||||||
return Task::ready(Err(anyhow!("Provider not found")));
|
return Task::ready(Err(anyhow!("Provider not found")));
|
||||||
};
|
};
|
||||||
Task::ready(Ok(LanguageModels::map_language_model_to_info(
|
Task::ready(Ok(LanguageModels::map_language_model_to_info(
|
||||||
&model, &provider,
|
model, &provider,
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -679,19 +683,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
let available_count = registry.available_models(cx).count();
|
let available_count = registry.available_models(cx).count();
|
||||||
log::debug!("Total available models: {}", available_count);
|
log::debug!("Total available models: {}", available_count);
|
||||||
|
|
||||||
let default_model = registry
|
let default_model = registry.default_model().and_then(|default_model| {
|
||||||
.default_model()
|
|
||||||
.and_then(|default_model| {
|
|
||||||
agent
|
agent
|
||||||
.models
|
.models
|
||||||
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
||||||
})
|
});
|
||||||
.ok_or_else(|| {
|
|
||||||
log::warn!("No default model configured in settings");
|
|
||||||
anyhow!(
|
|
||||||
"No default model. Please configure a default model in settings."
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let mut thread = Thread::new(
|
let mut thread = Thread::new(
|
||||||
|
@ -777,13 +773,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
log::debug!("Message id: {:?}", id);
|
log::debug!("Message id: {:?}", id);
|
||||||
log::debug!("Message content: {:?}", content);
|
log::debug!("Message content: {:?}", content);
|
||||||
|
|
||||||
Ok(thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| thread.send(id, content, cx))
|
||||||
log::info!(
|
|
||||||
"Sending message to thread with model: {:?}",
|
|
||||||
thread.model().name()
|
|
||||||
);
|
|
||||||
thread.send(id, content, cx)
|
|
||||||
}))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1008,7 +998,7 @@ mod tests {
|
||||||
agent.read_with(cx, |agent, _| {
|
agent.read_with(cx, |agent, _| {
|
||||||
let session = agent.sessions.get(&session_id).unwrap();
|
let session = agent.sessions.get(&session_id).unwrap();
|
||||||
session.thread.read_with(cx, |thread, _| {
|
session.thread.read_with(cx, |thread, _| {
|
||||||
assert_eq!(thread.model().id().0, "fake");
|
assert_eq!(thread.model().unwrap().id().0, "fake");
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,7 @@ async fn test_echo(cx: &mut TestAppContext) {
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
|
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
|
||||||
})
|
})
|
||||||
|
.unwrap()
|
||||||
.collect()
|
.collect()
|
||||||
.await;
|
.await;
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
|
@ -73,6 +74,7 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
.unwrap()
|
||||||
.collect()
|
.collect()
|
||||||
.await;
|
.await;
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
|
@ -101,9 +103,11 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
project_context.borrow_mut().shell = "test-shell".into();
|
project_context.borrow_mut().shell = "test-shell".into();
|
||||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||||
thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let mut pending_completions = fake_model.pending_completions();
|
let mut pending_completions = fake_model.pending_completions();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -136,9 +140,11 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
// Send initial user message and verify it's cached
|
// Send initial user message and verify it's cached
|
||||||
thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Message 1"], cx)
|
thread.send(UserMessageId::new(), ["Message 1"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
let completion = fake_model.pending_completions().pop().unwrap();
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
@ -157,9 +163,11 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
// Send another user message and verify only the latest is cached
|
// Send another user message and verify only the latest is cached
|
||||||
thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Message 2"], cx)
|
thread.send(UserMessageId::new(), ["Message 2"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
let completion = fake_model.pending_completions().pop().unwrap();
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
@ -191,9 +199,11 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
// Simulate a tool call and verify that the latest tool result is cached
|
// Simulate a tool call and verify that the latest tool result is cached
|
||||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||||
thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
|
thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
let tool_use = LanguageModelToolUse {
|
let tool_use = LanguageModelToolUse {
|
||||||
|
@ -273,6 +283,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
.unwrap()
|
||||||
.collect()
|
.collect()
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||||
|
@ -291,6 +302,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
.unwrap()
|
||||||
.collect()
|
.collect()
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||||
|
@ -322,10 +334,12 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||||
|
|
||||||
// Test a tool call that's likely to complete *before* streaming stops.
|
// Test a tool call that's likely to complete *before* streaming stops.
|
||||||
let mut events = thread.update(cx, |thread, cx| {
|
let mut events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(WordListTool);
|
thread.add_tool(WordListTool);
|
||||||
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
|
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut saw_partial_tool_use = false;
|
let mut saw_partial_tool_use = false;
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
|
@ -371,10 +385,12 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let mut events = thread.update(cx, |thread, cx| {
|
let mut events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(ToolRequiringPermission);
|
thread.add_tool(ToolRequiringPermission);
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
LanguageModelToolUse {
|
LanguageModelToolUse {
|
||||||
|
@ -501,9 +517,11 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let mut events = thread.update(cx, |thread, cx| {
|
let mut events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
LanguageModelToolUse {
|
LanguageModelToolUse {
|
||||||
|
@ -528,10 +546,12 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let events = thread.update(cx, |thread, cx| {
|
let events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let tool_use = LanguageModelToolUse {
|
let tool_use = LanguageModelToolUse {
|
||||||
id: "tool_id_1".into(),
|
id: "tool_id_1".into(),
|
||||||
|
@ -644,10 +664,12 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let events = thread.update(cx, |thread, cx| {
|
let events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
let tool_use = LanguageModelToolUse {
|
let tool_use = LanguageModelToolUse {
|
||||||
|
@ -677,9 +699,11 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
.is::<language_model::ToolUseLimitReachedError>()
|
.is::<language_model::ToolUseLimitReachedError>()
|
||||||
);
|
);
|
||||||
|
|
||||||
thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), vec!["ghi"], cx)
|
thread.send(UserMessageId::new(), vec!["ghi"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let completion = fake_model.pending_completions().pop().unwrap();
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -790,6 +814,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
.unwrap()
|
||||||
.collect()
|
.collect()
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
@ -857,10 +882,12 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
// Test that test-1 profile (default) has echo and delay tools
|
// Test that test-1 profile (default) has echo and delay tools
|
||||||
thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.set_profile(AgentProfileId("test-1".into()));
|
thread.set_profile(AgentProfileId("test-1".into()));
|
||||||
thread.send(UserMessageId::new(), ["test"], cx);
|
thread.send(UserMessageId::new(), ["test"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
let mut pending_completions = fake_model.pending_completions();
|
let mut pending_completions = fake_model.pending_completions();
|
||||||
|
@ -875,10 +902,12 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
||||||
fake_model.end_last_completion_stream();
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
||||||
thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.set_profile(AgentProfileId("test-2".into()));
|
thread.set_profile(AgentProfileId("test-2".into()));
|
||||||
thread.send(UserMessageId::new(), ["test2"], cx)
|
thread.send(UserMessageId::new(), ["test2"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let mut pending_completions = fake_model.pending_completions();
|
let mut pending_completions = fake_model.pending_completions();
|
||||||
assert_eq!(pending_completions.len(), 1);
|
assert_eq!(pending_completions.len(), 1);
|
||||||
|
@ -896,7 +925,8 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
||||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||||
|
|
||||||
let mut events = thread.update(cx, |thread, cx| {
|
let mut events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(InfiniteTool);
|
thread.add_tool(InfiniteTool);
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(
|
thread.send(
|
||||||
|
@ -904,7 +934,8 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
["Call the echo tool, then call the infinite tool, then explain their output"],
|
["Call the echo tool, then call the infinite tool, then explain their output"],
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Wait until both tools are called.
|
// Wait until both tools are called.
|
||||||
let mut expected_tools = vec!["Echo", "Infinite Tool"];
|
let mut expected_tools = vec!["Echo", "Infinite Tool"];
|
||||||
|
@ -960,6 +991,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
.unwrap()
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.await;
|
.await;
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
|
@ -978,16 +1010,20 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let events_1 = thread.update(cx, |thread, cx| {
|
let events_1 = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
|
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
let events_2 = thread.update(cx, |thread, cx| {
|
let events_2 = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
|
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
|
||||||
fake_model
|
fake_model
|
||||||
|
@ -1005,9 +1041,11 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let events_1 = thread.update(cx, |thread, cx| {
|
let events_1 = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
|
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
|
||||||
fake_model
|
fake_model
|
||||||
|
@ -1015,9 +1053,11 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
|
||||||
fake_model.end_last_completion_stream();
|
fake_model.end_last_completion_stream();
|
||||||
let events_1 = events_1.collect::<Vec<_>>().await;
|
let events_1 = events_1.collect::<Vec<_>>().await;
|
||||||
|
|
||||||
let events_2 = thread.update(cx, |thread, cx| {
|
let events_2 = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
|
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
|
||||||
fake_model
|
fake_model
|
||||||
|
@ -1034,9 +1074,11 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let events = thread.update(cx, |thread, cx| {
|
let events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello"], cx)
|
thread.send(UserMessageId::new(), ["Hello"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1082,9 +1124,11 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let message_id = UserMessageId::new();
|
let message_id = UserMessageId::new();
|
||||||
thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(message_id.clone(), ["Hello"], cx)
|
thread.send(message_id.clone(), ["Hello"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1123,9 +1167,11 @@ async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Ensure we can still send a new message after truncation.
|
// Ensure we can still send a new message after truncation.
|
||||||
thread.update(cx, |thread, cx| {
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hi"], cx)
|
thread.send(UserMessageId::new(), ["Hi"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.to_markdown(),
|
thread.to_markdown(),
|
||||||
|
@ -1291,9 +1337,11 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||||
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
|
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let mut events = thread.update(cx, |thread, cx| {
|
let mut events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Think"], cx)
|
thread.send(UserMessageId::new(), ["Think"], cx)
|
||||||
});
|
})
|
||||||
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
// Simulate streaming partial input.
|
// Simulate streaming partial input.
|
||||||
|
@ -1506,7 +1554,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
action_log,
|
action_log,
|
||||||
templates,
|
templates,
|
||||||
model.clone(),
|
Some(model.clone()),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
|
@ -469,7 +469,7 @@ pub struct Thread {
|
||||||
profile_id: AgentProfileId,
|
profile_id: AgentProfileId,
|
||||||
project_context: Rc<RefCell<ProjectContext>>,
|
project_context: Rc<RefCell<ProjectContext>>,
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Option<Arc<dyn LanguageModel>>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
}
|
}
|
||||||
|
@ -481,7 +481,7 @@ impl Thread {
|
||||||
context_server_registry: Entity<ContextServerRegistry>,
|
context_server_registry: Entity<ContextServerRegistry>,
|
||||||
action_log: Entity<ActionLog>,
|
action_log: Entity<ActionLog>,
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Option<Arc<dyn LanguageModel>>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||||
|
@ -512,12 +512,12 @@ impl Thread {
|
||||||
&self.action_log
|
&self.action_log
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn model(&self) -> &Arc<dyn LanguageModel> {
|
pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
|
||||||
&self.model
|
self.model.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
|
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
|
||||||
self.model = model;
|
self.model = Some(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn completion_mode(&self) -> CompletionMode {
|
pub fn completion_mode(&self) -> CompletionMode {
|
||||||
|
@ -575,6 +575,7 @@ impl Thread {
|
||||||
&mut self,
|
&mut self,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
|
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
|
||||||
|
anyhow::ensure!(self.model.is_some(), "Model not set");
|
||||||
anyhow::ensure!(
|
anyhow::ensure!(
|
||||||
self.tool_use_limit_reached,
|
self.tool_use_limit_reached,
|
||||||
"can only resume after tool use limit is reached"
|
"can only resume after tool use limit is reached"
|
||||||
|
@ -584,7 +585,7 @@ impl Thread {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
|
||||||
log::info!("Total messages in thread: {}", self.messages.len());
|
log::info!("Total messages in thread: {}", self.messages.len());
|
||||||
Ok(self.run_turn(cx))
|
self.run_turn(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||||
|
@ -595,11 +596,13 @@ impl Thread {
|
||||||
id: UserMessageId,
|
id: UserMessageId,
|
||||||
content: impl IntoIterator<Item = T>,
|
content: impl IntoIterator<Item = T>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
|
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>
|
||||||
where
|
where
|
||||||
T: Into<UserMessageContent>,
|
T: Into<UserMessageContent>,
|
||||||
{
|
{
|
||||||
log::info!("Thread::send called with model: {:?}", self.model.name());
|
let model = self.model().context("No language model configured")?;
|
||||||
|
|
||||||
|
log::info!("Thread::send called with model: {:?}", model.name());
|
||||||
self.advance_prompt_id();
|
self.advance_prompt_id();
|
||||||
|
|
||||||
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
|
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
|
||||||
|
@ -616,10 +619,10 @@ impl Thread {
|
||||||
fn run_turn(
|
fn run_turn(
|
||||||
&mut self,
|
&mut self,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
|
||||||
self.cancel();
|
self.cancel();
|
||||||
|
|
||||||
let model = self.model.clone();
|
let model = self.model.clone().context("No language model configured")?;
|
||||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||||
let event_stream = AgentResponseEventStream(events_tx);
|
let event_stream = AgentResponseEventStream(events_tx);
|
||||||
let message_ix = self.messages.len().saturating_sub(1);
|
let message_ix = self.messages.len().saturating_sub(1);
|
||||||
|
@ -637,7 +640,7 @@ impl Thread {
|
||||||
);
|
);
|
||||||
let request = this.update(cx, |this, cx| {
|
let request = this.update(cx, |this, cx| {
|
||||||
this.build_completion_request(completion_intent, cx)
|
this.build_completion_request(completion_intent, cx)
|
||||||
})?;
|
})??;
|
||||||
|
|
||||||
log::info!("Calling model.stream_completion");
|
log::info!("Calling model.stream_completion");
|
||||||
let mut events = model.stream_completion(request, cx).await?;
|
let mut events = model.stream_completion(request, cx).await?;
|
||||||
|
@ -729,7 +732,7 @@ impl Thread {
|
||||||
.ok();
|
.ok();
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
events_rx
|
Ok(events_rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_system_message(&self) -> LanguageModelRequestMessage {
|
pub fn build_system_message(&self) -> LanguageModelRequestMessage {
|
||||||
|
@ -917,7 +920,7 @@ impl Thread {
|
||||||
status: Some(acp::ToolCallStatus::InProgress),
|
status: Some(acp::ToolCallStatus::InProgress),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
});
|
});
|
||||||
let supports_images = self.model.supports_images();
|
let supports_images = self.model().map_or(false, |model| model.supports_images());
|
||||||
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
||||||
log::info!("Running tool {}", tool_use.name);
|
log::info!("Running tool {}", tool_use.name);
|
||||||
Some(cx.foreground_executor().spawn(async move {
|
Some(cx.foreground_executor().spawn(async move {
|
||||||
|
@ -1005,7 +1008,9 @@ impl Thread {
|
||||||
&self,
|
&self,
|
||||||
completion_intent: CompletionIntent,
|
completion_intent: CompletionIntent,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> LanguageModelRequest {
|
) -> Result<LanguageModelRequest> {
|
||||||
|
let model = self.model().context("No language model configured")?;
|
||||||
|
|
||||||
log::debug!("Building completion request");
|
log::debug!("Building completion request");
|
||||||
log::debug!("Completion intent: {:?}", completion_intent);
|
log::debug!("Completion intent: {:?}", completion_intent);
|
||||||
log::debug!("Completion mode: {:?}", self.completion_mode);
|
log::debug!("Completion mode: {:?}", self.completion_mode);
|
||||||
|
@ -1021,9 +1026,7 @@ impl Thread {
|
||||||
Some(LanguageModelRequestTool {
|
Some(LanguageModelRequestTool {
|
||||||
name: tool_name,
|
name: tool_name,
|
||||||
description: tool.description().to_string(),
|
description: tool.description().to_string(),
|
||||||
input_schema: tool
|
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
|
||||||
.input_schema(self.model.tool_input_format())
|
|
||||||
.log_err()?,
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
|
@ -1042,20 +1045,22 @@ impl Thread {
|
||||||
tools,
|
tools,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: AgentSettings::temperature_for_model(self.model(), cx),
|
temperature: AgentSettings::temperature_for_model(&model, cx),
|
||||||
thinking_allowed: true,
|
thinking_allowed: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
log::debug!("Completion request built successfully");
|
log::debug!("Completion request built successfully");
|
||||||
request
|
Ok(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
|
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
|
||||||
|
let model = self.model().context("No language model configured")?;
|
||||||
|
|
||||||
let profile = AgentSettings::get_global(cx)
|
let profile = AgentSettings::get_global(cx)
|
||||||
.profiles
|
.profiles
|
||||||
.get(&self.profile_id)
|
.get(&self.profile_id)
|
||||||
.context("profile not found")?;
|
.context("profile not found")?;
|
||||||
let provider_id = self.model.provider_id();
|
let provider_id = model.provider_id();
|
||||||
|
|
||||||
Ok(self
|
Ok(self
|
||||||
.tools
|
.tools
|
||||||
|
|
|
@ -237,11 +237,17 @@ impl AgentTool for EditFileTool {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let request = self.thread.update(cx, |thread, cx| {
|
let Some(request) = self.thread.update(cx, |thread, cx| {
|
||||||
thread.build_completion_request(CompletionIntent::ToolResults, cx)
|
thread
|
||||||
});
|
.build_completion_request(CompletionIntent::ToolResults, cx)
|
||||||
|
.ok()
|
||||||
|
}) else {
|
||||||
|
return Task::ready(Err(anyhow!("Failed to build completion request")));
|
||||||
|
};
|
||||||
let thread = self.thread.read(cx);
|
let thread = self.thread.read(cx);
|
||||||
let model = thread.model().clone();
|
let Some(model) = thread.model().cloned() else {
|
||||||
|
return Task::ready(Err(anyhow!("No language model configured")));
|
||||||
|
};
|
||||||
let action_log = thread.action_log().clone();
|
let action_log = thread.action_log().clone();
|
||||||
|
|
||||||
let authorize = self.authorize(&input, &event_stream, cx);
|
let authorize = self.authorize(&input, &event_stream, cx);
|
||||||
|
@ -520,7 +526,7 @@ mod tests {
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
action_log,
|
action_log,
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model,
|
Some(model),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -717,7 +723,7 @@ mod tests {
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
Some(model.clone()),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -853,7 +859,7 @@ mod tests {
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
Some(model.clone()),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -979,7 +985,7 @@ mod tests {
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
Some(model.clone()),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -1116,7 +1122,7 @@ mod tests {
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
Some(model.clone()),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -1226,7 +1232,7 @@ mod tests {
|
||||||
context_server_registry.clone(),
|
context_server_registry.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
Some(model.clone()),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -1307,7 +1313,7 @@ mod tests {
|
||||||
context_server_registry.clone(),
|
context_server_registry.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
Some(model.clone()),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -1391,7 +1397,7 @@ mod tests {
|
||||||
context_server_registry.clone(),
|
context_server_registry.clone(),
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
Some(model.clone()),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -1472,7 +1478,7 @@ mod tests {
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
action_log.clone(),
|
action_log.clone(),
|
||||||
Templates::new(),
|
Templates::new(),
|
||||||
model.clone(),
|
Some(model.clone()),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
|
@ -94,7 +94,9 @@ impl ProfileProvider for Entity<agent2::Thread> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn profiles_supported(&self, cx: &App) -> bool {
|
fn profiles_supported(&self, cx: &App) -> bool {
|
||||||
self.read(cx).model().supports_tools()
|
self.read(cx)
|
||||||
|
.model()
|
||||||
|
.map_or(false, |model| model.supports_tools())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2475,7 +2477,10 @@ impl AcpThreadView {
|
||||||
fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||||
let thread = self.as_native_thread(cx)?.read(cx);
|
let thread = self.as_native_thread(cx)?.read(cx);
|
||||||
|
|
||||||
if !thread.model().supports_burn_mode() {
|
if thread
|
||||||
|
.model()
|
||||||
|
.map_or(true, |model| !model.supports_burn_mode())
|
||||||
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3219,7 +3224,10 @@ impl AcpThreadView {
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Option<Callout> {
|
) -> Option<Callout> {
|
||||||
let thread = self.as_native_thread(cx)?;
|
let thread = self.as_native_thread(cx)?;
|
||||||
let supports_burn_mode = thread.read(cx).model().supports_burn_mode();
|
let supports_burn_mode = thread
|
||||||
|
.read(cx)
|
||||||
|
.model()
|
||||||
|
.map_or(false, |model| model.supports_burn_mode());
|
||||||
|
|
||||||
let focus_handle = self.focus_handle(cx);
|
let focus_handle = self.focus_handle(cx);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue