Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,33 +153,71 @@ def _cmwc_random_sequence(num_elements, seed):

# Create constants needed for the algorithm. The constants and notation
# follows from the above reference.
a = tf.tile(tf.constant([3636507990], tf.int64), [parallelism])
b = tf.tile(tf.constant([2**32], tf.int64), [parallelism])
logb_scalar = tf.constant(32, tf.int64)
a = tf.tile(tf.constant([3636507990], tf.uint64), [parallelism])
b = tf.tile(tf.constant([2**32], tf.uint64), [parallelism])
logb_scalar = tf.constant(32, tf.uint64)
logb = tf.tile([logb_scalar], [parallelism])
f = tf.tile(tf.constant([0], dtype=tf.int64), [parallelism])
bits = tf.constant(0, dtype=tf.int64, name='bits')
f = tf.tile(tf.constant([0], dtype=tf.uint64), [parallelism])
bits = tf.constant(0, dtype=tf.uint64, name='bits')

# TensorArray used in tf.while_loop for efficiency.
values = tf.TensorArray(
dtype=tf.float64, size=num_iters, element_shape=[parallelism])
# Iteration counter.
num = tf.constant(0, dtype=tf.int32, name='num')
# TensorFlow constant to be used at multiple places.
val_53 = tf.constant(53, tf.int64, name='val_53')
val_53 = tf.constant(53, tf.uint64, name='val_53')

# Construct initial sequence of seeds.
# From a single input seed, we construct multiple starting seeds for the
# sequences to be computed in parallel.
def next_seed_fn(i, val, q):
val = val**7 + val**6 + 1 # PRBS7.
"""Generates the next seed using a 7-bit LFSR.

This function implements a proper 7-bit Fibonacci LFSR with the polynomial
x^7 + x^6 + 1. It takes the lower 7 bits of `val` as the current state,
computes the next state, and writes it to the TensorArray `q`.

Args:
i: The current index in the while loop.
val: The current seed value (tf.uint64). The lower 7 bits are used as
the LFSR state.
q: The tf.TensorArray to write the generated seed into.

Returns:
A tuple of (i + 1, new_val, q), where `new_val` is the next state of the
LFSR.
"""
state = tf.bitwise.bitwise_and(val, tf.constant(0x7F, tf.uint64))
# Avoid zero state, which is a trapping state for this LFSR polynomial.
state = tf.bitwise.bitwise_or(
state,
tf.cast(tf.equal(state, tf.constant(0, tf.uint64)), tf.uint64)
)
# Feedback bit = bit 7 (index 6) ^ bit 6 (index 5)
feedback = tf.bitwise.bitwise_and(
tf.bitwise.bitwise_xor(
tf.bitwise.right_shift(state, tf.constant(6, tf.uint64)),
tf.bitwise.right_shift(state, tf.constant(5, tf.uint64))
),
tf.constant(1, tf.uint64)
)
# Shift left and insert feedback
val = tf.bitwise.bitwise_and(
tf.bitwise.bitwise_or(
tf.bitwise.left_shift(state, tf.constant(1, tf.uint64)),
feedback
),
tf.constant(0x7F, tf.uint64)
)
q = q.write(i, val)
return i + 1, val, q

q = tf.TensorArray(dtype=tf.int64, size=parallelism, element_shape=())
q = tf.TensorArray(dtype=tf.uint64, size=parallelism, element_shape=())
seed_u64 = tf.cast(seed, tf.uint64)
_, _, q = tf.while_loop(lambda i, _, __: i < parallelism,
next_seed_fn,
[tf.constant(0), seed, q])
[tf.constant(0), seed_u64, q])
c = q = q.stack()

# The random sequence generation code.
Expand All @@ -193,9 +231,10 @@ def cmwc_step(f, bits, q, c, num, values):
f.set_shape((1,)) # Correct for failed shape inference.
bits += logb_scalar
def add_val(bits, f, values, num):
mask_53 = tf.constant(2**53 - 1, tf.uint64)
new_val = tf.cast(
tf.bitwise.bitwise_and(f, (2**val_53 - 1)),
dtype=tf.float64) * (1 / 2**val_53)
tf.bitwise.bitwise_and(f, mask_53),
dtype=tf.float64) * (1.0 / 2.0**53)
values = values.write(num, new_val)
f += tf.bitwise.right_shift(f, val_53)
bits -= val_53
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ def test_tf_int32_seed_raises(self):
with self.assertRaisesRegex(TypeError, 'tf.int64 Tensor'):
tf_utils._cmwc_random_sequence(10, tf.constant(123, tf.int32))

def test_reproduction_b511305971(self):
"""Verifies that the PRNG does not produce negative states or bounds violations."""
# Reproduction steps from b/511305971
sequence = tf_utils._cmwc_random_sequence(
1000, tf.constant(12345, dtype=tf.int64))
sequence = self.evaluate(sequence)
self.assertAllGreaterEqual(sequence, 0.0)
self.assertAllLessEqual(sequence, 1.0)


class RandomSignsCMWCTests(tf.test.TestCase, parameterized.TestCase):
"""Tests for `random_signs_cmwc` method."""
Expand Down
Loading