import matplotlib.pyplot as plt import networkx as nx import subprocess import tkinter as tk from tkinter import filedialog def select_file(): root = tk.Tk() root.withdraw() # Hide the main window file_path = filedialog.askopenfilename( initialdir="/", # Set the initial directory to search for files title="Select file", filetypes=(("Net files", "*.net"), ("All files", "*.*")) ) return file_path def get_fann_data(network_file): # Adjust the path to the Rust executable as needed result = subprocess.run(['./extract_fann_data/target/debug/extract_fann_data.exe', network_file], capture_output=True, text=True) if result.returncode != 0: print("Error:", result.stderr) return None, None layer_sizes = [] connections = [] parsing_connections = False for line in result.stdout.splitlines(): if line.startswith("Layers:"): continue elif line.startswith("Connections:"): parsing_connections = True continue if parsing_connections: from_neuron, to_neuron, weight = map(float, line.split()) connections.append((int(from_neuron), int(to_neuron), weight)) else: layer_size, bias_count = map(int, line.split()) layer_sizes.append((layer_size, bias_count)) return layer_sizes, connections def visualize_fann_network(network_file): # Get network data layer_sizes, connections = get_fann_data(network_file) if layer_sizes is None or connections is None: return # Error handling in get_fann_data should provide error output # Create a directed graph G = nx.DiGraph() # Positions dictionary to hold the position of each neuron pos = {} node_count = 0 x_spacing = 1.0 y_spacing = 1.0 # Calculate the maximum layer size for proper spacing max_layer_size = max(size for size, bias in layer_sizes) # Build nodes and position them layer by layer from left to right for layer_index, (layer_size, bias_count) in enumerate(layer_sizes): y_positions = list(range(-layer_size-bias_count+1, 1, 1)) # Center-align vertically y_positions = [y * (max_layer_size / (layer_size + bias_count)) * y_spacing for y in y_positions] # Adjust spacing for neuron_index in range(layer_size + bias_count): # Include bias neurons node_label = f"L{layer_index}N{neuron_index}" G.add_node(node_count, label=node_label) pos[node_count] = (layer_index * x_spacing, y_positions[neuron_index % len(y_positions)]) node_count += 1 # Add connections to the graph for from_neuron, to_neuron, weight in connections: G.add_edge(from_neuron, to_neuron, weight=weight) max_weight = max(abs(weight) for _, _, weight in connections) print(f"Max weight: {max_weight}") # Draw nodes nx.draw_networkx_nodes(G, pos, node_color='skyblue', node_size=200) nx.draw_networkx_labels(G, pos, font_size=7) # Custom function for edge properties def adjust_properties(weight): # if weight > 0: # print("Weight:", weight) color = 'green' if weight > 0 else 'red' alpha = min((abs(weight) / max_weight) ** 3, 1) # print(f"Color: {color}, Alpha: {alpha}") return color, alpha # Draw edges with custom properties for u, v, d in G.edges(data=True): color, alpha = adjust_properties(d['weight']) nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], edge_color=color, alpha=alpha, width=1.5, arrows=False) # Show plot plt.title('FANN Network Visualization') plt.axis('off') # Turn off the axis plt.show() # Path to the FANN network file fann_path = select_file() visualize_fann_network(fann_path)