This commit is contained in:
2019-03-20 01:11:31 +04:00
commit a8c4573152
8 changed files with 413 additions and 0 deletions

100
tests/test_all.py Normal file
View File

@@ -0,0 +1,100 @@
from time import time_ns
import numpy as np
from explicit_storage import StoreBulbsNaive, StoreBulbsSqrt
from range_based import StoreRangesBinSearch, StoreRangesNaive
from segment_tree import StoreBulbsSegmentTree
ReferenceSolution = StoreBulbsNaive
TestSolutions = [
StoreBulbsNaive,
StoreBulbsSqrt,
StoreBulbsSegmentTree,
StoreRangesNaive,
StoreRangesBinSearch,
]
def create_actions(n, query_count, toggle_count):
actions = []
for i in range(toggle_count):
j = np.random.randint(n + 1)
i = np.random.randint(n)
if j < n - i:
l = i
r = i + j + 1
else:
l = n - i - 1
r = j
actions.append({
'l': l,
'r': r,
})
for i in range(query_count):
actions.append({
'pos': np.random.randint(n),
})
np.random.shuffle(actions)
reference = ReferenceSolution(n)
for action in actions:
if 'l' in action:
reference.toggle(action['l'], action['r'])
else:
action['answer'] = reference.query(action['pos'])
return actions
def run_test_case(n, toggle_count, query_count):
actions = create_actions(n, query_count, toggle_count)
solutions = [Solution(n) for Solution in TestSolutions]
toggle_times = np.zeros(len(solutions))
query_times = np.zeros(len(solutions))
for solution_i, solution in enumerate(solutions):
for action_i, action in enumerate(actions):
start_time = time_ns()
if 'l' in action:
solution.toggle(action['l'], action['r'])
toggle_times[solution_i] += time_ns() - start_time
else:
assert solution.query(action['pos']) == action['answer'], action['pos']
query_times[solution_i] += time_ns() - start_time
total_times = toggle_times + query_times
print()
print(f'n = {n:6} T = {toggle_count:6} Q = {query_count:6}')
for i in np.argsort(total_times):
solution_name = TestSolutions[i].__name__
toggle_time = toggle_times[i]
query_time = query_times[i]
total_time = total_times[i]
print(
f' {solution_name:40} {toggle_time / 1e6:6.0f}ms + {query_time / 1e6:6.0f}ms = {total_time / 1e6:6.0f}ms')
def test_all():
np.random.seed(0)
run_test_case(1000, 1000, 1000)
run_test_case(10_000, 10, 10_000)
run_test_case(10_000, 10_000, 10)
run_test_case(1_000_000, 1_000, 1_000)