You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
920 B
28 lines
920 B
import random
|
|
|
|
|
|
def subsample_data(data, subsample_size):
|
|
"""
|
|
Subsample data. Data is in the form of a tuple of lists.
|
|
"""
|
|
inputs, outputs = data
|
|
assert len(inputs) == len(outputs)
|
|
indices = random.sample(range(len(inputs)), subsample_size)
|
|
inputs = [inputs[i] for i in indices]
|
|
outputs = [outputs[i] for i in indices]
|
|
return inputs, outputs
|
|
|
|
|
|
def create_split(data, split_size):
|
|
"""
|
|
Split data into two parts. Data is in the form of a tuple of lists.
|
|
"""
|
|
inputs, outputs = data
|
|
assert len(inputs) == len(outputs)
|
|
indices = random.sample(range(len(inputs)), split_size)
|
|
inputs1 = [inputs[i] for i in indices]
|
|
outputs1 = [outputs[i] for i in indices]
|
|
inputs2 = [inputs[i] for i in range(len(inputs)) if i not in indices]
|
|
outputs2 = [outputs[i] for i in range(len(inputs)) if i not in indices]
|
|
return (inputs1, outputs1), (inputs2, outputs2)
|