Skip to content

Commit

Permalink
Merge pull request #333 from EricLBuehler/remove_candle_layer_norm
Browse files Browse the repository at this point in the history
Remove candle-layer-norm dep
  • Loading branch information
EricLBuehler committed May 19, 2024
2 parents 1d5f9f3 + 5eff6c2 commit 89d08a6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 29 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = { version = "1.0.80", feature = "std" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", branch = "remove_candle_layer_norm" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", branch = "remove_candle_layer_norm" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-transformers = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0" }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", optional = true }
candle-transformers = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", branch = "remove_candle_layer_norm" }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", optional = true, branch = "remove_candle_layer_norm" }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down
36 changes: 11 additions & 25 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ use candle_core::{
quantized::{gguf_file, QMatMul, QTensor},
DType, Device, IndexOp, Result, Tensor, WithDType,
};
use candle_nn::{
layer_norm::{RmsNormNonQuantized, RmsNormQuantized},
Linear, Module, VarBuilder,
};
use candle_nn::{Linear, Module, VarBuilder};
use once_cell::sync::Lazy;

// (bs, tgt_len, past_kv_len)
Expand All @@ -28,7 +25,6 @@ use crate::{cublaslt::CUBLASLT_HANDLE, models::phi3, INHIBIT_GEMM_F16};

#[derive(Debug, Clone)]
pub struct RmsNorm {
inner: candle_nn::RmsNorm<RmsNormNonQuantized>,
eps: f64,
weight: Tensor,
}
Expand All @@ -37,47 +33,37 @@ impl RmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let inner = candle_nn::rms_norm_non_quant(size, eps, vb)?;
let w = inner.inner().weight().clone();
Ok(Self {
inner,
eps,
weight: w,
})
Ok(Self { eps, weight: w })
}

pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
let inner = candle_nn::RmsNorm::<RmsNormNonQuantized>::new(w.clone(), eps);
Ok(Self {
inner,
eps,
weight: w,
})
Ok(Self { eps, weight: w })
}
}

impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
if x.device().is_cpu() {
// Handle device mapping case
return candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32);
}
self.inner.forward(x)
candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
}
}

#[derive(Debug, Clone)]
pub struct QRmsNorm {
inner: candle_nn::RmsNorm<RmsNormQuantized>,
eps: f64,
weight: Tensor,
}

impl QRmsNorm {
pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
let scale = scale.dequantize(&scale.device())?;
let inner = candle_nn::RmsNorm::<RmsNormQuantized>::new(scale, eps as f64);
Ok(Self { inner })
Ok(Self {
eps: eps as f64,
weight: scale,
})
}

pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.inner.forward(x)
candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
}
}

Expand Down

0 comments on commit 89d08a6

Please sign in to comment.