diff --git a/src/python/m5/proxy.py b/src/python/m5/proxy.py index 9d91b84d9c..d15b6f297c 100644 --- a/src/python/m5/proxy.py +++ b/src/python/m5/proxy.py @@ -55,7 +55,7 @@ class BaseProxy(object): def __init__(self, search_self, search_up): self._search_self = search_self self._search_up = search_up - self._multipliers = [] + self._ops = [] def __str__(self): if self._search_self and not self._search_up: @@ -72,29 +72,48 @@ class BaseProxy(object): "cannot set attribute '%s' on proxy object" % attr) super(BaseProxy, self).__setattr__(attr, value) - # support for multiplying proxies by constants or other proxies to - # other params - def __mul__(self, other): - if not (isinstance(other, (int, long, float)) or isproxy(other)): - raise TypeError( - "Proxy multiplier must be a constant or a proxy to a param") - self._multipliers.append(other) - return self + def _gen_op(operation): + def op(self, operand): + if not (isinstance(operand, (int, long, float)) or \ + isproxy(operand)): + raise TypeError( + "Proxy operand must be a constant or a proxy to a param") + self._ops.append((operation, operand)) + return self + return op + # Support for multiplying proxies by either constants or other proxies + __mul__ = _gen_op(lambda operand_a, operand_b : operand_a * operand_b) __rmul__ = __mul__ - def _mulcheck(self, result, base): + # Support for dividing proxies by either constants or other proxies + __truediv__ = _gen_op(lambda operand_a, operand_b : + operand_a / operand_b) + __floordiv__ = _gen_op(lambda operand_a, operand_b : + operand_a // operand_b) + + # Support for dividing constants by proxies + __rtruediv__ = _gen_op(lambda operand_a, operand_b : + operand_b / operand_a.getValue()) + __rfloordiv__ = _gen_op(lambda operand_a, operand_b : + operand_b // operand_a.getValue()) + + # After all the operators and operands have been defined, this function + # should be called to perform the actual operation + def _opcheck(self, result, base): from . import params - for multiplier in self._multipliers: - if isproxy(multiplier): - multiplier = multiplier.unproxy(base) - # assert that we are multiplying with a compatible - # param - if not isinstance(multiplier, params.NumericParamValue): - raise TypeError( - "Proxy multiplier must be a numerical param") - multiplier = multiplier.getValue() - result = result * multiplier + for operation, operand in self._ops: + # Get the operand's value + if isproxy(operand): + operand = operand.unproxy(base) + # assert that we are operating with a compatible param + if not isinstance(operand, params.NumericParamValue): + raise TypeError("Proxy operand must be a numerical param") + operand = operand.getValue() + + # Apply the operation + result = operation(result, operand) + return result def unproxy(self, base): @@ -128,7 +147,7 @@ class BaseProxy(object): raise RuntimeError("Cycle in unproxy") result = result.unproxy(obj) - return self._mulcheck(result, base) + return self._opcheck(result, base) def getindex(obj, index): if index == None: