arrayfire-rust
arrayfire-rust copied to clipboard
Feature Make argument "b" of replace_scalar type T
Description
pub fn replace_scalar<T>(a: &mut Array<T>, cond: &Array
Shouldn't type of input "b" == T and not f64?
@BA8F0D39 I think I didn't refactor the scalar argument(when Array was refactored to be generic) to avoid an up-cast at library level from T
to f64
. Given that ArrayFire C-API takes double argument anyway, the up-cast didn't make sense and would only add unnecessary trait bound (something like num crate's ToPrimitive/AsPrimitive) on the input generic type T
which adds an additional trait bound for any user written generic function that uses repalce_scalar
.
Instead, I feel like the caller doing as f64
avoids these trait bounds and any pit falls (not being able to convert to f64) above mechanisms bring.
Feel free to share if you think there is better work around for this problem.
It doesn't seem to insert the values into the array correctly for u64
let A_dims = arrayfire::Dim4::new(&[6,1,1,1]);
let mut A = arrayfire::randu::<u64>(A_dims);
let bool_cpu: Vec<bool> = vec![true, true, false, true, false, true];
let mut boolarr = arrayfire::Array::new(&bool_cpu, arrayfire::Dim4::new(&[bool_cpu.len() as u64, 1, 1, 1]));
arrayfire::replace_scalar(&mut A, &boolarr, 18446744073709551614.0);
arrayfire::print_gen("A".to_string(), &A,Some(20));
A [6 1 1 1] 16242730742183356629 6679402142117448868 0 7657373526131797801 0 511886651086640123
@BA8F0D39 I was able to reproduce the issue. Shall update here once I have more info on the root cause.
Found the problem, the u64
happens to be quite large that it isn't accurately represented using double
/f64
. If I reduce the constant by a thousand or so, then the value is getting copied correctly. I think this needs to be correctly handled in C-API of arrayfire by using different functions with appropriate type suffixes to use the exact input scalar types rather than approximate using double.
You can use the below workaround until we handle this use case appropriately in upstream project.
let cnst = constant(18446744073709551614_u64, A_dims);
arrayfire::replace(&mut A, &boolarr, &cnst);
Sorry about this inconvenience and thank you for reporting it! I shall post another update once I have fix in upstream available.