#! /usr/bin/python 

p_s = {}

p_s['no'] = 0.8
p_s['light'] = 0.15
p_s['heavy'] = 0.05

p_c_s = {}

# two-dim is more complicated, need to initialize each row

c_labels = ['none', 'benign', 'malign']
for c in c_labels: p_c_s[c] = {}

# now we can start filling in, we'll do it row-wise
p_c_s['none']['no'] = 0.96
p_c_s['none']['light'] = 0.88
p_c_s['none']['heavy'] = 0.60

p_c_s['benign']['no'] = 0.03
p_c_s['benign']['light'] = 0.08
p_c_s['benign']['heavy'] = 0.25

p_c_s['malign']['no'] = 0.01
p_c_s['malign']['light'] = 0.04
p_c_s['malign']['heavy'] = 0.15

# check that cols add up to 1

for s in p_s.keys():
    # for each smoking level check that all cancer outcomes add up to 1
    # note: never compare float numbers by equality!!
    assert(abs(sum(p_c_s[c][s] for c in p_c_s) - 1) < 0.0001)

# print p(c,s)

print "\t".join([''] + p_s.keys())

for c in p_c_s:
    # line contains list with c label followed by row-wise entries; convert these to str
    line = [c] + [str(p_c_s[c][s]) for s in p_c_s[c]]
    # print jointly
    print "\t".join(line)

print

# compute p(c,s) = p(c|s) * p(s)

p_cs = {}

# initialize rows
for c in p_c_s.keys(): p_cs[c] = {}

for c in p_c_s:
    for s in p_c_s[c]:
        p_cs[c][s] = p_c_s[c][s] * p_s[s]

# ...aaand printout

print "\t".join([''] + p_s.keys())

for c in p_cs:
    # line contains list with c label followed by row-wise entries; convert these to str
    line = [c] + [str(p_cs[c][s]) for s in p_cs[c]]
    # print jointly
    print "\t".join(line)

print

# compute p(s|c) = p(c,s)/p(c)

# need p(c) for this:

p_c = {}

for c in p_cs:
    p_c[c] = sum(p_cs[c][s] for s in p_cs[c])


# initialize p_s_c; first the rows
p_s_c = {}
for s in p_s.keys(): p_s_c[s] = {}

for s in p_s:
    for c in p_cs:
        assert(p_c[c] != 0)
        p_s_c[s][c] = p_cs[c][s]/p_c[c]

# ...aaand printout

print "\t".join([''] + p_c.keys())

for s in p_s_c:
    # line contains list with c label followed by row-wise entries; convert these to str
    line = [s] + ["%.4f" % p_s_c[s][c] for c in p_s_c[s]]
    # print jointly
    print "\t".join(line)


