assistant2: Restore tool uses when loading saved threads (#25942)
This PR makes it so tool uses are restored when loading saved threads in Assistant 2. Release Notes: - N/A
This commit is contained in:
parent
6635462f7b
commit
b2add8c803
3 changed files with 120 additions and 8 deletions
|
@ -8,8 +8,9 @@ use futures::StreamExt as _;
|
||||||
use gpui::{App, Context, EventEmitter, SharedString, Task};
|
use gpui::{App, Context, EventEmitter, SharedString, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId,
|
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||||
MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason,
|
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||||
|
Role, StopReason,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use util::{post_inc, TryFutureExt as _};
|
use util::{post_inc, TryFutureExt as _};
|
||||||
|
@ -88,7 +89,7 @@ impl Thread {
|
||||||
completion_count: 0,
|
completion_count: 0,
|
||||||
pending_completions: Vec::new(),
|
pending_completions: Vec::new(),
|
||||||
tools,
|
tools,
|
||||||
tool_use: ToolUseState::default(),
|
tool_use: ToolUseState::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,6 +100,7 @@ impl Thread {
|
||||||
_cx: &mut Context<Self>,
|
_cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let next_message_id = MessageId(saved.messages.len());
|
let next_message_id = MessageId(saved.messages.len());
|
||||||
|
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
id,
|
id,
|
||||||
|
@ -120,7 +122,7 @@ impl Thread {
|
||||||
completion_count: 0,
|
completion_count: 0,
|
||||||
pending_completions: Vec::new(),
|
pending_completions: Vec::new(),
|
||||||
tools,
|
tools,
|
||||||
tool_use: ToolUseState::default(),
|
tool_use,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,6 +191,10 @@ impl Thread {
|
||||||
self.tool_use.tool_uses_for_message(id)
|
self.tool_use.tool_uses_for_message(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||||
|
self.tool_use.tool_results_for_message(id)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||||
self.tool_use.message_has_tool_results(message_id)
|
self.tool_use.message_has_tool_results(message_id)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ use gpui::{
|
||||||
};
|
};
|
||||||
use heed::types::{SerdeBincode, SerdeJson};
|
use heed::types::{SerdeBincode, SerdeJson};
|
||||||
use heed::Database;
|
use heed::Database;
|
||||||
use language_model::Role;
|
use language_model::{LanguageModelToolUseId, Role};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use util::ResultExt as _;
|
use util::ResultExt as _;
|
||||||
|
@ -113,6 +113,24 @@ impl ThreadStore {
|
||||||
id: message.id,
|
id: message.id,
|
||||||
role: message.role,
|
role: message.role,
|
||||||
text: message.text.clone(),
|
text: message.text.clone(),
|
||||||
|
tool_uses: thread
|
||||||
|
.tool_uses_for_message(message.id)
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool_use| SavedToolUse {
|
||||||
|
id: tool_use.id,
|
||||||
|
name: tool_use.name,
|
||||||
|
input: tool_use.input,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
tool_results: thread
|
||||||
|
.tool_results_for_message(message.id)
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool_result| SavedToolResult {
|
||||||
|
tool_use_id: tool_result.tool_use_id.clone(),
|
||||||
|
is_error: tool_result.is_error,
|
||||||
|
content: tool_result.content.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
};
|
};
|
||||||
|
@ -239,11 +257,29 @@ pub struct SavedThread {
|
||||||
pub messages: Vec<SavedMessage>,
|
pub messages: Vec<SavedMessage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct SavedMessage {
|
pub struct SavedMessage {
|
||||||
pub id: MessageId,
|
pub id: MessageId,
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
pub text: String,
|
pub text: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tool_uses: Vec<SavedToolUse>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tool_results: Vec<SavedToolResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct SavedToolUse {
|
||||||
|
pub id: LanguageModelToolUseId,
|
||||||
|
pub name: SharedString,
|
||||||
|
pub input: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct SavedToolResult {
|
||||||
|
pub tool_use_id: LanguageModelToolUseId,
|
||||||
|
pub is_error: bool,
|
||||||
|
pub content: Arc<str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GlobalThreadsDatabase(
|
struct GlobalThreadsDatabase(
|
||||||
|
|
|
@ -7,10 +7,11 @@ use futures::FutureExt as _;
|
||||||
use gpui::{SharedString, Task};
|
use gpui::{SharedString, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
|
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
|
||||||
LanguageModelToolUseId, MessageContent,
|
LanguageModelToolUseId, MessageContent, Role,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::thread::MessageId;
|
use crate::thread::MessageId;
|
||||||
|
use crate::thread_store::SavedMessage;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ToolUse {
|
pub struct ToolUse {
|
||||||
|
@ -28,7 +29,6 @@ pub enum ToolUseStatus {
|
||||||
Error(SharedString),
|
Error(SharedString),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct ToolUseState {
|
pub struct ToolUseState {
|
||||||
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||||
|
@ -37,6 +37,65 @@ pub struct ToolUseState {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolUseState {
|
impl ToolUseState {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
tool_uses_by_assistant_message: HashMap::default(),
|
||||||
|
tool_uses_by_user_message: HashMap::default(),
|
||||||
|
tool_results: HashMap::default(),
|
||||||
|
pending_tool_uses_by_id: HashMap::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_saved_messages(messages: &[SavedMessage]) -> Self {
|
||||||
|
let mut this = Self::new();
|
||||||
|
|
||||||
|
for message in messages {
|
||||||
|
match message.role {
|
||||||
|
Role::Assistant => {
|
||||||
|
if !message.tool_uses.is_empty() {
|
||||||
|
this.tool_uses_by_assistant_message.insert(
|
||||||
|
message.id,
|
||||||
|
message
|
||||||
|
.tool_uses
|
||||||
|
.iter()
|
||||||
|
.map(|tool_use| LanguageModelToolUse {
|
||||||
|
id: tool_use.id.clone(),
|
||||||
|
name: tool_use.name.clone().into(),
|
||||||
|
input: tool_use.input.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Role::User => {
|
||||||
|
if !message.tool_results.is_empty() {
|
||||||
|
let tool_uses_by_user_message = this
|
||||||
|
.tool_uses_by_user_message
|
||||||
|
.entry(message.id)
|
||||||
|
.or_default();
|
||||||
|
|
||||||
|
for tool_result in &message.tool_results {
|
||||||
|
let tool_use_id = tool_result.tool_use_id.clone();
|
||||||
|
|
||||||
|
tool_uses_by_user_message.push(tool_use_id.clone());
|
||||||
|
this.tool_results.insert(
|
||||||
|
tool_use_id.clone(),
|
||||||
|
LanguageModelToolResult {
|
||||||
|
tool_use_id,
|
||||||
|
is_error: tool_result.is_error,
|
||||||
|
content: tool_result.content.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Role::System => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
|
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
|
||||||
self.pending_tool_uses_by_id.values().collect()
|
self.pending_tool_uses_by_id.values().collect()
|
||||||
}
|
}
|
||||||
|
@ -84,6 +143,17 @@ impl ToolUseState {
|
||||||
tool_uses
|
tool_uses
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||||
|
let empty = Vec::new();
|
||||||
|
|
||||||
|
self.tool_uses_by_user_message
|
||||||
|
.get(&message_id)
|
||||||
|
.unwrap_or(&empty)
|
||||||
|
.iter()
|
||||||
|
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||||
self.tool_uses_by_user_message
|
self.tool_uses_by_user_message
|
||||||
.get(&message_id)
|
.get(&message_id)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue