Skip to content

Commit

Permalink
Merge pull request #332 from EricLBuehler/device_map_fixes
Browse files Browse the repository at this point in the history
Fixes and verbosity improvements for device mapping
  • Loading branch information
EricLBuehler committed May 19, 2024
2 parents ac5dd0f + 9fc025d commit 8eafc85
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 5 deletions.
1 change: 0 additions & 1 deletion mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ fn main() -> anyhow::Result<()> {
candle_core::utils::with_f16c()
);
info!("Sampling method: penalties -> temperature -> topk -> topp -> multinomial");
info!("Loading model `{}` on {device:?}...", loader.get_id());
if use_flash_attn {
info!("Using flash attention.");
}
Expand Down
10 changes: 7 additions & 3 deletions mistralrs-core/src/device_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ impl DeviceMapMetadata {
}));
};
// How many host (cpu) layers, defaulting to automatically filling the rest.
let n_host_layers = self.host_layers.unwrap_or(model_layers - n_device_layers);
// If n_device_layers > model_layers, n_host_layers = 0
let n_host_layers = self
.host_layers
.unwrap_or(model_layers.saturating_sub(n_device_layers));
if n_device_layers + n_host_layers != model_layers {
candle_core::bail!("Expected the number of device ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
candle_core::bail!("Expected the number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
}
info!("Using {n_device_layers} layers on device and {n_host_layers} on host.");
info!("Model has {model_layers} repeating layers.");
info!("Using {n_device_layers} repeating layers on GPU and {n_host_layers} repeating layers on host.");
let mut combined = vec![device.clone(); n_device_layers];
// Always put the CPU layers at the end so that we reduce dtoh and htod copies
combined.extend(vec![Device::Cpu; n_host_layers]);
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ impl Loader for GGMLLoader {
if !mapper.is_dummy() {
warn!("GGML models do not support device mapping. Device mapping will not work. Please consider using a GGUF model.");
}
info!("Loading model `{}` on {device:?}...", self.get_id());

let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
let model = ggml_file::Content::read(&mut file, device)
Expand Down
5 changes: 5 additions & 0 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ impl Loader for GGUFLoader {
"You are trying to in-situ quantize a GGUF model. This will not do anything."
);
}
// Otherwise, the device mapper will print it
if mapper.is_dummy() {
info!("Loading model `{}` on {device:?}...", self.get_id());
}

let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
let model = gguf_file::Content::read(&mut file)
.map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?;
Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ impl Loader for NormalLoader {
} else {
DType::F32
};
// Otherwise, the device mapper will print it
if mapper.is_dummy() {
info!("Loading model `{}` on {device:?}...", self.get_id());
}

info!(
"Model config: {:?}",
Expand Down
1 change: 0 additions & 1 deletion mistralrs-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ async fn main() -> Result<()> {
candle_core::utils::with_f16c()
);
info!("Sampling method: penalties -> temperature -> topk -> topp -> multinomial");
info!("Loading model `{}` on {device:?}...", loader.get_id());
if use_flash_attn {
info!("Using flash attention.");
}
Expand Down

0 comments on commit 8eafc85

Please sign in to comment.