Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
2.7k views
in Technique[技术] by (71.8m points)

rust - What's the fastest way of finding the index of the maximum value in an array?

I have a 2D array of type f32 (from ndarray::ArrayView2) and I want to find the index of the maximum value in each row, and put the index value into another array.

The equivalent in Python is something like:

import numpy as np

for i in range (0, max_val, batch_size):
   sims = xp.dot(batch, vectors.T) 
   # sims is the dot product of batch and vectors.T
   # the shape is, for example, (1024, 10000)

   best_rows[i: i+batch_size] = sims.argmax(axis = 1)

In Python, the function .argmax is very fast, but I don't see any function like that in Rust. What's the fastest way of doing so?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Consider the easy case of a general Ord type: The answer will differ slightly depending on whether you know the values are Copy or not, but here's the code:

fn position_max_copy<T: Ord + Copy>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by_key(|(_, &value)| value).map(|(idx, _)| idx)
}

fn position_max<T: Ord>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by(|(_, value0), (_, value1)| value0.cmp(value1)).map(|(idx, _)| idx)
}

The basic idea is that we pair [a reference to] each item in the array (really, a slice - it doesn't matter if it's a Vec or an array or something more exotic) with its index, use std::iter::Iterator functions to find the maximum value according to the value only (not the index), then return just the index. If the slice is empty None will be returned. Per the documentation, the rightmost index will be returned; if you need the leftmost, do rev() after enumerate().

rev(), enumerate(), max_by_key(), and max_by() are documented here; slice::iter() is documented here (but that one needs to be on your shortlist of things to recall without documentation as a rust dev); map is Option::map() documented here (ditto). Oh, and cmp is Ord::cmp but most of the time you can use the Copy version which doesn't need it (e.g. if you're comparing integers).


Now here's the catch: f32 isn't Ord because of the way IEEE floats work. Most languages ignore this and have subtly wrong algorithms. The most popular crate to provide a total order on Ord (by declaring all NaN to be equal, and greater than all numbers) seems to be ordered-float. Assuming it's implemented correctly it should be very very lightweight. It does pull in num_traits but this is part of the most popular numerics library so might well be pulled in by other dependencies already.

You'd use it in this case by mapping ordered_float::OrderedFloat (the "constructor" of the tuple type) over the slice iter (slice.iter().map(ordered_float::OrderedFloat)). Since you only want the position of the maximum element, no need to extract the f32 afterward.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...