Allow Anthropic custom models to override temperature (#18160)

Release Notes:

- Allow Anthropic custom models to override "temperature"

This also centralized the defaulting of "temperature" to be inside of
each model's `into_x` call instead of being sprinkled around the code.
This commit is contained in:
Roy Williams 2024-09-20 16:59:12 -04:00 committed by GitHub
parent 7d62fda5a3
commit 5905fbb9ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 54 additions and 17 deletions

View file

@ -49,6 +49,7 @@ pub enum Model {
/// Indicates whether this custom model supports caching. /// Indicates whether this custom model supports caching.
cache_configuration: Option<AnthropicModelCacheConfiguration>, cache_configuration: Option<AnthropicModelCacheConfiguration>,
max_output_tokens: Option<u32>, max_output_tokens: Option<u32>,
default_temperature: Option<f32>,
}, },
} }
@ -124,6 +125,19 @@ impl Model {
} }
} }
pub fn default_temperature(&self) -> f32 {
match self {
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 1.0,
Self::Custom {
default_temperature,
..
} => default_temperature.unwrap_or(1.0),
}
}
pub fn tool_model_id(&self) -> &str { pub fn tool_model_id(&self) -> &str {
if let Self::Custom { if let Self::Custom {
tool_override: Some(tool_override), tool_override: Some(tool_override),

View file

@ -2180,7 +2180,7 @@ impl Context {
messages: Vec::new(), messages: Vec::new(),
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: 1.0, temperature: None,
}; };
for message in self.messages(cx) { for message in self.messages(cx) {
if message.status != MessageStatus::Done { if message.status != MessageStatus::Done {

View file

@ -2732,7 +2732,7 @@ impl CodegenAlternative {
messages, messages,
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: 1., temperature: None,
}) })
} }

View file

@ -796,7 +796,7 @@ impl PromptLibrary {
}], }],
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: 1., temperature: None,
}, },
cx, cx,
) )

View file

@ -216,7 +216,7 @@ async fn commands_for_summaries(
}], }],
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: 1.0, temperature: None,
}; };
while let Some(current_summaries) = stack.pop() { while let Some(current_summaries) = stack.pop() {

View file

@ -284,7 +284,7 @@ impl TerminalInlineAssistant {
messages, messages,
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: 1.0, temperature: None,
}) })
} }

View file

@ -51,6 +51,7 @@ pub struct AvailableModel {
/// Configuration of Anthropic's caching API. /// Configuration of Anthropic's caching API.
pub cache_configuration: Option<LanguageModelCacheConfiguration>, pub cache_configuration: Option<LanguageModelCacheConfiguration>,
pub max_output_tokens: Option<u32>, pub max_output_tokens: Option<u32>,
pub default_temperature: Option<f32>,
} }
pub struct AnthropicLanguageModelProvider { pub struct AnthropicLanguageModelProvider {
@ -200,6 +201,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
} }
}), }),
max_output_tokens: model.max_output_tokens, max_output_tokens: model.max_output_tokens,
default_temperature: model.default_temperature,
}, },
); );
} }
@ -375,8 +377,11 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = let request = request.into_anthropic(
request.into_anthropic(self.model.id().into(), self.model.max_output_tokens()); self.model.id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(),
);
let request = self.stream_completion(request, cx); let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?; let response = request.await.map_err(|err| anyhow!(err))?;
@ -405,6 +410,7 @@ impl LanguageModel for AnthropicModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let mut request = request.into_anthropic( let mut request = request.into_anthropic(
self.model.tool_model_id().into(), self.model.tool_model_id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(), self.model.max_output_tokens(),
); );
request.tool_choice = Some(anthropic::ToolChoice::Tool { request.tool_choice = Some(anthropic::ToolChoice::Tool {

View file

@ -87,6 +87,8 @@ pub struct AvailableModel {
pub tool_override: Option<String>, pub tool_override: Option<String>,
/// Indicates whether this custom model supports caching. /// Indicates whether this custom model supports caching.
pub cache_configuration: Option<LanguageModelCacheConfiguration>, pub cache_configuration: Option<LanguageModelCacheConfiguration>,
/// The default temperature to use for this model.
pub default_temperature: Option<f32>,
} }
pub struct CloudLanguageModelProvider { pub struct CloudLanguageModelProvider {
@ -255,6 +257,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
min_total_token: config.min_total_token, min_total_token: config.min_total_token,
} }
}), }),
default_temperature: model.default_temperature,
max_output_tokens: model.max_output_tokens, max_output_tokens: model.max_output_tokens,
}), }),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
@ -516,7 +519,11 @@ impl LanguageModel for CloudLanguageModel {
match &self.model { match &self.model {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
let request = request.into_anthropic(model.id().into(), model.max_output_tokens()); let request = request.into_anthropic(
model.id().into(),
model.default_temperature(),
model.max_output_tokens(),
);
let client = self.client.clone(); let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
@ -642,8 +649,11 @@ impl LanguageModel for CloudLanguageModel {
match &self.model { match &self.model {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
let mut request = let mut request = request.into_anthropic(
request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens()); model.tool_model_id().into(),
model.default_temperature(),
model.max_output_tokens(),
);
request.tool_choice = Some(anthropic::ToolChoice::Tool { request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(), name: tool_name.clone(),
}); });

View file

@ -235,7 +235,7 @@ impl OllamaLanguageModel {
options: Some(ChatOptions { options: Some(ChatOptions {
num_ctx: Some(self.model.max_tokens), num_ctx: Some(self.model.max_tokens),
stop: Some(request.stop), stop: Some(request.stop),
temperature: Some(request.temperature), temperature: request.temperature.or(Some(1.0)),
..Default::default() ..Default::default()
}), }),
tools: vec![], tools: vec![],

View file

@ -236,7 +236,7 @@ pub struct LanguageModelRequest {
pub messages: Vec<LanguageModelRequestMessage>, pub messages: Vec<LanguageModelRequestMessage>,
pub tools: Vec<LanguageModelRequestTool>, pub tools: Vec<LanguageModelRequestTool>,
pub stop: Vec<String>, pub stop: Vec<String>,
pub temperature: f32, pub temperature: Option<f32>,
} }
impl LanguageModelRequest { impl LanguageModelRequest {
@ -262,7 +262,7 @@ impl LanguageModelRequest {
.collect(), .collect(),
stream, stream,
stop: self.stop, stop: self.stop,
temperature: self.temperature, temperature: self.temperature.unwrap_or(1.0),
max_tokens: max_output_tokens, max_tokens: max_output_tokens,
tools: Vec::new(), tools: Vec::new(),
tool_choice: None, tool_choice: None,
@ -290,7 +290,7 @@ impl LanguageModelRequest {
candidate_count: Some(1), candidate_count: Some(1),
stop_sequences: Some(self.stop), stop_sequences: Some(self.stop),
max_output_tokens: None, max_output_tokens: None,
temperature: Some(self.temperature as f64), temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
top_p: None, top_p: None,
top_k: None, top_k: None,
}), }),
@ -298,7 +298,12 @@ impl LanguageModelRequest {
} }
} }
pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request { pub fn into_anthropic(
self,
model: String,
default_temperature: f32,
max_output_tokens: u32,
) -> anthropic::Request {
let mut new_messages: Vec<anthropic::Message> = Vec::new(); let mut new_messages: Vec<anthropic::Message> = Vec::new();
let mut system_message = String::new(); let mut system_message = String::new();
@ -400,7 +405,7 @@ impl LanguageModelRequest {
tool_choice: None, tool_choice: None,
metadata: None, metadata: None,
stop_sequences: Vec::new(), stop_sequences: Vec::new(),
temperature: Some(self.temperature), temperature: self.temperature.or(Some(default_temperature)),
top_k: None, top_k: None,
top_p: None, top_p: None,
} }

View file

@ -99,6 +99,7 @@ impl AnthropicSettingsContent {
tool_override, tool_override,
cache_configuration, cache_configuration,
max_output_tokens, max_output_tokens,
default_temperature,
} => Some(provider::anthropic::AvailableModel { } => Some(provider::anthropic::AvailableModel {
name, name,
display_name, display_name,
@ -112,6 +113,7 @@ impl AnthropicSettingsContent {
}, },
), ),
max_output_tokens, max_output_tokens,
default_temperature,
}), }),
_ => None, _ => None,
}) })

View file

@ -562,7 +562,7 @@ impl SummaryIndex {
}], }],
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: 1.0, temperature: None,
}; };
let code_len = code.len(); let code_len = code.len();