from theano import function
from theano import tensor as T
from theano import shared
state = shared(0)
inc = T.iscalar()
accumulator = function([inc],state,updates=[(state,state+inc)])
decrementor = function([inc],state,updates=[(state,state-inc)])
fn_of_state = state*2+inc
foo = T.scalar(dtype=state.dtype)
skip_shared = function([inc,foo],fn_of_state,givens=[(state,foo)])
print state.get_value()
print skip_shared(1,3)
----------------------------------------------------------------------
givens=[(state,foo)]은 state를 foo로 대체해준다.
따라서, inc =1 , foo= 3 이였고, given으로 state를 foo즉 3으로 바꿔서 fn_of_state가 계산된다.
그리고 원래의 state는 바뀌지 안는다.
이 givens는 shared valuable뿐만 아니라, 모든 심볼릭 변수들을 대체할 수 있다.
댓글 없음:
댓글 쓰기