#!/usr/bin/env python import os, sys, traceback import getpass from threading import Thread from subprocess import * if(sys.hexversion < 0x03000000): import Queue else: import queue as Queue # svmtrain and gnuplot executable is_win32 = (sys.platform == 'win32') svmtrain_exe = "../svm-train" gnuplot_exe = "/usr/bin/gnuplot" # example for windows # svmtrain_exe = r"c:\tmp\libsvm-2.4\windows\svmtrain.exe" # gnuplot_exe = r"c:\tmp\gp373w32\pgnuplot.exe" # global parameters and their default values fold = 5 c_begin, c_end, c_step = -1, 6, 1 g_begin, g_end, g_step = 0, -8, -1 p_begin, p_end, p_step = -8, -1, 1 global dataset_pathname, dataset_title, pass_through_string global out_filename, png_filename # experimental ssh_workers = [] # ssh_workers = ['linux1','linux1','linux2','linux2','linux3', 'linux4', 'linux6','linux7','linux8','linux8','linux9','linux10','linux11','linux12'] nr_local_worker = 1 # process command line options, set global parameters def process_options(argv=sys.argv): global fold global c_begin, c_end, c_step global g_begin, g_end, g_step global p_begin, p_end, p_step global dataset_pathname, dataset_title, pass_through_string global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename usage = """\ Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-log2p begin,end,step] [-v fold] [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname] [additional parameters for svm-train] dataset""" if len(argv) < 2: print(usage) sys.exit(1) dataset_pathname = argv[-1] dataset_title = os.path.split(dataset_pathname)[1] out_filename = '%s.out' % dataset_title png_filename = '%s.png' % dataset_title pass_through_options = [] i = 1 while i < len(argv) - 1: if argv[i] == "-log2c": i = i + 1 (c_begin,c_end,c_step) = map(float,argv[i].split(",")) elif argv[i] == "-log2g": i = i + 1 (g_begin,g_end,g_step) = map(float,argv[i].split(",")) elif argv[i] == "-log2p": i = i + 1 (p_begin,p_end,p_step) = map(float,argv[i].split(",")) elif argv[i] == "-v": i = i + 1 fold = argv[i] elif argv[i] in ('-c','-g'): print("Option -c and -g are renamed.") print(usage) sys.exit(1) elif argv[i] == '-svmtrain': i = i + 1 svmtrain_exe = argv[i] elif argv[i] == '-gnuplot': i = i + 1 gnuplot_exe = argv[i] elif argv[i] == '-out': i = i + 1 out_filename = argv[i] elif argv[i] == '-png': i = i + 1 png_filename = argv[i] else: pass_through_options.append(argv[i]) i = i + 1 pass_through_string = " ".join(pass_through_options) assert os.path.exists(svmtrain_exe),"svm-train executable not found" assert os.path.exists(gnuplot_exe),"gnuplot executable not found" assert os.path.exists(dataset_pathname),"dataset not found" # gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin def range_f(begin,end,step): # like range, but works on non-integer too seq = [] while True: if step > 0 and begin > end: break if step < 0 and begin < end: break seq.append(begin) begin = begin + step return seq def permute_sequence(seq): n = len(seq) if n <= 1: return seq mid = int(n/2) left = permute_sequence(seq[:mid]) right = permute_sequence(seq[mid+1:]) ret = [seq[mid]] while left or right: if left: ret.append(left.pop(0)) if right: ret.append(right.pop(0)) return ret def redraw (db,tofile=0): if len(db) == 0: return begin_level = round(max(map(lambda x: (x[2],db)))) - 3 step_size = 0.5 if tofile: gnuplot.write(b"set term png transparent small color\n") gnuplot.write(b"set output \"%s\"\n" % png_filename.replace('\\','\\\\')) #gnuplot.write("set term postscript color solid\n".encode()) #gnuplot.write(("set output \"%s.ps\"\n" % dataset_title).encode()) else: if is_win32: gnuplot.write(b"set term windows\n") else: gnuplot.write(b"set term x11\n") gnuplot.write(b"set xlabel \"lg(C)\"\n") gnuplot.write(b"set ylabel \"lg(gamma)\"\n") gnuplot.write(b"set xrange [%s:%s]\n" % (c_begin,c_end)) gnuplot.write(b"set yrange [%s:%s]\n" % (g_begin,g_end)) gnuplot.write(b"set contour\n") gnuplot.write(b"set cntrparam levels incremental %s,%s,100\n" % (begin_level,step_size)) gnuplot.write(b"set nosurface\n") gnuplot.write(b"set view 0,0\n") gnuplot.write(b"set label \"%s\" at screen 0.4,0.9\n" % dataset_title) gnuplot.write(b"splot \"-\" with lines\n") db.sort(key = lambda x:(x[0], -x[1])) prevc = db[0][0] for line in db: if prevc != line[0]: gnuplot.write(b"\n") prevc = line[0] gnuplot.write(b"%s %s %s\n" % line) gnuplot.write(b"e\n") gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure gnuplot.flush() def calculate_jobs(): c_seq = permute_sequence(range_f(c_begin,c_end,c_step)) g_seq = permute_sequence(range_f(g_begin,g_end,g_step)) p_seq = permute_sequence(range_f(p_begin,p_end,p_step)) nr_c = len(c_seq) nr_g = len(g_seq) nr_p = len(p_seq) jobs = [] for i in range(0,nr_c): for j in range(0,nr_g): for s in range(0,nr_p): line = [] line.append((c_seq[i],g_seq[j],p_seq[s])) jobs.append(line) return jobs class WorkerStopToken: # used to notify the worker to stop pass class Worker(Thread): def __init__(self,name,job_queue,result_queue): Thread.__init__(self) self.name = name self.job_queue = job_queue self.result_queue = result_queue def run(self): while True: (cexp,gexp,pexp) = self.job_queue.get() if cexp is WorkerStopToken: self.job_queue.put((cexp,gexp,pexp)) # print 'worker %s stop.' % self.name break try: rate = self.run_one(2.0**cexp,2.0**gexp,2.0**pexp) if rate is None: raise RuntimeError("get no rate") except: # we failed, let others do that and we just quit traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2]) self.job_queue.put((cexp,gexp,pexp)) print('worker %s quit.' % self.name) break else: self.result_queue.put((self.name,cexp,gexp,pexp,rate)) class LocalWorker(Worker): def run_one(self,c,g,p): cmdline = '%s -s 3 -c %s -g %s -p %s -v %s %s %s' % \ (svmtrain_exe,c,g,p,fold,pass_through_string,dataset_pathname) result = Popen(cmdline,shell=True,stdout=PIPE).stdout for line in result.readlines(): if str(line).find("Cross") != -1: return float(line.split()[-1]) class SSHWorker(Worker): def __init__(self,name,job_queue,result_queue,host): Worker.__init__(self,name,job_queue,result_queue) self.host = host self.cwd = os.getcwd() def run_one(self,c,g,p): cmdline = 'ssh %s "cd %s; %s -s 3 -c %s -g %s -p %s -v %s %s %s"' % \ (self.host,self.cwd, svmtrain_exe,c,g,p,fold,pass_through_string,dataset_pathname) # print cmdline result = Popen(cmdline,shell=True,stdout=PIPE).stdout for line in result.readlines(): if str(line).find("Cross") != -1: return float(line.split()[-1]) def main(): # set parameters process_options() # put jobs in queue jobs = calculate_jobs() #print(len(jobs)) job_queue = Queue.Queue(0) result_queue = Queue.Queue(0) for line in jobs: for (c,g,p) in line: job_queue.put((c,g,p)) # hack the queue to become a stack -- # this is important when some thread # failed and re-put a job. It we still # use FIFO, the job will be put # into the end of the queue, and the graph # will only be updated in the end job_queue._put = job_queue.queue.appendleft # fire ssh workers if ssh_workers: for host in ssh_workers: SSHWorker(host,job_queue,result_queue,host).start() # fire local workers for i in range(nr_local_worker): LocalWorker('local',job_queue,result_queue).start() # gather results done_jobs = {} result_file = open(out_filename,'w') db = [] best_mse = float('+inf') for line in jobs: for (c,g,p) in line: while (c,g,p) not in done_jobs: (worker,c1,g1,p1,mse) = result_queue.get() done_jobs[(c1,g1,p1)] = mse result_file.write('%s %s %s %s\n' % (c1,g1,p1,mse)) result_file.flush() if mse < best_mse: best_mse = mse best_c = 2.0**c1 best_g = 2.0**g1 best_p = 2.0**p1 print("[%s] %s %s %s %s (best c=%s, g=%s, p=%s, mse=%s)" % \ (worker,c1,g1,p1,mse,best_c,best_g,best_p,best_mse)) # db.append((c,g,r,done_jobs[(c,g,r)])) job_queue.put((WorkerStopToken,None,None)) print("%s %s %s %s" % (best_c,best_g,best_p,best_mse)) main()