Fast and precise algorithms, with tests

This commit is contained in:
Fedor Lyanguzov
2025-01-22 18:21:39 +03:00
parent 303ddee08f
commit dfd6e21f2b
5 changed files with 185 additions and 0 deletions
+43
View File
@@ -0,0 +1,43 @@
from heapq import heapify, heappush, heappop
from .util import mask, get_data, cidr4_to_node, make_cidr4
def solution(cidrs, M):
h = []
d = {}
for ip, l in cidrs:
h.append((32-l, ip, l))
d[(ip, l)] = 0
heapify(h)
while len(h)>M:
x1, ip1, l1 = heappop(h)
x2, ip2, l2 = h[0]
if l1==l2 and ip1 & mask[l1-1] == ip2 & mask[l2-1]:
heappop(h)
if (ip1 & mask[l1-1], l1-1) not in d:
heappush(h, (x1+1, ip1 & mask[l1-1], l1-1))
d[(ip1 & mask[l1-1], l1-1)] = d[(ip1, l1)] + d[(ip2, l2)]
del d[(ip2, l2)]
else:
if (ip1 & mask[l1-1], l1-1) not in d:
heappush(h, (x1+1, ip1 & mask[l1-1], l1-1))
d[(ip1 & mask[l1-1], l1-1)] = d[(ip1, l1)] + 2**x1
del d[(ip1, l1)]
s = sum(d.values())
cidrs = list(d.keys())
return cidrs, s
def main():
M = 20
a = get_data()
b = list(map(cidr4_to_node, a))
cidrs, s = solution(b, M)
cidrs = sorted([make_cidr4(*x) for x in cidrs])
print(cidrs, s, sep='\n')
if __name__=='__main__':
import cProfile
main()
+59
View File
@@ -0,0 +1,59 @@
from .util import mask, get_data, cidr4_to_node, make_cidr4
def f(x, y):
t = x
b = y
if x[1]>y[1]:
t = y
b = x
ip1, l1, a1 = t
ip2, l2, a2 = b
if ip1 & mask[l1] == ip2 & mask[l1]:
return (0, t)
t1 = t2 = 0
while not l1==l2:
t2 += 2**(32-l2)
l2 -= 1
ip2 = ip2 & mask[l2]
while not ip1 & mask[l1-1] == ip2 & mask[l2-1]:
t1 += 2**(32-l1)
l1 -= 1
ip1 = ip1 & mask[l1]
t2 += 2**(32-l2)
l2 -= 1
ip2 = ip2 & mask[l2]
r = (ip1 & mask[l1-1], l1-1, a1+a2+t1+t2)
return (t1+t2, r)
def solution(cidrs, M):
cidrs = sorted((ip, l, 0) for ip, l in cidrs)
while len(cidrs)>M:
t = (None, float('+inf'), None)
for i, (x, y) in enumerate(zip(cidrs, cidrs[1:])):
m, r = f(x, y)
if m<t[1]:
t = (i, m, r)
if m==0:
break
cidrs[t[0]] = t[2]
del cidrs[t[0]+1]
s = sum(x[2] for x in cidrs)
cidrs = [x[:2] for x in cidrs]
return cidrs, s
def main():
M = 20
a = get_data()
b = list(map(cidr4_to_node, a))
cidrs, s = solution(b, M)
cidrs = sorted([make_cidr4(*x) for x in cidrs])
print(cidrs, s, sep='\n')
if __name__=='__main__':
import cProfile
main()
+20
View File
@@ -0,0 +1,20 @@
mask = [((1 << i) - 1) << (32 - i) for i in range(33)]
def get_data(input_file='cidr4.txt'):
with open(input_file, "r") as file:
return file.read().splitlines()
def cidr4_to_node(cidr4: str):
ip_address, mask_len = cidr4.strip().split("/")
mask_len = int(mask_len)
a, b, c, d = list(map(int, ip_address.split(".")))
ip = a * 256**3 + b * 256**2 + c * 256**1 + d * 256**0
return ip, mask_len
def make_cidr4(ip, mask_len) -> str:
lst = [str(ip >> (i << 3) & 0xFF) for i in reversed(range(4))]
ip_address = ".".join(lst)
return f"{ip_address}/{mask_len}"