Source
65
65
chiSq_(0.0),
66
66
chiSqV_(4,0.0),
67
67
lastChiSq_(0.0),dChiSq_(0.0),
68
68
sumWt_(0.0),sumWtV_(4,0.0),nWt_(0),
69
69
cvrgcount_(0),
70
70
par_(), parOK_(), parErr_(), lastPar_(),
71
71
dpar_(),
72
72
grad_(),hess_(),
73
73
lambda_(2.0),
74
74
optstep_(True),
75
+
doL1_(false),
76
+
L1clamp_(0),
77
+
doRMSThresh_(false),
78
+
RMSThresh_(0),
79
+
nRMSThresh_(0),
75
80
prtlev_(VCS2_PRTLEV)
76
81
{
77
82
if (prtlev()>0) cout << "VCS2::VCS2()" << endl;
78
83
}
79
84
85
+
VisCalSolver2::VisCalSolver2(String solmode, Vector<Float>& rmsthresh) :
86
+
SDBs_(NULL),
87
+
ve_(NULL),
88
+
svc_(NULL),
89
+
nPar_(0),
90
+
maxIter_(50),
91
+
chiSq_(0.0),
92
+
chiSqV_(4,0.0),
93
+
lastChiSq_(0.0),dChiSq_(0.0),
94
+
sumWt_(0.0),sumWtV_(4,0.0),nWt_(0),
95
+
cvrgcount_(0),
96
+
par_(), parOK_(), parErr_(), lastPar_(),
97
+
dpar_(),
98
+
grad_(),hess_(),
99
+
lambda_(2.0),
100
+
optstep_(True),
101
+
doL1_(false),
102
+
L1clamp_(std::vector<Float>({5e-3, 5e-4, 5e-5})),
103
+
doRMSThresh_(false),
104
+
RMSThresh_(rmsthresh), //
105
+
nRMSThresh_(rmsthresh.nelements()),
106
+
prtlev_(VCS2_PRTLEV)
107
+
{
108
+
if (prtlev()>0) cout << "VCS2::VCS2(solmode)" << endl;
109
+
110
+
if (solmode.contains("L1")) doL1_=true;
111
+
if (solmode.contains("R")) doRMSThresh_=true;
112
+
113
+
if (doRMSThresh_ && nRMSThresh_==0) {
114
+
RMSThresh_=Vector<Float>(std::vector<Float>({7.0,5.0,4.0,3.5,3.0,2.8,2.6,2.4,2.2}));
115
+
nRMSThresh_=RMSThresh_.nelements();
116
+
}
117
+
118
+
}
119
+
80
120
VisCalSolver2::~VisCalSolver2()
81
121
{
82
122
if (prtlev()>0) cout << "VCS2::~VCS2()" << endl;
83
123
}
84
124
85
-
// New VisBuffGroupAcc version
125
+
126
+
// New SDBList version
86
127
Bool VisCalSolver2::solve(VisEquation& ve, SolvableVisCal& svc, SDBList& sdbs) {
87
128
88
-
if (prtlev()>1) cout << "VCS2::solve(,,VBGA)" << endl;
129
+
// If L1 and/or outlier flagging requested, call specialize method
130
+
if (doL1_ || doRMSThresh_)
131
+
return solveL1R(ve,svc,sdbs);
132
+
133
+
if (prtlev()>1) cout << "VCS2::solve(,,SDBs)" << endl;
89
134
90
135
/*
91
136
LogSink logsink;
92
137
{
93
138
LogMessage message(LogOrigin("VisCalSolver2", "solve"));
94
139
ostringstream o; o<<"Beginning solve...";
95
140
message.message(o);
96
141
logsink.post(message);
97
142
}
98
143
*/
228
273
}
229
274
230
275
}
231
276
else {
232
277
cout << " Insufficient unflagged antennas to proceed with this solve." << endl;
233
278
}
234
279
235
280
return False;
236
281
237
282
}
283
+
284
+
// New L1(R)-capable version
285
+
Bool VisCalSolver2::solveL1R(VisEquation& ve, SolvableVisCal& svc, SDBList& sdbs) {
286
+
287
+
if (prtlev()>1) cout << "VCS2::solve(,,SDBs)" << endl;
288
+
289
+
/*
290
+
LogSink logsink;
291
+
{
292
+
LogMessage message(LogOrigin("VisCalSolver2", "solve"));
293
+
ostringstream o; o<<"Beginning solve...";
294
+
message.message(o);
295
+
logsink.post(message);
296
+
}
297
+
*/
298
+
// Pointers to local ve,svc
299
+
ve_=&ve;
300
+
svc_=&svc;
301
+
SDBs_=&sdbs;
302
+
303
+
// Verify that VisEq has the correct svc:
304
+
// TBD?
305
+
306
+
// Initialize everything
307
+
initSolve();
308
+
309
+
Vector<Float> steplist(maxIter_+2,0.0);
310
+
Vector<Float> rsteplist(maxIter_+2,0.0);
311
+
312
+
// Verify Data's validity for solve w.r.t. baselines available
313
+
// (this sets parOK() on per-antenna basis (for focusChan)
314
+
// based on data weights and baseline participation)
315
+
Bool oktosolve = svc_->verifyConstraints(*SDBs_);
316
+
317
+
if (oktosolve) {
318
+
319
+
if (prtlev()>1) cout << "First guess:" << endl
320
+
<< "amp = " << amplitude(par()) << endl
321
+
<< "pha = " << phase(par())
322
+
<< endl;
323
+
324
+
// Iterate solution
325
+
Int iter(0);
326
+
Bool done(False);
327
+
Bool applyWorkingFlags(false);
328
+
Int L1iter(0), IRiter(0);
329
+
while (!done) {
330
+
331
+
if (prtlev()>2) cout << " Beginning iteration " << iter
332
+
<< "---------------------------------" << endl;
333
+
334
+
// Differentiate the VB and get current Chi2
335
+
differentiate2();
336
+
337
+
if (doRMSThresh_ && applyWorkingFlags) {
338
+
SDBs_->updateWorkingFlags();
339
+
applyWorkingFlags=false; // must be explicitly triggered below
340
+
}
341
+
342
+
// Set up working weights
343
+
if (doL1_)
344
+
SDBs_->updateWorkingWeights(doL1_,L1clamp_(L1iter));
345
+
else
346
+
SDBs_->updateWorkingWeights(false);
347
+
348
+
349
+
chiSquare2();
350
+
if (chiSq()==0.0) {
351
+
cout << "CHI2 IS SPURIOUSLY ZERO!*************************************" << endl;
352
+
//cout << "R() = " << R() << endl;
353
+
// cout << "sum(wtmat) = " << sum(wtmat) << endl;
354
+
return False;
355
+
}
356
+
357
+
dChiSq() = chiSq()-lastChiSq();
358
+
359
+
//cout << "iter=" << iter << " X2=" << chiSq() << " dX2=" << dChiSq() << " dX2/X2=" << dChiSq()/chiSq(); // << endl;
360
+
361
+
// Continuue if we haven't converged
362
+
if (!converged()) {
363
+
364
+
//if (dChiSq()<=0.0) {
365
+
if (true || dChiSq()<=0.0) {
366
+
// last step was good...
367
+
lastChiSq()=chiSq();
368
+
369
+
// so accumulate new grad/hess...
370
+
accGradHess2();
371
+
372
+
//...and adjust lambda downward
373
+
// lambda()/=2.0;
374
+
// lambda()=0.8;
375
+
lambda()=1.0;
376
+
}
377
+
else {
378
+
// cout << "reverting..." << chiSq() << " " << dChiSq() << " (" << iter << ")" << endl;
379
+
// last step was bad, revert to previous
380
+
revert();
381
+
//...with a larger lambda
382
+
// lambda()*=4.0;
383
+
lambda()=1.0;
384
+
}
385
+
386
+
// Solve for the parameter step
387
+
solveGradHess();
388
+
389
+
// Remember curr pars
390
+
lastPar()=par();
391
+
392
+
// Refine the step size by exploring chi2 in the
393
+
// gradient direction
394
+
if (optstep_ && !doL1_) // && cvrgcount_>=3)
395
+
optStepSize2();
396
+
397
+
// Update current parameters (saves a copy of them)
398
+
updatePar();
399
+
400
+
steplist(iter)=max(amplitude(dpar()));
401
+
rsteplist(iter)=max(amplitude(dpar())/amplitude(par()));
402
+
403
+
//cout << " rstep=" << rsteplist(iter) << endl;
404
+
405
+
}
406
+
else {
407
+
408
+
// Convergence means we're done, NOMINALLY
409
+
done=True;
410
+
411
+
// Override convergence if we need to solve again with
412
+
// revised weight/flag conditions for robustness
413
+
if (doL1_ && L1iter<Int(L1clamp_.nelements())-1) {
414
+
//cout << "*~*~*~*~*~*~* Converged w/ L1clamp = " << L1clamp_(L1iter) << " *~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*" << endl;
415
+
done=false;
416
+
++L1iter;
417
+
iter=-1;
418
+
cvrgcount_=0;
419
+
lastChiSq()=DBL_MAX;
420
+
}
421
+
else if (doRMSThresh_ && IRiter<nRMSThresh_) {
422
+
//cout << "*~*~*~*~*~*~* Applying RMSThresh = " << RMSThresh_(IRiter) << " *~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*" << endl;
423
+
RMSThresh(IRiter);
424
+
++IRiter;
425
+
applyWorkingFlags=true; // force apply of the RMSThresh'd flags at the top of loop _after_ differentiation
426
+
done=false;
427
+
L1iter=0;
428
+
iter=-1;
429
+
cvrgcount_=0;
430
+
lastChiSq()=DBL_MAX;
431
+
}
432
+
433
+
// If still done (robustness options absent or exhausted), escape solve loop
434
+
if (done) {
435
+
436
+
if (prtlev()>0) {
437
+
cout << "par()=" << par() << endl;
438
+
}
439
+
440
+
/*
441
+
cout << " good pars=" << ntrue(parOK())
442
+
<< " iterations=" << iter << endl
443
+
<< " steps=" << steplist(IPosition(1,0),IPosition(1,iter))
444
+
<< endl
445
+
<< " rsteps=" << rsteplist(IPosition(1,0),IPosition(1,iter))
446
+
<< endl;
447
+
*/
448
+
449
+
// Get parameter errors:
450
+
accGradHess2();
451
+
getErrors();
452
+
453
+
// Return, signaling success if at least 1 good solution
454
+
return (ntrue(parOK())>0);
455
+
}
456
+
457
+
} // converged?
458
+
459
+
// Escape iteration loop via iteration limit
460
+
if (iter==maxIter()) {
461
+
cout << "Reached iteration limit: " << iter << " iterations. " << endl;
462
+
// cout << " good pars = " << ntrue(parOK())
463
+
// << " steps = " << steplist
464
+
// << endl;
465
+
done=True;
466
+
}
467
+
468
+
// Advance iteration counter
469
+
iter++;
470
+
}
471
+
472
+
}
473
+
else {
474
+
cout << " Insufficient unflagged antennas to proceed with this solve." << endl;
475
+
}
476
+
477
+
return False;
478
+
479
+
}
238
480
239
481
void VisCalSolver2::initSolve() {
240
482
241
483
if (prtlev()>2) cout << " VCS2::initSolve()" << endl;
242
484
243
485
// Get total number of cal parameters from svc info
244
486
nPar()=svc().nTotalPar();
245
487
246
488
if (prtlev()>2)
247
489
cout << " Total parameters in solve: " << nPar() << endl;
332
574
sumWtV()=0.0;
333
575
nWt()=0;
334
576
335
577
Cube<Complex> R;
336
578
337
579
// Loop over SDBs
338
580
for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
339
581
340
582
// Current SDB
341
583
SolveDataBuffer& sdb(sdbs()(isdb));
342
-
343
584
R.reference(sdb.residuals());
344
585
586
+
// _const_ access to working flags and weights
587
+
const Cube<Bool>& wFC(sdb.const_workingFlagCube());
588
+
const Cube<Float>& wWS(sdb.const_workingWtSpec());
589
+
345
590
// Shapes for iteration
346
591
IPosition shR(R.shape());
347
592
Int nCorr=shR(0);
348
593
Int nChan=shR(1);
349
594
Int nRow=shR(2);
350
595
351
596
// Simple indexed accumulation of chiSq
352
597
// TBD: optimize w.r.t. indexing?
353
598
Double chisq0(0.0);
354
599
for (Int irow=0;irow<nRow;++irow) {
355
600
if (!sdb.flagRow()(irow)) {
356
601
for (Int ich=0;ich<nChan;++ich) {
357
602
for (Int icorr=0;icorr<nCorr;++icorr) {
358
-
if (!sdb.residFlagCube()(icorr,ich,irow)) {
359
-
Float& wt(sdb.infocusWtSpec()(icorr,ich,irow));
603
+
//if (!sdb.residFlagCube()(icorr,ich,irow)) { // OLD: residFlagCube
604
+
const Bool& fl(wFC(icorr,ich,irow)); // NEW: workingFlagCube CORRECT?
605
+
if (!fl) {
606
+
const Float& wt(wWS(icorr,ich,irow));
360
607
if (wt>0.0) {
361
608
Complex& Ri(R(icorr,ich,irow));
362
609
363
610
// This element's contribution
364
611
chisq0=Double(wt*real(Ri*conj(Ri))); // cf: square(abs(R))?
365
612
366
613
// Accumulate per-corr
367
614
chiSqV()(icorr)+=chisq0;
368
615
sumWtV()(icorr)+=wt;
369
616
nWt()++;
370
617
} // wt>0
371
618
} // !flag
372
619
} // icorr
373
620
} // ich
374
621
} // !flagRow
375
622
} // irow
376
623
377
624
} // isdb
378
625
626
+
//cout << "chiSqV() = " << chiSqV() << endl;
627
+
379
628
// Totals over corrs
380
629
chiSq()=sum(chiSqV());
381
630
sumWt()=sum(sumWtV());
382
631
383
632
}
384
633
634
+
// RMS calculation (for thresholding)
635
+
void VisCalSolver2::RMSThresh(Int RejIter) {
636
+
637
+
if (prtlev()>2) cout << " VCS2::RMS(SDB version)" << endl;
638
+
639
+
const Float threshold(RMSThresh_(RejIter));
640
+
641
+
// TBD: per-ant/bln chiSq?
642
+
643
+
Int nCorr=sdbs().nCorrelations();
644
+
Vector<Double> xxV(nCorr,0.0);
645
+
Vector<Double> sWtV(nCorr,0.0);
646
+
647
+
Cube<Complex> R;
648
+
649
+
// Loop over SDBs
650
+
for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
651
+
652
+
// Current SDB
653
+
SolveDataBuffer& sdb(sdbs()(isdb));
654
+
R.reference(sdb.residuals());
655
+
656
+
// Shapes for iteration
657
+
IPosition shR(R.shape());
658
+
Int nCorr=shR(0);
659
+
Int nChan=shR(1);
660
+
Int nRow=shR(2);
661
+
662
+
const Cube<Bool>& wFC(sdb.const_workingFlagCube());
663
+
664
+
// Simple indexed accumulation of XX
665
+
Double xx0(0.0);
666
+
for (Int irow=0;irow<nRow;++irow) {
667
+
if (!sdb.flagRow()(irow)) {
668
+
for (Int ich=0;ich<nChan;++ich) {
669
+
for (Int icorr=0;icorr<nCorr;++icorr) {
670
+
if (!wFC(icorr,ich,irow)) {
671
+
Float& wt(sdb.infocusWtSpec()(icorr,ich,irow));
672
+
if (wt>0.0) {
673
+
Complex& Ri(R(icorr,ich,irow));
674
+
675
+
// This element's contribution
676
+
xx0=Double(wt*real(Ri*conj(Ri))); // cf: square(abs(R))?
677
+
678
+
// Accumulate per-corr
679
+
xxV(icorr)+=xx0;
680
+
sWtV(icorr)+=wt;
681
+
} // wt>0
682
+
} // !flag
683
+
} // icorr
684
+
} // ich
685
+
} // !flagRow
686
+
} // irow
687
+
688
+
} // isdb
689
+
690
+
Vector<Float> rmsV(nCorr,0.0);
691
+
for (Int icorr=0;icorr<nCorr;++icorr) {
692
+
if (sWtV(icorr)>0.0)
693
+
rmsV(icorr)=Float(sqrt(xxV(icorr)/sWtV(icorr)));
694
+
}
695
+
696
+
// Now Apply the threshold
697
+
698
+
// Loop over SDBs
699
+
for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
700
+
701
+
// Current SDB
702
+
SolveDataBuffer& sdb(sdbs()(isdb));
703
+
R.reference(sdb.residuals());
704
+
705
+
// Initialize wFC afresh
706
+
sdb.workingFlagCube().resize(0,0,0);
707
+
sdb.workingFlagCube().assign(sdb.residFlagCube());
708
+
709
+
// Shapes for iteration
710
+
IPosition shR(R.shape());
711
+
Int nCorr=shR(0);
712
+
Int nChan=shR(1);
713
+
Int nRow=shR(2);
714
+
715
+
for (Int irow=0;irow<nRow;++irow) {
716
+
if (!sdb.flagRow()(irow)) {
717
+
for (Int ich=0;ich<nChan;++ich) {
718
+
for (Int icorr=0;icorr<nCorr;++icorr) {
719
+
if (!sdb.residFlagCube()(icorr,ich,irow)) {
720
+
Float& wt(sdb.infocusWtSpec()(icorr,ich,irow));
721
+
if (wt>0.0) {
722
+
Float Ra(abs(R(icorr,ich,irow)));
723
+
if (Ra>(threshold*rmsV(icorr))) {
724
+
cout << "Flagging at [" << icorr << "," << ich << "," << irow << "] sig=" << Ra/rmsV(icorr) << " (threshold=" << threshold << ")" << endl;
725
+
sdb.workingFlagCube()(icorr,ich,irow)=true;
726
+
//sdb.workingWtSpec()(icorr,ich,irow)=0.0;
727
+
}
728
+
} // wt>0
729
+
} // !flag
730
+
} // icorr
731
+
} // ich
732
+
} // !flagRow
733
+
} // irow
734
+
735
+
} // isdb
736
+
737
+
}
738
+
739
+
385
740
386
741
Bool VisCalSolver2::converged() {
387
742
388
743
if (prtlev()>2) cout << " VCS2::converged()" << endl;
389
744
390
745
// Change in chi2
391
746
dChiSq() = chiSq()-lastChiSq();
392
747
Float fChiSq(dChiSq()/chiSq());
393
748
394
749
// Consider convergence if chi2 decreases...
458
813
459
814
// Loop over SDBs
460
815
for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
461
816
462
817
// Current SDB
463
818
SolveDataBuffer& sdb(sdbs()(isdb));
464
819
465
820
R.reference(sdb.residuals());
466
821
dR.reference(sdb.diffResiduals());
467
822
823
+
const Cube<Float>& wWS(sdb.const_workingWtSpec());
824
+
const Cube<Bool>& wFC(sdb.const_workingFlagCube());
825
+
468
826
IPosition dRip(dR.shape());
469
827
470
828
Int nRow(dRip(3));
471
829
Int nChan(dRip(2));
472
830
Int nParPerAnt(dRip(1)); // pars per antenna
473
831
Int nCorr(dRip(0));
474
832
475
833
// Simple indexed accumulation
476
834
for (Int irow=0;irow<nRow;++irow) {
477
835
if (!sdb.flagRow()(irow)) {
478
836
Int a1i= nParPerAnt*sdb.antenna1()(irow);
479
837
Int a2i= nParPerAnt*sdb.antenna2()(irow);
480
838
for (Int ichan=0;ichan<nChan;++ichan) {
481
839
for (int icorr=0;icorr<nCorr;++icorr) {
482
-
if (!sdb.residFlagCube()(icorr,ichan,irow)) {
483
-
Float& wt(sdb.infocusWtSpec()(icorr,ichan,irow));
840
+
//if (!sdb.residFlagCube()(icorr,ichan,irow)) { // OLD: residFlagCube
841
+
const Bool& fl(wFC(icorr,ichan,irow)); // NEW: workingFlagCube CORRECT?
842
+
if (!fl) {
843
+
const Float& wt(wWS(icorr,ichan,irow));
484
844
if (wt>0.0) {
485
845
Complex& Ri(R(icorr,ichan,irow));
486
846
for (Int ipar=0;ipar<nParPerAnt;++ipar) {
487
847
488
848
// Accumulate grad and hess for this icorr,ichan,irow,ipar
489
849
// for a1:
490
850
Complex& dR1(dR(IPosition(5,icorr,ipar,ichan,irow,0)));
491
851
grad()(a1i+ipar)+= DComplex(wt*(Ri*conj(dR1)));
492
852
hess()(a1i+ipar)+= Double(wt*real(dR1*conj(dR1)));
493
853
// for a2: