Source code for rx.linq.observable.average

from rx.core import Observable
from rx.internal import extensionmethod


class AverageValue(object):

    def __init__(self, sum, count):
        self.sum = sum
        self.count = count


@extensionmethod(Observable)
def average(self, key_selector=None):
    """Computes the average of an observable sequence of values that are in
    the sequence or obtained by invoking a transform function on each
    element of the input sequence if present.

    Example
    res = source.average();
    res = source.average(lambda x: x.value)

    :param Observable self: Observable to average.
    :param types.FunctionType key_selector: A transform function to apply to
        each element.

    :returns: An observable sequence containing a single element with the
        average of the sequence of values.
    :rtype: Observable
    """

    if key_selector:
        return self.map(key_selector).average()

    def accumulator(prev, cur):
        return AverageValue(sum=prev.sum+cur, count=prev.count+1)

    def mapper(s):
        if s.count == 0:
            raise Exception('The input sequence was empty')

        return s.sum / float(s.count)

    seed = AverageValue(sum=0, count=0)
    return self.scan(accumulator, seed).last().map(mapper)