Skip to content

Reference

Recording data

AxonaTrial

Bases: TrialInterface

Source code in ephysiopy/io/recording.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
class AxonaTrial(TrialInterface):
    def __init__(self, pname: Path, **kwargs) -> None:
        pname = Path(pname)
        super().__init__(pname, **kwargs)
        self._settings = None
        use_volts = kwargs.get("volts", True)
        self.TETRODE = TetrodeDict(
            str(self.pname.with_suffix("")), volts=use_volts)
        self.load_settings()

    def load_lfp(self, *args, **kwargs):
        from ephysiopy.axona.axonaIO import EEG

        if "egf" in args:
            lfp = EEG(self.pname, egf=1)
        else:
            lfp = EEG(self.pname)
        if lfp is not None:
            self.EEGCalcs = EEGCalcsGeneric(lfp.sig, lfp.sample_rate)

    def load_neural_data(self, *args, **kwargs):
        if "tetrode" in kwargs.keys():
            use_volts = kwargs.get("volts", True)
            self.TETRODE[kwargs["tetrode"], use_volts]  # lazy load

    def load_cluster_data(self, *args, **kwargs):
        return False

    def load_settings(self, *args, **kwargs):
        if self._settings is None:
            try:
                settings_io = IO()
                self.settings = settings_io.getHeader(str(self.pname))
            except IOError:
                print(".set file not loaded")
                self.settings = None

    def load_pos_data(
        self, ppm: int = 300, jumpmax: int = 100, *args, **kwargs
    ) -> None:
        try:
            AxonaPos = Pos(Path(self.pname))
            P = PosCalcsGeneric(
                AxonaPos.led_pos[:, 0],
                AxonaPos.led_pos[:, 1],
                cm=True,
                ppm=ppm,
                jumpmax=jumpmax,
            )
            P.sample_rate = AxonaPos.getHeaderVal(
                AxonaPos.header, "sample_rate")
            P.xyTS = AxonaPos.ts / P.sample_rate  # in seconds now
            P.postprocesspos(tracker_params={"SampleRate": P.sample_rate})
            print("Loaded pos data")
            self.PosCalcs = P
        except IOError:
            print("Couldn't load the pos data")

    def load_ttl(self, *args, **kwargs) -> bool:
        from ephysiopy.axona.axonaIO import Stim
        try:
            self.ttl_data = Stim(self.pname)
            # ttl times in Stim are in ms
        except IOError:
            return False
        return True

    def get_spike_times(self,
                        tetrode=None,
                        cluster: int = None,
                        *args,
                        **kwargs):
        """
        Args:
            tetrode (int): 
            cluster (int): 

        Returns:
            spike_times (np.ndarray): 
        """
        if tetrode is not None:
            return self.TETRODE.get_spike_samples(int(tetrode), int(cluster))

get_spike_times(tetrode=None, cluster=None, *args, **kwargs)

Parameters:

Name Type Description Default
tetrode int
None
cluster int
None

Returns:

Name Type Description
spike_times ndarray
Source code in ephysiopy/io/recording.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def get_spike_times(self,
                    tetrode=None,
                    cluster: int = None,
                    *args,
                    **kwargs):
    """
    Args:
        tetrode (int): 
        cluster (int): 

    Returns:
        spike_times (np.ndarray): 
    """
    if tetrode is not None:
        return self.TETRODE.get_spike_samples(int(tetrode), int(cluster))

OpenEphysBase

Bases: TrialInterface

Source code in ephysiopy/io/recording.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
class OpenEphysBase(TrialInterface):
    def __init__(self, pname: Path, **kwargs) -> None:
        pname = Path(pname)
        super().__init__(pname, **kwargs)
        setattr(self, "sync_message_file", None)
        self.load_settings()
        # The numbers after the strings in this list are the node id's
        # in openephys
        record_methods = [
            "Acquisition Board [0-9][0-9][0-9]",
            "Acquisition Board",
            "Neuropix-PXI [0-9][0-9][0-9]",
            "Neuropix-PXI",
            "Sources/Neuropix-PXI [0-9][0-9][0-9]",
            "Rhythm FPGA [0-9][0-9][0-9]",
            "Rhythm",
            "Sources/Rhythm FPGA [0-9][0-9][0-9]",
        ]
        rec_method = [
            re.search(m, k).string
            for k in self.settings.processors.keys()
            for m in record_methods
            if re.search(m, k) is not None
        ][0]
        if "Sources/" in rec_method:
            rec_method = rec_method.lstrip("Sources/")

        self.rec_kind = Xml2RecordingKind[rec_method.rpartition(" ")[0]]

        # Attempt to find the files contained in the parent directory
        # related to the recording with the default experiment and
        # recording name
        self.find_files(pname)
        self.sample_rate = None
        self.sample_rate = self.settings.processors[rec_method].sample_rate
        if self.sample_rate is None:
            if self.rec_kind == RecordingKind.NEUROPIXELS:
                self.sample_rate = 30000
        else:  # rubbish fix - many strs need casting to int/float
            self.sample_rate = float(self.sample_rate)
        self.channel_count = self.settings.processors[rec_method].channel_count
        if self.channel_count is None:
            if self.rec_kind == RecordingKind.NEUROPIXELS:
                self.channel_count = 384
        self.kilodata = None
        self.template_model = None

    def _get_recording_start_time(self) -> float:
        """
        Get the recording start time from the sync_messages.txt file
        """
        recording_start_time = 0.0
        if self.sync_message_file is not None:
            with open(self.sync_message_file, "r") as f:
                sync_strs = f.read()
            sync_lines = sync_strs.split("\n")
            for line in sync_lines:
                if "Start Time" in line:
                    tokens = line.split(":")
                    start_time = int(tokens[-1])
                    sample_rate = int(tokens[0].split(
                        "@")[-1].strip().split()[0])
                    recording_start_time = start_time / float(sample_rate)
        return recording_start_time

    @cache
    def get_spike_times(self,
                        tetrode: int = None,
                        cluster: int = None,
                        *args, **kwargs):
        """
        Args:
            tetrode (int): 
            cluster (int): 

        Returns:
            spike_times (np.ndarray): 
        """
        if not self.clusterData:
            self.load_cluster_data()
        ts = self.clusterData.spk_times
        if cluster in self.clusterData.spk_clusters:
            times = ts[self.clusterData.spk_clusters == cluster]
            return times.astype(np.int64) / self.sample_rate
        else:
            warnings.warn("Cluster not present")

    @cache
    def load_lfp(self, *args, **kwargs):
        """
        Valid kwargs are:
        'target_sample_rate' - int
            the sample rate to downsample to from the original
        """
        from scipy import signal

        if self.path2LFPdata is not None:
            lfp = memmapBinaryFile(
                os.path.join(self.path2LFPdata, "continuous.dat"),
                n_channels=self.channel_count,
            )
            channel = 0
            if "channel" in kwargs.keys():
                channel = kwargs["channel"]
            target_sample_rate = 500
            if "target_sample_rate" in kwargs.keys():
                target_sample_rate = kwargs["target_sample_rate"]
            n_samples = np.shape(lfp[channel, :])[0]
            sig = signal.resample(
                lfp[channel, :], int(
                    n_samples / self.sample_rate) * target_sample_rate
            )
            self.EEGCalcs = EEGCalcsGeneric(sig, target_sample_rate)

    @cache
    def load_neural_data(self, *args, **kwargs) -> None:
        if "path2APdata" in kwargs.keys():
            self.path2APdata: Path = Path(kwargs["path2APdata"])
        n_channels: int = self.channel_count or kwargs["nChannels"]
        try:
            self.template_model = TemplateModel(
                dir_path=self.path2APdata,
                sample_rate=3e4,
                dat_path=Path(self.path2APdata).joinpath("continuous.dat"),
                n_channels_dat=int(n_channels),
            )
            print("Loaded neural data")
        except Exception:
            warnings.warn("Could not find raw data file")

    @cache
    def load_settings(self, *args, **kwargs):
        if self._settings is None:
            # pname_root gets walked through and over-written with
            # correct location of settings.xml
            self.settings = Settings(self.pname)
            print("Loaded settings data")

    @cache
    def load_cluster_data(
            self, removeNoiseClusters=True, *args, **kwargs) -> bool:
        if self.path2KiloSortData is not None:
            clusterData = KiloSortSession(self.pname)
        else:
            return False
        if clusterData is not None:
            if clusterData.load():
                print("Loaded KiloSort data")
                if removeNoiseClusters:
                    try:
                        clusterData.removeKSNoiseClusters()
                        print("Removed noise clusters")
                    except Exception:
                        pass
        else:
            return False
        self.clusterData = clusterData
        return True

    @cache
    def load_pos_data(
        self, ppm: int = 300, jumpmax: int = 100, *args, **kwargs
    ) -> None:
        # kwargs valid keys = "loadTTLPos" - if present loads the ttl
        # timestamps not the ones in the plugin folder

        # Only sub-class that doesn't use this is OpenEphysNWB
        # which needs updating
        # TODO: Update / overhaul OpenEphysNWB
        # Load the start time from the sync_messages file
        if "cm" in kwargs:
            cm = kwargs["cm"]
        else:
            cm = True

        recording_start_time = self._get_recording_start_time()

        if self.path2PosData is not None:
            pos_method = [
                "Pos Tracker [0-9][0-9][0-9]",
                "PosTracker [0-9][0-9][0-9]",
                "TrackMe [0-9][0-9][0-9]",
                "TrackingPlugin [0-9][0-9][0-9]",
                "Tracking Port"
            ]
            pos_plugin_name = [
                re.search(m, k).string
                for k in self.settings.processors.keys()
                for m in pos_method
                if re.search(m, k) is not None
            ][0]
            if "Sources/" in pos_plugin_name:
                pos_plugin_name = pos_plugin_name.lstrip("Sources/")

            self.pos_plugin_name = pos_plugin_name

            if "Tracker" in pos_plugin_name:
                print("Loading Tracker data...")
                pos_data = np.load(os.path.join(
                    self.path2PosData, "data_array.npy"))
            if "Tracking Port" in pos_plugin_name:
                print("Loading Tracking Port data...")
                pos_data = loadTrackingPluginData(
                    os.path.join(self.path2PosData, "data_array.npy"))
            if "TrackingPlugin" in pos_plugin_name:
                print("Loading TrackingPlugin data...")
                pos_data = loadTrackingPluginData(
                    os.path.join(self.path2PosData, "data_array.npy")
                )

            pos_ts = np.load(os.path.join(self.path2PosData, "timestamps.npy"))
            # pos_ts in seconds
            pos_ts = np.ravel(pos_ts)
            if "TrackMe" in pos_plugin_name:
                print("Loading TrackMe data...")
                n_pos_chans = int(
                    self.settings.processors[pos_plugin_name].channel_count
                )
                pos_data = loadTrackMePluginData(
                    Path(os.path.join(self.path2PosData, "continuous.dat")),
                    n_channels=n_pos_chans,
                )
                if "loadTTLPos" in kwargs.keys():
                    pos_ts = loadTrackMeTTLTimestamps(
                        Path(self.path2EventsData))
                else:
                    pos_ts = loadTrackMeTimestamps(Path(self.path2PosData))
                pos_ts = pos_ts[0:len(pos_data)]
            sample_rate = self.settings.processors[pos_plugin_name].sample_rate
            sample_rate = float(sample_rate) if sample_rate is not None else 50
            # the timestamps for the Tracker Port plugin are fucked so
            # we have to infer from the shape of the position data
            if "Tracking Port" in pos_plugin_name:
                sample_rate = kwargs["sample_rate"] or 50
                # pos_ts in seconds
                pos_ts = np.arange(
                    0, pos_data.shape[0]/sample_rate, 1.0/sample_rate)
                print(f"Tracker first and last ts: {pos_ts[0]} & {pos_ts[-1]}")
            if pos_plugin_name != "TrackMe":
                xyTS = pos_ts - recording_start_time
            else:
                xyTS = pos_ts
            if self.sync_message_file is not None:
                recording_start_time = xyTS[0]
            print(
                f"First & last ts before PosCalcs: {pos_ts[0]} & {pos_ts[-1]}")
            P = PosCalcsGeneric(
                pos_data[:, 0],
                pos_data[:, 1],
                cm=cm,
                ppm=ppm,
                jumpmax=jumpmax,
            )
            P.xyTS = xyTS
            P.sample_rate = sample_rate
            P.postprocesspos({"SampleRate": sample_rate})
            print("Loaded pos data")
            self.PosCalcs = P
        else:
            warnings.warn(
                "Could not find the pos data. \
                Make sure there is a pos_data folder with data_array.npy \
                and timestamps.npy in"
            )
        self.recording_start_time = recording_start_time

    @cache
    def load_ttl(self, *args, **kwargs) -> bool:
        """
        Args:
            StimControl_id (str): This is the string 
                "StimControl [0-9][0-9][0-9]" where the numbers
                are the node id in the openephys signal chain
            TTL_channel_number (int): The integer value in the "states.npy"
                file that corresponds to the
                identity of the TTL input on the Digital I/O board on the
                openephys recording system. i.e. if there is input to BNC
                port 3 on the digital I/O board then values of 3 in the
                states.npy file are high TTL values on this input and -3
                are low TTL values (I think)

        Returns:
            Nothing but sets some keys/values in a dict on 'self'
            called ttl_data, namely:

            ttl_timestamps (list): the times of high ttl pulses in ms
            stim_duration (int): the duration of the ttl pulse in ms
        """
        if not Path(self.path2EventsData).exists:
            return False
        ttl_ts = np.load(os.path.join(self.path2EventsData, "timestamps.npy"))
        states = np.load(os.path.join(self.path2EventsData, "states.npy"))
        recording_start_time = self._get_recording_start_time()
        self.ttl_data = {}
        if "StimControl_id" in kwargs.keys():
            stim_id = kwargs["StimControl_id"]
            if stim_id in self.settings.processors.keys():
                duration = getattr(
                    self.settings.processors[stim_id], "Duration")
            else:
                return False
            self.ttl_data["stim_duration"] = int(duration)
        if "TTL_channel_number" in kwargs.keys():
            chan = kwargs["TTL_channel_number"]
            high_ttl = ttl_ts[states == chan]
            # get into ms
            high_ttl = (high_ttl * 1000.0) - recording_start_time
            self.ttl_data['ttl_timestamps'] = high_ttl
        if not self.ttl_data:
            return False
        print("Loaded ttl data")
        return True

    @cache
    def find_files(
        self,
        pname_root: str,
        experiment_name: str = "experiment1",
        rec_name: str = "recording1",
    ):
        exp_name = Path(experiment_name)
        PosTracker_match = (
            exp_name / rec_name / "events" / "*Pos_Tracker*/BINARY_group*"
        )
        TrackingPlugin_match = (
            exp_name / rec_name / "events" / "*Tracking_Port*/BINARY_group*"
        )
        TrackMe_match = (
            exp_name / rec_name / "continuous" /
            "TrackMe-[0-9][0-9][0-9].TrackingNode"
        )
        sync_file_match = exp_name / rec_name
        acq_method = ""
        if self.rec_kind == RecordingKind.NEUROPIXELS:
            # the old OE NPX plugins saved two forms of the data,
            # one for AP @30kHz and one for LFP @??Hz
            # the newer plugin saves only the 30kHz data. Also, the
            # 2.0 probes are saved with Probe[A-Z] appended to the end
            # of the folder
            # the older way:
            acq_method = "Neuropix-PXI-[0-9][0-9][0-9]."
            APdata_match = exp_name / rec_name / \
                "continuous" / (acq_method + "0")
            LFPdata_match = exp_name / rec_name / \
                "continuous" / (acq_method + "1")
            # the new way:
            Rawdata_match = (
                exp_name / rec_name / "continuous" /
                (acq_method + "Probe[A-Z]")
            )
        elif self.rec_kind == RecordingKind.FPGA:
            acq_method = "Rhythm_FPGA-[0-9][0-9][0-9]."
            APdata_match = exp_name / rec_name / \
                "continuous" / (acq_method + "0")
            LFPdata_match = exp_name / rec_name / \
                "continuous" / (acq_method + "1")
            Rawdata_match = (
                exp_name / rec_name / "continuous" /
                (acq_method + "Probe[A-Z]")
            )
        else:
            acq_method = "Acquisition_Board-[0-9][0-9][0-9].*"
            APdata_match = exp_name / rec_name / "continuous" / acq_method
            LFPdata_match = exp_name / rec_name / "continuous" / acq_method
            Rawdata_match = (
                exp_name / rec_name / "continuous" /
                (acq_method + "Probe[A-Z]")
            )
        Events_match = (
            # only dealing with a single TTL channel at the moment
            exp_name
            / rec_name
            / "events"
            / acq_method
            / "TTL"
        )

        if pname_root is None:
            pname_root = self.pname_root

        for d, c, f in os.walk(pname_root):
            for ff in f:
                if "." not in c:  # ignore hidden directories
                    if "data_array.npy" in ff:
                        if PurePath(d).match(str(PosTracker_match)):
                            if self.path2PosData is None:
                                self.path2PosData = os.path.join(d)
                                print(f"Pos data at: {self.path2PosData}")
                            self.path2PosOEBin = Path(d).parents[1]
                        if PurePath(d).match("*pos_data*"):
                            if self.path2PosData is None:
                                self.path2PosData = os.path.join(d)
                                print(f"Pos data at: {self.path2PosData}")
                        if PurePath(d).match(str(TrackingPlugin_match)):
                            if self.path2PosData is None:
                                self.path2PosData = os.path.join(d)
                                print(f"Pos data at: {self.path2PosData}")
                    if "continuous.dat" in ff:
                        if PurePath(d).match(str(APdata_match)):
                            self.path2APdata = os.path.join(d)
                            print(f"Continuous AP data at: {self.path2APdata}")
                            self.path2APOEBin = Path(d).parents[1]
                        if PurePath(d).match(str(LFPdata_match)):
                            self.path2LFPdata = os.path.join(d)
                            print(
                                f"Continuous LFP data at: {self.path2LFPdata}")
                        if PurePath(d).match(str(Rawdata_match)):
                            self.path2APdata = os.path.join(d)
                            self.path2LFPdata = os.path.join(d)
                        if PurePath(d).match(str(TrackMe_match)):
                            self.path2PosData = os.path.join(d)
                            print(f"TrackMe posdata at: {self.path2PosData}")
                    if "sync_messages.txt" in ff:
                        if PurePath(d).match(str(sync_file_match)):
                            sync_file = os.path.join(d, "sync_messages.txt")
                            if fileContainsString(sync_file, "Start Time"):
                                self.sync_message_file = sync_file
                                print(f"sync_messages file at: {sync_file}")
                    if "full_words.npy" in ff:
                        if PurePath(d).match(str(Events_match)):
                            self.path2EventsData = os.path.join(d)
                            print(f"Event data at: {self.path2EventsData}")
                    if ".nwb" in ff:
                        self.path2NWBData = os.path.join(d, ff)
                        print(f"nwb data at: {self.path2NWBData}")
                    if "spike_templates.npy" in ff:
                        self.path2KiloSortData = os.path.join(d)
                        print(
                            f"Found KiloSort data at {self.path2KiloSortData}")

get_spike_times(tetrode=None, cluster=None, *args, **kwargs) cached

Parameters:

Name Type Description Default
tetrode int
None
cluster int
None

Returns:

Name Type Description
spike_times ndarray
Source code in ephysiopy/io/recording.py
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
@cache
def get_spike_times(self,
                    tetrode: int = None,
                    cluster: int = None,
                    *args, **kwargs):
    """
    Args:
        tetrode (int): 
        cluster (int): 

    Returns:
        spike_times (np.ndarray): 
    """
    if not self.clusterData:
        self.load_cluster_data()
    ts = self.clusterData.spk_times
    if cluster in self.clusterData.spk_clusters:
        times = ts[self.clusterData.spk_clusters == cluster]
        return times.astype(np.int64) / self.sample_rate
    else:
        warnings.warn("Cluster not present")

load_lfp(*args, **kwargs) cached

Valid kwargs are: 'target_sample_rate' - int the sample rate to downsample to from the original

Source code in ephysiopy/io/recording.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
@cache
def load_lfp(self, *args, **kwargs):
    """
    Valid kwargs are:
    'target_sample_rate' - int
        the sample rate to downsample to from the original
    """
    from scipy import signal

    if self.path2LFPdata is not None:
        lfp = memmapBinaryFile(
            os.path.join(self.path2LFPdata, "continuous.dat"),
            n_channels=self.channel_count,
        )
        channel = 0
        if "channel" in kwargs.keys():
            channel = kwargs["channel"]
        target_sample_rate = 500
        if "target_sample_rate" in kwargs.keys():
            target_sample_rate = kwargs["target_sample_rate"]
        n_samples = np.shape(lfp[channel, :])[0]
        sig = signal.resample(
            lfp[channel, :], int(
                n_samples / self.sample_rate) * target_sample_rate
        )
        self.EEGCalcs = EEGCalcsGeneric(sig, target_sample_rate)

load_ttl(*args, **kwargs) cached

Parameters:

Name Type Description Default
StimControl_id str

This is the string "StimControl [0-9][][0-9]" where the numbers are the node id in the openephys signal chain

required
TTL_channel_number int

The integer value in the "states.npy" file that corresponds to the identity of the TTL input on the Digital I/O board on the openephys recording system. i.e. if there is input to BNC port 3 on the digital I/O board then values of 3 in the states.npy file are high TTL values on this input and -3 are low TTL values (I think)

required

Returns:

Name Type Description
bool

Nothing but sets some keys/values in a dict on 'self'

bool

called ttl_data, namely:

ttl_timestamps list

the times of high ttl pulses in ms

stim_duration int

the duration of the ttl pulse in ms

Source code in ephysiopy/io/recording.py
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
@cache
def load_ttl(self, *args, **kwargs) -> bool:
    """
    Args:
        StimControl_id (str): This is the string 
            "StimControl [0-9][0-9][0-9]" where the numbers
            are the node id in the openephys signal chain
        TTL_channel_number (int): The integer value in the "states.npy"
            file that corresponds to the
            identity of the TTL input on the Digital I/O board on the
            openephys recording system. i.e. if there is input to BNC
            port 3 on the digital I/O board then values of 3 in the
            states.npy file are high TTL values on this input and -3
            are low TTL values (I think)

    Returns:
        Nothing but sets some keys/values in a dict on 'self'
        called ttl_data, namely:

        ttl_timestamps (list): the times of high ttl pulses in ms
        stim_duration (int): the duration of the ttl pulse in ms
    """
    if not Path(self.path2EventsData).exists:
        return False
    ttl_ts = np.load(os.path.join(self.path2EventsData, "timestamps.npy"))
    states = np.load(os.path.join(self.path2EventsData, "states.npy"))
    recording_start_time = self._get_recording_start_time()
    self.ttl_data = {}
    if "StimControl_id" in kwargs.keys():
        stim_id = kwargs["StimControl_id"]
        if stim_id in self.settings.processors.keys():
            duration = getattr(
                self.settings.processors[stim_id], "Duration")
        else:
            return False
        self.ttl_data["stim_duration"] = int(duration)
    if "TTL_channel_number" in kwargs.keys():
        chan = kwargs["TTL_channel_number"]
        high_ttl = ttl_ts[states == chan]
        # get into ms
        high_ttl = (high_ttl * 1000.0) - recording_start_time
        self.ttl_data['ttl_timestamps'] = high_ttl
    if not self.ttl_data:
        return False
    print("Loaded ttl data")
    return True

TrialInterface

Bases: FigureMaker

Defines a minimal and required set of methods for loading electrophysiology data recorded using Axona or OpenEphys (OpenEphysNWB is there but not used)

Source code in ephysiopy/io/recording.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class TrialInterface(FigureMaker, metaclass=abc.ABCMeta):
    """
    Defines a minimal and required set of methods for loading
    electrophysiology data recorded using Axona or OpenEphys
    (OpenEphysNWB is there but not used)
    """

    def __init__(self, pname: Path, **kwargs) -> None:
        assert Path(pname).exists(), f"Path provided doesnt exist: {pname}"
        self._pname = pname
        self._settings = None
        self._PosCalcs = None
        self._RateMap = None
        self._EEGCalcs = None
        self._sync_message_file = None
        self._clusterData = None  # Kilosort or .cut / .clu file
        self._recording_start_time = None  # float
        self._ttl_data = None  # dict
        self._accelerometer_data = None
        self._path2PosData = None  # Path or str

    @classmethod
    def __subclasshook__(cls, subclass):
        return (
            hasattr(subclass, "load_neural_data")
            and callable(subclass.load_neural_data)
            and hasattr(subclass, "load_lfp")
            and callable(subclass.load_lfp)
            and hasattr(subclass, "load_pos")
            and callable(subclass.load_pos)
            and hasattr(subclass, "load_cluster_data")
            and callable(subclass.load_cluster_data)
            and hasattr(subclass, "load_settings")
            and callable(subclass.load_settings)
            and hasattr(subclass, "get_spike_times")
            and callable(subclass.get_spike_times)
            and hasattr(subclass, "load_ttl")
            and callable(subclass.load_ttl)
            or NotImplemented
        )

    @property
    def pname(self):
        return self._pname

    @pname.setter
    def pname(self, val):
        self._pname = val

    @property
    def settings(self):
        return self._settings

    @settings.setter
    def settings(self, val):
        self._settings = val

    @property
    def PosCalcs(self):
        return self._PosCalcs

    @PosCalcs.setter
    def PosCalcs(self, val):
        self._PosCalcs = val

    @property
    def RateMap(self):
        return self._RateMap

    @RateMap.setter
    def RateMap(self, value):
        self._RateMap = value

    @property
    def EEGCalcs(self):
        return self._EEGCalcs

    @EEGCalcs.setter
    def EEGCalcs(self, val):
        self._EEGCalcs = val

    @property
    def clusterData(self):
        return self._clusterData

    @clusterData.setter
    def clusterData(self, val):
        self._clusterData = val

    @property
    def recording_start_time(self):
        return self._recording_start_time

    @recording_start_time.setter
    def recording_start_time(self, val):
        self._recording_start_time = val

    @property
    def sync_message_file(self):
        return self._sync_message_file

    @sync_message_file.setter
    def sync_message_file(self, val):
        self._sync_message_file = val

    @property
    def ttl_data(self):
        return self._ttl_data

    @ttl_data.setter
    def ttl_data(self, val):
        self._ttl_data = val

    @property
    def accelerometer_data(self):
        return self._accelerometer_data

    @accelerometer_data.setter
    def accelerometer_data(self, val):
        self._accelerometer_data = val

    @property
    def path2PosData(self):
        return self._path2PosData

    @path2PosData.setter
    def path2PosData(self, val):
        self._path2PosData = val

    @abc.abstractmethod
    def load_lfp(self, *args, **kwargs) -> NoReturn:
        """Load the LFP data"""
        raise NotImplementedError

    @abc.abstractmethod
    def load_neural_data(self, *args, **kwargs) -> NoReturn:
        """Load the neural data"""
        raise NotImplementedError

    @abc.abstractmethod
    def load_pos_data(
        self, ppm: int = 300, jumpmax: int = 100, *args, **kwargs
    ) -> NoReturn:
        """
        Load the position data

        Args:
            pname (Path): Path to base directory containing pos data
            ppm (int): pixels per metre
            jumpmax (int): max jump in pixels between positions, more
                than this and the position is interpolated over
        """
        raise NotImplementedError

    @abc.abstractmethod
    def load_cluster_data(self, *args, **kwargs) -> bool:
        """Load the cluster data (Kilosort/ Axona cut/ whatever else"""
        raise NotImplementedError

    @abc.abstractmethod
    def load_settings(self, *args, **kwargs) -> NoReturn:
        """Loads the format specific settings file"""
        raise NotImplementedError

    @abc.abstractmethod
    def load_ttl(self, *args, **kwargs) -> bool:
        raise NotImplementedError

    @abc.abstractmethod
    def get_spike_times(self, cluster: int, tetrode: int, *args, **kwargs):
        """Returns the times of an individual cluster"""
        raise NotImplementedError

get_spike_times(cluster, tetrode, *args, **kwargs) abstractmethod

Returns the times of an individual cluster

Source code in ephysiopy/io/recording.py
287
288
289
290
@abc.abstractmethod
def get_spike_times(self, cluster: int, tetrode: int, *args, **kwargs):
    """Returns the times of an individual cluster"""
    raise NotImplementedError

load_cluster_data(*args, **kwargs) abstractmethod

Load the cluster data (Kilosort/ Axona cut/ whatever else

Source code in ephysiopy/io/recording.py
273
274
275
276
@abc.abstractmethod
def load_cluster_data(self, *args, **kwargs) -> bool:
    """Load the cluster data (Kilosort/ Axona cut/ whatever else"""
    raise NotImplementedError

load_lfp(*args, **kwargs) abstractmethod

Load the LFP data

Source code in ephysiopy/io/recording.py
248
249
250
251
@abc.abstractmethod
def load_lfp(self, *args, **kwargs) -> NoReturn:
    """Load the LFP data"""
    raise NotImplementedError

load_neural_data(*args, **kwargs) abstractmethod

Load the neural data

Source code in ephysiopy/io/recording.py
253
254
255
256
@abc.abstractmethod
def load_neural_data(self, *args, **kwargs) -> NoReturn:
    """Load the neural data"""
    raise NotImplementedError

load_pos_data(ppm=300, jumpmax=100, *args, **kwargs) abstractmethod

Load the position data

Parameters:

Name Type Description Default
pname Path

Path to base directory containing pos data

required
ppm int

pixels per metre

300
jumpmax int

max jump in pixels between positions, more than this and the position is interpolated over

100
Source code in ephysiopy/io/recording.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
@abc.abstractmethod
def load_pos_data(
    self, ppm: int = 300, jumpmax: int = 100, *args, **kwargs
) -> NoReturn:
    """
    Load the position data

    Args:
        pname (Path): Path to base directory containing pos data
        ppm (int): pixels per metre
        jumpmax (int): max jump in pixels between positions, more
            than this and the position is interpolated over
    """
    raise NotImplementedError

load_settings(*args, **kwargs) abstractmethod

Loads the format specific settings file

Source code in ephysiopy/io/recording.py
278
279
280
281
@abc.abstractmethod
def load_settings(self, *args, **kwargs) -> NoReturn:
    """Loads the format specific settings file"""
    raise NotImplementedError

memmapBinaryFile(path2file, n_channels=384, **kwargs)

Returns a numpy memmap of the int16 data in the file path2file, if present

Source code in ephysiopy/io/recording.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def memmapBinaryFile(path2file: str, n_channels=384, **kwargs) -> np.ndarray:
    """
    Returns a numpy memmap of the int16 data in the
    file path2file, if present
    """
    import os

    if "data_type" in kwargs.keys():
        data_type = kwargs["data_type"]
    else:
        data_type = np.int16

    if os.path.exists(path2file):
        # make sure n_channels is int as could be str
        n_channels = int(n_channels)
        status = os.stat(path2file)
        n_samples = int(status.st_size / (2.0 * n_channels))
        mmap = np.memmap(
            path2file, data_type, "r", 0, (n_channels, n_samples), order="F"
        )
        return mmap
    else:
        return np.empty(0)

Plotting the results

FigureMaker

Bases: object

A mixin class for TrialInterface that deals solely with producing graphical output.

Source code in ephysiopy/visualise/plotting.py
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
class FigureMaker(object):
    """
    A mixin class for TrialInterface that deals solely with
    producing graphical output.
    """

    def __init__(self):
        """
        Initializes the FigureMaker object.
        """
        self.PosCalcs = None

    def initialise(self):
        """
        Initializes the FigureMaker object with data from PosCalcs.
        """
        self.RateMap = RateMap(self.PosCalcs.xy,
                               self.PosCalcs.dir,
                               self.PosCalcs.speed,)
        self.npos = self.PosCalcs.xy.shape[1]

    def _plot_multiple_clusters(self,
                                func,
                                clusters: list,
                                channel: int,
                                **kwargs):
        """
        Plots multiple clusters.

        Args:
            func (function): The function to apply to each cluster.
            clusters (list): The list of clusters to plot.
            channel (int): The channel number.
            **kwargs: Additional keyword arguments for the function.
        """
        fig = plt.figure()
        nrows = int(np.ceil(len(clusters) / 5))
        if 'projection' in kwargs.keys():
            proj = kwargs.pop('projection')
        else:
            proj = None
        for i, c in enumerate(clusters):
            ax = fig.add_subplot(nrows, 5, i+1, projection=proj)
            ts = self.get_spike_times(channel, c)
            func(ts, ax=ax, **kwargs)

    def rate_map(self, cluster: int | list, channel: int, **kwargs):
        """
        Gets the rate map for the specified cluster(s) and channel.

        Args:
            cluster (int | list): The cluster(s) to get the rate map for.
            channel (int): The channel number.
            **kwargs: Additional keyword arguments for the function.
        """
        if isinstance(cluster, list):
            self._plot_multiple_clusters(self.makeRateMap,
                                         cluster,
                                         channel,
                                         **kwargs)
        else:
            ts = self.get_spike_times(channel, cluster)
            self.makeRateMap(ts, **kwargs)
        plt.show()

    def hd_map(self, cluster: int | list, channel: int, **kwargs):
        """
        Gets the head direction map for the specified cluster(s) and channel.

        Args:
            cluster (int | list): The cluster(s) to get the head direction map
                for.
            channel (int): The channel number.
            **kwargs: Additional keyword arguments for the function.
        """
        if isinstance(cluster, list):
            self._plot_multiple_clusters(self.makeHDPlot,
                                         cluster,
                                         channel,
                                         projection="polar",
                                         strip_axes=True,
                                         **kwargs)
        else:
            ts = self.get_spike_times(channel, cluster)
            self.makeHDPlot(ts, **kwargs)
        plt.show()

    def spike_path(self, cluster=None, channel=None, **kwargs):
        """
        Gets the spike path for the specified cluster(s) and channel.

        Args:
            cluster (int | list | None): The cluster(s) to get the spike path
                for.
            channel (int | None): The channel number.
            **kwargs: Additional keyword arguments for the function.
        """
        if isinstance(cluster, list):
            self._plot_multiple_clusters(self.makeSpikePathPlot,
                                         cluster,
                                         channel,
                                         **kwargs)
        else:
            if channel is not None and cluster is not None:
                ts = self.get_spike_times(channel, cluster)
            else:
                ts = None
            self.makeSpikePathPlot(ts, **kwargs)
        plt.show()

    def eb_map(self, cluster: int | list, channel: int, **kwargs):
        """
        Gets the ego-centric boundary map for the specified cluster(s) and
        channel.

        Args:
            cluster (int | list): The cluster(s) to get the ego-centric
                boundary map for.
            channel (int): The channel number.
            **kwargs: Additional keyword arguments for the function.
        """
        if isinstance(cluster, list):
            self._plot_multiple_clusters(self.makeEgoCentricBoundaryMap,
                                         cluster,
                                         channel,
                                         projection='polar',
                                         **kwargs)
        else:
            ts = self.get_spike_times(channel, cluster)
            self.makeEgoCentricBoundaryMap(ts, **kwargs)
        plt.show()

    def eb_spikes(self, cluster: int | list, channel: int, **kwargs):
        """
        Gets the ego-centric boundary spikes for the specified cluster(s)
        and channel.

        Args:
            cluster (int | list): The cluster(s) to get the ego-centric
                boundary spikes for.
            channel (int): The channel number.
            **kwargs: Additional keyword arguments for the function.
        """
        if isinstance(cluster, list):
            self._plot_multiple_clusters(self.makeEgoCentricBoundarySpikePlot,
                                         cluster,
                                         channel,
                                         **kwargs)
        else:
            ts = self.get_spike_times(channel, cluster)
            self.makeEgoCentricBoundarySpikePlot(ts, **kwargs)
        plt.show()

    def sac(self, cluster: int | list, channel: int, **kwargs):
        """
        Gets the spatial autocorrelation for the specified cluster(s) and
        channel.

        Args:
            cluster (int | list): The cluster(s) to get the spatial
                autocorrelation for.
            channel (int): The channel number.
            **kwargs: Additional keyword arguments for the function.
        """
        if isinstance(cluster, list):
            self._plot_multiple_clusters(self.makeSAC,
                                         cluster,
                                         channel,
                                         **kwargs)
        else:
            ts = self.get_spike_times(channel, cluster)
            self.makeSAC(ts, **kwargs)
        plt.show()

    def speed_v_rate(self, cluster: int | list, channel: int, **kwargs):
        """
        Gets the speed versus rate plot for the specified cluster(s) and
        channel.

        Args:
            cluster (int | list): The cluster(s) to get the speed versus rate
                plot for.
            channel (int): The channel number.
            **kwargs: Additional keyword arguments for the function.
        """
        if isinstance(cluster, list):
            self._plot_multiple_clusters(self.makeSpeedVsRatePlot,
                                         cluster,
                                         channel,
                                         **kwargs)
        else:
            ts = self.get_spike_times(channel, cluster)
            self.makeSpeedVsRatePlot(ts, **kwargs)
        plt.show()

    def speed_v_hd(self, cluster: int | list, channel: int, **kwargs):
        """
        Gets the speed versus head direction plot for the specified cluster(s)
        and channel.

        Args:
            cluster (int | list): The cluster(s) to get the speed versus head
                direction plot for.
            channel (int): The channel number.
            **kwargs: Additional keyword arguments for the function.
        """
        if isinstance(cluster, list):
            self._plot_multiple_clusters(self.makeSpeedVsHeadDirectionPlot,
                                         cluster,
                                         channel,
                                         **kwargs)
        else:
            ts = self.get_spike_times(channel, cluster)
            self.makeSpeedVsHeadDirectionPlot(ts, **kwargs)
        plt.show()

    def power_spectrum(self, **kwargs):
        """
        Gets the power spectrum.

        Args:
            **kwargs: Additional keyword arguments for the function.
        """
        p = self.EEGCalcs.calcEEGPowerSpectrum()
        self.makePowerSpectrum(p[0], p[1], p[2], p[3], p[4], **kwargs)
        plt.show()

    def getSpikePosIndices(self, spk_times: np.ndarray):
        """
        Returns the indices into the position data at which some spike times
        occurred.

        Args:
            spk_times (np.ndarray): The spike times in seconds.

        Returns:
            np.ndarray: The indices into the position data at which the spikes
                occurred.
        """
        pos_times = getattr(self.PosCalcs, "xyTS")
        idx = np.searchsorted(pos_times, spk_times) - 1
        return idx

    def makeSummaryPlot(self, spk_times: np.ndarray):
        """
        Creates a summary plot with spike path, rate map, head direction plot,
        and spatial autocorrelation.

        Args:
            spk_times (np.ndarray): The spike times in seconds.

        Returns:
            matplotlib.figure.Figure: The created figure.
        """
        fig = plt.figure()
        ax = plt.subplot(221)
        self.makeSpikePathPlot(spk_times, ax=ax, markersize=2)
        ax = plt.subplot(222)
        self.makeRateMap(spk_times, ax=ax)
        ax = plt.subplot(223, projection="polar")
        self.makeHDPlot(spk_times, ax=ax)
        ax = plt.subplot(224)
        try:
            self.makeSAC(spk_times, ax=ax)
        except IndexError:
            pass
        return fig

    def makeSpikePlot(self,
                      mean_waveform: bool = True,
                      ax: matplotlib.axes = None,
                      **kwargs) -> matplotlib.figure:
        if not self.SpikeCalcs:
            Warning("No spike data loaded")
            return
        waves = self.SpikeCalcs.waveforms(range(4))
        if ax is None:
            fig = plt.figure()
        spike_at = np.shape(waves)[2] // 2
        if spike_at > 25:  # OE data
            # this should be equal to range(25, 75)
            t = range(spike_at - self.SpikeCalcs.pre_spike_samples,
                      spike_at + self.SpikeCalcs.post_spike_samples)
        else:  # Axona data
            t = range(50)
        if mean_waveform:
            for i in range(4):
                ax = fig.add_subplot(2, 2, i+1)
                ax = self._plotSpikes(np.mean(
                    waves[:, :, t], 0)[i, :], ax=ax, **kwargs)
                if spike_at > 25:  # OE data
                    ax.invert_yaxis()
        else:
            for i in range(4):
                ax = fig.add_subplot(2, 2, i+1)
                ax = self._plotSpikes(waves[:, i, t], ax=ax, **kwargs)
                if spike_at > 25:  # OE data
                    ax.invert_yaxis()
        return fig

    @stripAxes
    def _plotSpikes(self, waves: np.ndarray,
                    ax: matplotlib.axes,
                    **kwargs) -> matplotlib.axes:
        ax.plot(waves, c='k', **kwargs)
        return ax

    @stripAxes
    def makeRateMap(self,
                    spk_times: np.ndarray,
                    ax: matplotlib.axes = None,
                    **kwargs) -> matplotlib.axes:
        """
        Creates a rate map plot.

        Args:
            spk_times (np.ndarray): The spike times in seconds.
            ax (matplotlib.axes, optional): The axes to plot on. If None,
                new axes are created.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if not self.RateMap:
            self.initialise()
        spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)
        spk_weights = np.bincount(
            spk_times_in_pos_samples, minlength=self.npos)
        rmap = self.RateMap.getMap(spk_weights)
        ratemap = np.ma.MaskedArray(rmap[0], np.isnan(rmap[0]), copy=True)
        x, y = np.meshgrid(rmap[1][1][0:-1].data, rmap[1][0][0:-1].data)
        vmax = np.nanmax(np.ravel(ratemap))
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
        ax.pcolormesh(
            x, y, ratemap,
            cmap=jet_cmap,
            edgecolors="face",
            vmax=vmax,
            shading="auto",
            **kwargs
        )
        ax.set_aspect("equal")
        return ax

    @stripAxes
    def makeSpikePathPlot(self,
                          spk_times: np.ndarray = None,
                          ax: matplotlib.axes = None,
                          **kwargs) -> matplotlib.axes:
        """
        Creates a spike path plot.

        Args:
            spk_times (np.ndarray, optional): The spike times in seconds.
                If None, no spikes are plotted.
            ax (matplotlib.axes, optional): The axes to plot on.
                If None, new axes are created.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if not self.RateMap:
            self.initialise()
        if "c" in kwargs:
            col = kwargs.pop("c")
        else:
            col = tcols.colours[1]
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
        ax.plot(
            self.PosCalcs.xy[0, :],
            self.PosCalcs.xy[1, :],
            c=tcols.colours[0], zorder=1
        )
        ax.set_aspect("equal")
        if spk_times is not None:
            idx = self.getSpikePosIndices(spk_times)
            ax.plot(
                self.PosCalcs.xy[0, idx],
                self.PosCalcs.xy[1, idx],
                "s", c=col, **kwargs
            )
        return ax

    def makeEgoCentricBoundaryMap(self,
                                  spk_times: np.ndarray,
                                  ax: matplotlib.axes = None,
                                  **kwargs) -> matplotlib.axes:
        """
        Creates an ego-centric boundary map plot.

        Args:
            spk_times (np.ndarray): The spike times in seconds.
            ax (matplotlib.axes, optional): The axes to plot on. If None,
                new axes are created.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if not self.RateMap:
            self.initialise()

        degs_per_bin = 3
        xy_binsize = 2.5
        arena_type = "circle"
        # parse kwargs
        if "degs_per_bin" in kwargs.keys():
            degs_per_bin = kwargs["degs_per_bin"]
        if "xy_binsize" in kwargs.keys():
            xy_binsize = kwargs["xy_binsize"]
        if "arena_type" in kwargs.keys():
            arena_type = kwargs["arena_type"]
        if "strip_axes" in kwargs.keys():
            strip_axes = kwargs.pop("strip_axes")
        else:
            strip_axes = False
        if 'return_ratemap' in kwargs.keys():
            return_ratemap = kwargs.pop('return_ratemap')
        else:
            return_ratemap = False

        idx = self.getSpikePosIndices(spk_times)
        spk_weights = np.bincount(idx, minlength=len(self.RateMap.dir))
        ego_map = self.RateMap.get_egocentric_boundary_map(spk_weights,
                                                           degs_per_bin,
                                                           xy_binsize,
                                                           arena_type)
        rmap = ego_map.rmap
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(projection='polar')
        theta = np.arange(0, 2*np.pi, 2*np.pi/rmap.shape[1])
        phi = np.arange(0, rmap.shape[0]*2.5, 2.5)
        X, Y = np.meshgrid(theta, phi)
        ax.pcolormesh(X, Y, rmap, **kwargs)
        ax.set_xticks(np.arange(0, 2*np.pi, np.pi/4))
        # ax.set_xticklabels(np.arange(0, 2*np.pi, np.pi/4))
        ax.set_yticks(np.arange(0, 50, 10))
        ax.set_yticklabels(np.arange(0, 50, 10))
        ax.set_xlabel('Angle (deg)')
        ax.set_ylabel('Distance (cm)')
        if strip_axes:
            return stripAxes(ax)
        if return_ratemap:
            return ax, rmap
        return ax

    @stripAxes
    def makeEgoCentricBoundarySpikePlot(self,
                                        spk_times: np.ndarray,
                                        add_colour_wheel: bool = False,
                                        ax: matplotlib.axes = None,
                                        **kwargs) -> matplotlib.axes:
        """
        Creates an ego-centric boundary spike plot.

        Args:
            spk_times (np.ndarray): The spike times in seconds.
            add_colour_wheel (bool, optional): Whether to add a colour wheel
                to the plot. Defaults to False.
            ax (matplotlib.axes, optional): The axes to plot on. If None,
                new axes are created.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if not self.RateMap:
            self.initialise()
        # get the index into a circular colormap based
        # on directional heading, then create a LineCollection
        num_dir_bins = 60
        if "dir_bins" in kwargs.keys():
            num_dir_bins = kwargs["num_dir_bins"]
        if "strip_axes" in kwargs.keys():
            strip_axes = kwargs.pop("strip_axes")
        else:
            strip_axes = False
        if "ms" in kwargs.keys():
            rect_size = kwargs.pop("ms")
        else:
            rect_size = 1
        dir_colours = sns.color_palette('hls', num_dir_bins)
        # need to create line colours and line widths for the collection
        idx = self.getSpikePosIndices(spk_times)
        dir_spike_fired_at = self.RateMap.dir[idx]
        idx_of_dir_to_colour = np.floor(
            dir_spike_fired_at / (360 / num_dir_bins)).astype(int)
        rects = [Rectangle(self.RateMap.xy[:, i],
                           width=rect_size, height=rect_size)
                 for i in idx]
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot()
        # plot the path
        ax.plot(self.RateMap.xy[0],
                self.RateMap.xy[1],
                c=tcols.colours[0],
                zorder=1,
                alpha=0.3)
        ax.set_aspect('equal')
        for col_idx, r in zip(idx_of_dir_to_colour, rects):
            ax.add_artist(r)
            r.set_clip_box(ax.bbox)
            r.set_facecolor(dir_colours[col_idx])
            r.set_rasterized(True)
        if add_colour_wheel:
            ax_col = ax.inset_axes(bounds=[0.75, 0.75, 0.15, 0.15],
                                   projection='polar',
                                   transform=fig.transFigure)
            ax_col.set_theta_zero_location("N")
            theta = np.linspace(0, 2*np.pi, 1000)
            phi = np.linspace(0, 1, 2)
            X, Y = np.meshgrid(phi, theta)
            norm = matplotlib.colors.Normalize(0, 2*np.pi)
            col_map = sns.color_palette('hls', as_cmap=True)
            ax_col.pcolormesh(theta, phi, Y.T, norm=norm, cmap=col_map)
            ax_col.set_yticklabels([])
            ax_col.spines['polar'].set_visible(False)
            ax_col.set_thetagrids([0, 90])
        if strip_axes:
            return stripAxes(ax)
        return ax

    @stripAxes
    def makeSAC(
        self, spk_times: np.array = None, ax: matplotlib.axes = None, **kwargs
    ) -> matplotlib.axes:
        """
        Creates a spatial autocorrelation plot.

        Args:
            spk_times (np.array, optional): The spike times in seconds. If
                None, no spikes are plotted.
            ax (matplotlib.axes, optional): The axes to plot on. If None,
                new axes are created.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if not self.RateMap:
            self.initialise()
        spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)
        spk_weights = np.bincount(
            spk_times_in_pos_samples, minlength=self.npos)
        sac = self.RateMap.getSAC(spk_weights)
        from ephysiopy.common.gridcell import SAC

        S = SAC()
        measures = S.getMeasures(sac)
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
        ax = self.show_SAC(sac, measures, ax)
        return ax

    def makeHDPlot(
        self, spk_times: np.array = None, ax: matplotlib.axes = None, **kwargs
    ) -> matplotlib.axes:
        """
        Creates a head direction plot.

        Args:
            spk_times (np.array, optional): The spike times in seconds. If
                None, no spikes are plotted.
            ax (matplotlib.axes, optional): The axes to plot on. If None, new
                axes are created.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if not self.RateMap:
            self.initialise()
        if "strip_axes" in kwargs.keys():
            strip_axes = kwargs.pop("strip_axes")
        else:
            strip_axes = True
        spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)
        spk_weights = np.bincount(
            spk_times_in_pos_samples, minlength=self.npos)
        rmap = self.RateMap.getMap(spk_weights, varType=VariableToBin.DIR)
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111, **kwargs)
        ax.set_theta_zero_location("N")
        # need to deal with the case where the axis is supplied but
        # is not polar. deal with polar first
        theta = np.deg2rad(rmap[1][0])
        ax.clear()
        r = rmap[0]  # in samples so * pos sample_rate
        r = np.insert(r, -1, r[0])
        if "polar" in ax.name:
            ax.plot(theta, r)
            if "fill" in kwargs:
                ax.fill(theta, r, alpha=0.5)
            ax.set_aspect("equal")
        else:
            pass

        # See if we should add the mean resultant vector (mrv)
        if "add_mrv" in kwargs:
            from ephysiopy.common.statscalcs import mean_resultant_vector

            angles = self.PosCalcs.dir[spk_times_in_pos_samples]
            r, th = mean_resultant_vector(np.deg2rad(angles))
            ax.plot([th, th], [0, r * np.max(rmap[0])], "r")
        if "polar" in ax.name:
            ax.set_thetagrids([0, 90, 180, 270])
        if strip_axes:
            return stripAxes(ax)
        return ax

    def makeSpeedVsRatePlot(
        self,
        spk_times: np.array,
        minSpeed: float = 0.0,
        maxSpeed: float = 40.0,
        sigma: float = 3.0,
        ax: matplotlib.axes = None,
        **kwargs
    ) -> matplotlib.axes:
        """
        Plots the instantaneous firing rate of a cell against running speed.
        Also outputs a couple of measures as with Kropff et al., 2015; the
        Pearsons correlation and the depth of modulation (dom).

        Args:
            spk_times (np.array): The spike times in seconds.
            minSpeed (float, optional): The minimum speed. Defaults to 0.0.
            maxSpeed (float, optional): The maximum speed. Defaults to 40.0.
            sigma (float, optional): The sigma value. Defaults to 3.0.
            ax (matplotlib.axes, optional): The axes to plot on. If None, new
                axes are created.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if "strip_axes" in kwargs.keys():
            strip_axes = kwargs.pop("strip_axes")
        else:
            strip_axes = False
        if not self.RateMap:
            self.initialise()
        spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)

        speed = np.ravel(self.PosCalcs.speed)
        if np.nanmax(speed) < maxSpeed:
            maxSpeed = np.nanmax(speed)
        spd_bins = np.arange(minSpeed, maxSpeed, 1.0)
        # Construct the mask
        speed_filt = np.ma.MaskedArray(speed)
        speed_filt = np.ma.masked_where(speed_filt < minSpeed, speed_filt)
        speed_filt = np.ma.masked_where(speed_filt > maxSpeed, speed_filt)
        from ephysiopy.common.spikecalcs import SpikeCalcsGeneric

        x1 = spk_times_in_pos_samples
        S = SpikeCalcsGeneric(x1)
        spk_sm = S.smooth_spike_train(x1,
                                      self.PosCalcs.xyTS.shape[0],
                                      sigma, None)
        spk_sm = np.ma.MaskedArray(spk_sm, mask=np.ma.getmask(speed_filt))
        spd_dig = np.digitize(speed_filt, spd_bins, right=True)
        mn_rate = np.array(
            [np.ma.mean(spk_sm[spd_dig == i]) for i in range(0, len(spd_bins))]
        )
        var = np.array(
            [np.ma.std(spk_sm[spd_dig == i]) for i in range(0, len(spd_bins))]
        )
        np.array([np.ma.sum(spk_sm[spd_dig == i]) for i in range(
            0, len(spd_bins))])
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
        ax.errorbar(spd_bins, mn_rate * self.PosCalcs.sample_rate,
                    yerr=var, color="k")
        ax.set_xlim(spd_bins[0], spd_bins[-1])
        ax.set_xticks(
            [spd_bins[0], spd_bins[-1]],
            labels=["0", "{:.2g}".format(spd_bins[-1])],
            fontweight="normal",
            size=6,
        )
        ax.set_yticks(
            [0, np.nanmax(mn_rate) * self.PosCalcs.sample_rate],
            labels=["0", "{:.2f}".format(np.nanmax(mn_rate))],
            fontweight="normal",
            size=6,
        )
        if strip_axes:
            return stripAxes(ax)
        return ax

    def makeSpeedVsHeadDirectionPlot(
        self, spk_times: np.array, ax: matplotlib.axes = None, **kwargs
    ) -> matplotlib.axes:
        """
        Creates a speed versus head direction plot.

        Args:
            spk_times (np.array): The spike times in seconds.
            ax (matplotlib.axes, optional): The axes to plot on. If None,
                new axes are created.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if "strip_axes" in kwargs.keys():
            strip_axes = kwargs.pop("strip_axes")
        else:
            strip_axes = False
        if not self.RateMap:
            self.initialise()
        spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)
        idx = np.array(spk_times_in_pos_samples, dtype=int)
        w = np.bincount(idx, minlength=self.PosCalcs.speed.shape[0])
        if np.ma.is_masked(self.PosCalcs.speed):
            w[self.PosCalcs.speed.mask] = 0

        dir_bins = np.arange(0, 360, 6)
        spd_bins = np.arange(0, 30, 1)
        h = np.histogram2d(self.PosCalcs.dir,
                           self.PosCalcs.speed,
                           [dir_bins, spd_bins], weights=w)
        from ephysiopy.common.utils import blurImage

        im = blurImage(h[0], 5, ftype="gaussian")
        im = np.ma.MaskedArray(im)
        # mask low rates...
        im = np.ma.masked_where(im <= 1, im)
        # ... and where less than 0.5% of data is accounted for
        x, y = np.meshgrid(dir_bins, spd_bins)
        vmax = np.max(np.ravel(im))
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
        ax.pcolormesh(x, y, im.T,
                      cmap=jet_cmap, edgecolors="face",
                      vmax=vmax, shading="auto")
        ax.set_xticks([90, 180, 270], labels=['90', '180', '270'],
                      fontweight="normal", size=6)
        ax.set_yticks([10, 20], labels=['10', '20'],
                      fontweight="normal", size=6)
        ax.set_xlabel("Heading", fontweight="normal", size=6)
        if strip_axes:
            stripAxes(ax)
        return ax

    def makePowerSpectrum(
        self,
        freqs: np.array,
        power: np.array,
        sm_power: np.array,
        band_max_power: float,
        freq_at_band_max_power: float,
        max_freq: int = 50,
        theta_range: tuple = [6, 12],
        ax: matplotlib.axes = None,
        **kwargs
    ) -> matplotlib.axes:
        """
        Plots the power spectrum. The parameters can be obtained from
        calcEEGPowerSpectrum() in the EEGCalcsGeneric class.

        Args:
            freqs (np.array): The frequencies.
            power (np.array): The power values.
            sm_power (np.array): The smoothed power values.
            band_max_power (float): The maximum power in the band.
            freq_at_band_max_power (float): The frequency at which the maximum
                power in the band occurs.
            max_freq (int, optional): The maximum frequency. Defaults to 50.
            theta_range (tuple, optional): The theta range.
                Defaults to [6, 12].
            ax (matplotlib.axes, optional): The axes to plot on. If None, new
                axes are created.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if "strip_axes" in kwargs.keys():
            strip_axes = kwargs.pop("strip_axes")
        else:
            strip_axes = False
        # downsample frequencies and power
        freqs = freqs[0::50]
        power = power[0::50]
        sm_power = sm_power[0::50]
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
        ax.plot(freqs, power, alpha=0.5, color=[0.8627, 0.8627, 0.8627])
        ax.plot(freqs, sm_power)
        ax.set_xlim(0, max_freq)
        ylim = [0, np.max(sm_power[freqs < max_freq])]
        if "ylim" in kwargs:
            ylim = kwargs["ylim"]
        ax.set_ylim(ylim)
        ax.set_ylabel("Power")
        ax.set_xlabel("Frequency")
        ax.text(
            x=theta_range[1] / 0.9,
            y=band_max_power,
            s=str(freq_at_band_max_power)[0:4],
            fontsize=20,
        )
        from matplotlib.patches import Rectangle

        r = Rectangle(
            (theta_range[0], 0),
            width=np.diff(theta_range)[0],
            height=np.diff(ax.get_ylim())[0],
            alpha=0.25,
            color="r",
            ec="none",
        )
        ax.add_patch(r)
        if strip_axes:
            return stripAxes(ax)
        return ax

    def makeXCorr(
        self, spk_times: np.array, ax: matplotlib.axes = None, **kwargs
    ) -> matplotlib.axes:
        """
        Returns an axis containing the autocorrelogram of the spike
        times provided over the range +/-500ms.

        Args:
            spk_times (np.array): Spike times in seconds.
            ax (matplotlib.axes, optional): The axes to plot into. If None,
                new axes are created.
            **kwargs: Additional keyword arguments for the function.
                binsize (int, optional): The size of the bins in ms. Gets
                passed to SpikeCalcsGeneric.xcorr(). Defaults to 1.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        if "strip_axes" in kwargs.keys():
            strip_axes = kwargs.pop("strip_axes")
        else:
            strip_axes = False
        # spk_times in samples provided in seconds but convert to
        # ms for a more display friendly scale
        spk_times = spk_times
        S = SpikeCalcsGeneric(spk_times)
        c, b = S.acorr(spk_times, **kwargs)
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
        if 'binsize' in kwargs.keys():
            binsize = kwargs['binsize']
        else:
            binsize = 0.001
        if "Trange" in kwargs.keys():
            xrange = kwargs["Trange"]
        else:
            xrange = [-0.5, 0.5]
        ax.bar(b[:-1], c, width=binsize, color="k", align="edge")
        ax.set_xlim(xrange)
        ax.set_xticks((xrange[0], 0, xrange[1]))
        ax.set_xticklabels("")
        ax.tick_params(axis="both", which="both", left=False, right=False,
                       bottom=False, top=False)
        ax.set_yticklabels("")
        ax.xaxis.set_ticks_position("bottom")
        if strip_axes:
            return stripAxes(ax)
        return ax

    def makeRaster(
        self,
        spk_times: np.array,
        dt=(-50, 100),
        prc_max: float = 0.5,
        ax: matplotlib.axes = None,
        ms_per_bin: int = 1,
        sample_rate: float = 3e4,  # OE=3e4, Axona=96000
        **kwargs
    ) -> matplotlib.axes:
        """
        Plots a raster plot for a specified tetrode/ cluster.

        Args:
            spk_times (np.array): The spike times in samples.
            dt (tuple, optional): The window of time in ms to examine zeroed
                on the event of interest i.e. the first value will probably
                be negative as in the example. Defaults to (-50, 100).
            prc_max (float, optional): The proportion of firing the cell has
                to 'lose' to count as silent; a float between 0 and 1.
                Defaults to 0.5.
            ax (matplotlib.axes, optional): The axes to plot into.
                If not provided a new figure is created. Defaults to None.
            ms_per_bin (int, optional): The number of milliseconds in each bin
                of the raster plot. Defaults to 1.
            sample_rate (float, optional): The sample rate. Defaults to 3e4.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.
        """
        assert hasattr(self, "ttl_data")

        if "strip_axes" in kwargs.keys():
            strip_axes = kwargs.pop("strip_axes")
        else:
            strip_axes = False
        x1 = spk_times / sample_rate * 1000.0  # get into ms
        x1.sort()
        on_good = self.ttl_data["ttl_timestamps"]
        dt = np.array(dt)
        irange = on_good[:, np.newaxis] + dt[np.newaxis, :]
        dts = np.searchsorted(x1, irange)
        y = []
        x = []
        for i, t in enumerate(dts):
            tmp = x1[t[0]:t[1]] - on_good[i]
            x.extend(tmp)
            y.extend(np.repeat(i, len(tmp)))
        if ax is None:
            fig = plt.figure(figsize=(4.0, 7.0))
            axScatter = fig.add_subplot(111)
        else:
            axScatter = ax
        histColor = [1 / 255.0, 1 / 255.0, 1 / 255.0]
        axScatter.scatter(x, y, marker=".", s=2,
                          rasterized=False, color=histColor)
        divider = make_axes_locatable(axScatter)
        axScatter.set_xticks((dt[0], 0, dt[1]))
        axScatter.set_xticklabels((str(dt[0]), "0", str(dt[1])))
        axHistx = divider.append_axes("top", 0.95, pad=0.2,
                                      sharex=axScatter,
                                      transform=axScatter.transAxes)
        scattTrans = transforms.blended_transform_factory(
            axScatter.transData, axScatter.transAxes
        )
        stim_pwidth = self.ttl_data["stim_duration"]
        if stim_pwidth is None:
            raise ValueError("stim duration is None")

        axScatter.add_patch(
            Rectangle(
                (0, 0),
                width=stim_pwidth,
                height=1,
                transform=scattTrans,
                color=[0, 0, 1],
                alpha=0.3,
            )
        )
        histTrans = transforms.blended_transform_factory(
            axHistx.transData, axHistx.transAxes
        )
        axHistx.add_patch(
            Rectangle(
                (0, 0),
                width=stim_pwidth,
                height=1,
                transform=histTrans,
                color=[0, 0, 1],
                alpha=0.3,
            )
        )
        axScatter.set_ylabel("Laser stimulation events", labelpad=-2.5)
        axScatter.set_xlabel("Time to stimulus onset(ms)")
        nStms = len(on_good)
        axScatter.set_ylim(0, nStms)
        # Label only the min and max of the y-axis
        ylabels = axScatter.get_yticklabels()
        for i in range(1, len(ylabels) - 1):
            ylabels[i].set_visible(False)
        yticks = axScatter.get_yticklines()
        for i in range(1, len(yticks) - 1):
            yticks[i].set_visible(False)

        axHistx.hist(
            x,
            bins=np.arange(dt[0], dt[1] + ms_per_bin, ms_per_bin),
            color=histColor,
            range=dt,
            rasterized=True,
            histtype="stepfilled",
        )
        axHistx.set_ylabel("Spike count", labelpad=-2.5)
        plt.setp(axHistx.get_xticklabels(), visible=False)
        # Label only the min and max of the y-axis
        ylabels = axHistx.get_yticklabels()
        for i in range(1, len(ylabels) - 1):
            ylabels[i].set_visible(False)
        yticks = axHistx.get_yticklines()
        for i in range(1, len(yticks) - 1):
            yticks[i].set_visible(False)
        axHistx.set_xlim(dt)
        axScatter.set_xlim(dt)
        if strip_axes:
            return stripAxes(axScatter)
        return axScatter

    '''
    def getRasterHist(
            self, spike_ts: np.array,
            sample_rate: int,
            dt=(-50, 100), hist=True):
        """
        MOVE TO SPIKECALCS

        Calculates the histogram of the raster of spikes during a series of
        events

        Parameters
        ----------
        tetrode : int
        cluster : int
        dt : tuple
            the window of time in ms to examine zeroed on the event of interest
            i.e. the first value will probably be negative as in the example
        hist : bool
            not sure
        """
        spike_ts = spike_ts * float(sample_rate)  # in ms
        spike_ts.sort()
        on_good = getattr(self, 'ttl_timestamps') / sample_rate / float(1000)
        dt = np.array(dt)
        irange = on_good[:, np.newaxis] + dt[np.newaxis, :]
        dts = np.searchsorted(spike_ts, irange)
        y = []
        x = []
        for i, t in enumerate(dts):
            tmp = spike_ts[t[0]:t[1]] - on_good[i]
            x.extend(tmp)
            y.extend(np.repeat(i, len(tmp)))

        if hist:
            nEvents = int(self.STM["num_stm_samples"])
            return np.histogram2d(
                x, y, bins=[np.arange(
                    dt[0], dt[1]+1, 1), np.arange(0, nEvents+1, 1)])[0]
        else:
            return np.histogram(
                x, bins=np.arange(
                    dt[0], dt[1]+1, 1), range=dt)[0]
    '''

    @stripAxes
    def show_SAC(
        self, A: np.array, inDict: dict, ax: matplotlib.axes = None, **kwargs
    ) -> matplotlib.axes:
        """
        Displays the result of performing a spatial autocorrelation (SAC)
        on a grid cell.

        Uses the dictionary containing measures of the grid cell SAC to
        make a pretty picture

        Args:
            A (np.array): The spatial autocorrelogram.
            inDict (dict): The dictionary calculated in getmeasures.
            ax (matplotlib.axes, optional): If given the plot will get drawn
                in these axes. Default None.
            **kwargs: Additional keyword arguments for the function.

        Returns:
            matplotlib.axes: The axes with the plot.

        See Also:
            ephysiopy.common.binning.RateMap.autoCorr2D()
            ephysiopy.common.ephys_generic.FieldCalcs.getMeaures()
        """
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
        Am = A.copy()
        Am[~inDict["dist_to_centre"]] = np.nan
        Am = np.ma.masked_invalid(np.atleast_2d(Am))
        x, y = np.meshgrid(np.arange(0, np.shape(A)[1]),
                           np.arange(0, np.shape(A)[0]))
        vmax = np.nanmax(np.ravel(A))
        ax.pcolormesh(x, y, A, cmap=grey_cmap, edgecolors="face",
                      vmax=vmax, shading="auto")
        import copy

        cmap = copy.copy(jet_cmap)
        cmap.set_bad("w", 0)
        ax.pcolormesh(x, y, Am, cmap=cmap,
                      edgecolors="face", vmax=vmax, shading="auto")
        # horizontal green line at 3 o'clock
        _y = (np.shape(A)[0] / 2, np.shape(A)[0] / 2)
        _x = (np.shape(A)[1] / 2, np.shape(A)[0])
        ax.plot(_x, _y, c="g")
        mag = inDict["scale"] * 0.5
        th = np.linspace(0, inDict["orientation"], 50)
        from ephysiopy.common.utils import rect

        [x, y] = rect(mag, th, deg=1)
        # angle subtended by orientation
        ax.plot(
            x + (inDict["dist_to_centre"].shape[1] / 2),
            (inDict["dist_to_centre"].shape[0] / 2) - y,
            "r",
            **kwargs
        )
        # plot lines from centre to peaks above middle
        for p in inDict["closest_peak_coords"]:
            if p[0] <= inDict["dist_to_centre"].shape[0] / 2:
                ax.plot(
                    (inDict["dist_to_centre"].shape[1] / 2, p[1]),
                    (inDict["dist_to_centre"].shape[0] / 2, p[0]),
                    "k",
                    **kwargs
                )
        ax.invert_yaxis()
        all_ax = ax.axes
        all_ax.set_aspect("equal")
        all_ax.set_xlim((0.5, inDict["dist_to_centre"].shape[1] - 1.5))
        all_ax.set_ylim((inDict["dist_to_centre"].shape[0] - 0.5, -0.5))
        return ax

    def plotSpectrogramByDepth(
        self,
        nchannels: int = 384,
        nseconds: int = 100,
        maxFreq: int = 125,
        channels: list = [],
        frequencies: list = [],
        frequencyIncrement: int = 1,
        **kwargs
    ):
        """
        Plots a heat map spectrogram of the LFP for each channel.
        Line plots of power per frequency band and power on a subset of
        channels are also displayed to the right and above the main plot.

        Args:
            nchannels (int): The number of channels on the probe.
            nseconds (int, optional): How long in seconds from the start of
                the trial to do the spectrogram for (for speed).
                Default is 100.
            maxFreq (int): The maximum frequency in Hz to plot the spectrogram
                out to. Maximum is 1250. Default is 125.
            channels (list): The channels to plot separately on the top plot.
            frequencies (list): The specific frequencies to examine across
                all channels. The mean from frequency: 
                frequency+frequencyIncrement is calculated and plotted on
                the left hand side of the plot.
            frequencyIncrement (int): The amount to add to each value of
                the frequencies list above.
            **kwargs: Additional keyword arguments for the function.
                Valid key value pairs:
                    "saveas" - save the figure to this location, needs absolute
                    path and filename.

        Notes:
            Should also allow kwargs to specify exactly which channels
            and / or frequency bands to do the line plots for.
        """
        if not self.path2LFPdata:
            raise TypeError("Not a probe recording so not plotting")
        import os

        lfp_file = os.path.join(self.path2LFPdata, "continuous.dat")
        status = os.stat(lfp_file)
        nsamples = int(status.st_size / 2 / nchannels)
        mmap = np.memmap(lfp_file, np.int16, "r", 0,
                         (nchannels, nsamples), order="F")
        # Load the channel map NB assumes this is in the AP data
        # location and that kilosort was run there
        channel_map = np.squeeze(
            np.load(os.path.join(self.path2APdata, "channel_map.npy"))
        )
        lfp_sample_rate = 2500
        data = np.array(mmap[channel_map, 0:nseconds * lfp_sample_rate])
        from ephysiopy.common.ephys_generic import EEGCalcsGeneric

        E = EEGCalcsGeneric(data[0, :], lfp_sample_rate)
        E.calcEEGPowerSpectrum()
        spec_data = np.zeros(shape=(data.shape[0], len(E.sm_power[0::50])))
        for chan in range(data.shape[0]):
            E = EEGCalcsGeneric(data[chan, :], lfp_sample_rate)
            E.calcEEGPowerSpectrum()
            spec_data[chan, :] = E.sm_power[0::50]

        x, y = np.meshgrid(E.freqs[0::50], channel_map)
        import matplotlib.colors as colors
        from matplotlib.pyplot import cm
        from mpl_toolkits.axes_grid1 import make_axes_locatable

        _, spectoAx = plt.subplots()
        spectoAx.pcolormesh(x, y, spec_data,
                            edgecolors="face", cmap="bone",
                            norm=colors.LogNorm())
        spectoAx.set_xlim(0, maxFreq)
        spectoAx.set_ylim(channel_map[0], channel_map[-1])
        spectoAx.set_xlabel("Frequency (Hz)")
        spectoAx.set_ylabel("Channel")
        divider = make_axes_locatable(spectoAx)
        channel_spectoAx = divider.append_axes("top", 1.2, pad=0.1,
                                               sharex=spectoAx)
        meanfreq_powerAx = divider.append_axes("right", 1.2, pad=0.1,
                                               sharey=spectoAx)
        plt.setp(channel_spectoAx.get_xticklabels()
                 + meanfreq_powerAx.get_yticklabels(),
                 visible=False)

        # plot mean power across some channels
        mn_power = np.mean(spec_data, 0)
        if not channels:
            channels = range(1, nchannels, 60)
        cols = iter(cm.rainbow(np.linspace(0, 1, len(channels))))
        for chan in channels:
            c = next(cols)
            channel_spectoAx.plot(
                E.freqs[0::50],
                10 * np.log10(spec_data[chan, :] / mn_power),
                c=c,
                label=str(chan),
            )

        channel_spectoAx.set_ylabel("Channel power(dB)")
        channel_spectoAx.legend(
            bbox_to_anchor=(0.0, 1.02, 1.0, 0.102),
            loc="lower left",
            mode="expand",
            fontsize="x-small",
            ncol=4,
        )

        # plot mean frequencies across all channels
        if not frequencyIncrement:
            freq_inc = 6
        else:
            freq_inc = frequencyIncrement
        if not frequencies:
            lower_freqs = np.arange(1, maxFreq - freq_inc, freq_inc)
        else:
            lower_freqs = frequencies
        upper_freqs = [f + freq_inc for f in lower_freqs]
        cols = iter(cm.nipy_spectral(np.linspace(0, 1, len(upper_freqs))))
        mn_power = np.mean(spec_data, 1)
        for freqs in zip(lower_freqs, upper_freqs):
            freq_mask = np.logical_and(
                E.freqs[0::50] > freqs[0], E.freqs[0::50] < freqs[1]
            )
            mean_power = 10 * np.log10(np.mean(
                spec_data[:, freq_mask], 1) / mn_power)
            c = next(cols)
            meanfreq_powerAx.plot(
                mean_power,
                channel_map,
                c=c,
                label=str(freqs[0]) + " - " + str(freqs[1]),
            )
        meanfreq_powerAx.set_xlabel("Mean freq. band power(dB)")
        meanfreq_powerAx.legend(
            bbox_to_anchor=(0.0, 1.02, 1.0, 0.102),
            loc="lower left",
            mode="expand",
            fontsize="x-small",
            ncol=1,
        )
        if "saveas" in kwargs:
            saveas = kwargs["saveas"]
            plt.savefig(saveas)
        plt.show()

    '''
    def plotDirFilteredRmaps(self, tetrode, cluster, maptype='rmap', **kwargs):
        """
        Plots out directionally filtered ratemaps for the tetrode/ cluster

        Parameters
        ----------
        tetrode : int
        cluster : int
        maptype : str
            Valid values include 'rmap', 'polar', 'xcorr'
        """
        inc = 8.0
        step = 360/inc
        dirs_st = np.arange(-step/2, 360-(step/2), step)
        dirs_en = np.arange(step/2, 360, step)
        dirs_st[0] = dirs_en[-1]

        if 'polar' in maptype:
            _, axes = plt.subplots(
                nrows=3, ncols=3, subplot_kw={'projection': 'polar'})
        else:
            _, axes = plt.subplots(nrows=3, ncols=3)
        ax0 = axes[0][0]  # top-left
        ax1 = axes[0][1]  # top-middle
        ax2 = axes[0][2]  # top-right
        ax3 = axes[1][0]  # middle-left
        ax4 = axes[1][1]  # middle
        ax5 = axes[1][2]  # middle-right
        ax6 = axes[2][0]  # bottom-left
        ax7 = axes[2][1]  # bottom-middle
        ax8 = axes[2][2]  # bottom-right

        max_rate = 0
        for d in zip(dirs_st, dirs_en):
            self.posFilter = {'dir': (d[0], d[1])}
            if 'polar' in maptype:
                rmap = self._getMap(
                    tetrode=tetrode, cluster=cluster, var2bin='dir')[0]
            elif 'xcorr' in maptype:
                x1 = self.TETRODE[tetrode].getClustTS(cluster) / (96000/1000)
                rmap = self.spikecalcs.acorr(
                    x1, x1, Trange=np.array([-500, 500]))
            else:
                rmap = self._getMap(tetrode=tetrode, cluster=cluster)[0]
            if np.nanmax(rmap) > max_rate:
                max_rate = np.nanmax(rmap)

        from collections import OrderedDict
        dir_rates = OrderedDict.fromkeys(dirs_st, None)

        ax_collection = [ax5, ax2, ax1, ax0, ax3, ax6, ax7, ax8]
        for d in zip(dirs_st, dirs_en, ax_collection):
            self.posFilter = {'dir': (d[0], d[1])}
            npos = np.count_nonzero(np.ma.compressed(~self.POS.dir.mask))
            print("npos = {}".format(npos))
            nspikes = np.count_nonzero(
                np.ma.compressed(
                    ~self.TETRODE[tetrode].getClustSpks(
                        cluster).mask[:, 0, 0]))
            print("nspikes = {}".format(nspikes))
            dir_rates[d[0]] = nspikes  # / (npos/50.0)
            if 'spikes' in maptype:
                self.plotSpikesOnPath(
                    tetrode, cluster, ax=d[2], markersize=4)
            elif 'rmap' in maptype:
                self._plotMap(
                    tetrode, cluster, ax=d[2], vmax=max_rate)
            elif 'polar' in maptype:
                self._plotMap(
                    tetrode, cluster, var2bin='dir', ax=d[2], vmax=max_rate)
            elif 'xcorr' in maptype:
                self.plotXCorr(
                    tetrode, cluster, ax=d[2])
                x1 = self.TETRODE[tetrode].getClustTS(cluster) / (96000/1000)
                print("x1 len = {}".format(len(x1)))
                dir_rates[d[0]] = self.spikecalcs.theta_band_max_freq(x1)
                d[2].set_xlabel('')
                d[2].set_title('')
                d[2].set_xticklabels('')
            d[2].set_title("nspikes = {}".format(nspikes))
        self.posFilter = None
        if 'spikes' in maptype:
            self.plotSpikesOnPath(tetrode, cluster, ax=ax4)
        elif 'rmap' in maptype:
            self._plotMap(tetrode, cluster, ax=ax4)
        elif 'polar' in maptype:
            self._plotMap(tetrode, cluster, var2bin='dir', ax=ax4)
        elif 'xcorr' in maptype:
            self.plotXCorr(tetrode, cluster, ax=ax4)
            ax4.set_xlabel('')
            ax4.set_title('')
            ax4.set_xticklabels('')
        return dir_rates
        '''

__init__()

Initializes the FigureMaker object.

Source code in ephysiopy/visualise/plotting.py
49
50
51
52
53
def __init__(self):
    """
    Initializes the FigureMaker object.
    """
    self.PosCalcs = None

eb_map(cluster, channel, **kwargs)

Gets the ego-centric boundary map for the specified cluster(s) and channel.

Parameters:

Name Type Description Default
cluster int | list

The cluster(s) to get the ego-centric boundary map for.

required
channel int

The channel number.

required
**kwargs

Additional keyword arguments for the function.

{}
Source code in ephysiopy/visualise/plotting.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def eb_map(self, cluster: int | list, channel: int, **kwargs):
    """
    Gets the ego-centric boundary map for the specified cluster(s) and
    channel.

    Args:
        cluster (int | list): The cluster(s) to get the ego-centric
            boundary map for.
        channel (int): The channel number.
        **kwargs: Additional keyword arguments for the function.
    """
    if isinstance(cluster, list):
        self._plot_multiple_clusters(self.makeEgoCentricBoundaryMap,
                                     cluster,
                                     channel,
                                     projection='polar',
                                     **kwargs)
    else:
        ts = self.get_spike_times(channel, cluster)
        self.makeEgoCentricBoundaryMap(ts, **kwargs)
    plt.show()

eb_spikes(cluster, channel, **kwargs)

Gets the ego-centric boundary spikes for the specified cluster(s) and channel.

Parameters:

Name Type Description Default
cluster int | list

The cluster(s) to get the ego-centric boundary spikes for.

required
channel int

The channel number.

required
**kwargs

Additional keyword arguments for the function.

{}
Source code in ephysiopy/visualise/plotting.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def eb_spikes(self, cluster: int | list, channel: int, **kwargs):
    """
    Gets the ego-centric boundary spikes for the specified cluster(s)
    and channel.

    Args:
        cluster (int | list): The cluster(s) to get the ego-centric
            boundary spikes for.
        channel (int): The channel number.
        **kwargs: Additional keyword arguments for the function.
    """
    if isinstance(cluster, list):
        self._plot_multiple_clusters(self.makeEgoCentricBoundarySpikePlot,
                                     cluster,
                                     channel,
                                     **kwargs)
    else:
        ts = self.get_spike_times(channel, cluster)
        self.makeEgoCentricBoundarySpikePlot(ts, **kwargs)
    plt.show()

getSpikePosIndices(spk_times)

Returns the indices into the position data at which some spike times occurred.

Parameters:

Name Type Description Default
spk_times ndarray

The spike times in seconds.

required

Returns:

Type Description

np.ndarray: The indices into the position data at which the spikes occurred.

Source code in ephysiopy/visualise/plotting.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def getSpikePosIndices(self, spk_times: np.ndarray):
    """
    Returns the indices into the position data at which some spike times
    occurred.

    Args:
        spk_times (np.ndarray): The spike times in seconds.

    Returns:
        np.ndarray: The indices into the position data at which the spikes
            occurred.
    """
    pos_times = getattr(self.PosCalcs, "xyTS")
    idx = np.searchsorted(pos_times, spk_times) - 1
    return idx

hd_map(cluster, channel, **kwargs)

Gets the head direction map for the specified cluster(s) and channel.

Parameters:

Name Type Description Default
cluster int | list

The cluster(s) to get the head direction map for.

required
channel int

The channel number.

required
**kwargs

Additional keyword arguments for the function.

{}
Source code in ephysiopy/visualise/plotting.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def hd_map(self, cluster: int | list, channel: int, **kwargs):
    """
    Gets the head direction map for the specified cluster(s) and channel.

    Args:
        cluster (int | list): The cluster(s) to get the head direction map
            for.
        channel (int): The channel number.
        **kwargs: Additional keyword arguments for the function.
    """
    if isinstance(cluster, list):
        self._plot_multiple_clusters(self.makeHDPlot,
                                     cluster,
                                     channel,
                                     projection="polar",
                                     strip_axes=True,
                                     **kwargs)
    else:
        ts = self.get_spike_times(channel, cluster)
        self.makeHDPlot(ts, **kwargs)
    plt.show()

initialise()

Initializes the FigureMaker object with data from PosCalcs.

Source code in ephysiopy/visualise/plotting.py
55
56
57
58
59
60
61
62
def initialise(self):
    """
    Initializes the FigureMaker object with data from PosCalcs.
    """
    self.RateMap = RateMap(self.PosCalcs.xy,
                           self.PosCalcs.dir,
                           self.PosCalcs.speed,)
    self.npos = self.PosCalcs.xy.shape[1]

makeEgoCentricBoundaryMap(spk_times, ax=None, **kwargs)

Creates an ego-centric boundary map plot.

Parameters:

Name Type Description Default
spk_times ndarray

The spike times in seconds.

required
ax axes

The axes to plot on. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
def makeEgoCentricBoundaryMap(self,
                              spk_times: np.ndarray,
                              ax: matplotlib.axes = None,
                              **kwargs) -> matplotlib.axes:
    """
    Creates an ego-centric boundary map plot.

    Args:
        spk_times (np.ndarray): The spike times in seconds.
        ax (matplotlib.axes, optional): The axes to plot on. If None,
            new axes are created.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if not self.RateMap:
        self.initialise()

    degs_per_bin = 3
    xy_binsize = 2.5
    arena_type = "circle"
    # parse kwargs
    if "degs_per_bin" in kwargs.keys():
        degs_per_bin = kwargs["degs_per_bin"]
    if "xy_binsize" in kwargs.keys():
        xy_binsize = kwargs["xy_binsize"]
    if "arena_type" in kwargs.keys():
        arena_type = kwargs["arena_type"]
    if "strip_axes" in kwargs.keys():
        strip_axes = kwargs.pop("strip_axes")
    else:
        strip_axes = False
    if 'return_ratemap' in kwargs.keys():
        return_ratemap = kwargs.pop('return_ratemap')
    else:
        return_ratemap = False

    idx = self.getSpikePosIndices(spk_times)
    spk_weights = np.bincount(idx, minlength=len(self.RateMap.dir))
    ego_map = self.RateMap.get_egocentric_boundary_map(spk_weights,
                                                       degs_per_bin,
                                                       xy_binsize,
                                                       arena_type)
    rmap = ego_map.rmap
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(projection='polar')
    theta = np.arange(0, 2*np.pi, 2*np.pi/rmap.shape[1])
    phi = np.arange(0, rmap.shape[0]*2.5, 2.5)
    X, Y = np.meshgrid(theta, phi)
    ax.pcolormesh(X, Y, rmap, **kwargs)
    ax.set_xticks(np.arange(0, 2*np.pi, np.pi/4))
    # ax.set_xticklabels(np.arange(0, 2*np.pi, np.pi/4))
    ax.set_yticks(np.arange(0, 50, 10))
    ax.set_yticklabels(np.arange(0, 50, 10))
    ax.set_xlabel('Angle (deg)')
    ax.set_ylabel('Distance (cm)')
    if strip_axes:
        return stripAxes(ax)
    if return_ratemap:
        return ax, rmap
    return ax

makeEgoCentricBoundarySpikePlot(spk_times, add_colour_wheel=False, ax=None, **kwargs)

Creates an ego-centric boundary spike plot.

Parameters:

Name Type Description Default
spk_times ndarray

The spike times in seconds.

required
add_colour_wheel bool

Whether to add a colour wheel to the plot. Defaults to False.

False
ax axes

The axes to plot on. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
@stripAxes
def makeEgoCentricBoundarySpikePlot(self,
                                    spk_times: np.ndarray,
                                    add_colour_wheel: bool = False,
                                    ax: matplotlib.axes = None,
                                    **kwargs) -> matplotlib.axes:
    """
    Creates an ego-centric boundary spike plot.

    Args:
        spk_times (np.ndarray): The spike times in seconds.
        add_colour_wheel (bool, optional): Whether to add a colour wheel
            to the plot. Defaults to False.
        ax (matplotlib.axes, optional): The axes to plot on. If None,
            new axes are created.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if not self.RateMap:
        self.initialise()
    # get the index into a circular colormap based
    # on directional heading, then create a LineCollection
    num_dir_bins = 60
    if "dir_bins" in kwargs.keys():
        num_dir_bins = kwargs["num_dir_bins"]
    if "strip_axes" in kwargs.keys():
        strip_axes = kwargs.pop("strip_axes")
    else:
        strip_axes = False
    if "ms" in kwargs.keys():
        rect_size = kwargs.pop("ms")
    else:
        rect_size = 1
    dir_colours = sns.color_palette('hls', num_dir_bins)
    # need to create line colours and line widths for the collection
    idx = self.getSpikePosIndices(spk_times)
    dir_spike_fired_at = self.RateMap.dir[idx]
    idx_of_dir_to_colour = np.floor(
        dir_spike_fired_at / (360 / num_dir_bins)).astype(int)
    rects = [Rectangle(self.RateMap.xy[:, i],
                       width=rect_size, height=rect_size)
             for i in idx]
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot()
    # plot the path
    ax.plot(self.RateMap.xy[0],
            self.RateMap.xy[1],
            c=tcols.colours[0],
            zorder=1,
            alpha=0.3)
    ax.set_aspect('equal')
    for col_idx, r in zip(idx_of_dir_to_colour, rects):
        ax.add_artist(r)
        r.set_clip_box(ax.bbox)
        r.set_facecolor(dir_colours[col_idx])
        r.set_rasterized(True)
    if add_colour_wheel:
        ax_col = ax.inset_axes(bounds=[0.75, 0.75, 0.15, 0.15],
                               projection='polar',
                               transform=fig.transFigure)
        ax_col.set_theta_zero_location("N")
        theta = np.linspace(0, 2*np.pi, 1000)
        phi = np.linspace(0, 1, 2)
        X, Y = np.meshgrid(phi, theta)
        norm = matplotlib.colors.Normalize(0, 2*np.pi)
        col_map = sns.color_palette('hls', as_cmap=True)
        ax_col.pcolormesh(theta, phi, Y.T, norm=norm, cmap=col_map)
        ax_col.set_yticklabels([])
        ax_col.spines['polar'].set_visible(False)
        ax_col.set_thetagrids([0, 90])
    if strip_axes:
        return stripAxes(ax)
    return ax

makeHDPlot(spk_times=None, ax=None, **kwargs)

Creates a head direction plot.

Parameters:

Name Type Description Default
spk_times array

The spike times in seconds. If None, no spikes are plotted.

None
ax axes

The axes to plot on. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
def makeHDPlot(
    self, spk_times: np.array = None, ax: matplotlib.axes = None, **kwargs
) -> matplotlib.axes:
    """
    Creates a head direction plot.

    Args:
        spk_times (np.array, optional): The spike times in seconds. If
            None, no spikes are plotted.
        ax (matplotlib.axes, optional): The axes to plot on. If None, new
            axes are created.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if not self.RateMap:
        self.initialise()
    if "strip_axes" in kwargs.keys():
        strip_axes = kwargs.pop("strip_axes")
    else:
        strip_axes = True
    spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)
    spk_weights = np.bincount(
        spk_times_in_pos_samples, minlength=self.npos)
    rmap = self.RateMap.getMap(spk_weights, varType=VariableToBin.DIR)
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111, **kwargs)
    ax.set_theta_zero_location("N")
    # need to deal with the case where the axis is supplied but
    # is not polar. deal with polar first
    theta = np.deg2rad(rmap[1][0])
    ax.clear()
    r = rmap[0]  # in samples so * pos sample_rate
    r = np.insert(r, -1, r[0])
    if "polar" in ax.name:
        ax.plot(theta, r)
        if "fill" in kwargs:
            ax.fill(theta, r, alpha=0.5)
        ax.set_aspect("equal")
    else:
        pass

    # See if we should add the mean resultant vector (mrv)
    if "add_mrv" in kwargs:
        from ephysiopy.common.statscalcs import mean_resultant_vector

        angles = self.PosCalcs.dir[spk_times_in_pos_samples]
        r, th = mean_resultant_vector(np.deg2rad(angles))
        ax.plot([th, th], [0, r * np.max(rmap[0])], "r")
    if "polar" in ax.name:
        ax.set_thetagrids([0, 90, 180, 270])
    if strip_axes:
        return stripAxes(ax)
    return ax

makePowerSpectrum(freqs, power, sm_power, band_max_power, freq_at_band_max_power, max_freq=50, theta_range=[6, 12], ax=None, **kwargs)

Plots the power spectrum. The parameters can be obtained from calcEEGPowerSpectrum() in the EEGCalcsGeneric class.

Parameters:

Name Type Description Default
freqs array

The frequencies.

required
power array

The power values.

required
sm_power array

The smoothed power values.

required
band_max_power float

The maximum power in the band.

required
freq_at_band_max_power float

The frequency at which the maximum power in the band occurs.

required
max_freq int

The maximum frequency. Defaults to 50.

50
theta_range tuple

The theta range. Defaults to [6, 12].

[6, 12]
ax axes

The axes to plot on. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
def makePowerSpectrum(
    self,
    freqs: np.array,
    power: np.array,
    sm_power: np.array,
    band_max_power: float,
    freq_at_band_max_power: float,
    max_freq: int = 50,
    theta_range: tuple = [6, 12],
    ax: matplotlib.axes = None,
    **kwargs
) -> matplotlib.axes:
    """
    Plots the power spectrum. The parameters can be obtained from
    calcEEGPowerSpectrum() in the EEGCalcsGeneric class.

    Args:
        freqs (np.array): The frequencies.
        power (np.array): The power values.
        sm_power (np.array): The smoothed power values.
        band_max_power (float): The maximum power in the band.
        freq_at_band_max_power (float): The frequency at which the maximum
            power in the band occurs.
        max_freq (int, optional): The maximum frequency. Defaults to 50.
        theta_range (tuple, optional): The theta range.
            Defaults to [6, 12].
        ax (matplotlib.axes, optional): The axes to plot on. If None, new
            axes are created.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if "strip_axes" in kwargs.keys():
        strip_axes = kwargs.pop("strip_axes")
    else:
        strip_axes = False
    # downsample frequencies and power
    freqs = freqs[0::50]
    power = power[0::50]
    sm_power = sm_power[0::50]
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    ax.plot(freqs, power, alpha=0.5, color=[0.8627, 0.8627, 0.8627])
    ax.plot(freqs, sm_power)
    ax.set_xlim(0, max_freq)
    ylim = [0, np.max(sm_power[freqs < max_freq])]
    if "ylim" in kwargs:
        ylim = kwargs["ylim"]
    ax.set_ylim(ylim)
    ax.set_ylabel("Power")
    ax.set_xlabel("Frequency")
    ax.text(
        x=theta_range[1] / 0.9,
        y=band_max_power,
        s=str(freq_at_band_max_power)[0:4],
        fontsize=20,
    )
    from matplotlib.patches import Rectangle

    r = Rectangle(
        (theta_range[0], 0),
        width=np.diff(theta_range)[0],
        height=np.diff(ax.get_ylim())[0],
        alpha=0.25,
        color="r",
        ec="none",
    )
    ax.add_patch(r)
    if strip_axes:
        return stripAxes(ax)
    return ax

makeRaster(spk_times, dt=(-50, 100), prc_max=0.5, ax=None, ms_per_bin=1, sample_rate=30000.0, **kwargs)

Plots a raster plot for a specified tetrode/ cluster.

Parameters:

Name Type Description Default
spk_times array

The spike times in samples.

required
dt tuple

The window of time in ms to examine zeroed on the event of interest i.e. the first value will probably be negative as in the example. Defaults to (-50, 100).

(-50, 100)
prc_max float

The proportion of firing the cell has to 'lose' to count as silent; a float between 0 and 1. Defaults to 0.5.

0.5
ax axes

The axes to plot into. If not provided a new figure is created. Defaults to None.

None
ms_per_bin int

The number of milliseconds in each bin of the raster plot. Defaults to 1.

1
sample_rate float

The sample rate. Defaults to 3e4.

30000.0
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
def makeRaster(
    self,
    spk_times: np.array,
    dt=(-50, 100),
    prc_max: float = 0.5,
    ax: matplotlib.axes = None,
    ms_per_bin: int = 1,
    sample_rate: float = 3e4,  # OE=3e4, Axona=96000
    **kwargs
) -> matplotlib.axes:
    """
    Plots a raster plot for a specified tetrode/ cluster.

    Args:
        spk_times (np.array): The spike times in samples.
        dt (tuple, optional): The window of time in ms to examine zeroed
            on the event of interest i.e. the first value will probably
            be negative as in the example. Defaults to (-50, 100).
        prc_max (float, optional): The proportion of firing the cell has
            to 'lose' to count as silent; a float between 0 and 1.
            Defaults to 0.5.
        ax (matplotlib.axes, optional): The axes to plot into.
            If not provided a new figure is created. Defaults to None.
        ms_per_bin (int, optional): The number of milliseconds in each bin
            of the raster plot. Defaults to 1.
        sample_rate (float, optional): The sample rate. Defaults to 3e4.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    assert hasattr(self, "ttl_data")

    if "strip_axes" in kwargs.keys():
        strip_axes = kwargs.pop("strip_axes")
    else:
        strip_axes = False
    x1 = spk_times / sample_rate * 1000.0  # get into ms
    x1.sort()
    on_good = self.ttl_data["ttl_timestamps"]
    dt = np.array(dt)
    irange = on_good[:, np.newaxis] + dt[np.newaxis, :]
    dts = np.searchsorted(x1, irange)
    y = []
    x = []
    for i, t in enumerate(dts):
        tmp = x1[t[0]:t[1]] - on_good[i]
        x.extend(tmp)
        y.extend(np.repeat(i, len(tmp)))
    if ax is None:
        fig = plt.figure(figsize=(4.0, 7.0))
        axScatter = fig.add_subplot(111)
    else:
        axScatter = ax
    histColor = [1 / 255.0, 1 / 255.0, 1 / 255.0]
    axScatter.scatter(x, y, marker=".", s=2,
                      rasterized=False, color=histColor)
    divider = make_axes_locatable(axScatter)
    axScatter.set_xticks((dt[0], 0, dt[1]))
    axScatter.set_xticklabels((str(dt[0]), "0", str(dt[1])))
    axHistx = divider.append_axes("top", 0.95, pad=0.2,
                                  sharex=axScatter,
                                  transform=axScatter.transAxes)
    scattTrans = transforms.blended_transform_factory(
        axScatter.transData, axScatter.transAxes
    )
    stim_pwidth = self.ttl_data["stim_duration"]
    if stim_pwidth is None:
        raise ValueError("stim duration is None")

    axScatter.add_patch(
        Rectangle(
            (0, 0),
            width=stim_pwidth,
            height=1,
            transform=scattTrans,
            color=[0, 0, 1],
            alpha=0.3,
        )
    )
    histTrans = transforms.blended_transform_factory(
        axHistx.transData, axHistx.transAxes
    )
    axHistx.add_patch(
        Rectangle(
            (0, 0),
            width=stim_pwidth,
            height=1,
            transform=histTrans,
            color=[0, 0, 1],
            alpha=0.3,
        )
    )
    axScatter.set_ylabel("Laser stimulation events", labelpad=-2.5)
    axScatter.set_xlabel("Time to stimulus onset(ms)")
    nStms = len(on_good)
    axScatter.set_ylim(0, nStms)
    # Label only the min and max of the y-axis
    ylabels = axScatter.get_yticklabels()
    for i in range(1, len(ylabels) - 1):
        ylabels[i].set_visible(False)
    yticks = axScatter.get_yticklines()
    for i in range(1, len(yticks) - 1):
        yticks[i].set_visible(False)

    axHistx.hist(
        x,
        bins=np.arange(dt[0], dt[1] + ms_per_bin, ms_per_bin),
        color=histColor,
        range=dt,
        rasterized=True,
        histtype="stepfilled",
    )
    axHistx.set_ylabel("Spike count", labelpad=-2.5)
    plt.setp(axHistx.get_xticklabels(), visible=False)
    # Label only the min and max of the y-axis
    ylabels = axHistx.get_yticklabels()
    for i in range(1, len(ylabels) - 1):
        ylabels[i].set_visible(False)
    yticks = axHistx.get_yticklines()
    for i in range(1, len(yticks) - 1):
        yticks[i].set_visible(False)
    axHistx.set_xlim(dt)
    axScatter.set_xlim(dt)
    if strip_axes:
        return stripAxes(axScatter)
    return axScatter

makeRateMap(spk_times, ax=None, **kwargs)

Creates a rate map plot.

Parameters:

Name Type Description Default
spk_times ndarray

The spike times in seconds.

required
ax axes

The axes to plot on. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
@stripAxes
def makeRateMap(self,
                spk_times: np.ndarray,
                ax: matplotlib.axes = None,
                **kwargs) -> matplotlib.axes:
    """
    Creates a rate map plot.

    Args:
        spk_times (np.ndarray): The spike times in seconds.
        ax (matplotlib.axes, optional): The axes to plot on. If None,
            new axes are created.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if not self.RateMap:
        self.initialise()
    spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)
    spk_weights = np.bincount(
        spk_times_in_pos_samples, minlength=self.npos)
    rmap = self.RateMap.getMap(spk_weights)
    ratemap = np.ma.MaskedArray(rmap[0], np.isnan(rmap[0]), copy=True)
    x, y = np.meshgrid(rmap[1][1][0:-1].data, rmap[1][0][0:-1].data)
    vmax = np.nanmax(np.ravel(ratemap))
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    ax.pcolormesh(
        x, y, ratemap,
        cmap=jet_cmap,
        edgecolors="face",
        vmax=vmax,
        shading="auto",
        **kwargs
    )
    ax.set_aspect("equal")
    return ax

makeSAC(spk_times=None, ax=None, **kwargs)

Creates a spatial autocorrelation plot.

Parameters:

Name Type Description Default
spk_times array

The spike times in seconds. If None, no spikes are plotted.

None
ax axes

The axes to plot on. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
@stripAxes
def makeSAC(
    self, spk_times: np.array = None, ax: matplotlib.axes = None, **kwargs
) -> matplotlib.axes:
    """
    Creates a spatial autocorrelation plot.

    Args:
        spk_times (np.array, optional): The spike times in seconds. If
            None, no spikes are plotted.
        ax (matplotlib.axes, optional): The axes to plot on. If None,
            new axes are created.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if not self.RateMap:
        self.initialise()
    spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)
    spk_weights = np.bincount(
        spk_times_in_pos_samples, minlength=self.npos)
    sac = self.RateMap.getSAC(spk_weights)
    from ephysiopy.common.gridcell import SAC

    S = SAC()
    measures = S.getMeasures(sac)
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    ax = self.show_SAC(sac, measures, ax)
    return ax

makeSpeedVsHeadDirectionPlot(spk_times, ax=None, **kwargs)

Creates a speed versus head direction plot.

Parameters:

Name Type Description Default
spk_times array

The spike times in seconds.

required
ax axes

The axes to plot on. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
def makeSpeedVsHeadDirectionPlot(
    self, spk_times: np.array, ax: matplotlib.axes = None, **kwargs
) -> matplotlib.axes:
    """
    Creates a speed versus head direction plot.

    Args:
        spk_times (np.array): The spike times in seconds.
        ax (matplotlib.axes, optional): The axes to plot on. If None,
            new axes are created.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if "strip_axes" in kwargs.keys():
        strip_axes = kwargs.pop("strip_axes")
    else:
        strip_axes = False
    if not self.RateMap:
        self.initialise()
    spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)
    idx = np.array(spk_times_in_pos_samples, dtype=int)
    w = np.bincount(idx, minlength=self.PosCalcs.speed.shape[0])
    if np.ma.is_masked(self.PosCalcs.speed):
        w[self.PosCalcs.speed.mask] = 0

    dir_bins = np.arange(0, 360, 6)
    spd_bins = np.arange(0, 30, 1)
    h = np.histogram2d(self.PosCalcs.dir,
                       self.PosCalcs.speed,
                       [dir_bins, spd_bins], weights=w)
    from ephysiopy.common.utils import blurImage

    im = blurImage(h[0], 5, ftype="gaussian")
    im = np.ma.MaskedArray(im)
    # mask low rates...
    im = np.ma.masked_where(im <= 1, im)
    # ... and where less than 0.5% of data is accounted for
    x, y = np.meshgrid(dir_bins, spd_bins)
    vmax = np.max(np.ravel(im))
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    ax.pcolormesh(x, y, im.T,
                  cmap=jet_cmap, edgecolors="face",
                  vmax=vmax, shading="auto")
    ax.set_xticks([90, 180, 270], labels=['90', '180', '270'],
                  fontweight="normal", size=6)
    ax.set_yticks([10, 20], labels=['10', '20'],
                  fontweight="normal", size=6)
    ax.set_xlabel("Heading", fontweight="normal", size=6)
    if strip_axes:
        stripAxes(ax)
    return ax

makeSpeedVsRatePlot(spk_times, minSpeed=0.0, maxSpeed=40.0, sigma=3.0, ax=None, **kwargs)

Plots the instantaneous firing rate of a cell against running speed. Also outputs a couple of measures as with Kropff et al., 2015; the Pearsons correlation and the depth of modulation (dom).

Parameters:

Name Type Description Default
spk_times array

The spike times in seconds.

required
minSpeed float

The minimum speed. Defaults to 0.0.

0.0
maxSpeed float

The maximum speed. Defaults to 40.0.

40.0
sigma float

The sigma value. Defaults to 3.0.

3.0
ax axes

The axes to plot on. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
def makeSpeedVsRatePlot(
    self,
    spk_times: np.array,
    minSpeed: float = 0.0,
    maxSpeed: float = 40.0,
    sigma: float = 3.0,
    ax: matplotlib.axes = None,
    **kwargs
) -> matplotlib.axes:
    """
    Plots the instantaneous firing rate of a cell against running speed.
    Also outputs a couple of measures as with Kropff et al., 2015; the
    Pearsons correlation and the depth of modulation (dom).

    Args:
        spk_times (np.array): The spike times in seconds.
        minSpeed (float, optional): The minimum speed. Defaults to 0.0.
        maxSpeed (float, optional): The maximum speed. Defaults to 40.0.
        sigma (float, optional): The sigma value. Defaults to 3.0.
        ax (matplotlib.axes, optional): The axes to plot on. If None, new
            axes are created.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if "strip_axes" in kwargs.keys():
        strip_axes = kwargs.pop("strip_axes")
    else:
        strip_axes = False
    if not self.RateMap:
        self.initialise()
    spk_times_in_pos_samples = self.getSpikePosIndices(spk_times)

    speed = np.ravel(self.PosCalcs.speed)
    if np.nanmax(speed) < maxSpeed:
        maxSpeed = np.nanmax(speed)
    spd_bins = np.arange(minSpeed, maxSpeed, 1.0)
    # Construct the mask
    speed_filt = np.ma.MaskedArray(speed)
    speed_filt = np.ma.masked_where(speed_filt < minSpeed, speed_filt)
    speed_filt = np.ma.masked_where(speed_filt > maxSpeed, speed_filt)
    from ephysiopy.common.spikecalcs import SpikeCalcsGeneric

    x1 = spk_times_in_pos_samples
    S = SpikeCalcsGeneric(x1)
    spk_sm = S.smooth_spike_train(x1,
                                  self.PosCalcs.xyTS.shape[0],
                                  sigma, None)
    spk_sm = np.ma.MaskedArray(spk_sm, mask=np.ma.getmask(speed_filt))
    spd_dig = np.digitize(speed_filt, spd_bins, right=True)
    mn_rate = np.array(
        [np.ma.mean(spk_sm[spd_dig == i]) for i in range(0, len(spd_bins))]
    )
    var = np.array(
        [np.ma.std(spk_sm[spd_dig == i]) for i in range(0, len(spd_bins))]
    )
    np.array([np.ma.sum(spk_sm[spd_dig == i]) for i in range(
        0, len(spd_bins))])
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    ax.errorbar(spd_bins, mn_rate * self.PosCalcs.sample_rate,
                yerr=var, color="k")
    ax.set_xlim(spd_bins[0], spd_bins[-1])
    ax.set_xticks(
        [spd_bins[0], spd_bins[-1]],
        labels=["0", "{:.2g}".format(spd_bins[-1])],
        fontweight="normal",
        size=6,
    )
    ax.set_yticks(
        [0, np.nanmax(mn_rate) * self.PosCalcs.sample_rate],
        labels=["0", "{:.2f}".format(np.nanmax(mn_rate))],
        fontweight="normal",
        size=6,
    )
    if strip_axes:
        return stripAxes(ax)
    return ax

makeSpikePathPlot(spk_times=None, ax=None, **kwargs)

Creates a spike path plot.

Parameters:

Name Type Description Default
spk_times ndarray

The spike times in seconds. If None, no spikes are plotted.

None
ax axes

The axes to plot on. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
@stripAxes
def makeSpikePathPlot(self,
                      spk_times: np.ndarray = None,
                      ax: matplotlib.axes = None,
                      **kwargs) -> matplotlib.axes:
    """
    Creates a spike path plot.

    Args:
        spk_times (np.ndarray, optional): The spike times in seconds.
            If None, no spikes are plotted.
        ax (matplotlib.axes, optional): The axes to plot on.
            If None, new axes are created.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if not self.RateMap:
        self.initialise()
    if "c" in kwargs:
        col = kwargs.pop("c")
    else:
        col = tcols.colours[1]
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    ax.plot(
        self.PosCalcs.xy[0, :],
        self.PosCalcs.xy[1, :],
        c=tcols.colours[0], zorder=1
    )
    ax.set_aspect("equal")
    if spk_times is not None:
        idx = self.getSpikePosIndices(spk_times)
        ax.plot(
            self.PosCalcs.xy[0, idx],
            self.PosCalcs.xy[1, idx],
            "s", c=col, **kwargs
        )
    return ax

makeSummaryPlot(spk_times)

Creates a summary plot with spike path, rate map, head direction plot, and spatial autocorrelation.

Parameters:

Name Type Description Default
spk_times ndarray

The spike times in seconds.

required

Returns:

Type Description

matplotlib.figure.Figure: The created figure.

Source code in ephysiopy/visualise/plotting.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def makeSummaryPlot(self, spk_times: np.ndarray):
    """
    Creates a summary plot with spike path, rate map, head direction plot,
    and spatial autocorrelation.

    Args:
        spk_times (np.ndarray): The spike times in seconds.

    Returns:
        matplotlib.figure.Figure: The created figure.
    """
    fig = plt.figure()
    ax = plt.subplot(221)
    self.makeSpikePathPlot(spk_times, ax=ax, markersize=2)
    ax = plt.subplot(222)
    self.makeRateMap(spk_times, ax=ax)
    ax = plt.subplot(223, projection="polar")
    self.makeHDPlot(spk_times, ax=ax)
    ax = plt.subplot(224)
    try:
        self.makeSAC(spk_times, ax=ax)
    except IndexError:
        pass
    return fig

makeXCorr(spk_times, ax=None, **kwargs)

Returns an axis containing the autocorrelogram of the spike times provided over the range +/-500ms.

Parameters:

Name Type Description Default
spk_times array

Spike times in seconds.

required
ax axes

The axes to plot into. If None, new axes are created.

None
**kwargs

Additional keyword arguments for the function. binsize (int, optional): The size of the bins in ms. Gets passed to SpikeCalcsGeneric.xcorr(). Defaults to 1.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

Source code in ephysiopy/visualise/plotting.py
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
def makeXCorr(
    self, spk_times: np.array, ax: matplotlib.axes = None, **kwargs
) -> matplotlib.axes:
    """
    Returns an axis containing the autocorrelogram of the spike
    times provided over the range +/-500ms.

    Args:
        spk_times (np.array): Spike times in seconds.
        ax (matplotlib.axes, optional): The axes to plot into. If None,
            new axes are created.
        **kwargs: Additional keyword arguments for the function.
            binsize (int, optional): The size of the bins in ms. Gets
            passed to SpikeCalcsGeneric.xcorr(). Defaults to 1.

    Returns:
        matplotlib.axes: The axes with the plot.
    """
    if "strip_axes" in kwargs.keys():
        strip_axes = kwargs.pop("strip_axes")
    else:
        strip_axes = False
    # spk_times in samples provided in seconds but convert to
    # ms for a more display friendly scale
    spk_times = spk_times
    S = SpikeCalcsGeneric(spk_times)
    c, b = S.acorr(spk_times, **kwargs)
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    if 'binsize' in kwargs.keys():
        binsize = kwargs['binsize']
    else:
        binsize = 0.001
    if "Trange" in kwargs.keys():
        xrange = kwargs["Trange"]
    else:
        xrange = [-0.5, 0.5]
    ax.bar(b[:-1], c, width=binsize, color="k", align="edge")
    ax.set_xlim(xrange)
    ax.set_xticks((xrange[0], 0, xrange[1]))
    ax.set_xticklabels("")
    ax.tick_params(axis="both", which="both", left=False, right=False,
                   bottom=False, top=False)
    ax.set_yticklabels("")
    ax.xaxis.set_ticks_position("bottom")
    if strip_axes:
        return stripAxes(ax)
    return ax

plotSpectrogramByDepth(nchannels=384, nseconds=100, maxFreq=125, channels=[], frequencies=[], frequencyIncrement=1, **kwargs)

Plots a heat map spectrogram of the LFP for each channel. Line plots of power per frequency band and power on a subset of channels are also displayed to the right and above the main plot.

Parameters:

Name Type Description Default
nchannels int

The number of channels on the probe.

384
nseconds int

How long in seconds from the start of the trial to do the spectrogram for (for speed). Default is 100.

100
maxFreq int

The maximum frequency in Hz to plot the spectrogram out to. Maximum is 1250. Default is 125.

125
channels list

The channels to plot separately on the top plot.

[]
frequencies list

The specific frequencies to examine across all channels. The mean from frequency: frequency+frequencyIncrement is calculated and plotted on the left hand side of the plot.

[]
frequencyIncrement int

The amount to add to each value of the frequencies list above.

1
**kwargs

Additional keyword arguments for the function. Valid key value pairs: "saveas" - save the figure to this location, needs absolute path and filename.

{}
Notes

Should also allow kwargs to specify exactly which channels and / or frequency bands to do the line plots for.

Source code in ephysiopy/visualise/plotting.py
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
def plotSpectrogramByDepth(
    self,
    nchannels: int = 384,
    nseconds: int = 100,
    maxFreq: int = 125,
    channels: list = [],
    frequencies: list = [],
    frequencyIncrement: int = 1,
    **kwargs
):
    """
    Plots a heat map spectrogram of the LFP for each channel.
    Line plots of power per frequency band and power on a subset of
    channels are also displayed to the right and above the main plot.

    Args:
        nchannels (int): The number of channels on the probe.
        nseconds (int, optional): How long in seconds from the start of
            the trial to do the spectrogram for (for speed).
            Default is 100.
        maxFreq (int): The maximum frequency in Hz to plot the spectrogram
            out to. Maximum is 1250. Default is 125.
        channels (list): The channels to plot separately on the top plot.
        frequencies (list): The specific frequencies to examine across
            all channels. The mean from frequency: 
            frequency+frequencyIncrement is calculated and plotted on
            the left hand side of the plot.
        frequencyIncrement (int): The amount to add to each value of
            the frequencies list above.
        **kwargs: Additional keyword arguments for the function.
            Valid key value pairs:
                "saveas" - save the figure to this location, needs absolute
                path and filename.

    Notes:
        Should also allow kwargs to specify exactly which channels
        and / or frequency bands to do the line plots for.
    """
    if not self.path2LFPdata:
        raise TypeError("Not a probe recording so not plotting")
    import os

    lfp_file = os.path.join(self.path2LFPdata, "continuous.dat")
    status = os.stat(lfp_file)
    nsamples = int(status.st_size / 2 / nchannels)
    mmap = np.memmap(lfp_file, np.int16, "r", 0,
                     (nchannels, nsamples), order="F")
    # Load the channel map NB assumes this is in the AP data
    # location and that kilosort was run there
    channel_map = np.squeeze(
        np.load(os.path.join(self.path2APdata, "channel_map.npy"))
    )
    lfp_sample_rate = 2500
    data = np.array(mmap[channel_map, 0:nseconds * lfp_sample_rate])
    from ephysiopy.common.ephys_generic import EEGCalcsGeneric

    E = EEGCalcsGeneric(data[0, :], lfp_sample_rate)
    E.calcEEGPowerSpectrum()
    spec_data = np.zeros(shape=(data.shape[0], len(E.sm_power[0::50])))
    for chan in range(data.shape[0]):
        E = EEGCalcsGeneric(data[chan, :], lfp_sample_rate)
        E.calcEEGPowerSpectrum()
        spec_data[chan, :] = E.sm_power[0::50]

    x, y = np.meshgrid(E.freqs[0::50], channel_map)
    import matplotlib.colors as colors
    from matplotlib.pyplot import cm
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    _, spectoAx = plt.subplots()
    spectoAx.pcolormesh(x, y, spec_data,
                        edgecolors="face", cmap="bone",
                        norm=colors.LogNorm())
    spectoAx.set_xlim(0, maxFreq)
    spectoAx.set_ylim(channel_map[0], channel_map[-1])
    spectoAx.set_xlabel("Frequency (Hz)")
    spectoAx.set_ylabel("Channel")
    divider = make_axes_locatable(spectoAx)
    channel_spectoAx = divider.append_axes("top", 1.2, pad=0.1,
                                           sharex=spectoAx)
    meanfreq_powerAx = divider.append_axes("right", 1.2, pad=0.1,
                                           sharey=spectoAx)
    plt.setp(channel_spectoAx.get_xticklabels()
             + meanfreq_powerAx.get_yticklabels(),
             visible=False)

    # plot mean power across some channels
    mn_power = np.mean(spec_data, 0)
    if not channels:
        channels = range(1, nchannels, 60)
    cols = iter(cm.rainbow(np.linspace(0, 1, len(channels))))
    for chan in channels:
        c = next(cols)
        channel_spectoAx.plot(
            E.freqs[0::50],
            10 * np.log10(spec_data[chan, :] / mn_power),
            c=c,
            label=str(chan),
        )

    channel_spectoAx.set_ylabel("Channel power(dB)")
    channel_spectoAx.legend(
        bbox_to_anchor=(0.0, 1.02, 1.0, 0.102),
        loc="lower left",
        mode="expand",
        fontsize="x-small",
        ncol=4,
    )

    # plot mean frequencies across all channels
    if not frequencyIncrement:
        freq_inc = 6
    else:
        freq_inc = frequencyIncrement
    if not frequencies:
        lower_freqs = np.arange(1, maxFreq - freq_inc, freq_inc)
    else:
        lower_freqs = frequencies
    upper_freqs = [f + freq_inc for f in lower_freqs]
    cols = iter(cm.nipy_spectral(np.linspace(0, 1, len(upper_freqs))))
    mn_power = np.mean(spec_data, 1)
    for freqs in zip(lower_freqs, upper_freqs):
        freq_mask = np.logical_and(
            E.freqs[0::50] > freqs[0], E.freqs[0::50] < freqs[1]
        )
        mean_power = 10 * np.log10(np.mean(
            spec_data[:, freq_mask], 1) / mn_power)
        c = next(cols)
        meanfreq_powerAx.plot(
            mean_power,
            channel_map,
            c=c,
            label=str(freqs[0]) + " - " + str(freqs[1]),
        )
    meanfreq_powerAx.set_xlabel("Mean freq. band power(dB)")
    meanfreq_powerAx.legend(
        bbox_to_anchor=(0.0, 1.02, 1.0, 0.102),
        loc="lower left",
        mode="expand",
        fontsize="x-small",
        ncol=1,
    )
    if "saveas" in kwargs:
        saveas = kwargs["saveas"]
        plt.savefig(saveas)
    plt.show()

power_spectrum(**kwargs)

Gets the power spectrum.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments for the function.

{}
Source code in ephysiopy/visualise/plotting.py
259
260
261
262
263
264
265
266
267
268
def power_spectrum(self, **kwargs):
    """
    Gets the power spectrum.

    Args:
        **kwargs: Additional keyword arguments for the function.
    """
    p = self.EEGCalcs.calcEEGPowerSpectrum()
    self.makePowerSpectrum(p[0], p[1], p[2], p[3], p[4], **kwargs)
    plt.show()

rate_map(cluster, channel, **kwargs)

Gets the rate map for the specified cluster(s) and channel.

Parameters:

Name Type Description Default
cluster int | list

The cluster(s) to get the rate map for.

required
channel int

The channel number.

required
**kwargs

Additional keyword arguments for the function.

{}
Source code in ephysiopy/visualise/plotting.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def rate_map(self, cluster: int | list, channel: int, **kwargs):
    """
    Gets the rate map for the specified cluster(s) and channel.

    Args:
        cluster (int | list): The cluster(s) to get the rate map for.
        channel (int): The channel number.
        **kwargs: Additional keyword arguments for the function.
    """
    if isinstance(cluster, list):
        self._plot_multiple_clusters(self.makeRateMap,
                                     cluster,
                                     channel,
                                     **kwargs)
    else:
        ts = self.get_spike_times(channel, cluster)
        self.makeRateMap(ts, **kwargs)
    plt.show()

sac(cluster, channel, **kwargs)

Gets the spatial autocorrelation for the specified cluster(s) and channel.

Parameters:

Name Type Description Default
cluster int | list

The cluster(s) to get the spatial autocorrelation for.

required
channel int

The channel number.

required
**kwargs

Additional keyword arguments for the function.

{}
Source code in ephysiopy/visualise/plotting.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def sac(self, cluster: int | list, channel: int, **kwargs):
    """
    Gets the spatial autocorrelation for the specified cluster(s) and
    channel.

    Args:
        cluster (int | list): The cluster(s) to get the spatial
            autocorrelation for.
        channel (int): The channel number.
        **kwargs: Additional keyword arguments for the function.
    """
    if isinstance(cluster, list):
        self._plot_multiple_clusters(self.makeSAC,
                                     cluster,
                                     channel,
                                     **kwargs)
    else:
        ts = self.get_spike_times(channel, cluster)
        self.makeSAC(ts, **kwargs)
    plt.show()

show_SAC(A, inDict, ax=None, **kwargs)

Displays the result of performing a spatial autocorrelation (SAC) on a grid cell.

Uses the dictionary containing measures of the grid cell SAC to make a pretty picture

Parameters:

Name Type Description Default
A array

The spatial autocorrelogram.

required
inDict dict

The dictionary calculated in getmeasures.

required
ax axes

If given the plot will get drawn in these axes. Default None.

None
**kwargs

Additional keyword arguments for the function.

{}

Returns:

Type Description
axes

matplotlib.axes: The axes with the plot.

See Also

ephysiopy.common.binning.RateMap.autoCorr2D() ephysiopy.common.ephys_generic.FieldCalcs.getMeaures()

Source code in ephysiopy/visualise/plotting.py
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
@stripAxes
def show_SAC(
    self, A: np.array, inDict: dict, ax: matplotlib.axes = None, **kwargs
) -> matplotlib.axes:
    """
    Displays the result of performing a spatial autocorrelation (SAC)
    on a grid cell.

    Uses the dictionary containing measures of the grid cell SAC to
    make a pretty picture

    Args:
        A (np.array): The spatial autocorrelogram.
        inDict (dict): The dictionary calculated in getmeasures.
        ax (matplotlib.axes, optional): If given the plot will get drawn
            in these axes. Default None.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        matplotlib.axes: The axes with the plot.

    See Also:
        ephysiopy.common.binning.RateMap.autoCorr2D()
        ephysiopy.common.ephys_generic.FieldCalcs.getMeaures()
    """
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    Am = A.copy()
    Am[~inDict["dist_to_centre"]] = np.nan
    Am = np.ma.masked_invalid(np.atleast_2d(Am))
    x, y = np.meshgrid(np.arange(0, np.shape(A)[1]),
                       np.arange(0, np.shape(A)[0]))
    vmax = np.nanmax(np.ravel(A))
    ax.pcolormesh(x, y, A, cmap=grey_cmap, edgecolors="face",
                  vmax=vmax, shading="auto")
    import copy

    cmap = copy.copy(jet_cmap)
    cmap.set_bad("w", 0)
    ax.pcolormesh(x, y, Am, cmap=cmap,
                  edgecolors="face", vmax=vmax, shading="auto")
    # horizontal green line at 3 o'clock
    _y = (np.shape(A)[0] / 2, np.shape(A)[0] / 2)
    _x = (np.shape(A)[1] / 2, np.shape(A)[0])
    ax.plot(_x, _y, c="g")
    mag = inDict["scale"] * 0.5
    th = np.linspace(0, inDict["orientation"], 50)
    from ephysiopy.common.utils import rect

    [x, y] = rect(mag, th, deg=1)
    # angle subtended by orientation
    ax.plot(
        x + (inDict["dist_to_centre"].shape[1] / 2),
        (inDict["dist_to_centre"].shape[0] / 2) - y,
        "r",
        **kwargs
    )
    # plot lines from centre to peaks above middle
    for p in inDict["closest_peak_coords"]:
        if p[0] <= inDict["dist_to_centre"].shape[0] / 2:
            ax.plot(
                (inDict["dist_to_centre"].shape[1] / 2, p[1]),
                (inDict["dist_to_centre"].shape[0] / 2, p[0]),
                "k",
                **kwargs
            )
    ax.invert_yaxis()
    all_ax = ax.axes
    all_ax.set_aspect("equal")
    all_ax.set_xlim((0.5, inDict["dist_to_centre"].shape[1] - 1.5))
    all_ax.set_ylim((inDict["dist_to_centre"].shape[0] - 0.5, -0.5))
    return ax

speed_v_hd(cluster, channel, **kwargs)

Gets the speed versus head direction plot for the specified cluster(s) and channel.

Parameters:

Name Type Description Default
cluster int | list

The cluster(s) to get the speed versus head direction plot for.

required
channel int

The channel number.

required
**kwargs

Additional keyword arguments for the function.

{}
Source code in ephysiopy/visualise/plotting.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def speed_v_hd(self, cluster: int | list, channel: int, **kwargs):
    """
    Gets the speed versus head direction plot for the specified cluster(s)
    and channel.

    Args:
        cluster (int | list): The cluster(s) to get the speed versus head
            direction plot for.
        channel (int): The channel number.
        **kwargs: Additional keyword arguments for the function.
    """
    if isinstance(cluster, list):
        self._plot_multiple_clusters(self.makeSpeedVsHeadDirectionPlot,
                                     cluster,
                                     channel,
                                     **kwargs)
    else:
        ts = self.get_spike_times(channel, cluster)
        self.makeSpeedVsHeadDirectionPlot(ts, **kwargs)
    plt.show()

speed_v_rate(cluster, channel, **kwargs)

Gets the speed versus rate plot for the specified cluster(s) and channel.

Parameters:

Name Type Description Default
cluster int | list

The cluster(s) to get the speed versus rate plot for.

required
channel int

The channel number.

required
**kwargs

Additional keyword arguments for the function.

{}
Source code in ephysiopy/visualise/plotting.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def speed_v_rate(self, cluster: int | list, channel: int, **kwargs):
    """
    Gets the speed versus rate plot for the specified cluster(s) and
    channel.

    Args:
        cluster (int | list): The cluster(s) to get the speed versus rate
            plot for.
        channel (int): The channel number.
        **kwargs: Additional keyword arguments for the function.
    """
    if isinstance(cluster, list):
        self._plot_multiple_clusters(self.makeSpeedVsRatePlot,
                                     cluster,
                                     channel,
                                     **kwargs)
    else:
        ts = self.get_spike_times(channel, cluster)
        self.makeSpeedVsRatePlot(ts, **kwargs)
    plt.show()

spike_path(cluster=None, channel=None, **kwargs)

Gets the spike path for the specified cluster(s) and channel.

Parameters:

Name Type Description Default
cluster int | list | None

The cluster(s) to get the spike path for.

None
channel int | None

The channel number.

None
**kwargs

Additional keyword arguments for the function.

{}
Source code in ephysiopy/visualise/plotting.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def spike_path(self, cluster=None, channel=None, **kwargs):
    """
    Gets the spike path for the specified cluster(s) and channel.

    Args:
        cluster (int | list | None): The cluster(s) to get the spike path
            for.
        channel (int | None): The channel number.
        **kwargs: Additional keyword arguments for the function.
    """
    if isinstance(cluster, list):
        self._plot_multiple_clusters(self.makeSpikePathPlot,
                                     cluster,
                                     channel,
                                     **kwargs)
    else:
        if channel is not None and cluster is not None:
            ts = self.get_spike_times(channel, cluster)
        else:
            ts = None
        self.makeSpikePathPlot(ts, **kwargs)
    plt.show()

Binning up data

RateMap

Bases: object

Bins up positional data (xy, head direction etc) and produces rate maps of the relevant kind. This is a generic class meant to be independent of any particular recording format.

Parameters:

Name Type Description Default
xy ndarray

The xy data, usually given as a 2 x n sample numpy array.

None
hdir ndarray

The head direction data, usually a 1 x n sample numpy array.

None
speed ndarray

Similar to hdir.

None
pos_weights ndarray

A 1D numpy array n samples long which is used to weight a particular position sample when binning data. For example, if there were 5 positions recorded and a cell spiked once in position 2 and 5 times in position 3 and nothing anywhere else then pos_weights looks like: [0 0 1 5 0] In the case of binning up position this will be an array of mostly 1's unless there are some positions you want excluded for some reason.

None
ppm int

Pixels per metre. Specifies how many camera pixels per metre so this, in combination with cmsPerBin, will determine how many bins there are in the rate map. Defaults to None.

430
xyInCms bool

Whether the positional data is in cms. Defaults to False.

False
cmsPerBin int

How many cms on a side each bin is in a rate map OR the number of degrees per bin in the case of directional binning. Defaults to 3.

required
smooth_sz int

The width of the smoothing kernel for smoothing rate maps. Defaults to 5.

5
Notes

There are several instance variables you can set, see below.

Source code in ephysiopy/common/binning.py
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
class RateMap(object):
    """
    Bins up positional data (xy, head direction etc) and produces rate maps
    of the relevant kind. This is a generic class meant to be independent of
    any particular recording format.

    Args:
        xy (ndarray): The xy data, usually given as a 2 x n sample numpy array.
        hdir (ndarray): The head direction data, usually a 1 x n sample numpy array.
        speed (ndarray): Similar to hdir.
        pos_weights (ndarray): A 1D numpy array n samples long which is used to weight a particular
            position sample when binning data. For example, if there were 5
            positions recorded and a cell spiked once in position 2 and 5 times
            in position 3 and nothing anywhere else then pos_weights looks like:
            [0 0 1 5 0]
            In the case of binning up position this will be an array of mostly 1's
            unless there are some positions you want excluded for some reason.
        ppm (int, optional): Pixels per metre. Specifies how many camera pixels per metre so this,
            in combination with cmsPerBin, will determine how many bins there are
            in the rate map. Defaults to None.
        xyInCms (bool, optional): Whether the positional data is in cms. Defaults to False.
        cmsPerBin (int, optional): How many cms on a side each bin is in a rate map OR the number of
            degrees per bin in the case of directional binning. Defaults to 3.
        smooth_sz (int, optional): The width of the smoothing kernel for smoothing rate maps. Defaults to 5.

    Notes:
        There are several instance variables you can set, see below.
    """

    def __init__(
        self,
        xy: np.array = None,
        hdir: np.array = None,
        speed: np.array = None,
        pos_weights: np.array = None,
        pos_times: np.array = None,
        ppm: int = 430,
        xyInCms: bool = False,
        binsize: int = 3,
        smooth_sz: int = 5,
    ):
        self.xy = xy
        self.dir = hdir
        self.speed = speed
        self._pos_weights = pos_weights
        self._pos_times = pos_times
        self._pos_time_splits = None
        self._spike_weights = None
        self._ppm = ppm  # pixels per metre
        self._binsize = binsize
        self._inCms = xyInCms
        self._nBins = None
        self._binedges = None  # has setter and getter - see below
        self._x_lims = None
        self._y_lims = None
        self._smooth_sz = smooth_sz
        self._smoothingType = "gaussian"  # 'boxcar' or 'gaussian'
        self.whenToSmooth = "before"  # or 'after'
        self._var2Bin = VariableToBin.XY
        self._mapType = MapType.RATE
        self._calcBinEdges()

    @property
    def inCms(self):
        # Whether the units are in cms or not
        return self._inCms

    @inCms.setter
    def inCms(self, value):
        self._inCms = value

    @property
    def ppm(self):
        # Get the current pixels per metre (ppm)
        return self._ppm

    @ppm.setter
    def ppm(self, value):
        self._ppm = value
        # self._binedges = self._calcBinEdges(self.binsize)

    @property
    def var2Bin(self):
        return self._var2Bin

    @var2Bin.setter
    def var2Bin(self, value):
        self._var2Bin = value

    @property
    def mapType(self):
        return self._mapType

    @mapType.setter
    def mapType(self, value):
        self._mapType = value

    @property
    def nBins(self):
        '''
        The number of bins for each dim
        '''
        if self.binsize:
            return len(self._binedges[0]), len(self._binedges[1])
        else:
            return None

    @nBins.setter
    def nBins(self, value):
        '''
        Sets the number of bins
        '''
        if self.var2Bin == VariableToBin.XY:
            x_lims, y_lims = self._getXYLimits()
            if isinstance(value, int):
                value = [value]
            if len(value) == 1:
                _x, bs_x = np.linspace(x_lims[0],
                                       x_lims[1],
                                       int(value[0]),
                                       retstep=True)
                _y, bs_y = np.linspace(y_lims[0],
                                       y_lims[1],
                                       int(value[0]),
                                       retstep=True)
            elif len(value) == 2:
                _x, bs_x = np.linspace(x_lims[0],
                                       x_lims[1],
                                       int(value[0]),
                                       retstep=True)
                _y, bs_y = np.linspace(y_lims[0],
                                       y_lims[1],
                                       int(value[1]),
                                       retstep=True)
            self._binedges = _y, _x
            self.binsize = np.mean([bs_x, bs_y])
        elif self.var2Bin == VariableToBin.DIR:
            self._binedges, binsize = np.linspace(0,
                                                  360 + self.binsize,
                                                  value[0],
                                                  retstep=True)
            self.binsize = binsize
        elif self.var2Bin == VariableToBin.SPEED:
            maxspeed = np.max(self.speed)
            self._binedges, binsize = np.linspace(0,
                                                  maxspeed,
                                                  value[0],
                                                  retstep=True)
            self.binsize = binsize

    @property
    def binedges(self):
        return self._binedges

    @binedges.setter
    def binedges(self, value):
        self._binedges = value

    @property
    def x_lims(self):
        return self._x_lims

    @x_lims.setter
    def x_lims(self, value):
        self._x_lims = value

    @property
    def y_lims(self):
        return self._y_lims

    @y_lims.setter
    def y_lims(self, value):
        self._y_lims = value

    @property
    def pos_weights(self):
        """
        The 'weights' used as an argument to np.histogram* for binning up
        position
        Mostly this is just an array of 1's equal to the length of the pos
        data, but usefully can be adjusted when masking data in the trial
        by
        """
        if self._pos_weights is None:
            self._pos_weights = np.ones(self.xy.shape[1])
        return self._pos_weights

    @pos_weights.setter
    def pos_weights(self, value):
        self._pos_weights = value

    @property
    def pos_times(self):
        return self._pos_times

    @pos_times.setter
    def pos_times(self, value):
        self._pos_times = value

    @property
    def pos_time_splits(self):
        return self._pos_times

    @pos_time_splits.setter
    def pos_time_splits(self, value):
        self._pos_times = value

    @property
    def spike_weights(self):
        return self._spike_weights

    @spike_weights.setter
    def spike_weights(self, value):
        self._spike_weights = value

    @property
    def binsize(self):
        # The number of cms per bin of the binned up map
        return self._binsize

    @binsize.setter
    def binsize(self, value):
        self._binsize = value
        self._binedges = self._calcBinEdges(value)

    @property
    def smooth_sz(self):
        # The size of the smoothing window applied to the binned data
        return self._smooth_sz

    @smooth_sz.setter
    def smooth_sz(self, value):
        self._smooth_sz = value

    @property
    def smoothingType(self):
        # The type of smoothing to do - legal values are 'boxcar' or 'gaussian'
        return self._smoothingType

    @smoothingType.setter
    def smoothingType(self, value):
        self._smoothingType = value

    def _getXYLimits(self):
        '''
        Gets the min/max of the x/y data
        '''
        x_lims = getattr(self, "x_lims", None)
        y_lims = getattr(self, "y_lims", None)
        if x_lims is None:
            x_lims = (np.nanmin(self.xy[0]), np.nanmax(self.xy[0]))
        if y_lims is None:
            y_lims = (np.nanmin(self.xy[1]), np.nanmax(self.xy[1]))
        self.x_lims = x_lims
        self.y_lims = y_lims
        return x_lims, y_lims

    def _calcBinDims(self):
        try:
            self._binDims = [len(b) for b in self._binedges]
        except TypeError:
            self._binDims = len(self._binedges)

    def _calcBinEdges(self, binsize: int = 3) -> tuple:
        """
        Aims to get the right number of bins for the variable to be binned

        Args:
            binsize (int, optional): The number of cms per bin for XY OR degrees for DIR OR cm/s for SPEED. Defaults to 3.

        Returns:
            tuple: each member an array of bin edges
        """
        if self.var2Bin.value == VariableToBin.DIR.value:
            self.binedges = np.arange(0, 360 + binsize, binsize)
        elif self.var2Bin.value == VariableToBin.SPEED.value:
            maxspeed = np.max(self.speed)
            # assume min speed = 0
            self.binedges = np.arange(0, maxspeed, binsize)
        elif self.var2Bin == VariableToBin.XY:
            x_lims, y_lims = self._getXYLimits()
            nxbins = int(np.ceil((x_lims[1]-x_lims[0])/binsize))
            nybins = int(np.ceil((y_lims[1]-y_lims[0])/binsize))
            _x = np.linspace(x_lims[0], x_lims[1], nxbins)
            _y = np.linspace(y_lims[0], y_lims[1], nybins)
            self.binedges = _y, _x
        elif self.var2Bin == VariableToBin.XY_TIME:
            if self._pos_time_splits is None:
                raise ValueError("Need pos times to bin up XY_TIME")
            x_lims, y_lims = self._getXYLimits()
            nxbins = int(np.ceil((x_lims[1]-x_lims[0])/binsize))
            nybins = int(np.ceil((y_lims[1]-y_lims[0])/binsize))
            _x = np.linspace(x_lims[0], x_lims[1], nxbins)
            _y = np.linspace(y_lims[0], y_lims[1], nybins)
            be = [_y, _x]
            be.append(self.pos_time_splits)
            self.binedges = be
        self._calcBinDims()
        return self.binedges

    def getSpatialSparsity(self,
                           spkWeights,
                           sample_rate=50,
                           **kwargs):
        """
        Gets the spatial sparsity measure - closer to 1 means
        sparser firing field.

        References:
            Skaggs, W.E., McNaughton, B.L., Wilson, M.A. & Barnes, C.A.
            Theta phase precession in hippocampal neuronal populations
            and the compression of temporal sequences.
            Hippocampus 6, 149–172 (1996).
        """
        self.var2Bin = VariableToBin.XY
        self._calcBinEdges()
        sample = self.xy
        keep_these = np.isfinite(sample[0])
        pos, _ = self._binData(sample,
                               self._binedges,
                               self.pos_weights,
                               keep_these)
        npos = len(self.dir)
        p_i = np.count_nonzero(pos) / npos / sample_rate
        spk, _ = self._binData(sample,
                               self._binedges,
                               spkWeights,
                               keep_these)
        res = 1-(np.nansum(p_i*spk)**2) / np.nansum(p_i*spk**2)
        return res

    def getMap(self, spkWeights,
               varType=VariableToBin.XY,
               mapType=MapType.RATE,
               smoothing=True,
               **kwargs):
        """
        Bins up the variable type varType and returns a tuple of
        (rmap, binnedPositionDir) or
        (rmap, binnedPostionX, binnedPositionY)

        Args:
            spkWeights (array_like): Shape equal to number of positions samples captured and consists of
                position weights. For example, if there were 5 positions
                recorded and a cell spiked once in position 2 and 5 times in
                position 3 and nothing anywhere else then pos_weights looks
                like: [0 0 1 5 0]
            varType (Enum value - see Variable2Bin defined at top of this file): The variable to bin up. Legal values are: XY, DIR and SPEED
            mapType (enum value - see MapType defined at top of this file): If RATE then the binned up spikes are divided by varType.
                Otherwise return binned up position. Options are RATE or POS
            smoothing (bool, optional): Whether to smooth the data or not. Defaults to True.

        Returns:
            binned_data, binned_pos (tuple): This is either a 2-tuple or a 3-tuple depening on whether binned
                pos (mapType 'pos') or binned spikes (mapType 'rate') is asked
                for respectively
        """
        if varType.value == VariableToBin.DIR.value:
            sample = self.dir
            keep_these = np.isfinite(sample)
        elif varType.value == VariableToBin.SPEED.value:
            sample = self.speed
            keep_these = np.isfinite(sample)
        elif varType.value == VariableToBin.XY.value:
            sample = self.xy
            keep_these = np.isfinite(sample[0])
        elif varType.value == VariableToBin.XY_TIME.value:
            sample = np.concatenate((np.atleast_2d(self.xy),
                                    np.atleast_2d(self.pos_times)))
            keep_these = np.isfinite(self.xy[0])
        else:
            raise ValueError("Unrecognized variable to bin.")
        assert sample is not None

        self.var2Bin = varType
        self._spike_weights = spkWeights
        self._calcBinEdges(self.binsize)

        binned_pos, binned_pos_edges = self._binData(
            sample,
            self._binedges,
            self.pos_weights,
            keep_these)
        nanIdx = binned_pos == 0

        if mapType.value == MapType.POS.value:  # return binned up position
            if smoothing:
                if varType.value == VariableToBin.DIR.value:
                    binned_pos = self._circPadSmooth(
                        binned_pos, n=self.smooth_sz)
                else:
                    binned_pos = blurImage(binned_pos,
                                           self.smooth_sz,
                                           ftype=self.smoothingType,
                                           **kwargs)
            return binned_pos, binned_pos_edges

        binned_spk, _ = self._binData(
            sample, self._binedges, spkWeights, keep_these)
        if mapType.value == MapType.SPK:
            return binned_spk
        # binned_spk is returned as a tuple of the binned data and the bin
        # edges
        if "after" in self.whenToSmooth:
            rmap = binned_spk / binned_pos
            if varType.value == VariableToBin.DIR.value:
                rmap = self._circPadSmooth(rmap, self.smooth_sz)
            else:
                rmap = blurImage(rmap,
                                 self.smooth_sz,
                                 ftype=self.smoothingType,
                                 **kwargs)
        else:  # default case
            if not smoothing:
                return binned_spk / binned_pos, binned_pos_edges
            if varType.value == VariableToBin.DIR.value:
                binned_pos = self._circPadSmooth(binned_pos, self.smooth_sz)
                binned_spk = self._circPadSmooth(binned_spk, self.smooth_sz)
                rmap = binned_spk / binned_pos
            else:
                binned_pos = blurImage(binned_pos,
                                       self.smooth_sz,
                                       ftype=self.smoothingType,
                                       **kwargs)
                if binned_spk.ndim == 2:
                    pass
                elif binned_spk.ndim == 1:
                    binned_spk_tmp = np.zeros(
                        [binned_spk.shape[0], binned_spk.shape[0], 1]
                    )
                    for i in range(binned_spk.shape[0]):
                        binned_spk_tmp[i, :, :] = binned_spk[i]
                    binned_spk = binned_spk_tmp
                binned_spk = blurImage(
                    binned_spk,
                    self.smooth_sz,
                    ftype=self.smoothingType,
                    **kwargs)
                rmap = binned_spk / binned_pos
                if rmap.ndim <= 2:
                    rmap[nanIdx] = np.nan

        return rmap, binned_pos_edges

    def getSAC(self, spkWeights, **kwargs):
        '''
        Returns the SAC - convenience function
        '''
        rmap = self.getMap(spkWeights=spkWeights, **kwargs)
        nodwell = ~np.isfinite(rmap[0])
        return self.autoCorr2D(rmap[0], nodwell)

    def _binData(self, var, bin_edges, weights, good_indices=None):
        """
        Bins data taking account of possible multi-dimensionality

        Args:
            var (array_like): The variable to bin
            bin_edges (array_like): The edges of the data - see numpys histogramdd for more
            weights (array_like): The weights attributed to the samples in var
            good_indices (array_like): Valid indices (i.e. not nan and not infinite)

        Returns:
            ndhist (2-tuple): Think this always returns a two-tuple of the binned variable and
                the bin edges - need to check to be sure...

        Notes:
            This breaks compatability with numpys histogramdd
            In the 2d histogram case below I swap the axes around so that x and y
            are binned in the 'normal' format i.e. so x appears horizontally and y
            vertically.
            Multi-binning issue is dealt with awkwardly through checking
            the dimensionality of the weights array.
            'normally' this would be 1 dim but when multiple clusters are being
            binned it will be 2 dim.
            In that case np.apply_along_axis functionality is applied.
            The spike weights in that case might be created like so:

            >>> spk_W = np.zeros(shape=[len(trial.nClusters), trial.npos])
            >>> for i, cluster in enumerate(trial.clusters):
            >>>		x1 = trial.getClusterIdx(cluster)
            >>>		spk_W[i, :] = np.bincount(x1, minlength=trial.npos)

            This can then be fed into this fcn something like so:

            >>> rng = np.array((np.ma.min(
                trial.POS.xy, 1).data, np.ma.max(rial.POS.xy, 1).data))
            >>> h = _binData(
                var=trial.POS.xy, bin_edges=np.array([64, 64]),
                weights=spk_W, rng=rng)

            Returned will be a tuple containing the binned up data and
            the bin edges for x and y (obv this will be the same for all
            entries of h)
        """
        if weights is None:
            weights = np.ones_like(var)
        dims = weights.ndim
        if dims == 1 and var.ndim == 1:
            var = var[np.newaxis, :]
            # if self.var2Bin != VariableToBin.XY and len(bin_edges) != 1:
            #     bin_edges = self._calcBinEdges(self.binsize)
            bin_edges = bin_edges[np.newaxis, :]
        elif dims > 1 and var.ndim == 1:
            var = var[np.newaxis, :]
            bin_edges = bin_edges[np.newaxis, :]
        else:
            var = np.flipud(var)
        weights = np.atleast_2d(weights)  # needed for list comp below
        var = np.array(var.data.T.tolist())
        ndhist = [np.histogramdd(
            sample=var[good_indices],
            bins=bin_edges,
            weights=np.ravel(w[good_indices])) for w in weights]
        if np.shape(weights)[0] == 1:
            return ndhist[0][0], ndhist[0][1]
        else:
            tmp = [d[0] for d in ndhist]
            tmp = np.array(tmp)
            return tmp, ndhist[1]

    def _circPadSmooth(self, var, n=3, ny=None):
        """
        Smooths a vector by convolving with a gaussian
        Mirror reflects the start and end of the vector to
        deal with edge effects

        Args:
            var (array_like): The vector to smooth
            n, ny (int): Size of the smoothing (sigma in gaussian)

        Returns:
            array_like: The smoothed vector with shape the same as var
        """

        tn = len(var)
        t2 = int(np.floor(tn / 2))
        var = np.concatenate((var[t2:tn], var, var[0:t2]))
        if ny is None:
            ny = n
        x, y = np.mgrid[-n: n + 1, 0 - ny: ny + 1]
        g = np.exp(-(x**2 / float(n) + y**2 / float(ny)))
        if np.ndim(var) == 1:
            g = g[n, :]
        g = g / g.sum()
        improc = signal.convolve(var, g, mode="same")
        improc = improc[tn - t2: tn - t2 + tn]
        return improc

    def _circularStructure(self, radius):
        """
        Generates a circular binary structure for use with morphological
        operations such as ndimage.binary_dilation etc

        This is only used in this implementation for adaptively binning
        ratemaps for use with information theoretic measures (Skaggs etc)

        Args:
            radius (int): the size of the circular structure

        Returns:
            res (array_like): Binary structure with shape [(radius*2) + 1,(radius*2) + 1]

        See Also:
            RateMap.__adpativeMap
        """
        from skimage.morphology import disk

        return disk(radius)

    def getAdaptiveMap(self, pos_binned, spk_binned, alpha=200):
        """
        Produces a ratemap that has been adaptively binned according to the
        algorithm described in Skaggs et al., 1996) [1]_.

        Args:
            pos_binned (array_like): The binned positional data. For example that returned from getMap
                above with mapType as 'pos'
            spk_binned (array_like): The binned spikes
            alpha (int, optional): A scaling parameter determing the amount of occupancy to aim at
                in each bin. Defaults to 200.

        Returns:
            Returns adaptively binned spike and pos maps. Use to generate Skaggs
            information measure

        Notes:
            Positions with high rates mean proportionately less error than those
            with low rates, so this tries to even the playing field. This type
            of binning should be used for calculations of spatial info
            as with the skaggs_info method in the fieldcalcs class (see below)
            alpha is a scaling parameter that might need tweaking for different
            data sets.
            From the paper:
                The data [are] first binned
                into a 64 X 64 grid of spatial locations, and then the firing rate
                at each point in this grid was calculated by expanding a circle
                around the point until the following criterion was met:
                    Nspks > alpha / (Nocc^2 * r^2)
                where Nspks is the number of spikes emitted in a circle of radius
                r (in bins), Nocc is the number of occupancy samples, alpha is the
                scaling parameter
                The firing rate in the given bin is then calculated as:
                    sample_rate * (Nspks / Nocc)

        References:
            .. [1] W. E. Skaggs, B. L. McNaughton, K. M. Gothard & E. J. Markus
                "An Information-Theoretic Approach to Deciphering the Hippocampal
                Code"
                Neural Information Processing Systems, 1993.
        """
        #  assign output arrays
        smthdPos = np.zeros_like(pos_binned)
        smthdSpk = np.zeros_like(spk_binned)
        smthdRate = np.zeros_like(pos_binned)
        idx = pos_binned == 0
        pos_binned[idx] = np.nan
        spk_binned[idx] = np.nan
        visited = np.zeros_like(pos_binned)
        visited[pos_binned > 0] = 1
        # array to check which bins have made it
        binCheck = np.isnan(pos_binned)
        r = 1
        while np.any(~binCheck):
            # create the filter kernel
            h = self._circularStructure(r)
            h[h >= np.max(h) / 3.0] = 1
            h[h != 1] = 0
            if h.shape >= pos_binned.shape:
                break
            # filter the arrays using astropys convolution
            filtPos = convolution.convolve(pos_binned, h, boundary=None)
            filtSpk = convolution.convolve(spk_binned, h, boundary=None)
            filtVisited = convolution.convolve(visited, h, boundary=None)
            # get the bins which made it through this iteration
            trueBins = alpha / (np.sqrt(filtSpk) * filtPos) <= r
            trueBins = np.logical_and(trueBins, ~binCheck)
            # insert values where true
            smthdPos[trueBins] = filtPos[trueBins] / filtVisited[trueBins]
            smthdSpk[trueBins] = filtSpk[trueBins] / filtVisited[trueBins]
            binCheck[trueBins] = True
            r += 1
        smthdRate = smthdSpk / smthdPos
        smthdRate[idx] = np.nan
        smthdSpk[idx] = np.nan
        smthdPos[idx] = np.nan
        return smthdRate, smthdSpk, smthdPos

    def autoCorr2D(self, A, nodwell, tol=1e-10):
        """
        Performs a spatial autocorrelation on the array A

        Args:
            A (array_like): Either 2 or 3D. In the former it is simply the binned up ratemap
                where the two dimensions correspond to x and y.
                If 3D then the first two dimensions are x
                and y and the third (last dimension) is 'stack' of ratemaps
            nodwell (array_like): A boolean array corresponding the bins in the ratemap that
                weren't visited. See Notes below.
            tol (float, optional): Values below this are set to zero to deal with v small values
                thrown up by the fft. Default 1e-10

        Returns:
            sac (array_like): The spatial autocorrelation in the relevant dimensionality

        Notes:
            The nodwell input can usually be generated by:

            >>> nodwell = ~np.isfinite(A)
        """

        assert np.ndim(A) == 2
        m, n = np.shape(A)
        o = 1
        x = np.reshape(A, (m, n, o))
        nodwell = np.reshape(nodwell, (m, n, o))
        x[nodwell] = 0
        # [Step 1] Obtain FFTs of x, the sum of squares and bins visited
        Fx = np.fft.fft(np.fft.fft(x, 2 * m - 1, axis=0), 2 * n - 1, axis=1)
        FsumOfSquares_x = np.fft.fft(
            np.fft.fft(np.power(x, 2), 2 * m - 1, axis=0), 2 * n - 1, axis=1
        )
        Fn = np.fft.fft(
            np.fft.fft(np.invert(nodwell).astype(int), 2 * m - 1, axis=0),
            2 * n - 1,
            axis=1,
        )
        # [Step 2] Multiply the relevant transforms and invert to obtain the
        # equivalent convolutions
        rawCorr = np.fft.fftshift(
            np.real(np.fft.ifft(
                np.fft.ifft(Fx * np.conj(Fx), axis=1), axis=0)),
            axes=(0, 1),
        )
        sums_x = np.fft.fftshift(
            np.real(np.fft.ifft(
                np.fft.ifft(np.conj(Fx) * Fn, axis=1), axis=0)),
            axes=(0, 1),
        )
        sumOfSquares_x = np.fft.fftshift(
            np.real(
                np.fft.ifft(
                    np.fft.ifft(Fn * np.conj(FsumOfSquares_x), axis=1), axis=0)
            ),
            axes=(0, 1),
        )
        N = np.fft.fftshift(
            np.real(np.fft.ifft(
                np.fft.ifft(Fn * np.conj(Fn), axis=1), axis=0)),
            axes=(0, 1),
        )
        # [Step 3] Account for rounding errors.
        rawCorr[np.abs(rawCorr) < tol] = 0
        sums_x[np.abs(sums_x) < tol] = 0
        sumOfSquares_x[np.abs(sumOfSquares_x) < tol] = 0
        N = np.round(N)
        N[N <= 1] = np.nan
        # [Step 4] Compute correlation matrix
        mapStd = np.sqrt((sumOfSquares_x * N) - sums_x**2)
        mapCovar = (rawCorr * N) - sums_x * \
            sums_x[::-1, :, :][:, ::-1, :][:, :, :]

        return np.squeeze(
            mapCovar / mapStd / mapStd[::-1, :, :][:, ::-1, :][:, :, :])

    def crossCorr2D(self, A, B, A_nodwell, B_nodwell, tol=1e-10):
        """
        Performs a spatial crosscorrelation between the arrays A and B

        Args:
            A, B (array_like): Either 2 or 3D. In the former it is simply the binned up ratemap
                where the two dimensions correspond to x and y.
                If 3D then the first two dimensions are x
                and y and the third (last dimension) is 'stack' of ratemaps
            nodwell_A, nodwell_B (array_like): A boolean array corresponding the bins in the ratemap that
                weren't visited. See Notes below.
            tol (float, optional): Values below this are set to zero to deal with v small values
                thrown up by the fft. Default 1e-10

        Returns:
            sac (array_like): The spatial crosscorrelation in the relevant dimensionality

        Notes:
            The nodwell input can usually be generated by:

            >>> nodwell = ~np.isfinite(A)
        """
        if np.ndim(A) != np.ndim(B):
            raise ValueError("Both arrays must have the same dimensionality")
        assert np.ndim(A) == 2
        ma, na = np.shape(A)
        mb, nb = np.shape(B)
        oa = ob = 1
        A = np.reshape(A, (ma, na, oa))
        B = np.reshape(B, (mb, nb, ob))
        A_nodwell = np.reshape(A_nodwell, (ma, na, oa))
        B_nodwell = np.reshape(B_nodwell, (mb, nb, ob))
        A[A_nodwell] = 0
        B[B_nodwell] = 0
        # [Step 1] Obtain FFTs of x, the sum of squares and bins visited
        Fa = np.fft.fft(np.fft.fft(A, 2 * mb - 1, axis=0), 2 * nb - 1, axis=1)
        FsumOfSquares_a = np.fft.fft(
            np.fft.fft(np.power(A, 2), 2 * mb - 1, axis=0), 2 * nb - 1, axis=1
        )
        Fn_a = np.fft.fft(
            np.fft.fft(np.invert(A_nodwell).astype(int), 2 * mb - 1, axis=0),
            2 * nb - 1,
            axis=1,
        )
        Fb = np.fft.fft(np.fft.fft(B, 2 * ma - 1, axis=0), 2 * na - 1, axis=1)
        FsumOfSquares_b = np.fft.fft(
            np.fft.fft(np.power(B, 2), 2 * ma - 1, axis=0), 2 * na - 1, axis=1
        )
        Fn_b = np.fft.fft(
            np.fft.fft(np.invert(B_nodwell).astype(int), 2 * ma - 1, axis=0),
            2 * na - 1,
            axis=1,
        )
        # [Step 2] Multiply the relevant transforms and invert to obtain the
        # equivalent convolutions
        rawCorr = np.fft.fftshift(
            np.real(np.fft.ifft(np.fft.ifft(Fa * np.conj(Fb), axis=1), axis=0))
        )
        sums_a = np.fft.fftshift(
            np.real(np.fft.ifft(np.fft.ifft(
                Fa * np.conj(Fn_b), axis=1), axis=0))
        )
        sums_b = np.fft.fftshift(
            np.real(np.fft.ifft(np.fft.ifft(
                Fn_a * np.conj(Fb), axis=1), axis=0))
        )
        sumOfSquares_a = np.fft.fftshift(
            np.real(
                np.fft.ifft(
                    np.fft.ifft(
                        FsumOfSquares_a * np.conj(Fn_b), axis=1), axis=0
                )
            )
        )
        sumOfSquares_b = np.fft.fftshift(
            np.real(
                np.fft.ifft(
                    np.fft.ifft(
                        Fn_a * np.conj(FsumOfSquares_b), axis=1), axis=0
                )
            )
        )
        N = np.fft.fftshift(
            np.real(np.fft.ifft(np.fft.ifft(
                Fn_a * np.conj(Fn_b), axis=1), axis=0))
        )
        # [Step 3] Account for rounding errors.
        rawCorr[np.abs(rawCorr) < tol] = 0
        sums_a[np.abs(sums_a) < tol] = 0
        sums_b[np.abs(sums_b) < tol] = 0
        sumOfSquares_a[np.abs(sumOfSquares_a) < tol] = 0
        sumOfSquares_b[np.abs(sumOfSquares_b) < tol] = 0
        N = np.round(N)
        N[N <= 1] = np.nan
        # [Step 4] Compute correlation matrix
        mapStd_a = np.sqrt((sumOfSquares_a * N) - sums_a**2)
        mapStd_b = np.sqrt((sumOfSquares_b * N) - sums_b**2)
        mapCovar = (rawCorr * N) - sums_a * sums_b

        return np.squeeze(mapCovar / (mapStd_a * mapStd_b))

    def tWinSAC(
        self,
        xy,
        spkIdx,
        ppm=365,
        winSize=10,
        pos_sample_rate=50,
        nbins=71,
        boxcar=5,
        Pthresh=100,
        downsampfreq=50,
        plot=False,
    ):
        """
        Temporal windowed spatial autocorrelation.

        Args:
            xy (array_like): The position data
            spkIdx (array_like): The indices in xy where the cell fired
            ppm (int, optional): The camera pixels per metre. Default 365
            winSize (int, optional): The window size for the temporal search
            pos_sample_rate (int, optional): The rate at which position was sampled. Default 50
            nbins (int, optional): The number of bins for creating the resulting ratemap. Default 71
            boxcar (int, optional): The size of the smoothing kernel to smooth ratemaps. Default 5
            Pthresh (int, optional): The cut-off for values in the ratemap; values < Pthresh become nans. Default 100
            downsampfreq (int, optional): How much to downsample. Default 50
            plot (bool, optional): Whether to show a plot of the result. Default False

        Returns:
            H (array_like): The temporal windowed SAC
        """
        # [Stage 0] Get some numbers
        xy = xy / ppm * 100
        n_samps = xy.shape[1]
        n_spks = len(spkIdx)
        winSizeBins = np.min([winSize * pos_sample_rate, n_samps])
        # factor by which positions are downsampled
        downsample = np.ceil(pos_sample_rate / downsampfreq)
        Pthresh = Pthresh / downsample  # take account of downsampling

        # [Stage 1] Calculate number of spikes in the window for each spikeInd
        # (ignoring spike itself)
        # 1a. Loop preparation
        nSpikesInWin = np.zeros(n_spks, dtype=int)

        # 1b. Keep looping until we have dealt with all spikes
        for i, s in enumerate(spkIdx):
            t = np.searchsorted(spkIdx, (s, s + winSizeBins))
            nSpikesInWin[i] = len(spkIdx[t[0]: t[1]]) - 1  # ignore ith spike

        # [Stage 2] Prepare for main loop
        # 2a. Work out offset inidices to be used when storing spike data
        off_spike = np.cumsum([nSpikesInWin])
        off_spike = np.pad(off_spike, (1, 0), "constant", constant_values=(0))

        # 2b. Work out number of downsampled pos bins in window and
        # offset indices for storing data
        nPosInWindow = np.minimum(winSizeBins, n_samps - spkIdx)
        nDownsampInWin = np.floor((nPosInWindow - 1) / downsample) + 1

        off_dwell = np.cumsum(nDownsampInWin.astype(int))
        off_dwell = np.pad(off_dwell, (1, 0), "constant", constant_values=(0))

        # 2c. Pre-allocate dwell and spike arrays, singles for speed
        dwell = np.zeros((2, off_dwell[-1]), dtype=np.single) * np.nan
        spike = np.zeros((2, off_spike[-1]), dtype=np.single) * np.nan

        filled_pvals = 0
        filled_svals = 0

        for i in range(n_spks):
            # calculate dwell displacements
            winInd_dwell = np.arange(
                spkIdx[i] + 1,
                np.minimum(spkIdx[i] + winSizeBins, n_samps),
                downsample,
                dtype=int,
            )
            WL = len(winInd_dwell)
            dwell[:, filled_pvals: filled_pvals + WL] = np.rot90(
                np.array(np.rot90(xy[:, winInd_dwell]) - xy[:, spkIdx[i]])
            )
            filled_pvals = filled_pvals + WL
            # calculate spike displacements
            winInd_spks = (
                i + np.nonzero(spkIdx[i + 1: n_spks] <
                               spkIdx[i] + winSizeBins)[0]
            )
            WL = len(winInd_spks)
            spike[:, filled_svals: filled_svals + WL] = np.rot90(
                np.array(
                    np.rot90(xy[:, spkIdx[winInd_spks]]) - xy[:, spkIdx[i]])
            )
            filled_svals = filled_svals + WL

        dwell = np.delete(dwell, np.isnan(dwell).nonzero()[1], axis=1)
        spike = np.delete(spike, np.isnan(spike).nonzero()[1], axis=1)

        dwell = np.hstack((dwell, -dwell))
        spike = np.hstack((spike, -spike))

        dwell_min = np.min(dwell, axis=1)
        dwell_max = np.max(dwell, axis=1)

        binsize = (dwell_max[1] - dwell_min[1]) / nbins

        dwell = np.round(
            (dwell - np.ones_like(dwell) * dwell_min[:, np.newaxis]) / binsize
        )
        spike = np.round(
            (spike - np.ones_like(spike) * dwell_min[:, np.newaxis]) / binsize
        )

        binsize = np.max(dwell, axis=1).astype(int)
        binedges = np.array(((-0.5, -0.5), binsize + 0.5)).T
        Hp = np.histogram2d(dwell[0, :], dwell[1, :],
                            range=binedges, bins=binsize)[0]
        Hs = np.histogram2d(spike[0, :], spike[1, :],
                            range=binedges, bins=binsize)[0]

        # reverse y,x order
        Hp = np.swapaxes(Hp, 1, 0)
        Hs = np.swapaxes(Hs, 1, 0)

        fHp = blurImage(Hp, boxcar)
        fHs = blurImage(Hs, boxcar)

        H = fHs / fHp
        H[Hp < Pthresh] = np.nan

        return H

    @cache
    def _create_boundary_distance_lookup(self,
                                         arena_boundary: MultiLineString,
                                         degs_per_bin: float,
                                         xy_binsize: float,
                                         **kwargs):
        # Now we generate lines radiating out from a point as a
        # multilinestring geometry collection - this looks
        # like a 360/degs_per_bin
        # star. We will move this to each valid location in the position map
        # and then calculate the distance to the nearest intersection with the
        # arena boundary.
        # get the arena boundaries to figure out the radius of the arena,
        # regardless of its actual shape
        x1, y1, x2, y2 = arena_boundary.bounds
        radius = max(x2-x1, y2-y1)/2
        startpoint = Point((x1+radius, y1+radius))
        endpoint = Point([x2, y1+radius])
        angles = np.arange(0, 360, degs_per_bin)
        lines = MultiLineString(
            [rotate(LineString([startpoint, endpoint]), ang, origin=startpoint)
             for ang in angles])
        prepare(lines)
        # arena centre
        cx = x1 + radius
        cy = y1 + radius
        # get the position map and the valid locations within it
        pos_map, (ybin_edges, xbin_edges) = self.getMap(np.ones_like(self.dir),
                                                        varType=VariableToBin.XY,
                                                        mapType=MapType.POS,
                                                        smoothing=False)
        yvalid, xvalid = np.nonzero(~np.isnan(pos_map))

        # preallocate the array to hold distances
        distances = np.full(
            (len(xbin_edges), len(ybin_edges), len(angles)), np.nan)

        # Now iterate through valid locations in the pos map and calculate the
        # distances and the indices of the lines that intersect with the
        # arena boundary. The indices are equivalent to the angle of the
        # line in the lines geometry collection. This iteration is a bit slow
        # but it will only need to be done once per session as it's creating
        # a lookup table for the distances
        for xi, yi in zip(xvalid, yvalid):
            i_point = Point((xbin_edges[xi]+xy_binsize,
                             ybin_edges[yi]+xy_binsize))
            ipx, ipy = i_point.xy
            new_point = Point(cx-ipx[0], cy-ipy[0])
            t_arena = translate(arena_boundary, -new_point.x, -new_point.y)
            prepare(t_arena)
            di = [(startpoint.distance(t_arena.intersection(line)), idx)
                  for idx, line in enumerate(lines.geoms) if
                  t_arena.intersects(line)]
            d, i = zip(*di)
            distances[xi, yi, i] = d
        return distances

    def get_egocentric_boundary_map(self,
                                    spk_weights,
                                    degs_per_bin: float = 3,
                                    xy_binsize: float = 2.5,
                                    arena_type: str = "circle",
                                    return_dists: bool = False,
                                    return_raw_spk: bool = False,
                                    return_raw_occ: bool = False) -> namedtuple:
        """
        Helps construct dwell time/spike counts maps with respect to boundaries at given egocentric directions and distances.

        Note:
            For the directional input, the 0 degree reference is horizontal pointing East and moves counter-clockwise.
        """
        assert self.dir is not None, "No direction data available"
        # initially do some binning to get valid locations
        # (some might be nans due to
        # arena shape and/or poor sampling) and then digitize
        # the x and y positions
        # and the angular positions
        self.binsize = xy_binsize  # this will trigger a
        # re-calculation of the bin edges

        angles = np.arange(0, 360, degs_per_bin)

        # Use the shaeply package to specify some geometry for the arena
        # boundary and the lines radiating out
        # from the current location of the animal. The geometry for the
        # arena should be user specified but for now I'll just use a circle
        if arena_type == "circle":
            radius = 50
            circle_centre = Point(
                np.nanmin(self.xy[0])+radius, np.nanmin(self.xy[1])+radius)
            arena_boundary = circle_centre.buffer(radius).boundary
        # now we have a circle with its centre at the centre of the arena
        # i.e. the circle defines the arena edges. Calling .boundary on the
        # circle geometry actually gives us a 65-gon polygon
        distances = self._create_boundary_distance_lookup(
            arena_boundary, degs_per_bin, xy_binsize)
        # iterate through the digitized locations (x/y and angular), using the
        # lookup table to get the distances to the arena boundary and then
        # increment the appropriate bin in the egocentric boundary map
        good_idx = np.isfinite(self.xy[0])
        xy_by_heading, _ = np.histogramdd([self.xy[0][good_idx],
                                           self.xy[1][good_idx],
                                           self.dir[good_idx]],
                                          bins=distances.shape,
                                          weights=self.pos_weights[good_idx])
        spk_xy_by_hd, _ = np.histogramdd([self.xy[0][good_idx],
                                          self.xy[1][good_idx],
                                          self.dir[good_idx]],
                                         bins=distances.shape,
                                         weights=spk_weights[good_idx])
        assert xy_by_heading.shape == distances.shape
        distlist = []
        anglist = []
        spkdists = []
        spkangs = []
        for i_bin in np.ndindex(distances.shape[:2]):
            i_dist = distances[i_bin]
            valid_dist = np.isfinite(i_dist)
            nonzero_bincounts = np.nonzero(xy_by_heading[i_bin])[0]
            nonzero_spkbins = np.nonzero(spk_xy_by_hd[i_bin])[0]
            for i_angle in nonzero_bincounts:
                ego_angles = np.roll(angles, i_angle)[valid_dist]
                n_repeats = xy_by_heading[i_bin][i_angle]
                ego_angles_repeats = np.repeat(ego_angles, n_repeats)
                dist_repeats = np.repeat(i_dist[valid_dist], n_repeats)
                distlist.append(dist_repeats)
                anglist.append(ego_angles_repeats)
                if i_angle in nonzero_spkbins:
                    n_repeats = spk_xy_by_hd[i_bin][i_angle]
                    ego_angles_repeats = np.repeat(ego_angles, n_repeats)
                    dist_repeats = np.repeat(i_dist[valid_dist], n_repeats)
                    spkdists.append(dist_repeats)
                    spkangs.append(ego_angles_repeats)
        flat_angs = flatten_list(anglist)
        flat_dists = flatten_list(distlist)
        flat_spk_dists = flatten_list(spkdists)
        flat_spk_angs = flatten_list(spkangs)
        bins = [int(radius/xy_binsize), len(angles)]
        ego_boundary_occ, _, _ = np.histogram2d(x=flat_dists, y=flat_angs,
                                                bins=bins)
        ego_boundary_spk, _, _ = np.histogram2d(x=flat_spk_dists,
                                                y=flat_spk_angs,
                                                bins=bins)
        kernel = convolution.Gaussian2DKernel(5, x_size=3, y_size=5)
        sm_occ = convolution.convolve(ego_boundary_occ,
                                      kernel,
                                      boundary='extend')
        sm_spk = convolution.convolve(ego_boundary_spk,
                                      kernel,
                                      boundary='extend')
        ego_boundary_map = sm_spk / sm_occ
        EgoMap = namedtuple("EgoMap", ['rmap', 'occ', 'spk', 'dists'],
                            defaults=None)
        em = EgoMap(None, None, None, None)
        em = em._replace(rmap=ego_boundary_map)
        if return_dists:
            em = em._replace(dists=distances)
        if return_raw_occ:
            em = em._replace(occ=ego_boundary_occ)
        if return_raw_spk:
            em = em._replace(spk=ego_boundary_spk)
        return em

    def getAllSpikeWeights(self,
                           spike_times: np.ndarray,
                           spike_clusters: np.ndarray,
                           pos_times: np.ndarray,
                           **kwargs):
        """
        Args:
            spike_times (np.ndarray): Spike times in seconds
            spike_clusters (np.ndarray): Cluster identity vector
            pos_times (np.ndarray): The times at which position was captured in seconds

        Returns:
            np.ndarray: The bincounts with respect to position for each cluster. Shape of returned array will be nClusters x npos
        """
        assert len(spike_clusters) == len(spike_times)
        clusters = np.unique(spike_clusters)
        npos = len(self.dir)
        idx = np.searchsorted(pos_times, spike_times) - 1
        weights = [np.bincount(idx[spike_clusters == c], minlength=npos)
                   for c in clusters]
        return np.array(weights)

    def _splitStackedCorrelations(self, binned_data: list) -> tuple:
        '''
        Takes in the result of doStackedCorrelations() and splits into
        two arrays and returns these as a 2-tuple
        '''
        result = [(s[0][:, :, 0], s[0][:, :, 1]) for s in binned_data]
        result = np.array(result)
        return np.squeeze(result[:, 0, :, :]), np.squeeze(result[:, 1, :, :])

    def doStackedCorrelations(self,
                              spkW: np.ndarray,
                              times: np.ndarray,
                              splits: np.ndarray,
                              var2bin: Enum = VariableToBin.XY,
                              maptype: Enum = MapType.RATE,
                              **kwargs):
        """
        Returns a list of binned data where each item in the list
        is the result of running np.histogramdd on a spatial
        variable (xy, dir etc) and a temporal one at the same
        time. The idea is to split the spatial variable into two
        temporal halves based on the bin edges in 'splits' and
        then to run correlations between the two halves and
        furthermore to do this for all of the clusters that have
        spike weights in 'spkW'. 'spkW' should be the result of
        using getAllSpikeWeights().

        Args:
            spkW (np.ndarray): The result of calling getAllSpikeWeights()
            times (np.ndarray): Position times in seconds
            splits (np.ndarray): Where to split the data in seconds. Will
                typically take the form (0, 100, 200) for
                example which will give a split between 0-100
                and 100-200 seconds
            var2bin (Enum): The spatial variable to bin up
            maptype (Enum): The type of map to produce
        """
        if var2bin.value == VariableToBin.DIR.value:
            sample = self.dir
        elif var2bin.value == VariableToBin.SPEED.value:
            sample = self.speed
        elif var2bin.value == VariableToBin.XY.value:
            sample = self.xy
        else:
            raise ValueError("Unrecognized variable to bin.")
        assert sample is not None
        self.pos_time_splits = splits

        sample = np.concatenate((np.atleast_2d(sample),
                                np.atleast_2d(times)))
        edges = [b for b in self._binedges][::-1]
        edges.append(splits)
        # bin pos
        bp, bpe = np.histogramdd(sample.T, bins=edges)
        map1_pos, map2_pos = np.squeeze(bp[:, :, 0]), np.squeeze(bp[:, :, 1])
        # smooth position
        map1_pos = blurImage(map1_pos, 7, ftype='gaussian')
        map2_pos = blurImage(map2_pos, 7, ftype='gaussian')
        # bin spk - ie the histogram is weighted by spike count
        # in bin i
        spk = [np.histogramdd(sample.T, bins=edges, weights=w)
               for w in spkW]
        map1_spk, map2_spk = self._splitStackedCorrelations(spk)
        map1_sm_spk = np.array([blurImage(m, 7, ftype='gaussian')
                                for m in map1_spk])
        map2_sm_spk = np.array([blurImage(m, 7, ftype='gaussian')
                                for m in map2_spk])
        map1_rmaps = map1_sm_spk / map1_pos
        map2_rmaps = map2_sm_spk / map2_pos
        return map1_rmaps, map2_rmaps

nBins property writable

The number of bins for each dim

pos_weights property writable

The 'weights' used as an argument to np.histogram* for binning up position Mostly this is just an array of 1's equal to the length of the pos data, but usefully can be adjusted when masking data in the trial by

autoCorr2D(A, nodwell, tol=1e-10)

Performs a spatial autocorrelation on the array A

Parameters:

Name Type Description Default
A array_like

Either 2 or 3D. In the former it is simply the binned up ratemap where the two dimensions correspond to x and y. If 3D then the first two dimensions are x and y and the third (last dimension) is 'stack' of ratemaps

required
nodwell array_like

A boolean array corresponding the bins in the ratemap that weren't visited. See Notes below.

required
tol float

Values below this are set to zero to deal with v small values thrown up by the fft. Default 1e-10

1e-10

Returns:

Name Type Description
sac array_like

The spatial autocorrelation in the relevant dimensionality

Notes

The nodwell input can usually be generated by:

nodwell = ~np.isfinite(A)

Source code in ephysiopy/common/binning.py
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
def autoCorr2D(self, A, nodwell, tol=1e-10):
    """
    Performs a spatial autocorrelation on the array A

    Args:
        A (array_like): Either 2 or 3D. In the former it is simply the binned up ratemap
            where the two dimensions correspond to x and y.
            If 3D then the first two dimensions are x
            and y and the third (last dimension) is 'stack' of ratemaps
        nodwell (array_like): A boolean array corresponding the bins in the ratemap that
            weren't visited. See Notes below.
        tol (float, optional): Values below this are set to zero to deal with v small values
            thrown up by the fft. Default 1e-10

    Returns:
        sac (array_like): The spatial autocorrelation in the relevant dimensionality

    Notes:
        The nodwell input can usually be generated by:

        >>> nodwell = ~np.isfinite(A)
    """

    assert np.ndim(A) == 2
    m, n = np.shape(A)
    o = 1
    x = np.reshape(A, (m, n, o))
    nodwell = np.reshape(nodwell, (m, n, o))
    x[nodwell] = 0
    # [Step 1] Obtain FFTs of x, the sum of squares and bins visited
    Fx = np.fft.fft(np.fft.fft(x, 2 * m - 1, axis=0), 2 * n - 1, axis=1)
    FsumOfSquares_x = np.fft.fft(
        np.fft.fft(np.power(x, 2), 2 * m - 1, axis=0), 2 * n - 1, axis=1
    )
    Fn = np.fft.fft(
        np.fft.fft(np.invert(nodwell).astype(int), 2 * m - 1, axis=0),
        2 * n - 1,
        axis=1,
    )
    # [Step 2] Multiply the relevant transforms and invert to obtain the
    # equivalent convolutions
    rawCorr = np.fft.fftshift(
        np.real(np.fft.ifft(
            np.fft.ifft(Fx * np.conj(Fx), axis=1), axis=0)),
        axes=(0, 1),
    )
    sums_x = np.fft.fftshift(
        np.real(np.fft.ifft(
            np.fft.ifft(np.conj(Fx) * Fn, axis=1), axis=0)),
        axes=(0, 1),
    )
    sumOfSquares_x = np.fft.fftshift(
        np.real(
            np.fft.ifft(
                np.fft.ifft(Fn * np.conj(FsumOfSquares_x), axis=1), axis=0)
        ),
        axes=(0, 1),
    )
    N = np.fft.fftshift(
        np.real(np.fft.ifft(
            np.fft.ifft(Fn * np.conj(Fn), axis=1), axis=0)),
        axes=(0, 1),
    )
    # [Step 3] Account for rounding errors.
    rawCorr[np.abs(rawCorr) < tol] = 0
    sums_x[np.abs(sums_x) < tol] = 0
    sumOfSquares_x[np.abs(sumOfSquares_x) < tol] = 0
    N = np.round(N)
    N[N <= 1] = np.nan
    # [Step 4] Compute correlation matrix
    mapStd = np.sqrt((sumOfSquares_x * N) - sums_x**2)
    mapCovar = (rawCorr * N) - sums_x * \
        sums_x[::-1, :, :][:, ::-1, :][:, :, :]

    return np.squeeze(
        mapCovar / mapStd / mapStd[::-1, :, :][:, ::-1, :][:, :, :])

crossCorr2D(A, B, A_nodwell, B_nodwell, tol=1e-10)

Performs a spatial crosscorrelation between the arrays A and B

Parameters:

Name Type Description Default
A, B (array_like

Either 2 or 3D. In the former it is simply the binned up ratemap where the two dimensions correspond to x and y. If 3D then the first two dimensions are x and y and the third (last dimension) is 'stack' of ratemaps

required
nodwell_A, nodwell_B (array_like

A boolean array corresponding the bins in the ratemap that weren't visited. See Notes below.

required
tol float

Values below this are set to zero to deal with v small values thrown up by the fft. Default 1e-10

1e-10

Returns:

Name Type Description
sac array_like

The spatial crosscorrelation in the relevant dimensionality

Notes

The nodwell input can usually be generated by:

nodwell = ~np.isfinite(A)

Source code in ephysiopy/common/binning.py
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
def crossCorr2D(self, A, B, A_nodwell, B_nodwell, tol=1e-10):
    """
    Performs a spatial crosscorrelation between the arrays A and B

    Args:
        A, B (array_like): Either 2 or 3D. In the former it is simply the binned up ratemap
            where the two dimensions correspond to x and y.
            If 3D then the first two dimensions are x
            and y and the third (last dimension) is 'stack' of ratemaps
        nodwell_A, nodwell_B (array_like): A boolean array corresponding the bins in the ratemap that
            weren't visited. See Notes below.
        tol (float, optional): Values below this are set to zero to deal with v small values
            thrown up by the fft. Default 1e-10

    Returns:
        sac (array_like): The spatial crosscorrelation in the relevant dimensionality

    Notes:
        The nodwell input can usually be generated by:

        >>> nodwell = ~np.isfinite(A)
    """
    if np.ndim(A) != np.ndim(B):
        raise ValueError("Both arrays must have the same dimensionality")
    assert np.ndim(A) == 2
    ma, na = np.shape(A)
    mb, nb = np.shape(B)
    oa = ob = 1
    A = np.reshape(A, (ma, na, oa))
    B = np.reshape(B, (mb, nb, ob))
    A_nodwell = np.reshape(A_nodwell, (ma, na, oa))
    B_nodwell = np.reshape(B_nodwell, (mb, nb, ob))
    A[A_nodwell] = 0
    B[B_nodwell] = 0
    # [Step 1] Obtain FFTs of x, the sum of squares and bins visited
    Fa = np.fft.fft(np.fft.fft(A, 2 * mb - 1, axis=0), 2 * nb - 1, axis=1)
    FsumOfSquares_a = np.fft.fft(
        np.fft.fft(np.power(A, 2), 2 * mb - 1, axis=0), 2 * nb - 1, axis=1
    )
    Fn_a = np.fft.fft(
        np.fft.fft(np.invert(A_nodwell).astype(int), 2 * mb - 1, axis=0),
        2 * nb - 1,
        axis=1,
    )
    Fb = np.fft.fft(np.fft.fft(B, 2 * ma - 1, axis=0), 2 * na - 1, axis=1)
    FsumOfSquares_b = np.fft.fft(
        np.fft.fft(np.power(B, 2), 2 * ma - 1, axis=0), 2 * na - 1, axis=1
    )
    Fn_b = np.fft.fft(
        np.fft.fft(np.invert(B_nodwell).astype(int), 2 * ma - 1, axis=0),
        2 * na - 1,
        axis=1,
    )
    # [Step 2] Multiply the relevant transforms and invert to obtain the
    # equivalent convolutions
    rawCorr = np.fft.fftshift(
        np.real(np.fft.ifft(np.fft.ifft(Fa * np.conj(Fb), axis=1), axis=0))
    )
    sums_a = np.fft.fftshift(
        np.real(np.fft.ifft(np.fft.ifft(
            Fa * np.conj(Fn_b), axis=1), axis=0))
    )
    sums_b = np.fft.fftshift(
        np.real(np.fft.ifft(np.fft.ifft(
            Fn_a * np.conj(Fb), axis=1), axis=0))
    )
    sumOfSquares_a = np.fft.fftshift(
        np.real(
            np.fft.ifft(
                np.fft.ifft(
                    FsumOfSquares_a * np.conj(Fn_b), axis=1), axis=0
            )
        )
    )
    sumOfSquares_b = np.fft.fftshift(
        np.real(
            np.fft.ifft(
                np.fft.ifft(
                    Fn_a * np.conj(FsumOfSquares_b), axis=1), axis=0
            )
        )
    )
    N = np.fft.fftshift(
        np.real(np.fft.ifft(np.fft.ifft(
            Fn_a * np.conj(Fn_b), axis=1), axis=0))
    )
    # [Step 3] Account for rounding errors.
    rawCorr[np.abs(rawCorr) < tol] = 0
    sums_a[np.abs(sums_a) < tol] = 0
    sums_b[np.abs(sums_b) < tol] = 0
    sumOfSquares_a[np.abs(sumOfSquares_a) < tol] = 0
    sumOfSquares_b[np.abs(sumOfSquares_b) < tol] = 0
    N = np.round(N)
    N[N <= 1] = np.nan
    # [Step 4] Compute correlation matrix
    mapStd_a = np.sqrt((sumOfSquares_a * N) - sums_a**2)
    mapStd_b = np.sqrt((sumOfSquares_b * N) - sums_b**2)
    mapCovar = (rawCorr * N) - sums_a * sums_b

    return np.squeeze(mapCovar / (mapStd_a * mapStd_b))

doStackedCorrelations(spkW, times, splits, var2bin=VariableToBin.XY, maptype=MapType.RATE, **kwargs)

Returns a list of binned data where each item in the list is the result of running np.histogramdd on a spatial variable (xy, dir etc) and a temporal one at the same time. The idea is to split the spatial variable into two temporal halves based on the bin edges in 'splits' and then to run correlations between the two halves and furthermore to do this for all of the clusters that have spike weights in 'spkW'. 'spkW' should be the result of using getAllSpikeWeights().

Parameters:

Name Type Description Default
spkW ndarray

The result of calling getAllSpikeWeights()

required
times ndarray

Position times in seconds

required
splits ndarray

Where to split the data in seconds. Will typically take the form (0, 100, 200) for example which will give a split between 0-100 and 100-200 seconds

required
var2bin Enum

The spatial variable to bin up

XY
maptype Enum

The type of map to produce

RATE
Source code in ephysiopy/common/binning.py
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
def doStackedCorrelations(self,
                          spkW: np.ndarray,
                          times: np.ndarray,
                          splits: np.ndarray,
                          var2bin: Enum = VariableToBin.XY,
                          maptype: Enum = MapType.RATE,
                          **kwargs):
    """
    Returns a list of binned data where each item in the list
    is the result of running np.histogramdd on a spatial
    variable (xy, dir etc) and a temporal one at the same
    time. The idea is to split the spatial variable into two
    temporal halves based on the bin edges in 'splits' and
    then to run correlations between the two halves and
    furthermore to do this for all of the clusters that have
    spike weights in 'spkW'. 'spkW' should be the result of
    using getAllSpikeWeights().

    Args:
        spkW (np.ndarray): The result of calling getAllSpikeWeights()
        times (np.ndarray): Position times in seconds
        splits (np.ndarray): Where to split the data in seconds. Will
            typically take the form (0, 100, 200) for
            example which will give a split between 0-100
            and 100-200 seconds
        var2bin (Enum): The spatial variable to bin up
        maptype (Enum): The type of map to produce
    """
    if var2bin.value == VariableToBin.DIR.value:
        sample = self.dir
    elif var2bin.value == VariableToBin.SPEED.value:
        sample = self.speed
    elif var2bin.value == VariableToBin.XY.value:
        sample = self.xy
    else:
        raise ValueError("Unrecognized variable to bin.")
    assert sample is not None
    self.pos_time_splits = splits

    sample = np.concatenate((np.atleast_2d(sample),
                            np.atleast_2d(times)))
    edges = [b for b in self._binedges][::-1]
    edges.append(splits)
    # bin pos
    bp, bpe = np.histogramdd(sample.T, bins=edges)
    map1_pos, map2_pos = np.squeeze(bp[:, :, 0]), np.squeeze(bp[:, :, 1])
    # smooth position
    map1_pos = blurImage(map1_pos, 7, ftype='gaussian')
    map2_pos = blurImage(map2_pos, 7, ftype='gaussian')
    # bin spk - ie the histogram is weighted by spike count
    # in bin i
    spk = [np.histogramdd(sample.T, bins=edges, weights=w)
           for w in spkW]
    map1_spk, map2_spk = self._splitStackedCorrelations(spk)
    map1_sm_spk = np.array([blurImage(m, 7, ftype='gaussian')
                            for m in map1_spk])
    map2_sm_spk = np.array([blurImage(m, 7, ftype='gaussian')
                            for m in map2_spk])
    map1_rmaps = map1_sm_spk / map1_pos
    map2_rmaps = map2_sm_spk / map2_pos
    return map1_rmaps, map2_rmaps

getAdaptiveMap(pos_binned, spk_binned, alpha=200)

Produces a ratemap that has been adaptively binned according to the algorithm described in Skaggs et al., 1996) [1]_.

Parameters:

Name Type Description Default
pos_binned array_like

The binned positional data. For example that returned from getMap above with mapType as 'pos'

required
spk_binned array_like

The binned spikes

required
alpha int

A scaling parameter determing the amount of occupancy to aim at in each bin. Defaults to 200.

200

Returns:

Type Description

Returns adaptively binned spike and pos maps. Use to generate Skaggs

information measure

Notes

Positions with high rates mean proportionately less error than those with low rates, so this tries to even the playing field. This type of binning should be used for calculations of spatial info as with the skaggs_info method in the fieldcalcs class (see below) alpha is a scaling parameter that might need tweaking for different data sets. From the paper: The data [are] first binned into a 64 X 64 grid of spatial locations, and then the firing rate at each point in this grid was calculated by expanding a circle around the point until the following criterion was met: Nspks > alpha / (Nocc^2 * r^2) where Nspks is the number of spikes emitted in a circle of radius r (in bins), Nocc is the number of occupancy samples, alpha is the scaling parameter The firing rate in the given bin is then calculated as: sample_rate * (Nspks / Nocc)

References

.. [1] W. E. Skaggs, B. L. McNaughton, K. M. Gothard & E. J. Markus "An Information-Theoretic Approach to Deciphering the Hippocampal Code" Neural Information Processing Systems, 1993.

Source code in ephysiopy/common/binning.py
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
def getAdaptiveMap(self, pos_binned, spk_binned, alpha=200):
    """
    Produces a ratemap that has been adaptively binned according to the
    algorithm described in Skaggs et al., 1996) [1]_.

    Args:
        pos_binned (array_like): The binned positional data. For example that returned from getMap
            above with mapType as 'pos'
        spk_binned (array_like): The binned spikes
        alpha (int, optional): A scaling parameter determing the amount of occupancy to aim at
            in each bin. Defaults to 200.

    Returns:
        Returns adaptively binned spike and pos maps. Use to generate Skaggs
        information measure

    Notes:
        Positions with high rates mean proportionately less error than those
        with low rates, so this tries to even the playing field. This type
        of binning should be used for calculations of spatial info
        as with the skaggs_info method in the fieldcalcs class (see below)
        alpha is a scaling parameter that might need tweaking for different
        data sets.
        From the paper:
            The data [are] first binned
            into a 64 X 64 grid of spatial locations, and then the firing rate
            at each point in this grid was calculated by expanding a circle
            around the point until the following criterion was met:
                Nspks > alpha / (Nocc^2 * r^2)
            where Nspks is the number of spikes emitted in a circle of radius
            r (in bins), Nocc is the number of occupancy samples, alpha is the
            scaling parameter
            The firing rate in the given bin is then calculated as:
                sample_rate * (Nspks / Nocc)

    References:
        .. [1] W. E. Skaggs, B. L. McNaughton, K. M. Gothard & E. J. Markus
            "An Information-Theoretic Approach to Deciphering the Hippocampal
            Code"
            Neural Information Processing Systems, 1993.
    """
    #  assign output arrays
    smthdPos = np.zeros_like(pos_binned)
    smthdSpk = np.zeros_like(spk_binned)
    smthdRate = np.zeros_like(pos_binned)
    idx = pos_binned == 0
    pos_binned[idx] = np.nan
    spk_binned[idx] = np.nan
    visited = np.zeros_like(pos_binned)
    visited[pos_binned > 0] = 1
    # array to check which bins have made it
    binCheck = np.isnan(pos_binned)
    r = 1
    while np.any(~binCheck):
        # create the filter kernel
        h = self._circularStructure(r)
        h[h >= np.max(h) / 3.0] = 1
        h[h != 1] = 0
        if h.shape >= pos_binned.shape:
            break
        # filter the arrays using astropys convolution
        filtPos = convolution.convolve(pos_binned, h, boundary=None)
        filtSpk = convolution.convolve(spk_binned, h, boundary=None)
        filtVisited = convolution.convolve(visited, h, boundary=None)
        # get the bins which made it through this iteration
        trueBins = alpha / (np.sqrt(filtSpk) * filtPos) <= r
        trueBins = np.logical_and(trueBins, ~binCheck)
        # insert values where true
        smthdPos[trueBins] = filtPos[trueBins] / filtVisited[trueBins]
        smthdSpk[trueBins] = filtSpk[trueBins] / filtVisited[trueBins]
        binCheck[trueBins] = True
        r += 1
    smthdRate = smthdSpk / smthdPos
    smthdRate[idx] = np.nan
    smthdSpk[idx] = np.nan
    smthdPos[idx] = np.nan
    return smthdRate, smthdSpk, smthdPos

getAllSpikeWeights(spike_times, spike_clusters, pos_times, **kwargs)

Parameters:

Name Type Description Default
spike_times ndarray

Spike times in seconds

required
spike_clusters ndarray

Cluster identity vector

required
pos_times ndarray

The times at which position was captured in seconds

required

Returns:

Type Description

np.ndarray: The bincounts with respect to position for each cluster. Shape of returned array will be nClusters x npos

Source code in ephysiopy/common/binning.py
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
def getAllSpikeWeights(self,
                       spike_times: np.ndarray,
                       spike_clusters: np.ndarray,
                       pos_times: np.ndarray,
                       **kwargs):
    """
    Args:
        spike_times (np.ndarray): Spike times in seconds
        spike_clusters (np.ndarray): Cluster identity vector
        pos_times (np.ndarray): The times at which position was captured in seconds

    Returns:
        np.ndarray: The bincounts with respect to position for each cluster. Shape of returned array will be nClusters x npos
    """
    assert len(spike_clusters) == len(spike_times)
    clusters = np.unique(spike_clusters)
    npos = len(self.dir)
    idx = np.searchsorted(pos_times, spike_times) - 1
    weights = [np.bincount(idx[spike_clusters == c], minlength=npos)
               for c in clusters]
    return np.array(weights)

getMap(spkWeights, varType=VariableToBin.XY, mapType=MapType.RATE, smoothing=True, **kwargs)

Bins up the variable type varType and returns a tuple of (rmap, binnedPositionDir) or (rmap, binnedPostionX, binnedPositionY)

Parameters:

Name Type Description Default
spkWeights array_like

Shape equal to number of positions samples captured and consists of position weights. For example, if there were 5 positions recorded and a cell spiked once in position 2 and 5 times in position 3 and nothing anywhere else then pos_weights looks like: [0 0 1 5 0]

required
varType Enum value - see Variable2Bin defined at top of this file

The variable to bin up. Legal values are: XY, DIR and SPEED

XY
mapType enum value - see MapType defined at top of this file

If RATE then the binned up spikes are divided by varType. Otherwise return binned up position. Options are RATE or POS

RATE
smoothing bool

Whether to smooth the data or not. Defaults to True.

True

Returns:

Type Description

binned_data, binned_pos (tuple): This is either a 2-tuple or a 3-tuple depening on whether binned pos (mapType 'pos') or binned spikes (mapType 'rate') is asked for respectively

Source code in ephysiopy/common/binning.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
def getMap(self, spkWeights,
           varType=VariableToBin.XY,
           mapType=MapType.RATE,
           smoothing=True,
           **kwargs):
    """
    Bins up the variable type varType and returns a tuple of
    (rmap, binnedPositionDir) or
    (rmap, binnedPostionX, binnedPositionY)

    Args:
        spkWeights (array_like): Shape equal to number of positions samples captured and consists of
            position weights. For example, if there were 5 positions
            recorded and a cell spiked once in position 2 and 5 times in
            position 3 and nothing anywhere else then pos_weights looks
            like: [0 0 1 5 0]
        varType (Enum value - see Variable2Bin defined at top of this file): The variable to bin up. Legal values are: XY, DIR and SPEED
        mapType (enum value - see MapType defined at top of this file): If RATE then the binned up spikes are divided by varType.
            Otherwise return binned up position. Options are RATE or POS
        smoothing (bool, optional): Whether to smooth the data or not. Defaults to True.

    Returns:
        binned_data, binned_pos (tuple): This is either a 2-tuple or a 3-tuple depening on whether binned
            pos (mapType 'pos') or binned spikes (mapType 'rate') is asked
            for respectively
    """
    if varType.value == VariableToBin.DIR.value:
        sample = self.dir
        keep_these = np.isfinite(sample)
    elif varType.value == VariableToBin.SPEED.value:
        sample = self.speed
        keep_these = np.isfinite(sample)
    elif varType.value == VariableToBin.XY.value:
        sample = self.xy
        keep_these = np.isfinite(sample[0])
    elif varType.value == VariableToBin.XY_TIME.value:
        sample = np.concatenate((np.atleast_2d(self.xy),
                                np.atleast_2d(self.pos_times)))
        keep_these = np.isfinite(self.xy[0])
    else:
        raise ValueError("Unrecognized variable to bin.")
    assert sample is not None

    self.var2Bin = varType
    self._spike_weights = spkWeights
    self._calcBinEdges(self.binsize)

    binned_pos, binned_pos_edges = self._binData(
        sample,
        self._binedges,
        self.pos_weights,
        keep_these)
    nanIdx = binned_pos == 0

    if mapType.value == MapType.POS.value:  # return binned up position
        if smoothing:
            if varType.value == VariableToBin.DIR.value:
                binned_pos = self._circPadSmooth(
                    binned_pos, n=self.smooth_sz)
            else:
                binned_pos = blurImage(binned_pos,
                                       self.smooth_sz,
                                       ftype=self.smoothingType,
                                       **kwargs)
        return binned_pos, binned_pos_edges

    binned_spk, _ = self._binData(
        sample, self._binedges, spkWeights, keep_these)
    if mapType.value == MapType.SPK:
        return binned_spk
    # binned_spk is returned as a tuple of the binned data and the bin
    # edges
    if "after" in self.whenToSmooth:
        rmap = binned_spk / binned_pos
        if varType.value == VariableToBin.DIR.value:
            rmap = self._circPadSmooth(rmap, self.smooth_sz)
        else:
            rmap = blurImage(rmap,
                             self.smooth_sz,
                             ftype=self.smoothingType,
                             **kwargs)
    else:  # default case
        if not smoothing:
            return binned_spk / binned_pos, binned_pos_edges
        if varType.value == VariableToBin.DIR.value:
            binned_pos = self._circPadSmooth(binned_pos, self.smooth_sz)
            binned_spk = self._circPadSmooth(binned_spk, self.smooth_sz)
            rmap = binned_spk / binned_pos
        else:
            binned_pos = blurImage(binned_pos,
                                   self.smooth_sz,
                                   ftype=self.smoothingType,
                                   **kwargs)
            if binned_spk.ndim == 2:
                pass
            elif binned_spk.ndim == 1:
                binned_spk_tmp = np.zeros(
                    [binned_spk.shape[0], binned_spk.shape[0], 1]
                )
                for i in range(binned_spk.shape[0]):
                    binned_spk_tmp[i, :, :] = binned_spk[i]
                binned_spk = binned_spk_tmp
            binned_spk = blurImage(
                binned_spk,
                self.smooth_sz,
                ftype=self.smoothingType,
                **kwargs)
            rmap = binned_spk / binned_pos
            if rmap.ndim <= 2:
                rmap[nanIdx] = np.nan

    return rmap, binned_pos_edges

getSAC(spkWeights, **kwargs)

Returns the SAC - convenience function

Source code in ephysiopy/common/binning.py
487
488
489
490
491
492
493
def getSAC(self, spkWeights, **kwargs):
    '''
    Returns the SAC - convenience function
    '''
    rmap = self.getMap(spkWeights=spkWeights, **kwargs)
    nodwell = ~np.isfinite(rmap[0])
    return self.autoCorr2D(rmap[0], nodwell)

getSpatialSparsity(spkWeights, sample_rate=50, **kwargs)

Gets the spatial sparsity measure - closer to 1 means sparser firing field.

References

Skaggs, W.E., McNaughton, B.L., Wilson, M.A. & Barnes, C.A. Theta phase precession in hippocampal neuronal populations and the compression of temporal sequences. Hippocampus 6, 149–172 (1996).

Source code in ephysiopy/common/binning.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def getSpatialSparsity(self,
                       spkWeights,
                       sample_rate=50,
                       **kwargs):
    """
    Gets the spatial sparsity measure - closer to 1 means
    sparser firing field.

    References:
        Skaggs, W.E., McNaughton, B.L., Wilson, M.A. & Barnes, C.A.
        Theta phase precession in hippocampal neuronal populations
        and the compression of temporal sequences.
        Hippocampus 6, 149–172 (1996).
    """
    self.var2Bin = VariableToBin.XY
    self._calcBinEdges()
    sample = self.xy
    keep_these = np.isfinite(sample[0])
    pos, _ = self._binData(sample,
                           self._binedges,
                           self.pos_weights,
                           keep_these)
    npos = len(self.dir)
    p_i = np.count_nonzero(pos) / npos / sample_rate
    spk, _ = self._binData(sample,
                           self._binedges,
                           spkWeights,
                           keep_these)
    res = 1-(np.nansum(p_i*spk)**2) / np.nansum(p_i*spk**2)
    return res

get_egocentric_boundary_map(spk_weights, degs_per_bin=3, xy_binsize=2.5, arena_type='circle', return_dists=False, return_raw_spk=False, return_raw_occ=False)

Helps construct dwell time/spike counts maps with respect to boundaries at given egocentric directions and distances.

Note

For the directional input, the 0 degree reference is horizontal pointing East and moves counter-clockwise.

Source code in ephysiopy/common/binning.py
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
def get_egocentric_boundary_map(self,
                                spk_weights,
                                degs_per_bin: float = 3,
                                xy_binsize: float = 2.5,
                                arena_type: str = "circle",
                                return_dists: bool = False,
                                return_raw_spk: bool = False,
                                return_raw_occ: bool = False) -> namedtuple:
    """
    Helps construct dwell time/spike counts maps with respect to boundaries at given egocentric directions and distances.

    Note:
        For the directional input, the 0 degree reference is horizontal pointing East and moves counter-clockwise.
    """
    assert self.dir is not None, "No direction data available"
    # initially do some binning to get valid locations
    # (some might be nans due to
    # arena shape and/or poor sampling) and then digitize
    # the x and y positions
    # and the angular positions
    self.binsize = xy_binsize  # this will trigger a
    # re-calculation of the bin edges

    angles = np.arange(0, 360, degs_per_bin)

    # Use the shaeply package to specify some geometry for the arena
    # boundary and the lines radiating out
    # from the current location of the animal. The geometry for the
    # arena should be user specified but for now I'll just use a circle
    if arena_type == "circle":
        radius = 50
        circle_centre = Point(
            np.nanmin(self.xy[0])+radius, np.nanmin(self.xy[1])+radius)
        arena_boundary = circle_centre.buffer(radius).boundary
    # now we have a circle with its centre at the centre of the arena
    # i.e. the circle defines the arena edges. Calling .boundary on the
    # circle geometry actually gives us a 65-gon polygon
    distances = self._create_boundary_distance_lookup(
        arena_boundary, degs_per_bin, xy_binsize)
    # iterate through the digitized locations (x/y and angular), using the
    # lookup table to get the distances to the arena boundary and then
    # increment the appropriate bin in the egocentric boundary map
    good_idx = np.isfinite(self.xy[0])
    xy_by_heading, _ = np.histogramdd([self.xy[0][good_idx],
                                       self.xy[1][good_idx],
                                       self.dir[good_idx]],
                                      bins=distances.shape,
                                      weights=self.pos_weights[good_idx])
    spk_xy_by_hd, _ = np.histogramdd([self.xy[0][good_idx],
                                      self.xy[1][good_idx],
                                      self.dir[good_idx]],
                                     bins=distances.shape,
                                     weights=spk_weights[good_idx])
    assert xy_by_heading.shape == distances.shape
    distlist = []
    anglist = []
    spkdists = []
    spkangs = []
    for i_bin in np.ndindex(distances.shape[:2]):
        i_dist = distances[i_bin]
        valid_dist = np.isfinite(i_dist)
        nonzero_bincounts = np.nonzero(xy_by_heading[i_bin])[0]
        nonzero_spkbins = np.nonzero(spk_xy_by_hd[i_bin])[0]
        for i_angle in nonzero_bincounts:
            ego_angles = np.roll(angles, i_angle)[valid_dist]
            n_repeats = xy_by_heading[i_bin][i_angle]
            ego_angles_repeats = np.repeat(ego_angles, n_repeats)
            dist_repeats = np.repeat(i_dist[valid_dist], n_repeats)
            distlist.append(dist_repeats)
            anglist.append(ego_angles_repeats)
            if i_angle in nonzero_spkbins:
                n_repeats = spk_xy_by_hd[i_bin][i_angle]
                ego_angles_repeats = np.repeat(ego_angles, n_repeats)
                dist_repeats = np.repeat(i_dist[valid_dist], n_repeats)
                spkdists.append(dist_repeats)
                spkangs.append(ego_angles_repeats)
    flat_angs = flatten_list(anglist)
    flat_dists = flatten_list(distlist)
    flat_spk_dists = flatten_list(spkdists)
    flat_spk_angs = flatten_list(spkangs)
    bins = [int(radius/xy_binsize), len(angles)]
    ego_boundary_occ, _, _ = np.histogram2d(x=flat_dists, y=flat_angs,
                                            bins=bins)
    ego_boundary_spk, _, _ = np.histogram2d(x=flat_spk_dists,
                                            y=flat_spk_angs,
                                            bins=bins)
    kernel = convolution.Gaussian2DKernel(5, x_size=3, y_size=5)
    sm_occ = convolution.convolve(ego_boundary_occ,
                                  kernel,
                                  boundary='extend')
    sm_spk = convolution.convolve(ego_boundary_spk,
                                  kernel,
                                  boundary='extend')
    ego_boundary_map = sm_spk / sm_occ
    EgoMap = namedtuple("EgoMap", ['rmap', 'occ', 'spk', 'dists'],
                        defaults=None)
    em = EgoMap(None, None, None, None)
    em = em._replace(rmap=ego_boundary_map)
    if return_dists:
        em = em._replace(dists=distances)
    if return_raw_occ:
        em = em._replace(occ=ego_boundary_occ)
    if return_raw_spk:
        em = em._replace(spk=ego_boundary_spk)
    return em

tWinSAC(xy, spkIdx, ppm=365, winSize=10, pos_sample_rate=50, nbins=71, boxcar=5, Pthresh=100, downsampfreq=50, plot=False)

Temporal windowed spatial autocorrelation.

Parameters:

Name Type Description Default
xy array_like

The position data

required
spkIdx array_like

The indices in xy where the cell fired

required
ppm int

The camera pixels per metre. Default 365

365
winSize int

The window size for the temporal search

10
pos_sample_rate int

The rate at which position was sampled. Default 50

50
nbins int

The number of bins for creating the resulting ratemap. Default 71

71
boxcar int

The size of the smoothing kernel to smooth ratemaps. Default 5

5
Pthresh int

The cut-off for values in the ratemap; values < Pthresh become nans. Default 100

100
downsampfreq int

How much to downsample. Default 50

50
plot bool

Whether to show a plot of the result. Default False

False

Returns:

Name Type Description
H array_like

The temporal windowed SAC

Source code in ephysiopy/common/binning.py
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
def tWinSAC(
    self,
    xy,
    spkIdx,
    ppm=365,
    winSize=10,
    pos_sample_rate=50,
    nbins=71,
    boxcar=5,
    Pthresh=100,
    downsampfreq=50,
    plot=False,
):
    """
    Temporal windowed spatial autocorrelation.

    Args:
        xy (array_like): The position data
        spkIdx (array_like): The indices in xy where the cell fired
        ppm (int, optional): The camera pixels per metre. Default 365
        winSize (int, optional): The window size for the temporal search
        pos_sample_rate (int, optional): The rate at which position was sampled. Default 50
        nbins (int, optional): The number of bins for creating the resulting ratemap. Default 71
        boxcar (int, optional): The size of the smoothing kernel to smooth ratemaps. Default 5
        Pthresh (int, optional): The cut-off for values in the ratemap; values < Pthresh become nans. Default 100
        downsampfreq (int, optional): How much to downsample. Default 50
        plot (bool, optional): Whether to show a plot of the result. Default False

    Returns:
        H (array_like): The temporal windowed SAC
    """
    # [Stage 0] Get some numbers
    xy = xy / ppm * 100
    n_samps = xy.shape[1]
    n_spks = len(spkIdx)
    winSizeBins = np.min([winSize * pos_sample_rate, n_samps])
    # factor by which positions are downsampled
    downsample = np.ceil(pos_sample_rate / downsampfreq)
    Pthresh = Pthresh / downsample  # take account of downsampling

    # [Stage 1] Calculate number of spikes in the window for each spikeInd
    # (ignoring spike itself)
    # 1a. Loop preparation
    nSpikesInWin = np.zeros(n_spks, dtype=int)

    # 1b. Keep looping until we have dealt with all spikes
    for i, s in enumerate(spkIdx):
        t = np.searchsorted(spkIdx, (s, s + winSizeBins))
        nSpikesInWin[i] = len(spkIdx[t[0]: t[1]]) - 1  # ignore ith spike

    # [Stage 2] Prepare for main loop
    # 2a. Work out offset inidices to be used when storing spike data
    off_spike = np.cumsum([nSpikesInWin])
    off_spike = np.pad(off_spike, (1, 0), "constant", constant_values=(0))

    # 2b. Work out number of downsampled pos bins in window and
    # offset indices for storing data
    nPosInWindow = np.minimum(winSizeBins, n_samps - spkIdx)
    nDownsampInWin = np.floor((nPosInWindow - 1) / downsample) + 1

    off_dwell = np.cumsum(nDownsampInWin.astype(int))
    off_dwell = np.pad(off_dwell, (1, 0), "constant", constant_values=(0))

    # 2c. Pre-allocate dwell and spike arrays, singles for speed
    dwell = np.zeros((2, off_dwell[-1]), dtype=np.single) * np.nan
    spike = np.zeros((2, off_spike[-1]), dtype=np.single) * np.nan

    filled_pvals = 0
    filled_svals = 0

    for i in range(n_spks):
        # calculate dwell displacements
        winInd_dwell = np.arange(
            spkIdx[i] + 1,
            np.minimum(spkIdx[i] + winSizeBins, n_samps),
            downsample,
            dtype=int,
        )
        WL = len(winInd_dwell)
        dwell[:, filled_pvals: filled_pvals + WL] = np.rot90(
            np.array(np.rot90(xy[:, winInd_dwell]) - xy[:, spkIdx[i]])
        )
        filled_pvals = filled_pvals + WL
        # calculate spike displacements
        winInd_spks = (
            i + np.nonzero(spkIdx[i + 1: n_spks] <
                           spkIdx[i] + winSizeBins)[0]
        )
        WL = len(winInd_spks)
        spike[:, filled_svals: filled_svals + WL] = np.rot90(
            np.array(
                np.rot90(xy[:, spkIdx[winInd_spks]]) - xy[:, spkIdx[i]])
        )
        filled_svals = filled_svals + WL

    dwell = np.delete(dwell, np.isnan(dwell).nonzero()[1], axis=1)
    spike = np.delete(spike, np.isnan(spike).nonzero()[1], axis=1)

    dwell = np.hstack((dwell, -dwell))
    spike = np.hstack((spike, -spike))

    dwell_min = np.min(dwell, axis=1)
    dwell_max = np.max(dwell, axis=1)

    binsize = (dwell_max[1] - dwell_min[1]) / nbins

    dwell = np.round(
        (dwell - np.ones_like(dwell) * dwell_min[:, np.newaxis]) / binsize
    )
    spike = np.round(
        (spike - np.ones_like(spike) * dwell_min[:, np.newaxis]) / binsize
    )

    binsize = np.max(dwell, axis=1).astype(int)
    binedges = np.array(((-0.5, -0.5), binsize + 0.5)).T
    Hp = np.histogram2d(dwell[0, :], dwell[1, :],
                        range=binedges, bins=binsize)[0]
    Hs = np.histogram2d(spike[0, :], spike[1, :],
                        range=binedges, bins=binsize)[0]

    # reverse y,x order
    Hp = np.swapaxes(Hp, 1, 0)
    Hs = np.swapaxes(Hs, 1, 0)

    fHp = blurImage(Hp, boxcar)
    fHs = blurImage(Hs, boxcar)

    H = fHs / fHp
    H[Hp < Pthresh] = np.nan

    return H

Field calculations

border_score(A, B=None, shape='square', fieldThresh=0.3, smthKernSig=3, circumPrc=0.2, binSize=3.0, minArea=200, debug=False)

Calculates a border score totally dis-similar to that calculated in Solstad et al (2008)

Parameters:

Name Type Description Default
A array_like

Should be the ratemap

required
B array_like

This should be a boolean mask where True (1) is equivalent to the presence of a border and False (0) is equivalent to 'open space'. Naievely this will be the edges of the ratemap but could be used to take account of boundary insertions/ creations to check tuning to multiple environmental boundaries. Default None: when the mask is None then a mask is created that has 1's at the edges of the ratemap i.e. it is assumed that occupancy = environmental shape

None
shape str

description of environment shape. Currently only 'square' or 'circle' accepted. Used to calculate the proportion of the environmental boundaries to examine for firing

'square'
fieldThresh float

Between 0 and 1 this is the percentage amount of the maximum firing rate to remove from the ratemap (i.e. to remove noise)

0.3
smthKernSig float

the sigma value used in smoothing the ratemap (again!) with a gaussian kernel

3
circumPrc float

The percentage amount of the circumference of the environment that the field needs to be to count as long enough to make it through

0.2
binSize float

bin size in cm

3.0
minArea float

min area for a field to be considered

200
debug bool

If True then some plots and text will be output

False

Returns:

Name Type Description
float

the border score

Notes

If the cell is a border cell (BVC) then we know that it should fire at a fixed distance from a given boundary (possibly more than one). In essence this algorithm estimates the amount of variance in this distance i.e. if the cell is a border cell this number should be small. This is achieved by first doing a bunch of morphological operations to isolate individual fields in the ratemap (similar to the code used in phasePrecession.py - see the partitionFields method therein). These partitioned fields are then thinned out (using skimage's skeletonize) to a single pixel wide field which will lie more or less in the middle of the (highly smoothed) sub-field. It is the variance in distance from the nearest boundary along this pseudo-iso-line that is the boundary measure

Other things to note are that the pixel-wide field has to have some minimum length. In the case of a circular environment this is set to 20% of the circumference; in the case of a square environment markers this is at least half the length of the longest side

Source code in ephysiopy/common/fieldcalcs.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def border_score(
    A,
    B=None,
    shape="square",
    fieldThresh=0.3,
    smthKernSig=3,
    circumPrc=0.2,
    binSize=3.0,
    minArea=200,
    debug=False,
):
    """
    Calculates a border score totally dis-similar to that calculated in
    Solstad et al (2008)

    Args:
        A (array_like): Should be the ratemap
        B (array_like): This should be a boolean mask where True (1)
            is equivalent to the presence of a border and False (0)
            is equivalent to 'open space'. Naievely this will be the
            edges of the ratemap but could be used to take account of
            boundary insertions/ creations to check tuning to multiple
            environmental boundaries. Default None: when the mask is
            None then a mask is created that has 1's at the edges of the
            ratemap i.e. it is assumed that occupancy = environmental
            shape
        shape (str): description of environment shape. Currently
            only 'square' or 'circle' accepted. Used to calculate the
            proportion of the environmental boundaries to examine for
            firing
        fieldThresh (float): Between 0 and 1 this is the percentage
            amount of the maximum firing rate
            to remove from the ratemap (i.e. to remove noise)
        smthKernSig (float): the sigma value used in smoothing the ratemap
            (again!) with a gaussian kernel
        circumPrc (float): The percentage amount of the circumference
            of the environment that the field needs to be to count
            as long enough to make it through
        binSize (float): bin size in cm
        minArea (float): min area for a field to be considered
        debug (bool): If True then some plots and text will be output

    Returns:
        float: the border score

    Notes:
        If the cell is a border cell (BVC) then we know that it should
        fire at a fixed distance from a given boundary (possibly more
        than one). In essence this algorithm estimates the amount of
        variance in this distance i.e. if the cell is a border cell this
        number should be small. This is achieved by first doing a bunch of
        morphological operations to isolate individual fields in the
        ratemap (similar to the code used in phasePrecession.py - see
        the partitionFields method therein). These partitioned fields are then
        thinned out (using skimage's skeletonize) to a single pixel
        wide field which will lie more or less in the middle of the
        (highly smoothed) sub-field. It is the variance in distance from the
        nearest boundary along this pseudo-iso-line that is the boundary
        measure

        Other things to note are that the pixel-wide field has to have some
        minimum length. In the case of a circular environment this is set to
        20% of the circumference; in the case of a square environment markers
        this is at least half the length of the longest side
    """
    # need to know borders of the environment so we can see if a field
    # touches the edges, and the perimeter length of the environment
    # deal with square or circles differently
    borderMask = np.zeros_like(A)
    A_rows, A_cols = np.shape(A)
    if "circle" in shape:
        radius = np.max(np.array(np.shape(A))) / 2.0
        dist_mask = skimage.morphology.disk(radius)
        if np.shape(dist_mask) > np.shape(A):
            dist_mask = dist_mask[1 : A_rows + 1, 1 : A_cols + 1]
        tmp = np.zeros([A_rows + 2, A_cols + 2])
        tmp[1:-1, 1:-1] = dist_mask
        dists = ndimage.distance_transform_bf(tmp)
        dists = dists[1:-1, 1:-1]
        borderMask = np.logical_xor(dists <= 0, dists < 2)
        # open up the border mask a little
        borderMask = skimage.morphology.binary_dilation(
            borderMask, skimage.morphology.disk(1)
        )
    elif "square" in shape:
        borderMask[0:3, :] = 1
        borderMask[-3:, :] = 1
        borderMask[:, 0:3] = 1
        borderMask[:, -3:] = 1
        tmp = np.zeros([A_rows + 2, A_cols + 2])
        dist_mask = np.ones_like(A)
        tmp[1:-1, 1:-1] = dist_mask
        dists = ndimage.distance_transform_bf(tmp)
        # remove edges to make same shape as input ratemap
        dists = dists[1:-1, 1:-1]
    A[np.isnan(A)] = 0
    # get some morphological info about the fields in the ratemap
    # start image processing:
    # get some markers
    # NB I've tried a variety of techniques to optimise this part and the
    # best seems to be the local adaptive thresholding technique which)
    # smooths locally with a gaussian - see the skimage docs for more
    idx = A >= np.nanmax(np.ravel(A)) * fieldThresh
    A_thresh = np.zeros_like(A)
    A_thresh[idx] = A[idx]

    # label these markers so each blob has a unique id
    labels, nFields = ndimage.label(A_thresh)
    # remove small objects
    min_size = int(minArea / binSize) - 1
    skimage.morphology.remove_small_objects(labels, min_size=min_size, connectivity=2)
    labels = skimage.segmentation.relabel_sequential(labels)[0]
    nFields = np.max(labels)
    if nFields == 0:
        return np.nan
    # Iterate over the labelled parts of the array labels calculating
    # how much of the total circumference of the environment edge it
    # covers

    fieldAngularCoverage = np.zeros([1, nFields]) * np.nan
    fractionOfPixelsOnBorder = np.zeros([1, nFields]) * np.nan
    fieldsToKeep = np.zeros_like(A).astype(bool)
    for i in range(1, nFields + 1):
        fieldMask = np.logical_and(labels == i, borderMask)

        # check the angle subtended by the fieldMask
        if np.sum(fieldMask.astype(int)) > 0:
            s = skimage.measure.regionprops(
                fieldMask.astype(int), intensity_image=A_thresh
            )[0]
            x = s.coords[:, 0] - (A_cols / 2.0)
            y = s.coords[:, 1] - (A_rows / 2.0)
            subtended_angle = np.rad2deg(np.ptp(np.arctan2(x, y)))
            if subtended_angle > (360 * circumPrc):
                pixelsOnBorder = np.count_nonzero(fieldMask) / float(
                    np.count_nonzero(labels == i)
                )
                fractionOfPixelsOnBorder[:, i - 1] = pixelsOnBorder
                if pixelsOnBorder > 0.5:
                    fieldAngularCoverage[0, i - 1] = subtended_angle

            fieldsToKeep = np.logical_or(fieldsToKeep, labels == i)
    fieldAngularCoverage = fieldAngularCoverage / 360.0
    rateInField = A[fieldsToKeep]
    # normalize firing rate in the field to sum to 1
    rateInField = rateInField / np.nansum(rateInField)
    dist2WallInField = dists[fieldsToKeep]
    Dm = np.dot(dist2WallInField, rateInField)
    if "circle" in shape:
        Dm = Dm / radius
    elif "square" in shape:
        Dm = Dm / (np.max(np.shape(A)) / 2.0)
    borderScore = (fractionOfPixelsOnBorder - Dm) / (fractionOfPixelsOnBorder + Dm)
    return np.max(borderScore)

calc_angs(points)

Calculates the angles for all triangles in a delaunay tesselation of the peak points in the ratemap

Source code in ephysiopy/common/fieldcalcs.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
def calc_angs(points):
    """
    Calculates the angles for all triangles in a delaunay tesselation of
    the peak points in the ratemap
    """

    # calculate the lengths of the sides of the triangles
    tri = spatial.Delaunay(points)
    angs = []
    for s in tri.simplices:
        A = tri.points[s[1]] - tri.points[s[0]]
        B = tri.points[s[2]] - tri.points[s[1]]
        C = tri.points[s[0]] - tri.points[s[2]]
        for e1, e2 in ((A, -B), (B, -C), (C, -A)):
            num = np.dot(e1, e2)
            denom = np.linalg.norm(e1) * np.linalg.norm(e2)
            angs.append(np.arccos(num / denom) * 180 / np.pi)
    return np.array(angs).T

coherence(smthd_rate, unsmthd_rate)

calculates coherence of receptive field via correlation of smoothed and unsmoothed ratemaps

Source code in ephysiopy/common/fieldcalcs.py
567
568
569
570
571
572
573
574
575
576
577
def coherence(smthd_rate, unsmthd_rate):
    """calculates coherence of receptive field via correlation of smoothed
    and unsmoothed ratemaps
    """
    smthd = smthd_rate.ravel()
    unsmthd = unsmthd_rate.ravel()
    si = ~np.isnan(smthd)
    ui = ~np.isnan(unsmthd)
    idx = ~(~si | ~ui)
    coherence = np.corrcoef(unsmthd[idx], smthd[idx])
    return coherence[1, 0]

corr_maps(map1, map2, maptype='normal')

correlates two ratemaps together ignoring areas that have zero sampling

Source code in ephysiopy/common/fieldcalcs.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
def corr_maps(map1, map2, maptype="normal"):
    """
    correlates two ratemaps together ignoring areas that have zero sampling
    """
    if map1.shape > map2.shape:
        map2 = skimage.transform.resize(map2, map1.shape, mode="reflect")
    elif map1.shape < map2.shape:
        map1 = skimage.transform.resize(map1, map2.shape, mode="reflect")
    map1 = map1.flatten()
    map2 = map2.flatten()
    if "normal" in maptype:
        valid_map1 = np.logical_or((map1 > 0), ~np.isnan(map1))
        valid_map2 = np.logical_or((map2 > 0), ~np.isnan(map2))
    elif "grid" in maptype:
        valid_map1 = ~np.isnan(map1)
        valid_map2 = ~np.isnan(map2)
    valid = np.logical_and(valid_map1, valid_map2)
    r = np.corrcoef(map1[valid], map2[valid])
    return r[1][0]

deform_SAC(A, circleXY=None, ellipseXY=None)

Deforms a SAC that is non-circular to be more circular

Basically a blatant attempt to improve grid scores, possibly introduced in a paper by Matt Nolan...

Parameters:

Name Type Description Default
A array_like

The SAC

required
circleXY array_like

The xy coordinates defining a circle.

None
ellipseXY array_like

The xy coordinates defining an

None

Returns:

Name Type Description
deformed_sac array_like

The SAC deformed to be more circular

See Also

ephysiopy.common.ephys_generic.FieldCalcs.grid_field_props skimage.transform.AffineTransform skimage.transform.warp skimage.exposure.rescale_intensity

Source code in ephysiopy/common/fieldcalcs.py
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
def deform_SAC(A, circleXY=None, ellipseXY=None):
    """
    Deforms a SAC that is non-circular to be more circular

    Basically a blatant attempt to improve grid scores, possibly
    introduced in a paper by Matt Nolan...

    Args:
        A (array_like): The SAC
        circleXY (array_like, optional): The xy coordinates defining a circle.
        Default None.
        ellipseXY (array_like, optional): The xy coordinates defining an
        ellipse. Default None.

    Returns:
        deformed_sac (array_like): The SAC deformed to be more circular

    See Also:
        ephysiopy.common.ephys_generic.FieldCalcs.grid_field_props
        skimage.transform.AffineTransform
        skimage.transform.warp
        skimage.exposure.rescale_intensity
    """
    if circleXY is None or ellipseXY is None:
        SAC_stats = grid_field_props(A)
        circleXY = SAC_stats["circleXY"]
        ellipseXY = SAC_stats["ellipseXY"]
        # The ellipse detection stuff might have failed, if so
        # return the original SAC
        if circleXY is None:
            warnings.warn("Ellipse detection failed. Returning original SAC")
            return A

    tform = skimage.transform.AffineTransform()
    tform.estimate(ellipseXY, circleXY)

    """
    the transformation algorithms used here crop values < 0 to 0. Need to
    rescale the SAC values before doing the deformation and then rescale
    again so the values assume the same range as in the unadulterated SAC
    """
    A[np.isnan(A)] = 0
    SACmin = np.nanmin(A.flatten())
    SACmax = np.nanmax(A.flatten())  # should be 1 if autocorr
    AA = A + 1
    deformedSAC = skimage.transform.warp(
        AA / np.nanmax(AA.flatten()), inverse_map=tform.inverse, cval=0
    )
    return skimage.exposure.rescale_intensity(deformedSAC, out_range=(SACmin, SACmax))

field_lims(A)

Returns a labelled matrix of the ratemap A. Uses anything greater than the half peak rate to select as a field. Data is heavily smoothed.

Parameters:

Name Type Description Default
A array

The ratemap

required

Returns:

Name Type Description
label array

The labelled ratemap

Source code in ephysiopy/common/fieldcalcs.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def field_lims(A):
    """
    Returns a labelled matrix of the ratemap A.
    Uses anything greater than the half peak rate to select as a field.
    Data is heavily smoothed.

    Args:
        A (np.array): The ratemap

    Returns:
        label (np.array): The labelled ratemap
    """
    nan_idx = np.isnan(A)
    A[nan_idx] = 0
    h = int(np.max(A.shape) / 2)
    sm_rmap = blurImage(A, h, ftype="gaussian")
    thresh = np.max(sm_rmap.ravel()) * 0.2  # select area > 20% of peak
    distance = ndimage.distance_transform_edt(sm_rmap > thresh)
    peak_idx = skimage.feature.peak_local_max(
        distance, exclude_border=False, labels=sm_rmap > thresh
    )
    mask = np.zeros_like(distance, dtype=bool)
    mask[tuple(peak_idx.T)] = True
    label = ndimage.label(mask)[0]
    w = watershed(image=-distance, markers=label, mask=sm_rmap > thresh)
    label = ndimage.label(w)[0]
    return label

field_props(A, min_dist=5, neighbours=2, prc=50, plot=False, ax=None, tri=False, verbose=True, **kwargs)

Returns a dictionary of properties of the field(s) in a ratemap A

Parameters:

Name Type Description Default
A array_like

a ratemap (but could be any image)

required
min_dist float

the separation (in bins) between fields for measures such as field distance to make sense. Used to partition the image into separate fields in the call to feature.peak_local_max

5
neighbours int

the number of fields to consider as neighbours to any given field. Defaults to 2

2
prc float

percent of fields to consider

50
ax Axes

user supplied axis. If None a new figure window

None
tri bool

whether to do Delaunay triangulation between fields and add to plot

False
verbose bool

dumps the properties to the console

True
plot bool

whether to plot some output - currently consists of the ratemap A, the fields of which are outline in a black contour. Default False

False

Returns:

Name Type Description
result dict

The properties of the field(s) in the input ratemap A

Source code in ephysiopy/common/fieldcalcs.py
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
def field_props(
    A,
    min_dist=5,
    neighbours=2,
    prc=50,
    plot=False,
    ax=None,
    tri=False,
    verbose=True,
    **kwargs,
):
    """
    Returns a dictionary of properties of the field(s) in a ratemap A

    Args:
        A (array_like): a ratemap (but could be any image)
        min_dist (float): the separation (in bins) between fields for measures
            such as field distance to make sense. Used to
            partition the image into separate fields in the call to
            feature.peak_local_max
        neighbours (int): the number of fields to consider as neighbours to
            any given field. Defaults to 2
        prc (float): percent of fields to consider
        ax (matplotlib.Axes): user supplied axis. If None a new figure window
        is created
        tri (bool): whether to do Delaunay triangulation between fields
            and add to plot
        verbose (bool): dumps the properties to the console
        plot (bool): whether to plot some output - currently consists of the
            ratemap A, the fields of which are outline in a black
            contour. Default False

    Returns:
        result (dict): The properties of the field(s) in the input ratemap A
    """

    from skimage.measure import find_contours
    from sklearn.neighbors import NearestNeighbors

    nan_idx = np.isnan(A)
    Ac = A.copy()
    Ac[np.isnan(A)] = 0
    # smooth Ac more to remove local irregularities
    n = ny = 5
    x, y = np.mgrid[-n : n + 1, -ny : ny + 1]
    g = np.exp(-(x**2 / float(n) + y**2 / float(ny)))
    g = g / g.sum()
    Ac = signal.convolve(Ac, g, mode="same")

    peak_idx, field_labels = _get_field_labels(Ac, **kwargs)

    nFields = np.max(field_labels)
    if neighbours > nFields:
        print(
            "neighbours value of {0} > the {1} peaks found".format(neighbours, nFields)
        )
        print("Reducing neighbours to number of peaks found")
        neighbours = nFields
    sub_field_mask = np.zeros((nFields, Ac.shape[0], Ac.shape[1]))
    sub_field_props = skimage.measure.regionprops(field_labels, intensity_image=Ac)
    sub_field_centroids = []
    sub_field_size = []

    for sub_field in sub_field_props:
        tmp = np.zeros(Ac.shape).astype(bool)
        tmp[sub_field.coords[:, 0], sub_field.coords[:, 1]] = True
        tmp2 = Ac > sub_field.max_intensity * (prc / float(100))
        sub_field_mask[sub_field.label - 1, :, :] = np.logical_and(tmp2, tmp)
        sub_field_centroids.append(sub_field.centroid)
        sub_field_size.append(sub_field.area)  # in bins
    sub_field_mask = np.sum(sub_field_mask, 0)
    contours = skimage.measure.find_contours(sub_field_mask, 0.5)
    # find the nearest neighbors to the peaks of each sub-field
    nbrs = NearestNeighbors(n_neighbors=neighbours, algorithm="ball_tree").fit(peak_idx)
    distances, _ = nbrs.kneighbors(peak_idx)
    mean_field_distance = np.mean(distances[:, 1:neighbours])

    nValid_bins = np.sum(~nan_idx)
    # calculate the amount of out of field firing
    A_non_field = np.zeros_like(A) * np.nan
    A_non_field[~sub_field_mask.astype(bool)] = A[~sub_field_mask.astype(bool)]
    A_non_field[nan_idx] = np.nan
    out_of_field_firing_prc = (
        np.count_nonzero(A_non_field > 0) / float(nValid_bins)
    ) * 100
    Ac[np.isnan(A)] = np.nan
    """
    get some stats about the field ellipticity
    """
    ellipse_ratio = np.nan
    _, central_field, _ = limit_to_one(A, prc=50)

    contour_coords = find_contours(central_field, 0.5)
    from skimage.measure import EllipseModel

    E = EllipseModel()
    E.estimate(contour_coords[0])
    ellipse_axes = E.params[2:4]
    ellipse_ratio = np.min(ellipse_axes) / np.max(ellipse_axes)

    """ using the peak_idx values calculate the angles of the triangles that
    make up a delaunay tesselation of the space if the calc_angles arg is
    in kwargs
    """
    if "calc_angs" in kwargs.keys():
        angs = calc_angs(peak_idx)
    else:
        angs = None

    props = {
        "Ac": Ac,
        "Peak_rate": np.nanmax(A),
        "Mean_rate": np.nanmean(A),
        "Field_size": np.mean(sub_field_size),
        "Pct_bins_with_firing": (np.sum(sub_field_mask) / nValid_bins) * 100,
        "Out_of_field_firing_prc": out_of_field_firing_prc,
        "Dist_between_fields": mean_field_distance,
        "Num_fields": float(nFields),
        "Sub_field_mask": sub_field_mask,
        "Smoothed_map": Ac,
        "field_labels": field_labels,
        "Peak_idx": peak_idx,
        "angles": angs,
        "contours": contours,
        "ellipse_ratio": ellipse_ratio,
    }

    if verbose:
        print(
            "\nPercentage of bins with firing: {:.2%}".format(
                np.sum(sub_field_mask) / nValid_bins
            )
        )
        print(
            "Percentage out of field firing: {:.2%}".format(
                np.count_nonzero(A_non_field > 0) / float(nValid_bins)
            )
        )
        print("Peak firing rate: {:.3} Hz".format(np.nanmax(A)))
        print("Mean firing rate: {:.3} Hz".format(np.nanmean(A)))
        print("Number of fields: {0}".format(nFields))
        print("Mean field size: {:.5} cm".format(np.mean(sub_field_size)))
        print(
            "Mean inter-peak distance between \
            fields: {:.4} cm".format(
                mean_field_distance
            )
        )
    return props

get_circular_regions(A, **kwargs)

Returns a list of images which are expanding circular regions centred on the middle of the image out to the image edge. Used for calculating the grid score of each image to find the one with the max grid score. Based on some Moser paper I can't recall.

Parameters:

Name Type Description Default
A ndarray

The SAC

required

Other Parameters:

Name Type Description
min_radius int

The smallest radius circle to start with

Source code in ephysiopy/common/fieldcalcs.py
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
def get_circular_regions(A: np.ndarray, **kwargs) -> list:
    """
    Returns a list of images which are expanding circular
    regions centred on the middle of the image out to the
    image edge. Used for calculating the grid score of each
    image to find the one with the max grid score. Based on
    some Moser paper I can't recall.

    Args:
        A (np.ndarray): The SAC

    Keyword Args:
        min_radius (int): The smallest radius circle to start with
    """
    from skimage.measure import CircleModel, grid_points_in_poly

    min_radius = 5
    if "min_radius" in kwargs.keys():
        min_radius = kwargs["min_radius"]

    centre = tuple([d // 2 for d in np.shape(A)])
    max_radius = min(tuple(np.subtract(np.shape(A), centre)))
    t = np.linspace(0, 2 * np.pi, 51)
    circle = CircleModel()

    result = []
    for radius in range(min_radius, max_radius):
        circle.params = [*centre, radius]
        xy = circle.predict_xy(t)
        mask = grid_points_in_poly(np.shape(A), xy)
        im = A.copy()
        im[~mask] = np.nan
        result.append(im)
    return result

global_threshold(A, prc=50, min_dist=5)

Globally thresholds a ratemap and counts number of fields found

Source code in ephysiopy/common/fieldcalcs.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def global_threshold(A, prc=50, min_dist=5):
    """
    Globally thresholds a ratemap and counts number of fields found
    """
    Ac = A.copy()
    Ac[np.isnan(A)] = 0
    n = ny = 5
    x, y = np.mgrid[-n : n + 1, -ny : ny + 1]
    g = np.exp(-(x**2 / float(n) + y**2 / float(ny)))
    g = g / g.sum()
    Ac = signal.convolve(Ac, g, mode="same")
    maxRate = np.nanmax(np.ravel(Ac))
    Ac[Ac < maxRate * (prc / float(100))] = 0
    peak_idx = skimage.feature.peak_local_max(
        Ac, min_distance=min_dist, exclude_border=False
    )
    peak_mask = np.zeros_like(Ac, dtype=bool)
    peak_mask[tuple(peak_idx.T)] = True
    peak_labels = skimage.measure.label(peak_mask, connectivity=2)
    field_labels = watershed(image=-Ac, markers=peak_labels)
    nFields = np.max(field_labels)
    return nFields

grid_field_props(A, maxima='centroid', allProps=True, **kwargs)

Extracts various measures from a spatial autocorrelogram

Parameters:

Name Type Description Default
A array_like

The spatial autocorrelogram (SAC)

required
maxima str

The method used to detect the peaks in the SAC. Legal values are 'single' and 'centroid'. Default 'centroid'

'centroid'
allProps bool

Whether to return a dictionary that

True

Returns:

Name Type Description
props dict

A dictionary containing measures of the SAC.

Keys include: * gridness score * scale * orientation * coordinates of the peaks (nominally 6) closest to SAC centre * a binary mask around the extent of the 6 central fields * values of the rotation procedure used to calculate gridness * ellipse axes and angle (if allProps is True and the it worked)

Notes

The output from this method can be used as input to the show() method of this class. When it is the plot produced will display a lot more informative.

See Also

ephysiopy.common.binning.autoCorr2D()

Source code in ephysiopy/common/fieldcalcs.py
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
def grid_field_props(A, maxima="centroid", allProps=True, **kwargs):
    """
    Extracts various measures from a spatial autocorrelogram

    Args:
        A (array_like): The spatial autocorrelogram (SAC)
        maxima (str, optional): The method used to detect the peaks in the SAC.
            Legal values are 'single' and 'centroid'. Default 'centroid'
        allProps (bool, optional): Whether to return a dictionary that
        contains the attempt to fit an ellipse around the edges of the
        central size peaks. See below
            Default True

    Returns:
        props (dict): A dictionary containing measures of the SAC.
        Keys include:
            * gridness score
            * scale
            * orientation
            * coordinates of the peaks (nominally 6) closest to SAC centre
            * a binary mask around the extent of the 6 central fields
            * values of the rotation procedure used to calculate gridness
            * ellipse axes and angle (if allProps is True and the it worked)

    Notes:
        The output from this method can be used as input to the show() method
        of this class.
        When it is the plot produced will display a lot more informative.

    See Also:
        ephysiopy.common.binning.autoCorr2D()
    """
    A_tmp = A.copy()
    A_tmp[~np.isfinite(A)] = -1
    A_tmp[A_tmp <= 0] = -1
    A_sz = np.array(np.shape(A))
    # [STAGE 1] find peaks & identify 7 closest to centre
    if "min_distance" in kwargs:
        min_distance = kwargs.pop("min_distance")
    else:
        min_distance = np.ceil(np.min(A_sz / 2) / 8.0).astype(int)

    peak_idx, field_labels = _get_field_labels(A_tmp, neighbours=7, **kwargs)
    # a fcn for the labeled_comprehension function that returns
    # linear indices in A where the values in A for each label are
    # greater than half the max in that labeled region

    def fn(val, pos):
        return pos[val > (np.max(val) / 2)]

    nLbls = np.max(field_labels)
    indices = ndimage.labeled_comprehension(
        A_tmp, field_labels, np.arange(0, nLbls), fn, np.ndarray, 0, True
    )
    # turn linear indices into coordinates
    coords = [np.unravel_index(i, np.shape(A)) for i in indices]
    half_peak_labels = np.zeros_like(A)
    for peak_id, coord in enumerate(coords):
        xc, yc = coord
        half_peak_labels[xc, yc] = peak_id

    # Get some statistics about the labeled regions
    # fieldPerim = bwperim(half_peak_labels)
    lbl_range = np.arange(0, nLbls)
    # meanRInLabel = ndimage.mean(A, half_peak_labels, lbl_range)
    # nPixelsInLabel = np.bincount(np.ravel(half_peak_labels.astype(int)))
    # sumRInLabel = ndimage.sum_labels(A, half_peak_labels, lbl_range)
    # maxRInLabel = ndimage.maximum(A, half_peak_labels, lbl_range)
    peak_coords = ndimage.maximum_position(A, half_peak_labels, lbl_range)

    # Get some distance and morphology measures
    centre = np.floor(np.array(np.shape(A)) / 2)
    centred_peak_coords = peak_coords - centre
    peak_dist_to_centre = np.hypot(centred_peak_coords.T[0], centred_peak_coords.T[1])
    closest_peak_idx = np.argsort(peak_dist_to_centre)
    central_peak_label = closest_peak_idx[0]
    closest_peak_idx = closest_peak_idx[1 : np.min((7, len(closest_peak_idx) - 1))]
    # closest_peak_idx should now the indices of the labeled 6 peaks
    # surrounding the central peak at the image centre
    scale = np.median(peak_dist_to_centre[closest_peak_idx])
    orientation = np.nan
    orientation = grid_orientation(centred_peak_coords, closest_peak_idx)

    central_pt = peak_coords[central_peak_label]
    x = np.linspace(-central_pt[0], central_pt[0], A_sz[0])
    y = np.linspace(-central_pt[1], central_pt[1], A_sz[1])
    xv, yv = np.meshgrid(x, y, indexing="ij")
    dist_to_centre = np.hypot(xv, yv)
    # get the max distance of the half-peak width labeled fields
    # from the centre of the image
    max_dist_from_centre = 0
    for peak_id, _coords in enumerate(coords):
        if peak_id in closest_peak_idx:
            xc, yc = _coords
            if np.any(xc) and np.any(yc):
                xc = xc - np.floor(A_sz[0] / 2)
                yc = yc - np.floor(A_sz[1] / 2)
                d = np.max(np.hypot(xc, yc))
                if d > max_dist_from_centre:
                    max_dist_from_centre = d

    # Set the outer bits and the central region of the SAC to nans
    # getting ready for the correlation procedure
    dist_to_centre[np.abs(dist_to_centre) > max_dist_from_centre] = 0
    dist_to_centre[half_peak_labels == central_peak_label] = 0
    dist_to_centre[dist_to_centre != 0] = 1
    dist_to_centre = dist_to_centre.astype(bool)
    sac_middle = A.copy()
    sac_middle[~dist_to_centre] = np.nan

    if "step" in kwargs.keys():
        step = kwargs.pop("step")
    else:
        step = 30
    try:
        gridscore, rotationCorrVals, rotationArr = gridness(sac_middle, step=step)
    except Exception:
        gridscore, rotationCorrVals, rotationArr = np.nan, np.nan, np.nan

    im_centre = central_pt

    if allProps:
        # attempt to fit an ellipse around the outer edges of the nearest
        # peaks to the centre of the SAC. First find the outer edges for
        # the closest peaks using a ndimages labeled_comprehension
        try:

            def fn2(val, pos):
                xc, yc = np.unravel_index(pos, A_sz)
                xc = xc - np.floor(A_sz[0] / 2)
                yc = yc - np.floor(A_sz[1] / 2)
                idx = np.argmax(np.hypot(xc, yc))
                return xc[idx], yc[idx]

            ellipse_coords = ndimage.labeled_comprehension(
                A, half_peak_labels, closest_peak_idx, fn2, tuple, 0, True
            )

            ellipse_fit_coords = np.array([(x, y) for x, y in ellipse_coords])
            from skimage.measure import EllipseModel

            E = EllipseModel()
            E.estimate(ellipse_fit_coords)
            im_centre = E.params[0:2]
            ellipse_axes = E.params[2:4]
            ellipse_angle = E.params[-1]
            ellipseXY = E.predict_xy(np.linspace(0, 2 * np.pi, 50), E.params)

            # get the min containing circle given the eliipse minor axis
            from skimage.measure import CircleModel

            _params = [im_centre, np.min(ellipse_axes)]
            circleXY = CircleModel().predict_xy(
                np.linspace(0, 2 * np.pi, 50), params=_params
            )
        except (TypeError, ValueError):  # non-iterable x and y
            ellipse_axes = None
            ellipse_angle = (None, None)
            ellipseXY = None
            circleXY = None

    # collect all the following keywords into a dict for output
    closest_peak_coords = np.array(peak_coords)[closest_peak_idx]
    dictKeys = (
        "gridscore",
        "scale",
        "orientation",
        "closest_peak_coords",
        "dist_to_centre",
        "ellipse_axes",
        "ellipse_angle",
        "ellipseXY",
        "circleXY",
        "im_centre",
        "rotationArr",
        "rotationCorrVals",
    )
    outDict = dict.fromkeys(dictKeys, np.nan)
    for thiskey in outDict.keys():
        outDict[thiskey] = locals()[thiskey]
        # neat trick: locals is a dict holding all locally scoped variables
    return outDict

grid_orientation(peakCoords, closestPeakIdx)

Calculates the orientation angle of a grid field.

The orientation angle is the angle of the first peak working counter-clockwise from 3 o'clock

Parameters:

Name Type Description Default
peakCoords array_like

The peak coordinates as pairs of xy

required
closestPeakIdx array_like

A 1D array of the indices in peakCoords

required

Returns:

Name Type Description
peak_orientation float

The first value in an array of the angles of

the peaks in the SAC working counter-clockwise from a line

extending from the middle of the SAC to 3 o'clock.

Source code in ephysiopy/common/fieldcalcs.py
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
def grid_orientation(peakCoords, closestPeakIdx):
    """
    Calculates the orientation angle of a grid field.

    The orientation angle is the angle of the first peak working
    counter-clockwise from 3 o'clock

    Args:
        peakCoords (array_like): The peak coordinates as pairs of xy
        closestPeakIdx (array_like): A 1D array of the indices in peakCoords
        of the peaks closest to the centre of the SAC

    Returns:
        peak_orientation (float): The first value in an array of the angles of
        the peaks in the SAC working counter-clockwise from a line
        extending from the middle of the SAC to 3 o'clock.
    """
    if len(peakCoords) < 3 or closestPeakIdx.size == 0:
        return np.nan
    else:
        from ephysiopy.common.utils import polar

        # Assume that the first entry in peakCoords is
        # the central peak of the SAC
        peaks = peakCoords[closestPeakIdx]
        peaks = peaks - peakCoords[closestPeakIdx[0]]
        theta = polar(peaks[:, 1], -peaks[:, 0], deg=1)[1]
        return np.sort(theta.compress(theta >= 0))[0]

gridness(image, step=30)

Calculates the gridness score in a grid cell SAC.

Briefly, the data in image is rotated in step amounts and each rotated array is correlated with the original. The maximum of the values at 30, 90 and 150 degrees is the subtracted from the minimum of the values at 60, 120 and 180 degrees to give the grid score.

Parameters:

Name Type Description Default
image array_like

The spatial autocorrelogram

required
step int

The amount to rotate the SAC in each step of the

30

Returns:

Name Type Description
gridmeasures 3 - tuple

The gridscore, the correlation values at each

step and the rotational array

Notes

The correlation performed is a Pearsons R. Some rescaling of the values in image is performed following rotation.

See Also

skimage.transform.rotate : for how the rotation of image is done skimage.exposure.rescale_intensity : for the resscaling following rotation

Source code in ephysiopy/common/fieldcalcs.py
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
def gridness(image, step=30):
    """
    Calculates the gridness score in a grid cell SAC.

    Briefly, the data in `image` is rotated in `step` amounts and
    each rotated array is correlated with the original.
    The maximum of the values at 30, 90 and 150 degrees
    is the subtracted from the minimum of the values at 60, 120
    and 180 degrees to give the grid score.

    Args:
        image (array_like): The spatial autocorrelogram
        step (int, optional): The amount to rotate the SAC in each step of the
        rotational correlation procedure

    Returns:
        gridmeasures (3-tuple): The gridscore, the correlation values at each
        `step` and the rotational array

    Notes:
        The correlation performed is a Pearsons R. Some rescaling of the
        values in `image` is performed following rotation.

    See Also:
        skimage.transform.rotate : for how the rotation of `image` is done
        skimage.exposure.rescale_intensity : for the resscaling following
        rotation
    """
    # TODO: add options in here for whether the full range of correlations
    # are wanted or whether a reduced set is wanted (i.e. at the 30-tuples)
    from collections import OrderedDict

    rotationalCorrVals = OrderedDict.fromkeys(np.arange(0, 181, step), np.nan)
    rotationArr = np.zeros(len(rotationalCorrVals)) * np.nan
    # autoCorrMiddle needs to be rescaled or the image rotation falls down
    # as values are cropped to lie between 0 and 1.0
    in_range = (np.nanmin(image), np.nanmax(image))
    out_range = (0, 1)
    import skimage

    autoCorrMiddleRescaled = skimage.exposure.rescale_intensity(
        image, in_range=in_range, out_range=out_range
    )
    origNanIdx = np.isnan(autoCorrMiddleRescaled.ravel())
    for idx, angle in enumerate(rotationalCorrVals.keys()):
        rotatedA = skimage.transform.rotate(
            autoCorrMiddleRescaled, angle=angle, cval=np.nan, order=3
        )
        # ignore nans
        rotatedNanIdx = np.isnan(rotatedA.ravel())
        allNans = np.logical_or(origNanIdx, rotatedNanIdx)
        # get the correlation between the original and rotated images
        rotationalCorrVals[angle] = stats.pearsonr(
            autoCorrMiddleRescaled.ravel()[~allNans], rotatedA.ravel()[~allNans]
        )[0]
        rotationArr[idx] = rotationalCorrVals[angle]
    gridscore = np.min((rotationalCorrVals[60], rotationalCorrVals[120])) - np.max(
        (rotationalCorrVals[150], rotationalCorrVals[30], rotationalCorrVals[90])
    )
    return gridscore, rotationalCorrVals, rotationArr

kldiv(X, pvect1, pvect2, variant=None)

Calculates the Kullback-Leibler or Jensen-Shannon divergence between two distributions.

Parameters:

Name Type Description Default
X array_like

Vector of M variable values

required
P1 array_like

Length-M vector of probabilities representing

required
P2 array_like

Length-M vector of probabilities representing

required
sym str

If 'sym', returns a symmetric variant of the Kullback-Leibler divergence, given by [KL(P1,P2)+KL(P2,P1)]/2

required
js str

If 'js', returns the Jensen-Shannon divergence,

required

Returns:

Name Type Description
float

The Kullback-Leibler divergence or Jensen-Shannon divergence

Notes

The Kullback-Leibler divergence is given by:

.. math:: KL(P1(x),P2(x)) = sum_[P1(x).log(P1(x)/P2(x))]

If X contains duplicate values, there will be an warning message, and these values will be treated as distinct values. (I.e., the actual values do not enter into the computation, but the probabilities for the two duplicate values will be considered as probabilities corresponding to two unique values.). The elements of probability vectors P1 and P2 must each sum to 1 +/- .00001.

This function is taken from one on the Mathworks file exchange

See Also

Cover, T.M. and J.A. Thomas. "Elements of Information Theory," Wiley, 1991.

https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence

Source code in ephysiopy/common/fieldcalcs.py
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
def kldiv(X, pvect1, pvect2, variant=None):
    """
    Calculates the Kullback-Leibler or Jensen-Shannon divergence between
    two distributions.

    Args:
        X (array_like): Vector of M variable values
        P1 (array_like): Length-M vector of probabilities representing
        distribution 1
        P2 (array_like): Length-M vector of probabilities representing
        distribution 2
        sym (str, optional): If 'sym', returns a symmetric variant of the
            Kullback-Leibler divergence, given by [KL(P1,P2)+KL(P2,P1)]/2
        js (str, optional): If 'js', returns the Jensen-Shannon divergence,
        given by
            [KL(P1,Q)+KL(P2,Q)]/2, where Q = (P1+P2)/2

    Returns:
        float: The Kullback-Leibler divergence or Jensen-Shannon divergence

    Notes:
        The Kullback-Leibler divergence is given by:

        .. math:: KL(P1(x),P2(x)) = sum_[P1(x).log(P1(x)/P2(x))]

        If X contains duplicate values, there will be an warning message,
        and these values will be treated as distinct values.  (I.e., the
        actual values do not enter into the computation, but the probabilities
        for the two duplicate values will be considered as probabilities
        corresponding to two unique values.).
        The elements of probability vectors P1 and P2 must
        each sum to 1 +/- .00001.

        This function is taken from one on the Mathworks file exchange

    See Also:
        Cover, T.M. and J.A. Thomas. "Elements of Information Theory," Wiley,
        1991.

        https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
    """

    if len(np.unique(X)) != len(np.sort(X)):
        warnings.warn(
            "X contains duplicate values. Treated as distinct values.", UserWarning
        )
    if (
        not np.equal(np.shape(X), np.shape(pvect1)).all()
        or not np.equal(np.shape(X), np.shape(pvect2)).all()
    ):
        raise ValueError("Inputs are not the same size")
    if (np.abs(np.sum(pvect1) - 1) > 0.00001) or (np.abs(np.sum(pvect2) - 1) > 0.00001):
        print(f"Probabilities sum to {np.abs(np.sum(pvect1))} for pvect1")
        print(f"Probabilities sum to {np.abs(np.sum(pvect2))} for pvect2")
        warnings.warn("Probabilities don" "t sum to 1.", UserWarning)
    if variant:
        if variant == "js":
            logqvect = np.log2((pvect2 + pvect1) / 2)
            KL = 0.5 * (
                np.nansum(pvect1 * (np.log2(pvect1) - logqvect))
                + np.sum(pvect2 * (np.log2(pvect2) - logqvect))
            )
            return KL
        elif variant == "sym":
            KL1 = np.nansum(pvect1 * (np.log2(pvect1) - np.log2(pvect2)))
            KL2 = np.nansum(pvect2 * (np.log2(pvect2) - np.log2(pvect1)))
            KL = (KL1 + KL2) / 2
            return KL
        else:
            warnings.warn("Last argument not recognised", UserWarning)
    KL = np.nansum(pvect1 * (np.log2(pvect1) - np.log2(pvect2)))
    return KL

kldiv_dir(polarPlot)

Returns a kl divergence for directional firing: measure of directionality. Calculates kl diveregence between a smoothed ratemap (probably should be smoothed otherwise information theoretic measures don't 'care' about position of bins relative to one another) and a pure circular distribution. The larger the divergence the more tendancy the cell has to fire when the animal faces a specific direction.

Parameters:

Name Type Description Default
polarPlot 1D-array

The binned and smoothed directional ratemap

required

Returns:

Name Type Description
klDivergence float

The divergence from circular of the 1D-array

from a uniform circular distribution

Source code in ephysiopy/common/fieldcalcs.py
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
def kldiv_dir(polarPlot):
    """
    Returns a kl divergence for directional firing: measure of directionality.
    Calculates kl diveregence between a smoothed ratemap (probably should be
    smoothed otherwise information theoretic measures
    don't 'care' about position of bins relative to one another) and a
    pure circular distribution.
    The larger the divergence the more tendancy the cell has to fire when the
    animal faces a specific direction.

    Args:
        polarPlot (1D-array): The binned and smoothed directional ratemap

    Returns:
        klDivergence (float): The divergence from circular of the 1D-array
        from a uniform circular distribution
    """

    __inc = 0.00001
    polarPlot = np.atleast_2d(polarPlot)
    polarPlot[np.isnan(polarPlot)] = __inc
    polarPlot[polarPlot == 0] = __inc
    normdPolar = polarPlot / float(np.nansum(polarPlot))
    nDirBins = polarPlot.shape[1]
    compCirc = np.ones_like(polarPlot) / float(nDirBins)
    X = np.arange(0, nDirBins)
    kldivergence = kldiv(np.atleast_2d(X), normdPolar, compCirc)
    return kldivergence

limit_to_one(A, prc=50, min_dist=5)

Processes a multi-peaked ratemap (ie grid cell) and returns a matrix where the multi-peaked ratemap consist of a single peaked field that is a) not connected to the border and b) close to the middle of the ratemap

Source code in ephysiopy/common/fieldcalcs.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def limit_to_one(A, prc=50, min_dist=5):
    """
    Processes a multi-peaked ratemap (ie grid cell) and returns a matrix
    where the multi-peaked ratemap consist of a single peaked field that is
    a) not connected to the border and b) close to the middle of the
    ratemap
    """
    Ac = A.copy()
    Ac[np.isnan(A)] = 0
    # smooth Ac more to remove local irregularities
    n = ny = 5
    x, y = np.mgrid[-n : n + 1, -ny : ny + 1]
    g = np.exp(-(x**2 / float(n) + y**2 / float(ny)))
    g = g / g.sum()
    Ac = signal.convolve(Ac, g, mode="same")
    # remove really small values
    Ac[Ac < 1e-10] = 0
    peak_idx = skimage.feature.peak_local_max(
        Ac, min_distance=min_dist, exclude_border=False
    )
    peak_mask = np.zeros_like(Ac, dtype=bool)
    peak_mask[tuple(peak_idx.T)] = True
    peak_labels = skimage.measure.label(peak_mask, connectivity=2)
    field_labels = watershed(image=-Ac, markers=peak_labels)
    nFields = np.max(field_labels)
    sub_field_mask = np.zeros((nFields, Ac.shape[0], Ac.shape[1]))
    labelled_sub_field_mask = np.zeros_like(sub_field_mask)
    sub_field_props = skimage.measure.regionprops(field_labels, intensity_image=Ac)
    sub_field_centroids = []
    sub_field_size = []

    for sub_field in sub_field_props:
        tmp = np.zeros(Ac.shape).astype(bool)
        tmp[sub_field.coords[:, 0], sub_field.coords[:, 1]] = True
        tmp2 = Ac > sub_field.max_intensity * (prc / float(100))
        sub_field_mask[sub_field.label - 1, :, :] = np.logical_and(tmp2, tmp)
        labelled_sub_field_mask[sub_field.label - 1, np.logical_and(tmp2, tmp)] = (
            sub_field.label
        )
        sub_field_centroids.append(sub_field.centroid)
        sub_field_size.append(sub_field.area)  # in bins
    sub_field_mask = np.sum(sub_field_mask, 0)
    middle = np.round(np.array(A.shape) / 2)
    normd_dists = sub_field_centroids - middle
    field_dists_from_middle = np.hypot(normd_dists[:, 0], normd_dists[:, 1])
    central_field_idx = np.argmin(field_dists_from_middle)
    central_field = np.squeeze(labelled_sub_field_mask[central_field_idx, :, :])
    # collapse the labelled mask down to an 2d array
    labelled_sub_field_mask = np.sum(labelled_sub_field_mask, 0)
    # clear the border
    cleared_mask = skimage.segmentation.clear_border(central_field)
    # check we've still got stuff in the matrix or fail
    if ~np.any(cleared_mask):
        print("No fields were detected away from edges so nothing returned")
        return None, None, None
    else:
        central_field_props = sub_field_props[central_field_idx]
    return central_field_props, central_field, central_field_idx

local_threshold(A, prc=50, min_dist=5)

Locally thresholds a ratemap to take only the surrounding prc amount around any local peak

Source code in ephysiopy/common/fieldcalcs.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def local_threshold(A, prc=50, min_dist=5):
    """
    Locally thresholds a ratemap to take only the surrounding prc amount
    around any local peak
    """
    Ac = A.copy()
    nanidx = np.isnan(Ac)
    Ac[nanidx] = 0
    # smooth Ac more to remove local irregularities
    n = ny = 5
    x, y = np.mgrid[-n : n + 1, -ny : ny + 1]
    g = np.exp(-(x**2 / float(n) + y**2 / float(ny)))
    g = g / g.sum()
    Ac = signal.convolve(Ac, g, mode="same")
    Ac_r = skimage.exposure.rescale_intensity(
        Ac, in_range="image", out_range=(0, 1000)
    ).astype(np.int32)
    peak_idx = skimage.feature.peak_local_max(
        Ac_r, min_distance=min_dist, exclude_border=False
    )
    peak_mask = np.zeros_like(Ac, dtype=bool)
    peak_mask[tuple(peak_idx.T)] = True
    peak_labels = skimage.measure.label(peak_mask, connectivity=2)
    field_labels = watershed(image=-Ac, markers=peak_labels)
    nFields = np.max(field_labels)
    sub_field_mask = np.zeros((nFields, Ac.shape[0], Ac.shape[1]))
    sub_field_props = skimage.measure.regionprops(field_labels, intensity_image=Ac)
    sub_field_centroids = []
    sub_field_size = []

    for sub_field in sub_field_props:
        tmp = np.zeros(Ac.shape).astype(bool)
        tmp[sub_field.coords[:, 0], sub_field.coords[:, 1]] = True
        tmp2 = Ac > sub_field.max_intensity * (prc / float(100))
        sub_field_mask[sub_field.label - 1, :, :] = np.logical_and(tmp2, tmp)
        sub_field_centroids.append(sub_field.centroid)
        sub_field_size.append(sub_field.area)  # in bins
    sub_field_mask = np.sum(sub_field_mask, 0)
    A_out = np.zeros_like(A)
    A_out[sub_field_mask.astype(bool)] = A[sub_field_mask.astype(bool)]
    A_out[nanidx] = np.nan
    return A_out

skaggs_info(ratemap, dwelltimes, **kwargs)

Calculates Skaggs information measure

Parameters:

Name Type Description Default
ratemap array_like

The binned up ratemap

required
dwelltimes array_like

Must be same size as ratemap

required

Returns:

Name Type Description
bits_per_spike float

Skaggs information score

Notes

THIS DATA SHOULD UNDERGO ADAPTIVE BINNING See getAdaptiveMap() in binning class

Returns Skaggs et al's estimate of spatial information in bits per spike:

.. math:: I = sum_{x} p(x).r(x).log(r(x)/r)

Source code in ephysiopy/common/fieldcalcs.py
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
def skaggs_info(ratemap, dwelltimes, **kwargs):
    """
    Calculates Skaggs information measure

    Args:
        ratemap (array_like): The binned up ratemap
        dwelltimes (array_like): Must be same size as ratemap

    Returns:
        bits_per_spike (float): Skaggs information score

    Notes:
        THIS DATA SHOULD UNDERGO ADAPTIVE BINNING
        See getAdaptiveMap() in binning class

        Returns Skaggs et al's estimate of spatial information
        in bits per spike:

        .. math:: I = sum_{x} p(x).r(x).log(r(x)/r)
    """
    if "sample_rate" in kwargs:
        sample_rate = kwargs["sample_rate"]
    else:
        sample_rate = 50

    dwelltimes = dwelltimes / sample_rate  # assumed sample rate of 50Hz
    if ratemap.ndim > 1:
        ratemap = np.reshape(ratemap, (np.prod(np.shape(ratemap)), 1))
        dwelltimes = np.reshape(dwelltimes, (np.prod(np.shape(dwelltimes)), 1))
    duration = np.nansum(dwelltimes)
    meanrate = np.nansum(ratemap * dwelltimes) / duration
    if meanrate <= 0.0:
        bits_per_spike = np.nan
        return bits_per_spike
    p_x = dwelltimes / duration
    p_r = ratemap / meanrate
    dum = p_x * ratemap
    ind = np.nonzero(dum)[0]
    bits_per_spike = np.nansum(dum[ind] * np.log2(p_r[ind]))
    bits_per_spike = bits_per_spike / meanrate
    return bits_per_spike

Grid cells

SAC

Bases: object

Spatial AutoCorrelation (SAC) class

Source code in ephysiopy/common/gridcell.py
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class SAC(object):
    """
    Spatial AutoCorrelation (SAC) class
    """

    def getMeasures(
            self, A, maxima='centroid',
            allProps=True, **kwargs):
        """
        Extracts various measures from a spatial autocorrelogram

        Args:
            A (array_like): The spatial autocorrelogram (SAC)
            maxima (str, optional): The method used to detect the peaks in the
            SAC.
                Legal values are 'single' and 'centroid'. Default 'centroid'
            field_extent_method (int, optional): The method used to delimit
            the regions of interest in the SAC
                Legal values:
                * 1 - uses half height of the ROI peak to limit field extent
                * 2 - uses a watershed method to limit field extent
                Default 2
            allProps (bool, optional): Whether to return a dictionary that
            contains the attempt to fit an ellipse around the edges of the
            central size peaks. See below. Default True

        Returns:
            props (dict): A dictionary containing measures of the SAC. 
            Keys include:
                * gridness score
                * scale
                * orientation
                * the coordinates of the peaks closest to  SAC centre
                * a binary mask that defines the extent of the 6 central fields
                * values of the rotation procedure used to calculate gridness
                * ellipse axes and angle (if allProps is True and it worked)

        Notes:
            In order to maintain backward compatibility this is a wrapper for
            ephysiopy.common.ephys_generic.FieldCalcs.grid_field_props()

        See Also:
            ephysiopy.common.ephys_generic.FieldCalcs.grid_field_props()
        """
        return fieldcalcs.grid_field_props(
            A, maxima, allProps, **kwargs)

    def get_basic_gridscore(self, A: np.ndarray, step: int = 30, **kwargs):
        '''
        Rotates the image A in step amounts, correlated each rotated image
        with the original. The maximum of the values at 30, 90 and 150 degrees
        is the subtracted from the minimum of the values at 60, 120
        and 180 degrees to give the grid score.
        '''
        from ephysiopy.common.fieldcalcs import gridness
        return gridness(A, step)[0]

    def get_expanding_circle_gridscore(self, A: np.ndarray, **kwargs):
        '''
        Calculates the gridscore for each circular sub-region of image A
        where the circles are centred on the image centre and expanded to
        the edge of the image. The maximum of the get_basic_gridscore() for
        each of these circular sub-regions is returned as the gridscore
        '''

        from ephysiopy.common.fieldcalcs import get_circular_regions
        images = get_circular_regions(A, **kwargs)
        gridscores = [self.get_basic_gridscore(im) for im in images]
        return max(gridscores)

    def get_deformed_sac_gridscore(self, A: np.ndarray, **kwargs):
        '''
        Deforms a non-circular SAC into a circular SAC (circular meaning
        the ellipse drawn around the edges of the 6 nearest peaks to the
        SAC centre) and returns get_basic_griscore() calculated on the 
        deformed (or re-formed?!) SAC
        '''
        from ephysiopy.common.fieldcalcs import deform_SAC
        deformed_SAC = deform_SAC(A)
        return self.get_basic_gridscore(deformed_SAC)

    def show(self, A, inDict, ax=None, **kwargs):
        """
        Displays the result of performing a spatial autocorrelation (SAC)
        on a grid cell.

        Uses the dictionary containing measures of the grid cell SAC to
        make a pretty picture

        Args:
            A (array_like): The spatial autocorrelogram
            inDict (dict): The dictionary calculated in getmeasures
            ax (matplotlib.axes._subplots.AxesSubplot, optional).
            If given the plot will get drawn in these axes. Default None

        Returns:
            fig (matplotlib.Figure instance): The Figure on which the SAC is
            shown

        See Also:
            ephysiopy.common.binning.RateMap.autoCorr2D()
            ephysiopy.common.ephys_generic.FieldCalcs.getMeaures()
        """
        from ephysiopy.visualise.plotting import FigureMaker
        F = FigureMaker()
        F.show_SAC(A, inDict, ax, **kwargs)

getMeasures(A, maxima='centroid', allProps=True, **kwargs)

Extracts various measures from a spatial autocorrelogram

Parameters:

Name Type Description Default
A array_like

The spatial autocorrelogram (SAC)

required
maxima str

The method used to detect the peaks in the

'centroid'
field_extent_method int

The method used to delimit

required
allProps bool

Whether to return a dictionary that

True

Returns:

Name Type Description
props dict

A dictionary containing measures of the SAC.

Keys include: * gridness score * scale * orientation * the coordinates of the peaks closest to SAC centre * a binary mask that defines the extent of the 6 central fields * values of the rotation procedure used to calculate gridness * ellipse axes and angle (if allProps is True and it worked)

Notes

In order to maintain backward compatibility this is a wrapper for ephysiopy.common.ephys_generic.FieldCalcs.grid_field_props()

See Also

ephysiopy.common.ephys_generic.FieldCalcs.grid_field_props()

Source code in ephysiopy/common/gridcell.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def getMeasures(
        self, A, maxima='centroid',
        allProps=True, **kwargs):
    """
    Extracts various measures from a spatial autocorrelogram

    Args:
        A (array_like): The spatial autocorrelogram (SAC)
        maxima (str, optional): The method used to detect the peaks in the
        SAC.
            Legal values are 'single' and 'centroid'. Default 'centroid'
        field_extent_method (int, optional): The method used to delimit
        the regions of interest in the SAC
            Legal values:
            * 1 - uses half height of the ROI peak to limit field extent
            * 2 - uses a watershed method to limit field extent
            Default 2
        allProps (bool, optional): Whether to return a dictionary that
        contains the attempt to fit an ellipse around the edges of the
        central size peaks. See below. Default True

    Returns:
        props (dict): A dictionary containing measures of the SAC. 
        Keys include:
            * gridness score
            * scale
            * orientation
            * the coordinates of the peaks closest to  SAC centre
            * a binary mask that defines the extent of the 6 central fields
            * values of the rotation procedure used to calculate gridness
            * ellipse axes and angle (if allProps is True and it worked)

    Notes:
        In order to maintain backward compatibility this is a wrapper for
        ephysiopy.common.ephys_generic.FieldCalcs.grid_field_props()

    See Also:
        ephysiopy.common.ephys_generic.FieldCalcs.grid_field_props()
    """
    return fieldcalcs.grid_field_props(
        A, maxima, allProps, **kwargs)

get_basic_gridscore(A, step=30, **kwargs)

Rotates the image A in step amounts, correlated each rotated image with the original. The maximum of the values at 30, 90 and 150 degrees is the subtracted from the minimum of the values at 60, 120 and 180 degrees to give the grid score.

Source code in ephysiopy/common/gridcell.py
52
53
54
55
56
57
58
59
60
def get_basic_gridscore(self, A: np.ndarray, step: int = 30, **kwargs):
    '''
    Rotates the image A in step amounts, correlated each rotated image
    with the original. The maximum of the values at 30, 90 and 150 degrees
    is the subtracted from the minimum of the values at 60, 120
    and 180 degrees to give the grid score.
    '''
    from ephysiopy.common.fieldcalcs import gridness
    return gridness(A, step)[0]

get_deformed_sac_gridscore(A, **kwargs)

Deforms a non-circular SAC into a circular SAC (circular meaning the ellipse drawn around the edges of the 6 nearest peaks to the SAC centre) and returns get_basic_griscore() calculated on the deformed (or re-formed?!) SAC

Source code in ephysiopy/common/gridcell.py
75
76
77
78
79
80
81
82
83
84
def get_deformed_sac_gridscore(self, A: np.ndarray, **kwargs):
    '''
    Deforms a non-circular SAC into a circular SAC (circular meaning
    the ellipse drawn around the edges of the 6 nearest peaks to the
    SAC centre) and returns get_basic_griscore() calculated on the 
    deformed (or re-formed?!) SAC
    '''
    from ephysiopy.common.fieldcalcs import deform_SAC
    deformed_SAC = deform_SAC(A)
    return self.get_basic_gridscore(deformed_SAC)

get_expanding_circle_gridscore(A, **kwargs)

Calculates the gridscore for each circular sub-region of image A where the circles are centred on the image centre and expanded to the edge of the image. The maximum of the get_basic_gridscore() for each of these circular sub-regions is returned as the gridscore

Source code in ephysiopy/common/gridcell.py
62
63
64
65
66
67
68
69
70
71
72
73
def get_expanding_circle_gridscore(self, A: np.ndarray, **kwargs):
    '''
    Calculates the gridscore for each circular sub-region of image A
    where the circles are centred on the image centre and expanded to
    the edge of the image. The maximum of the get_basic_gridscore() for
    each of these circular sub-regions is returned as the gridscore
    '''

    from ephysiopy.common.fieldcalcs import get_circular_regions
    images = get_circular_regions(A, **kwargs)
    gridscores = [self.get_basic_gridscore(im) for im in images]
    return max(gridscores)

show(A, inDict, ax=None, **kwargs)

Displays the result of performing a spatial autocorrelation (SAC) on a grid cell.

Uses the dictionary containing measures of the grid cell SAC to make a pretty picture

Parameters:

Name Type Description Default
A array_like

The spatial autocorrelogram

required
inDict dict

The dictionary calculated in getmeasures

required

Returns:

Name Type Description
fig matplotlib.Figure instance

The Figure on which the SAC is

shown

See Also

ephysiopy.common.binning.RateMap.autoCorr2D() ephysiopy.common.ephys_generic.FieldCalcs.getMeaures()

Source code in ephysiopy/common/gridcell.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def show(self, A, inDict, ax=None, **kwargs):
    """
    Displays the result of performing a spatial autocorrelation (SAC)
    on a grid cell.

    Uses the dictionary containing measures of the grid cell SAC to
    make a pretty picture

    Args:
        A (array_like): The spatial autocorrelogram
        inDict (dict): The dictionary calculated in getmeasures
        ax (matplotlib.axes._subplots.AxesSubplot, optional).
        If given the plot will get drawn in these axes. Default None

    Returns:
        fig (matplotlib.Figure instance): The Figure on which the SAC is
        shown

    See Also:
        ephysiopy.common.binning.RateMap.autoCorr2D()
        ephysiopy.common.ephys_generic.FieldCalcs.getMeaures()
    """
    from ephysiopy.visualise.plotting import FigureMaker
    F = FigureMaker()
    F.show_SAC(A, inDict, ax, **kwargs)

Phase coding

phasePrecession2D

Bases: object

Performs phase precession analysis for single unit data

Mostly a total rip-off of code written by Ali Jeewajee for his paper on 2D phase precession in place and grid cells [1]_

.. [1] Jeewajee A, Barry C, Douchamps V, Manson D, Lever C, Burgess N. Theta phase precession of grid and place cell firing in open environments. Philos Trans R Soc Lond B Biol Sci. 2013 Dec 23;369(1635):20120532. doi: 10.1098/rstb.2012.0532.

Parameters:

Name Type Description Default
lfp_sig array

The LFP signal against which cells might precess...

required
lfp_fs int

The sampling frequency of the LFP signal

required
xy array

The position data as 2 x num_position_samples

required
spike_ts array

The times in samples at which the cell fired

required
pos_ts array

The times in samples at which position was captured

required
pp_config dict

Contains parameters for running the analysis. See phase_precession_config dict in ephysiopy.common.eegcalcs

phase_precession_config
Source code in ephysiopy/common/phasecoding.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
class phasePrecession2D(object):
    """
    Performs phase precession analysis for single unit data

    Mostly a total rip-off of code written by Ali Jeewajee for his paper on
    2D phase precession in place and grid cells [1]_

    .. [1] Jeewajee A, Barry C, Douchamps V, Manson D, Lever C, Burgess N.
        Theta phase precession of grid and place cell firing in open
        environments.
        Philos Trans R Soc Lond B Biol Sci. 2013 Dec 23;369(1635):20120532.
        doi: 10.1098/rstb.2012.0532.

    Args:
        lfp_sig (np.array): The LFP signal against which cells might precess...
        lfp_fs (int): The sampling frequency of the LFP signal
        xy (np.array): The position data as 2 x num_position_samples
        spike_ts (np.array): The times in samples at which the cell fired
        pos_ts (np.array): The times in samples at which position was captured
        pp_config (dict): Contains parameters for running the analysis.
            See phase_precession_config dict in ephysiopy.common.eegcalcs
    """

    def __init__(
        self,
        lfp_sig: np.array,
        lfp_fs: int,
        xy: np.array,
        spike_ts: np.array,
        pos_ts: np.array,
        pp_config: dict = phase_precession_config,
    ):
        # Set up the parameters
        # this sets a bunch of member attributes from the pp_config dict
        self.update_config(pp_config)
        self._pos_ts = pos_ts

        # Create a dict to hold the stats values
        stats_dict = {
            "values": None,
            "pha": None,
            "slope": None,
            "intercept": None,
            "cor": None,
            "p": None,
            "cor_boot": None,
            "p_shuffled": None,
            "ci": None,
            "reg": None,
        }
        # Create a dict of regressors to hold stat values
        # for each regressor
        from collections import defaultdict

        self.regressors = {}
        self.regressors = defaultdict(lambda: stats_dict.copy(), self.regressors)
        regressor_keys = [
            "spk_numWithinRun",
            "pos_exptdRate_cum",
            "pos_instFR",
            "pos_timeInRun",
            "pos_d_cum",
            "pos_d_meanDir",
            "pos_d_currentdir",
            "spk_thetaBatchLabelInRun",
        ]
        [self.regressors[k] for k in regressor_keys]
        # each of the regressors in regressor_keys is a key with a value
        # of stats_dict

        self.k = 1000
        self.alpha = 0.05
        self.hyp = 0
        self.conf = True

        # Process the EEG data a bit...
        self.eeg = lfp_sig
        L = LFPOscillations(lfp_sig, lfp_fs)
        filt_sig, phase, _, _ = L.getFreqPhase(lfp_sig, [6, 12], 2)
        self.filteredEEG = filt_sig
        self.phase = phase
        self.phaseAdj = None

        self.update_position(xy, self.ppm, cm=self.convert_xy_2_cm)
        self.update_rate_map()

        spk_times_in_pos_samples = self.getSpikePosIndices(spike_ts)
        spk_weights = np.bincount(spk_times_in_pos_samples, minlength=len(self.pos_ts))
        self.spk_times_in_pos_samples = spk_times_in_pos_samples
        self.spk_weights = spk_weights

        self.spike_ts = spike_ts

    @property
    def pos_ts(self):
        return self._pos_ts

    @pos_ts.setter
    def pos_ts(self, value):
        self._pos_ts = value

    @property
    def xy(self):
        return self.PosData.xy

    @xy.setter
    def xy(self, value):
        self.PosData.xy = value

    def update_config(self, pp_config):
        [setattr(self, k, pp_config[k]) for k in pp_config.keys()]

    def update_position(self, xy, ppm: float, cm: bool):
        P = PosCalcsGeneric(
            xy[0, :],
            xy[1, :],
            ppm=ppm,
            convert2cm=cm,
        )
        P.postprocesspos(tracker_params={"AxonaBadValue": 1023})
        # ... do the ratemap creation here once
        self.PosData = P

    def update_rate_map(self):
        R = RateMap(
            self.PosData.xy,
            self.PosData.dir,
            self.PosData.speed,
            xyInCms=self.convert_xy_2_cm,
        )
        R.binsize = self.cms_per_bin
        R.smooth_sz = self.field_smoothing_kernel_len
        R.ppm = self.ppm
        self.RateMap = R  # this will be used a fair bit below

    def getSpikePosIndices(self, spk_times: np.array):
        pos_times = getattr(self, "pos_ts")
        idx = np.searchsorted(pos_times, spk_times)
        idx[idx == len(pos_times)] = idx[idx == len(pos_times)] - 1
        return idx

    def performRegression(self, laserEvents=None, **kwargs):
        """
        Wrapper function for doing the actual regression which has multiple
        stages.

        Specifically here we partition fields into sub-fields, get a bunch of
        information about the position, spiking and theta data and then
        do the actual regression.

        Args:
            tetrode (int): The tetrode to examine
            cluster (int): The cluster to examine
            laserEvents (array_like, optional): The on times for laser events
            if present. Default is None

        See Also:
            ephysiopy.common.eegcalcs.phasePrecession.partitionFields()
            ephysiopy.common.eegcalcs.phasePrecession.getPosProps()
            ephysiopy.common.eegcalcs.phasePrecession.getThetaProps()
            ephysiopy.common.eegcalcs.phasePrecession.getSpikeProps()
            ephysiopy.common.eegcalcs.phasePrecession._ppRegress()
        """

        # Partition fields
        peaksXY, _, labels, _ = self.partitionFields(plot=True)

        # split into runs
        posD, runD = self.getPosProps(
            labels, peaksXY, laserEvents=laserEvents, plot=True
        )

        # get theta cycles, amplitudes, phase etc
        self.getThetaProps()

        # get the indices of spikes for various metrics such as
        # theta cycle, run etc
        spkD = self.getSpikeProps(
            posD["runLabel"], runD["meanDir"], runD["runDurationInPosBins"]
        )

        # Do the regressions
        regress_dict = self._ppRegress(spkD, plot=True)

        self.plotPPRegression(regress_dict)

    def partitionFields(self, ftype="g", plot=False, **kwargs):
        """
        Partitions fields.

        Partitions spikes into fields by finding the watersheds around the
        peaks of a super-smoothed ratemap

        Args:
            spike_ts (np.array): The ratemap to partition
            ftype (str): 'p' or 'g' denoting place or grid cells
              - not implemented yet
            plot (bool): Whether to produce a debugging plot or not

        Returns:
            peaksXY (array_like): The xy coordinates of the peak rates in
            each field
            peaksRate (array_like): The peak rates in peaksXY
            labels (numpy.ndarray): An array of the labels corresponding to
            each field (starting at 1)
            rmap (numpy.ndarray): The ratemap of the tetrode / cluster
        """
        rmap, (xe, ye) = self.RateMap.getMap(self.spk_weights)
        nan_idx = np.isnan(rmap)
        rmap[nan_idx] = 0
        # start image processing:
        # get some markers
        from ephysiopy.common import fieldcalcs

        markers = fieldcalcs.local_threshold(rmap, prc=self.field_threshold_percent)
        # clear the edges / any invalid positions again
        markers[nan_idx] = 0
        # label these markers so each blob has a unique id
        labels = ndimage.label(markers)[0]
        # labels is now a labelled int array from 0 to however many fields have
        # been detected
        # get the number of spikes in each field - NB this is done against a
        # flattened array so we need to figure out which count corresponds to
        # which particular field id using np.unique
        fieldId, _ = np.unique(labels, return_index=True)
        fieldId = fieldId[1::]
        # TODO: come back to this as may need to know field id ordering
        peakCoords = np.array(
            ndimage.maximum_position(rmap, labels=labels, index=fieldId)
        ).astype(int)
        # COMCoords = np.array(
        #     ndimage.center_of_mass(
        #         rmap, labels=labels, index=fieldId)
        # ).astype(int)
        peaksXY = np.vstack((xe[peakCoords[:, 0]], ye[peakCoords[:, 1]])).T
        # find the peak rate at each of the centre of the detected fields to
        # subsequently threshold the field at some fraction of the peak value
        peakRates = rmap[peakCoords[:, 0], peakCoords[:, 1]]
        fieldThresh = peakRates * self.field_threshold
        rmFieldMask = np.zeros_like(rmap)
        for fid in fieldId:
            f = labels[peakCoords[fid - 1, 0], peakCoords[fid - 1, 1]]
            rmFieldMask[labels == f] = rmap[labels == f] > fieldThresh[f - 1]
        labels[~rmFieldMask.astype(bool)] = 0
        # peakBinInds = np.ceil(peakCoords)
        # re-order some vars to get into same format as fieldLabels
        peakLabels = labels[peakCoords[:, 0], peakCoords[:, 1]]
        peaksXY = peaksXY[peakLabels - 1, :]
        peaksRate = peakRates[peakLabels - 1]
        # peakBinInds = peakBinInds[peakLabels-1, :]
        # peaksXY = peakCoords - np.min(xy, 1)

        # if ~np.isnan(self.area_threshold):
        #     # TODO: this needs fixing so sensible values are used and the
        #     # modified bool array is propagated correctly ie makes
        #     # sense to have a function that applies a bool array to whatever
        #     # arrays are used as output and call it in a couple of places
        #     # areaInBins = self.area_threshold * self.binsPerCm
        #     lb = ndimage.label(markers)[0]
        #     rp = skimage.measure.regionprops(lb)
        #     for reg in rp:
        #         print(reg.filled_area)
        #     markers = skimage.morphology.remove_small_objects(
        #         lb, min_size=4000, connectivity=4, in_place=True)
        if plot:
            fig = plt.figure()
            ax = fig.add_subplot(211)
            ax.pcolormesh(
                ye, xe, rmap, cmap=matplotlib.colormaps["jet"], edgecolors="face"
            )
            ax.set_title("Smoothed ratemap + peaks")
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.set_aspect("equal")
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            ax.plot(peaksXY[:, 1], peaksXY[:, 0], "ko")
            ax.set_ylim(ylim)
            ax.set_xlim(xlim)

            ax = fig.add_subplot(212)
            ax.imshow(labels, interpolation="nearest", origin="lower")
            ax.set_title("Labelled restricted fields")
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.set_aspect("equal")

        return peaksXY, peaksRate, labels, rmap

    def getPosProps(self, labels, peaksXY, laserEvents=None, plot=False, **kwargs):
        """
        Uses the output of partitionFields and returns vectors the same
        length as pos.

        Args:
            tetrode, cluster (int): The tetrode / cluster to examine
            peaksXY (array_like): The x-y coords of the peaks in the ratemap
            laserEvents (array_like): The position indices of on events
            (laser on)

        Returns:
            pos_dict, run_dict (dict): Contains a whole bunch of information
            for the whole trial and also on a run-by-run basis (run_dict).
            See the end of this function for all the key / value pairs.
        """

        spikeTS = self.spike_ts  # in seconds
        xy = self.RateMap.xy
        xydir = self.RateMap.dir
        spd = self.RateMap.speed
        spkPosInd = np.ceil(spikeTS * self.pos_sample_rate).astype(int)
        spkPosInd[spkPosInd > len(xy.T)] = len(xy.T) - 1
        nPos = xy.shape[1]
        xy_old = xy.copy()
        xydir = np.squeeze(xydir)
        xydir_old = xydir.copy()

        rmap, (x_bins_in_pixels, y_bins_in_pixels) = self.RateMap.getMap(
            self.spk_weights
        )
        xe = x_bins_in_pixels
        ye = y_bins_in_pixels

        # The large number of bins combined with the super-smoothed ratemap
        # will lead to fields labelled with lots of small holes in. Fill those
        # gaps in here and calculate the perimeter of the fields based on that
        # labelled image
        labels, n_labels = ndimage.label(ndimage.binary_fill_holes(labels))

        rmap[np.isnan(rmap)] = 0
        xBins = np.digitize(xy[0], ye[:-1])
        yBins = np.digitize(xy[1], xe[:-1])
        fieldLabel = labels[yBins - 1, xBins - 1]
        fl_counts, fl_bins = np.histogram(fieldLabel, bins=np.unique(labels))
        for i, fl in enumerate(fl_bins[1::]):
            print("Field {} has {} samples".format(i, fl_counts[i]))

        fieldPerimMask = bwperim(labels)
        fieldPerimYBins, fieldPerimXBins = np.nonzero(fieldPerimMask)
        fieldPerimX = ye[fieldPerimXBins]
        fieldPerimY = xe[fieldPerimYBins]
        fieldPerimXY = np.vstack((fieldPerimX, fieldPerimY))
        peaksXYBins = np.array(
            ndimage.maximum_position(rmap, labels=labels, index=np.unique(labels)[1::])
        ).astype(int)
        peakY = xe[peaksXYBins[:, 0]]
        peakX = ye[peaksXYBins[:, 1]]
        peaksXY = np.vstack((peakX, peakY)).T

        posRUnsmthd = np.zeros((nPos)) * np.nan
        posAngleFromPeak = np.zeros_like(posRUnsmthd) * np.nan
        perimAngleFromPeak = np.zeros_like(fieldPerimMask) * np.nan
        for i, peak in enumerate(peaksXY):
            i = i + 1
            # grab each fields perim coords and the pos samples within it
            y_ind, x_ind = np.nonzero(fieldPerimMask == i)
            thisFieldPerim = np.array([xe[x_ind], ye[y_ind]])
            if thisFieldPerim.any():
                this_xy = xy[:, fieldLabel == i]
                # calculate angle from the field peak for each point on the
                # perim and each pos sample that lies within the field
                thisPerimAngle = np.arctan2(
                    thisFieldPerim[1, :] - peak[1], thisFieldPerim[0, :] - peak[0]
                )
                thisPosAngle = np.arctan2(
                    this_xy[1, :] - peak[1], this_xy[0, :] - peak[0]
                )
                posAngleFromPeak[fieldLabel == i] = thisPosAngle

                perimAngleFromPeak[fieldPerimMask == i] = thisPerimAngle
                # for each pos sample calculate which point on the perim is
                # most colinear with the field centre - see _circ_abs for more
                thisAngleDf = circ_abs(
                    thisPerimAngle[:, np.newaxis] - thisPosAngle[np.newaxis, :]
                )
                thisPerimInd = np.argmin(thisAngleDf, 0)
                # calculate the distance to the peak from pos and the min perim
                # point and calculate the ratio (r - see OUtputs for method)
                tmp = this_xy.T - peak.T
                distFromPos2Peak = np.hypot(tmp[:, 0], tmp[:, 1])
                tmp = thisFieldPerim[:, thisPerimInd].T - peak.T
                distFromPerim2Peak = np.hypot(tmp[:, 0], tmp[:, 1])
                posRUnsmthd[fieldLabel == i] = distFromPos2Peak / distFromPerim2Peak
        # the skimage find_boundaries method combined with the labelled mask
        # strive to make some of the values in thisDistFromPos2Peak larger than
        # those in thisDistFromPerim2Peak which means that some of the vals in
        # posRUnsmthd larger than 1 which means the values in xy_new later are
        # wrong - so lets cap any value > 1 to 1. The same cap is applied later
        # to rho when calculating the angular values. Print out a warning
        # message letting the user know how many values have been capped
        print(
            "\n\n{:.2%} posRUnsmthd values have been capped to 1\n\n".format(
                np.sum(posRUnsmthd >= 1) / posRUnsmthd.size
            )
        )
        posRUnsmthd[posRUnsmthd >= 1] = 1
        # label non-zero contiguous runs with a unique id
        runLabel = labelContigNonZeroRuns(fieldLabel)
        isRun = runLabel > 0
        runStartIdx = getLabelStarts(runLabel)
        runEndIdx = getLabelEnds(runLabel)
        # find runs that are too short, have low speed or too few spikes
        runsSansSpikes = np.ones(len(runStartIdx), dtype=bool)
        spkRunLabels = runLabel[spkPosInd] - 1
        runsSansSpikes[spkRunLabels[spkRunLabels > 0]] = False
        k = signal.windows.boxcar(self.speed_smoothing_window_len) / float(
            self.speed_smoothing_window_len
        )
        spdSmthd = signal.convolve(np.squeeze(spd), k, mode="same")
        runDurationInPosBins = runEndIdx - runStartIdx + 1
        runsMinSpeed = []
        runId = np.unique(runLabel)[1::]
        for run in runId:
            runsMinSpeed.append(np.min(spdSmthd[runLabel == run]))
        runsMinSpeed = np.array(runsMinSpeed)
        badRuns = np.logical_or(
            np.logical_or(
                runsMinSpeed < self.minimum_allowed_run_speed,
                runDurationInPosBins < self.minimum_allowed_run_duration,
            ),
            runsSansSpikes,
        )
        badRuns = np.squeeze(badRuns)
        runLabel = applyFilter2Labels(~badRuns, runLabel)
        runStartIdx = runStartIdx[~badRuns]
        runEndIdx = runEndIdx[~badRuns]  # + 1
        runsMinSpeed = runsMinSpeed[~badRuns]
        runDurationInPosBins = runDurationInPosBins[~badRuns]
        isRun = runLabel > 0

        # calculate mean and std direction for each run
        runComplexMnDir = np.squeeze(np.zeros_like(runStartIdx))
        np.add.at(
            runComplexMnDir,
            runLabel[isRun] - 1,
            np.exp(1j * (xydir[isRun] * (np.pi / 180))),
        )
        meanDir = np.angle(runComplexMnDir)  # circ mean
        tortuosity = 1 - np.abs(runComplexMnDir) / runDurationInPosBins

        # caculate angular distance between the runs main direction and the
        # pos's direction to the peak centre
        posPhiUnSmthd = np.ones_like(fieldLabel) * np.nan
        posPhiUnSmthd[isRun] = posAngleFromPeak[isRun] - meanDir[runLabel[isRun] - 1]

        # smooth r and phi in cartesian space
        # convert to cartesian coords first
        posXUnSmthd, posYUnSmthd = pol2cart(posRUnsmthd, posPhiUnSmthd)
        posXYUnSmthd = np.vstack((posXUnSmthd, posYUnSmthd))

        # filter each run with filter of appropriate length
        filtLen = np.squeeze(
            np.floor((runEndIdx - runStartIdx + 1) * self.ifr_smoothing_constant)
        )
        xy_new = np.zeros_like(xy_old) * np.nan
        for i in range(len(runStartIdx)):
            if filtLen[i] > 2:
                filt = signal.firwin(
                    int(filtLen[i] - 1),
                    cutoff=self.spatial_lowpass_cutoff / self.pos_sample_rate * 2,
                    window="blackman",
                )
                xy_new[:, runStartIdx[i] : runEndIdx[i]] = signal.filtfilt(
                    filt, [1], posXYUnSmthd[:, runStartIdx[i] : runEndIdx[i]], axis=1
                )

        r, phi = cart2pol(xy_new[0], xy_new[1])
        r[r > 1] = 1

        # calculate the direction of the smoothed data
        xydir_new = np.arctan2(np.diff(xy_new[1]), np.diff(xy_new[0]))
        xydir_new = np.append(xydir_new, xydir_new[-1])
        xydir_new[runEndIdx] = xydir_new[runEndIdx - 1]

        # project the distance value onto the current direction
        d_currentdir = r * np.cos(xydir_new - phi)

        # calculate the cumulative distance travelled on each run
        dr = np.sqrt(np.diff(np.power(r, 2), 1))
        d_cumulative = labelledCumSum(np.insert(dr, 0, 0), runLabel)

        # calculate cumulative sum of the expected normalised firing rate
        exptdRate_cumulative = labelledCumSum(1 - r, runLabel)

        # direction projected onto the run mean direction is just the x coord
        d_meandir = xy_new[0]

        # smooth binned spikes to get an instantaneous firing rate
        # set up the smoothing kernel
        kernLenInBins = np.round(self.ifr_kernel_len * self.bins_per_second)
        kernSig = self.ifr_kernel_sigma * self.bins_per_second
        k = signal.windows.gaussian(kernLenInBins, kernSig)
        # get a count of spikes to smooth over
        spkCount = np.bincount(spkPosInd, minlength=nPos)
        # apply the smoothing kernel
        instFiringRate = signal.convolve(spkCount, k, mode="same")
        instFiringRate[~isRun] = np.nan

        # find time spent within run
        time = np.ones(nPos)
        time = labelledCumSum(time, runLabel)
        timeInRun = time / self.pos_sample_rate

        fieldNum = fieldLabel[runStartIdx]
        mnSpd = np.squeeze(np.zeros_like(fieldNum))
        np.add.at(mnSpd, runLabel[isRun] - 1, spd[isRun])
        nPts = np.bincount(runLabel[isRun] - 1, minlength=len(mnSpd))
        mnSpd = mnSpd / nPts
        centralPeripheral = np.squeeze(np.zeros_like(fieldNum))
        np.add.at(centralPeripheral, runLabel[isRun] - 1, xy_new[1, isRun])
        centralPeripheral = centralPeripheral / nPts
        if plot:
            fig = plt.figure()
            ax = fig.add_subplot(221)
            ax.plot(xy_new[0], xy_new[1])
            ax.set_title("Unit circle x-y")
            ax.set_aspect("equal")
            ax.set_xlim([-1, 1])
            ax.set_ylim([-1, 1])
            ax = fig.add_subplot(222)
            ax.plot(fieldPerimX, fieldPerimY, "k.")
            ax.set_title("Field perim and laser on events")
            ax.plot(xy[0, fieldLabel > 0], xy[1, fieldLabel > 0], "y.")
            if laserEvents is not None:
                validOns = np.setdiff1d(laserEvents, np.nonzero(~np.isnan(r))[0])
                ax.plot(xy[0, validOns], xy[1, validOns], "rx")
            ax.set_aspect("equal")
            angleCMInd = np.round(perimAngleFromPeak / np.pi * 180) + 180
            angleCMInd[angleCMInd == 0] = 360
            im = np.zeros_like(fieldPerimMask)
            im[fieldPerimMask] = angleCMInd
            imM = np.ma.MaskedArray(im, mask=~fieldPerimMask, copy=True)
            #############################################
            # create custom colormap
            cmap = plt.colormaps["jet_r"]
            cmaplist = [cmap(i) for i in range(cmap.N)]
            cmaplist[0] = (1, 1, 1, 1)
            cmap = cmap.from_list("Runvals cmap", cmaplist, cmap.N)
            bounds = np.linspace(0, 1.0, 100)
            norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)
            # add the runs through the fields
            runVals = np.zeros_like(im)
            runVals[yBins[isRun] - 1, xBins[isRun] - 1] = r[isRun]
            runVals = runVals
            ax = fig.add_subplot(223)
            imm = ax.imshow(
                runVals, cmap=cmap, norm=norm, origin="lower", interpolation="nearest"
            )
            plt.colorbar(imm, orientation="horizontal")
            ax.set_aspect("equal")
            # add a custom colorbar for colors in runVals

            # create a custom colormap for the plot
            cmap = matplotlib.colormaps["hsv"]
            cmaplist = [cmap(i) for i in range(cmap.N)]
            cmaplist[0] = (1, 1, 1, 1)
            cmap = cmap.from_list("Perim cmap", cmaplist, cmap.N)
            bounds = np.linspace(0, 360, cmap.N)
            norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)

            imm = ax.imshow(
                imM, cmap=cmap, norm=norm, origin="lower", interpolation="nearest"
            )
            plt.colorbar(imm)
            ax.set_title("Runs by distance and angle")
            ax.plot(peaksXYBins[:, 1], peaksXYBins[:, 0], "ko")
            ax.set_xlim(0, im.shape[1])
            ax.set_ylim(0, im.shape[0])
            #############################################
            ax = fig.add_subplot(224)
            ax.imshow(rmap, origin="lower", interpolation="nearest")
            ax.set_aspect("equal")
            ax.set_title("Smoothed ratemap")

        # update the regressor dict from __init__ with relevant values
        self.regressors["pos_exptdRate_cum"]["values"] = exptdRate_cumulative
        self.regressors["pos_instFR"]["values"] = instFiringRate
        self.regressors["pos_timeInRun"]["values"] = timeInRun
        self.regressors["pos_d_cum"]["values"] = d_cumulative
        self.regressors["pos_d_meanDir"]["values"] = d_meandir
        self.regressors["pos_d_currentdir"]["values"] = d_currentdir
        posKeys = (
            "xy",
            "xydir",
            "r",
            "phi",
            "xy_old",
            "xydir_old",
            "fieldLabel",
            "runLabel",
            "d_currentdir",
            "d_cumulative",
            "exptdRate_cumulative",
            "d_meandir",
            "instFiringRate",
            "timeInRun",
            "fieldPerimMask",
            "perimAngleFromPeak",
            "posAngleFromPeak",
        )
        runsKeys = (
            "runStartIdx",
            "runEndIdx",
            "runDurationInPosBins",
            "runsMinSpeed",
            "meanDir",
            "tortuosity",
            "mnSpd",
            "centralPeripheral",
        )
        posDict = dict.fromkeys(posKeys, np.nan)
        # neat trick: locals is a dict that holds all locally scoped variables
        for thiskey in posDict.keys():
            posDict[thiskey] = locals()[thiskey]
        runsDict = dict.fromkeys(runsKeys, np.nan)
        for thiskey in runsDict.keys():
            runsDict[thiskey] = locals()[thiskey]
        return posDict, runsDict

    def getThetaProps(self, **kwargs):
        spikeTS = self.spike_ts
        phase = self.phase
        filteredEEG = self.filteredEEG
        oldAmplt = filteredEEG.copy()
        # get indices of spikes into eeg
        spkEEGIdx = np.ceil(
            spikeTS * (self.lfp_sample_rate / self.pos_sample_rate)
        ).astype(int)
        spkEEGIdx[spkEEGIdx > len(phase)] = len(phase) - 1
        spkCount = np.bincount(spkEEGIdx, minlength=len(phase))
        spkPhase = phase[spkEEGIdx]
        minSpikingPhase = getPhaseOfMinSpiking(spkPhase)
        phaseAdj = fixAngle(
            phase - minSpikingPhase * (np.pi / 180) + self.allowed_min_spike_phase
        )
        isNegFreq = np.diff(np.unwrap(phaseAdj)) < 0
        isNegFreq = np.append(isNegFreq, isNegFreq[-1])
        # get start of theta cycles as points where diff > pi
        phaseDf = np.diff(phaseAdj)
        cycleStarts = phaseDf[1::] < -np.pi
        cycleStarts = np.append(cycleStarts, True)
        cycleStarts = np.insert(cycleStarts, 0, True)
        cycleStarts[isNegFreq] = False
        cycleLabel = np.cumsum(cycleStarts)

        # caculate power and find low power cycles
        power = np.power(filteredEEG, 2)
        cycleTotValidPow = np.bincount(
            cycleLabel[~isNegFreq], weights=power[~isNegFreq]
        )
        cycleValidBinCount = np.bincount(cycleLabel[~isNegFreq])
        cycleValidMnPow = cycleTotValidPow / cycleValidBinCount
        powRejectThresh = np.percentile(
            cycleValidMnPow, self.min_power_percent_threshold
        )
        cycleHasBadPow = cycleValidMnPow < powRejectThresh

        # find cycles too long or too short
        cycleTotBinCount = np.bincount(cycleLabel)
        cycleHasBadLen = np.logical_or(
            cycleTotBinCount > self.allowed_theta_len[1],
            cycleTotBinCount < self.allowed_theta_len[0],
        )

        # remove data calculated as 'bad'
        isBadCycle = np.logical_or(cycleHasBadLen, cycleHasBadPow)
        isInBadCycle = isBadCycle[cycleLabel]
        isBad = np.logical_or(isInBadCycle, isNegFreq)
        phaseAdj[isBad] = np.nan
        self.phaseAdj = phaseAdj
        ampAdj = filteredEEG.copy()
        ampAdj[isBad] = np.nan
        cycleLabel[isBad] = 0
        self.cycleLabel = cycleLabel
        out = {
            "phase": phaseAdj,
            "amp": ampAdj,
            "cycleLabel": cycleLabel,
            "oldPhase": phase.copy(),
            "oldAmplt": oldAmplt,
            "spkCount": spkCount,
        }
        return out

    def getSpikeProps(self, runLabel, meanDir, durationInPosBins):

        spikeTS = self.spike_ts
        xy = self.RateMap.xy
        phase = self.phaseAdj
        cycleLabel = self.cycleLabel
        spkEEGIdx = np.ceil(spikeTS * self.lfp_sample_rate)
        spkEEGIdx[spkEEGIdx > len(phase)] = len(phase) - 1
        spkEEGIdx = spkEEGIdx.astype(int)
        spkPosIdx = np.ceil(spikeTS * self.pos_sample_rate)
        spkPosIdx[spkPosIdx > xy.shape[1]] = xy.shape[1] - 1
        spkRunLabel = runLabel[spkPosIdx.astype(int)]
        thetaCycleLabel = cycleLabel[spkEEGIdx.astype(int)]

        # build mask true for spikes in 1st half of cycle
        firstInTheta = thetaCycleLabel[:-1] != thetaCycleLabel[1::]
        firstInTheta = np.insert(firstInTheta, 0, True)
        lastInTheta = firstInTheta[1::]
        # calculate two kinds of numbering for spikes in a run
        numWithinRun = labelledCumSum(np.ones_like(spkPosIdx), spkRunLabel)
        thetaBatchLabelInRun = labelledCumSum(firstInTheta.astype(float), spkRunLabel)

        spkCount = np.bincount(spkRunLabel[spkRunLabel > 0], minlength=len(meanDir))
        rateInPosBins = spkCount[1::] / durationInPosBins.astype(float)
        # update the regressor dict from __init__ with relevant values
        self.regressors["spk_numWithinRun"]["values"] = numWithinRun
        self.regressors["spk_thetaBatchLabelInRun"]["values"] = thetaBatchLabelInRun
        spkKeys = (
            "spikeTS",
            "spkPosIdx",
            "spkEEGIdx",
            "spkRunLabel",
            "thetaCycleLabel",
            "firstInTheta",
            "lastInTheta",
            "numWithinRun",
            "thetaBatchLabelInRun",
            "spkCount",
            "rateInPosBins",
        )
        spkDict = dict.fromkeys(spkKeys, np.nan)
        for thiskey in spkDict.keys():
            spkDict[thiskey] = locals()[thiskey]
        return spkDict

    def _ppRegress(self, spkDict, whichSpk="first", plot=False, **kwargs):

        phase = self.phaseAdj
        newSpkRunLabel = spkDict["spkRunLabel"].copy()
        # TODO: need code to deal with splitting the data based on a group of
        # variables
        spkUsed = newSpkRunLabel > 0
        if "first" in whichSpk:
            spkUsed[~spkDict["firstInTheta"]] = False
        elif "last" in whichSpk:
            if len(spkDict["lastInTheta"]) < len(spkDict["spkRunLabel"]):
                spkDict["lastInTheta"] = np.insert(spkDict["lastInTheta"], -1, False)
            spkUsed[~spkDict["lastInTheta"]] = False
        spkPosIdxUsed = spkDict["spkPosIdx"].astype(int)
        # copy self.regressors and update with spk/ pos of interest
        regressors = self.regressors.copy()
        for k in regressors.keys():
            if k.startswith("spk_"):
                regressors[k]["values"] = regressors[k]["values"][spkUsed]
            elif k.startswith("pos_"):
                regressors[k]["values"] = regressors[k]["values"][
                    spkPosIdxUsed[spkUsed]
                ]
        phase = phase[spkDict["spkEEGIdx"][spkUsed]]
        phase = phase.astype(np.double)
        if "mean" in whichSpk:
            goodPhase = ~np.isnan(phase)
            cycleLabels = spkDict["thetaCycleLabel"][spkUsed]
            sz = np.max(cycleLabels)
            cycleComplexPhase = np.squeeze(np.zeros(sz, dtype=np.complex))
            np.add.at(
                cycleComplexPhase,
                cycleLabels[goodPhase] - 1,
                np.exp(1j * phase[goodPhase]),
            )
            phase = np.angle(cycleComplexPhase)
            spkCountPerCycle = np.bincount(cycleLabels[goodPhase], minlength=sz)
            for k in regressors.keys():
                regressors[k]["values"] = (
                    np.bincount(
                        cycleLabels[goodPhase],
                        weights=regressors[k]["values"][goodPhase],
                        minlength=sz,
                    )
                    / spkCountPerCycle
                )

        goodPhase = ~np.isnan(phase)
        for k in regressors.keys():
            print(f"Doing regression: {k}")
            goodRegressor = ~np.isnan(regressors[k]["values"])
            reg = regressors[k]["values"][np.logical_and(goodRegressor, goodPhase)]
            pha = phase[np.logical_and(goodRegressor, goodPhase)]
            regressors[k]["slope"],
            regressors[k]["intercept"] = circRegress(reg, pha)
            regressors[k]["pha"] = pha
            mnx = np.mean(reg)
            reg = reg - mnx
            mxx = np.max(np.abs(reg)) + np.spacing(1)
            reg = reg / mxx
            # problem regressors = instFR, pos_d_cum
            breakpoint()
            theta = np.mod(np.abs(regressors[k]["slope"]) * reg, 2 * np.pi)
            rho, p, rho_boot, p_shuff, ci = circCircCorrTLinear(
                theta, pha, self.k, self.alpha, self.hyp, self.conf
            )
            regressors[k]["reg"] = reg
            regressors[k]["cor"] = rho
            regressors[k]["p"] = p
            regressors[k]["cor_boot"] = rho_boot
            regressors[k]["p_shuffled"] = p_shuff
            regressors[k]["ci"] = ci

        if plot:
            fig = plt.figure()

            ax = fig.add_subplot(2, 1, 1)
            ax.plot(regressors["pos_d_currentdir"]["values"], phase, "k.")
            ax.plot(regressors["pos_d_currentdir"]["values"], phase + 2 * np.pi, "k.")
            slope = regressors["pos_d_currentdir"]["slope"]
            intercept = regressors["pos_d_currentdir"]["intercept"]
            mm = (0, -2 * np.pi, 2 * np.pi, 4 * np.pi)
            for m in mm:
                ax.plot(
                    (-1, 1), (-slope + intercept + m, slope + intercept + m), "r", lw=3
                )
            ax.set_xlim(-1, 1)
            ax.set_ylim(-np.pi, 3 * np.pi)
            ax.set_title("pos_d_currentdir")
            ax.set_ylabel("Phase")

            ax = fig.add_subplot(2, 1, 2)
            ax.plot(regressors["pos_d_meanDir"]["values"], phase, "k.")
            ax.plot(regressors["pos_d_meanDir"]["values"], phase + 2 * np.pi, "k.")
            slope = regressors["pos_d_meanDir"]["slope"]
            intercept = regressors["pos_d_meanDir"]["intercept"]
            mm = (0, -2 * np.pi, 2 * np.pi, 4 * np.pi)
            for m in mm:
                ax.plot(
                    (-1, 1), (-slope + intercept + m, slope + intercept + m), "r", lw=3
                )
            ax.set_xlim(-1, 1)
            ax.set_ylim(-np.pi, 3 * np.pi)
            ax.set_title("pos_d_meanDir")
            ax.set_ylabel("Phase")
            ax.set_xlabel("Normalised position")
        self.reg_phase = phase
        return regressors

    def plotPPRegression(self, regressorDict, regressor2plot="pos_d_cum", ax=None):

        t = self.getLFPPhaseValsForSpikeTS()
        x = self.RateMap.xy[0, self.spk_times_in_pos_samples]
        from ephysiopy.common import fieldcalcs

        rmap, (xe, _) = self.RateMap.getMap(self.spk_weights)
        label = fieldcalcs.field_lims(rmap)
        xInField = xe[label.nonzero()[1]]
        mask = np.logical_and(x > np.min(xInField), x < np.max(xInField))
        x = x[mask]
        t = t[mask]
        # keep x between -1 and +1
        mnx = np.mean(x)
        xn = x - mnx
        mxx = np.max(np.abs(xn))
        x = xn / mxx
        # keep tn between 0 and 2pi
        t = np.remainder(t, 2 * np.pi)
        slope, intercept = circRegress(x, t)
        rho, p, rho_boot, p_shuff, ci = circCircCorrTLinear(x, t)
        plt.figure()
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
        else:
            ax = ax
        ax.plot(x, t, ".", color="k")
        ax.plot(x, t + 2 * np.pi, ".", color="k")
        mm = (0, -2 * np.pi, 2 * np.pi, 4 * np.pi)
        for m in mm:
            ax.plot((-1, 1), (-slope + intercept + m, slope + intercept + m), "r", lw=3)
        ax.set_xlim((-1, 1))
        ax.set_ylim((-np.pi, 3 * np.pi))
        return {
            "slope": slope,
            "intercept": intercept,
            "rho": rho,
            "p": p,
            "rho_boot": rho_boot,
            "p_shuff": p_shuff,
            "ci": ci,
        }

    def getLFPPhaseValsForSpikeTS(self):
        ts = self.spk_times_in_pos_samples * (
            self.lfp_sample_rate / self.pos_sample_rate
        )
        ts_idx = np.array(np.floor(ts), dtype=int)
        return self.phase[ts_idx]

getPosProps(labels, peaksXY, laserEvents=None, plot=False, **kwargs)

Uses the output of partitionFields and returns vectors the same length as pos.

Parameters:

Name Type Description Default
tetrode, cluster (int

The tetrode / cluster to examine

required
peaksXY array_like

The x-y coords of the peaks in the ratemap

required
laserEvents array_like

The position indices of on events

None

Returns:

Type Description

pos_dict, run_dict (dict): Contains a whole bunch of information

for the whole trial and also on a run-by-run basis (run_dict).

See the end of this function for all the key / value pairs.

Source code in ephysiopy/common/phasecoding.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
def getPosProps(self, labels, peaksXY, laserEvents=None, plot=False, **kwargs):
    """
    Uses the output of partitionFields and returns vectors the same
    length as pos.

    Args:
        tetrode, cluster (int): The tetrode / cluster to examine
        peaksXY (array_like): The x-y coords of the peaks in the ratemap
        laserEvents (array_like): The position indices of on events
        (laser on)

    Returns:
        pos_dict, run_dict (dict): Contains a whole bunch of information
        for the whole trial and also on a run-by-run basis (run_dict).
        See the end of this function for all the key / value pairs.
    """

    spikeTS = self.spike_ts  # in seconds
    xy = self.RateMap.xy
    xydir = self.RateMap.dir
    spd = self.RateMap.speed
    spkPosInd = np.ceil(spikeTS * self.pos_sample_rate).astype(int)
    spkPosInd[spkPosInd > len(xy.T)] = len(xy.T) - 1
    nPos = xy.shape[1]
    xy_old = xy.copy()
    xydir = np.squeeze(xydir)
    xydir_old = xydir.copy()

    rmap, (x_bins_in_pixels, y_bins_in_pixels) = self.RateMap.getMap(
        self.spk_weights
    )
    xe = x_bins_in_pixels
    ye = y_bins_in_pixels

    # The large number of bins combined with the super-smoothed ratemap
    # will lead to fields labelled with lots of small holes in. Fill those
    # gaps in here and calculate the perimeter of the fields based on that
    # labelled image
    labels, n_labels = ndimage.label(ndimage.binary_fill_holes(labels))

    rmap[np.isnan(rmap)] = 0
    xBins = np.digitize(xy[0], ye[:-1])
    yBins = np.digitize(xy[1], xe[:-1])
    fieldLabel = labels[yBins - 1, xBins - 1]
    fl_counts, fl_bins = np.histogram(fieldLabel, bins=np.unique(labels))
    for i, fl in enumerate(fl_bins[1::]):
        print("Field {} has {} samples".format(i, fl_counts[i]))

    fieldPerimMask = bwperim(labels)
    fieldPerimYBins, fieldPerimXBins = np.nonzero(fieldPerimMask)
    fieldPerimX = ye[fieldPerimXBins]
    fieldPerimY = xe[fieldPerimYBins]
    fieldPerimXY = np.vstack((fieldPerimX, fieldPerimY))
    peaksXYBins = np.array(
        ndimage.maximum_position(rmap, labels=labels, index=np.unique(labels)[1::])
    ).astype(int)
    peakY = xe[peaksXYBins[:, 0]]
    peakX = ye[peaksXYBins[:, 1]]
    peaksXY = np.vstack((peakX, peakY)).T

    posRUnsmthd = np.zeros((nPos)) * np.nan
    posAngleFromPeak = np.zeros_like(posRUnsmthd) * np.nan
    perimAngleFromPeak = np.zeros_like(fieldPerimMask) * np.nan
    for i, peak in enumerate(peaksXY):
        i = i + 1
        # grab each fields perim coords and the pos samples within it
        y_ind, x_ind = np.nonzero(fieldPerimMask == i)
        thisFieldPerim = np.array([xe[x_ind], ye[y_ind]])
        if thisFieldPerim.any():
            this_xy = xy[:, fieldLabel == i]
            # calculate angle from the field peak for each point on the
            # perim and each pos sample that lies within the field
            thisPerimAngle = np.arctan2(
                thisFieldPerim[1, :] - peak[1], thisFieldPerim[0, :] - peak[0]
            )
            thisPosAngle = np.arctan2(
                this_xy[1, :] - peak[1], this_xy[0, :] - peak[0]
            )
            posAngleFromPeak[fieldLabel == i] = thisPosAngle

            perimAngleFromPeak[fieldPerimMask == i] = thisPerimAngle
            # for each pos sample calculate which point on the perim is
            # most colinear with the field centre - see _circ_abs for more
            thisAngleDf = circ_abs(
                thisPerimAngle[:, np.newaxis] - thisPosAngle[np.newaxis, :]
            )
            thisPerimInd = np.argmin(thisAngleDf, 0)
            # calculate the distance to the peak from pos and the min perim
            # point and calculate the ratio (r - see OUtputs for method)
            tmp = this_xy.T - peak.T
            distFromPos2Peak = np.hypot(tmp[:, 0], tmp[:, 1])
            tmp = thisFieldPerim[:, thisPerimInd].T - peak.T
            distFromPerim2Peak = np.hypot(tmp[:, 0], tmp[:, 1])
            posRUnsmthd[fieldLabel == i] = distFromPos2Peak / distFromPerim2Peak
    # the skimage find_boundaries method combined with the labelled mask
    # strive to make some of the values in thisDistFromPos2Peak larger than
    # those in thisDistFromPerim2Peak which means that some of the vals in
    # posRUnsmthd larger than 1 which means the values in xy_new later are
    # wrong - so lets cap any value > 1 to 1. The same cap is applied later
    # to rho when calculating the angular values. Print out a warning
    # message letting the user know how many values have been capped
    print(
        "\n\n{:.2%} posRUnsmthd values have been capped to 1\n\n".format(
            np.sum(posRUnsmthd >= 1) / posRUnsmthd.size
        )
    )
    posRUnsmthd[posRUnsmthd >= 1] = 1
    # label non-zero contiguous runs with a unique id
    runLabel = labelContigNonZeroRuns(fieldLabel)
    isRun = runLabel > 0
    runStartIdx = getLabelStarts(runLabel)
    runEndIdx = getLabelEnds(runLabel)
    # find runs that are too short, have low speed or too few spikes
    runsSansSpikes = np.ones(len(runStartIdx), dtype=bool)
    spkRunLabels = runLabel[spkPosInd] - 1
    runsSansSpikes[spkRunLabels[spkRunLabels > 0]] = False
    k = signal.windows.boxcar(self.speed_smoothing_window_len) / float(
        self.speed_smoothing_window_len
    )
    spdSmthd = signal.convolve(np.squeeze(spd), k, mode="same")
    runDurationInPosBins = runEndIdx - runStartIdx + 1
    runsMinSpeed = []
    runId = np.unique(runLabel)[1::]
    for run in runId:
        runsMinSpeed.append(np.min(spdSmthd[runLabel == run]))
    runsMinSpeed = np.array(runsMinSpeed)
    badRuns = np.logical_or(
        np.logical_or(
            runsMinSpeed < self.minimum_allowed_run_speed,
            runDurationInPosBins < self.minimum_allowed_run_duration,
        ),
        runsSansSpikes,
    )
    badRuns = np.squeeze(badRuns)
    runLabel = applyFilter2Labels(~badRuns, runLabel)
    runStartIdx = runStartIdx[~badRuns]
    runEndIdx = runEndIdx[~badRuns]  # + 1
    runsMinSpeed = runsMinSpeed[~badRuns]
    runDurationInPosBins = runDurationInPosBins[~badRuns]
    isRun = runLabel > 0

    # calculate mean and std direction for each run
    runComplexMnDir = np.squeeze(np.zeros_like(runStartIdx))
    np.add.at(
        runComplexMnDir,
        runLabel[isRun] - 1,
        np.exp(1j * (xydir[isRun] * (np.pi / 180))),
    )
    meanDir = np.angle(runComplexMnDir)  # circ mean
    tortuosity = 1 - np.abs(runComplexMnDir) / runDurationInPosBins

    # caculate angular distance between the runs main direction and the
    # pos's direction to the peak centre
    posPhiUnSmthd = np.ones_like(fieldLabel) * np.nan
    posPhiUnSmthd[isRun] = posAngleFromPeak[isRun] - meanDir[runLabel[isRun] - 1]

    # smooth r and phi in cartesian space
    # convert to cartesian coords first
    posXUnSmthd, posYUnSmthd = pol2cart(posRUnsmthd, posPhiUnSmthd)
    posXYUnSmthd = np.vstack((posXUnSmthd, posYUnSmthd))

    # filter each run with filter of appropriate length
    filtLen = np.squeeze(
        np.floor((runEndIdx - runStartIdx + 1) * self.ifr_smoothing_constant)
    )
    xy_new = np.zeros_like(xy_old) * np.nan
    for i in range(len(runStartIdx)):
        if filtLen[i] > 2:
            filt = signal.firwin(
                int(filtLen[i] - 1),
                cutoff=self.spatial_lowpass_cutoff / self.pos_sample_rate * 2,
                window="blackman",
            )
            xy_new[:, runStartIdx[i] : runEndIdx[i]] = signal.filtfilt(
                filt, [1], posXYUnSmthd[:, runStartIdx[i] : runEndIdx[i]], axis=1
            )

    r, phi = cart2pol(xy_new[0], xy_new[1])
    r[r > 1] = 1

    # calculate the direction of the smoothed data
    xydir_new = np.arctan2(np.diff(xy_new[1]), np.diff(xy_new[0]))
    xydir_new = np.append(xydir_new, xydir_new[-1])
    xydir_new[runEndIdx] = xydir_new[runEndIdx - 1]

    # project the distance value onto the current direction
    d_currentdir = r * np.cos(xydir_new - phi)

    # calculate the cumulative distance travelled on each run
    dr = np.sqrt(np.diff(np.power(r, 2), 1))
    d_cumulative = labelledCumSum(np.insert(dr, 0, 0), runLabel)

    # calculate cumulative sum of the expected normalised firing rate
    exptdRate_cumulative = labelledCumSum(1 - r, runLabel)

    # direction projected onto the run mean direction is just the x coord
    d_meandir = xy_new[0]

    # smooth binned spikes to get an instantaneous firing rate
    # set up the smoothing kernel
    kernLenInBins = np.round(self.ifr_kernel_len * self.bins_per_second)
    kernSig = self.ifr_kernel_sigma * self.bins_per_second
    k = signal.windows.gaussian(kernLenInBins, kernSig)
    # get a count of spikes to smooth over
    spkCount = np.bincount(spkPosInd, minlength=nPos)
    # apply the smoothing kernel
    instFiringRate = signal.convolve(spkCount, k, mode="same")
    instFiringRate[~isRun] = np.nan

    # find time spent within run
    time = np.ones(nPos)
    time = labelledCumSum(time, runLabel)
    timeInRun = time / self.pos_sample_rate

    fieldNum = fieldLabel[runStartIdx]
    mnSpd = np.squeeze(np.zeros_like(fieldNum))
    np.add.at(mnSpd, runLabel[isRun] - 1, spd[isRun])
    nPts = np.bincount(runLabel[isRun] - 1, minlength=len(mnSpd))
    mnSpd = mnSpd / nPts
    centralPeripheral = np.squeeze(np.zeros_like(fieldNum))
    np.add.at(centralPeripheral, runLabel[isRun] - 1, xy_new[1, isRun])
    centralPeripheral = centralPeripheral / nPts
    if plot:
        fig = plt.figure()
        ax = fig.add_subplot(221)
        ax.plot(xy_new[0], xy_new[1])
        ax.set_title("Unit circle x-y")
        ax.set_aspect("equal")
        ax.set_xlim([-1, 1])
        ax.set_ylim([-1, 1])
        ax = fig.add_subplot(222)
        ax.plot(fieldPerimX, fieldPerimY, "k.")
        ax.set_title("Field perim and laser on events")
        ax.plot(xy[0, fieldLabel > 0], xy[1, fieldLabel > 0], "y.")
        if laserEvents is not None:
            validOns = np.setdiff1d(laserEvents, np.nonzero(~np.isnan(r))[0])
            ax.plot(xy[0, validOns], xy[1, validOns], "rx")
        ax.set_aspect("equal")
        angleCMInd = np.round(perimAngleFromPeak / np.pi * 180) + 180
        angleCMInd[angleCMInd == 0] = 360
        im = np.zeros_like(fieldPerimMask)
        im[fieldPerimMask] = angleCMInd
        imM = np.ma.MaskedArray(im, mask=~fieldPerimMask, copy=True)
        #############################################
        # create custom colormap
        cmap = plt.colormaps["jet_r"]
        cmaplist = [cmap(i) for i in range(cmap.N)]
        cmaplist[0] = (1, 1, 1, 1)
        cmap = cmap.from_list("Runvals cmap", cmaplist, cmap.N)
        bounds = np.linspace(0, 1.0, 100)
        norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)
        # add the runs through the fields
        runVals = np.zeros_like(im)
        runVals[yBins[isRun] - 1, xBins[isRun] - 1] = r[isRun]
        runVals = runVals
        ax = fig.add_subplot(223)
        imm = ax.imshow(
            runVals, cmap=cmap, norm=norm, origin="lower", interpolation="nearest"
        )
        plt.colorbar(imm, orientation="horizontal")
        ax.set_aspect("equal")
        # add a custom colorbar for colors in runVals

        # create a custom colormap for the plot
        cmap = matplotlib.colormaps["hsv"]
        cmaplist = [cmap(i) for i in range(cmap.N)]
        cmaplist[0] = (1, 1, 1, 1)
        cmap = cmap.from_list("Perim cmap", cmaplist, cmap.N)
        bounds = np.linspace(0, 360, cmap.N)
        norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)

        imm = ax.imshow(
            imM, cmap=cmap, norm=norm, origin="lower", interpolation="nearest"
        )
        plt.colorbar(imm)
        ax.set_title("Runs by distance and angle")
        ax.plot(peaksXYBins[:, 1], peaksXYBins[:, 0], "ko")
        ax.set_xlim(0, im.shape[1])
        ax.set_ylim(0, im.shape[0])
        #############################################
        ax = fig.add_subplot(224)
        ax.imshow(rmap, origin="lower", interpolation="nearest")
        ax.set_aspect("equal")
        ax.set_title("Smoothed ratemap")

    # update the regressor dict from __init__ with relevant values
    self.regressors["pos_exptdRate_cum"]["values"] = exptdRate_cumulative
    self.regressors["pos_instFR"]["values"] = instFiringRate
    self.regressors["pos_timeInRun"]["values"] = timeInRun
    self.regressors["pos_d_cum"]["values"] = d_cumulative
    self.regressors["pos_d_meanDir"]["values"] = d_meandir
    self.regressors["pos_d_currentdir"]["values"] = d_currentdir
    posKeys = (
        "xy",
        "xydir",
        "r",
        "phi",
        "xy_old",
        "xydir_old",
        "fieldLabel",
        "runLabel",
        "d_currentdir",
        "d_cumulative",
        "exptdRate_cumulative",
        "d_meandir",
        "instFiringRate",
        "timeInRun",
        "fieldPerimMask",
        "perimAngleFromPeak",
        "posAngleFromPeak",
    )
    runsKeys = (
        "runStartIdx",
        "runEndIdx",
        "runDurationInPosBins",
        "runsMinSpeed",
        "meanDir",
        "tortuosity",
        "mnSpd",
        "centralPeripheral",
    )
    posDict = dict.fromkeys(posKeys, np.nan)
    # neat trick: locals is a dict that holds all locally scoped variables
    for thiskey in posDict.keys():
        posDict[thiskey] = locals()[thiskey]
    runsDict = dict.fromkeys(runsKeys, np.nan)
    for thiskey in runsDict.keys():
        runsDict[thiskey] = locals()[thiskey]
    return posDict, runsDict

partitionFields(ftype='g', plot=False, **kwargs)

Partitions fields.

Partitions spikes into fields by finding the watersheds around the peaks of a super-smoothed ratemap

Parameters:

Name Type Description Default
spike_ts array

The ratemap to partition

required
ftype str

'p' or 'g' denoting place or grid cells - not implemented yet

'g'
plot bool

Whether to produce a debugging plot or not

False

Returns:

Name Type Description
peaksXY array_like

The xy coordinates of the peak rates in

each field

peaksRate array_like

The peak rates in peaksXY

labels ndarray

An array of the labels corresponding to

each field (starting at 1)

rmap ndarray

The ratemap of the tetrode / cluster

Source code in ephysiopy/common/phasecoding.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def partitionFields(self, ftype="g", plot=False, **kwargs):
    """
    Partitions fields.

    Partitions spikes into fields by finding the watersheds around the
    peaks of a super-smoothed ratemap

    Args:
        spike_ts (np.array): The ratemap to partition
        ftype (str): 'p' or 'g' denoting place or grid cells
          - not implemented yet
        plot (bool): Whether to produce a debugging plot or not

    Returns:
        peaksXY (array_like): The xy coordinates of the peak rates in
        each field
        peaksRate (array_like): The peak rates in peaksXY
        labels (numpy.ndarray): An array of the labels corresponding to
        each field (starting at 1)
        rmap (numpy.ndarray): The ratemap of the tetrode / cluster
    """
    rmap, (xe, ye) = self.RateMap.getMap(self.spk_weights)
    nan_idx = np.isnan(rmap)
    rmap[nan_idx] = 0
    # start image processing:
    # get some markers
    from ephysiopy.common import fieldcalcs

    markers = fieldcalcs.local_threshold(rmap, prc=self.field_threshold_percent)
    # clear the edges / any invalid positions again
    markers[nan_idx] = 0
    # label these markers so each blob has a unique id
    labels = ndimage.label(markers)[0]
    # labels is now a labelled int array from 0 to however many fields have
    # been detected
    # get the number of spikes in each field - NB this is done against a
    # flattened array so we need to figure out which count corresponds to
    # which particular field id using np.unique
    fieldId, _ = np.unique(labels, return_index=True)
    fieldId = fieldId[1::]
    # TODO: come back to this as may need to know field id ordering
    peakCoords = np.array(
        ndimage.maximum_position(rmap, labels=labels, index=fieldId)
    ).astype(int)
    # COMCoords = np.array(
    #     ndimage.center_of_mass(
    #         rmap, labels=labels, index=fieldId)
    # ).astype(int)
    peaksXY = np.vstack((xe[peakCoords[:, 0]], ye[peakCoords[:, 1]])).T
    # find the peak rate at each of the centre of the detected fields to
    # subsequently threshold the field at some fraction of the peak value
    peakRates = rmap[peakCoords[:, 0], peakCoords[:, 1]]
    fieldThresh = peakRates * self.field_threshold
    rmFieldMask = np.zeros_like(rmap)
    for fid in fieldId:
        f = labels[peakCoords[fid - 1, 0], peakCoords[fid - 1, 1]]
        rmFieldMask[labels == f] = rmap[labels == f] > fieldThresh[f - 1]
    labels[~rmFieldMask.astype(bool)] = 0
    # peakBinInds = np.ceil(peakCoords)
    # re-order some vars to get into same format as fieldLabels
    peakLabels = labels[peakCoords[:, 0], peakCoords[:, 1]]
    peaksXY = peaksXY[peakLabels - 1, :]
    peaksRate = peakRates[peakLabels - 1]
    # peakBinInds = peakBinInds[peakLabels-1, :]
    # peaksXY = peakCoords - np.min(xy, 1)

    # if ~np.isnan(self.area_threshold):
    #     # TODO: this needs fixing so sensible values are used and the
    #     # modified bool array is propagated correctly ie makes
    #     # sense to have a function that applies a bool array to whatever
    #     # arrays are used as output and call it in a couple of places
    #     # areaInBins = self.area_threshold * self.binsPerCm
    #     lb = ndimage.label(markers)[0]
    #     rp = skimage.measure.regionprops(lb)
    #     for reg in rp:
    #         print(reg.filled_area)
    #     markers = skimage.morphology.remove_small_objects(
    #         lb, min_size=4000, connectivity=4, in_place=True)
    if plot:
        fig = plt.figure()
        ax = fig.add_subplot(211)
        ax.pcolormesh(
            ye, xe, rmap, cmap=matplotlib.colormaps["jet"], edgecolors="face"
        )
        ax.set_title("Smoothed ratemap + peaks")
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.set_aspect("equal")
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        ax.plot(peaksXY[:, 1], peaksXY[:, 0], "ko")
        ax.set_ylim(ylim)
        ax.set_xlim(xlim)

        ax = fig.add_subplot(212)
        ax.imshow(labels, interpolation="nearest", origin="lower")
        ax.set_title("Labelled restricted fields")
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.set_aspect("equal")

    return peaksXY, peaksRate, labels, rmap

performRegression(laserEvents=None, **kwargs)

Wrapper function for doing the actual regression which has multiple stages.

Specifically here we partition fields into sub-fields, get a bunch of information about the position, spiking and theta data and then do the actual regression.

Parameters:

Name Type Description Default
tetrode int

The tetrode to examine

required
cluster int

The cluster to examine

required
laserEvents array_like

The on times for laser events

None
See Also

ephysiopy.common.eegcalcs.phasePrecession.partitionFields() ephysiopy.common.eegcalcs.phasePrecession.getPosProps() ephysiopy.common.eegcalcs.phasePrecession.getThetaProps() ephysiopy.common.eegcalcs.phasePrecession.getSpikeProps() ephysiopy.common.eegcalcs.phasePrecession._ppRegress()

Source code in ephysiopy/common/phasecoding.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def performRegression(self, laserEvents=None, **kwargs):
    """
    Wrapper function for doing the actual regression which has multiple
    stages.

    Specifically here we partition fields into sub-fields, get a bunch of
    information about the position, spiking and theta data and then
    do the actual regression.

    Args:
        tetrode (int): The tetrode to examine
        cluster (int): The cluster to examine
        laserEvents (array_like, optional): The on times for laser events
        if present. Default is None

    See Also:
        ephysiopy.common.eegcalcs.phasePrecession.partitionFields()
        ephysiopy.common.eegcalcs.phasePrecession.getPosProps()
        ephysiopy.common.eegcalcs.phasePrecession.getThetaProps()
        ephysiopy.common.eegcalcs.phasePrecession.getSpikeProps()
        ephysiopy.common.eegcalcs.phasePrecession._ppRegress()
    """

    # Partition fields
    peaksXY, _, labels, _ = self.partitionFields(plot=True)

    # split into runs
    posD, runD = self.getPosProps(
        labels, peaksXY, laserEvents=laserEvents, plot=True
    )

    # get theta cycles, amplitudes, phase etc
    self.getThetaProps()

    # get the indices of spikes for various metrics such as
    # theta cycle, run etc
    spkD = self.getSpikeProps(
        posD["runLabel"], runD["meanDir"], runD["runDurationInPosBins"]
    )

    # Do the regressions
    regress_dict = self._ppRegress(spkD, plot=True)

    self.plotPPRegression(regress_dict)

applyFilter2Labels(M, x)

M is a logical mask specifying which label numbers to keep x is an array of positive integer labels

This method sets the undesired labels to 0 and renumbers the remaining labels 1 to n when n is the number of trues in M

Source code in ephysiopy/common/phasecoding.py
987
988
989
990
991
992
993
994
995
996
997
def applyFilter2Labels(M, x):
    """
    M is a logical mask specifying which label numbers to keep
    x is an array of positive integer labels

    This method sets the undesired labels to 0 and renumbers the remaining
    labels 1 to n when n is the number of trues in M
    """
    newVals = M * np.cumsum(M)
    x[x > 0] = newVals[x[x > 0] - 1]
    return x

ccc(t, p)

Calculates correlation between two random circular variables

Source code in ephysiopy/common/phasecoding.py
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
def ccc(t, p):
    """
    Calculates correlation between two random circular variables
    """
    n = len(t)
    A = np.sum(np.cos(t) * np.cos(p))
    B = np.sum(np.sin(t) * np.sin(p))
    C = np.sum(np.cos(t) * np.sin(p))
    D = np.sum(np.sin(t) * np.cos(p))
    E = np.sum(np.cos(2 * t))
    F = np.sum(np.sin(2 * t))
    G = np.sum(np.cos(2 * p))
    H = np.sum(np.sin(2 * p))
    rho = 4 * (A * B - C * D) / np.sqrt((n**2 - E**2 - F**2) * (n**2 - G**2 - H**2))
    return rho

ccc_jack(t, p)

Function used to calculate jackknife estimates of correlation

Source code in ephysiopy/common/phasecoding.py
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
def ccc_jack(t, p):
    """
    Function used to calculate jackknife estimates of correlation
    """
    n = len(t) - 1
    A = np.cos(t) * np.cos(p)
    A = np.sum(A) - A
    B = np.sin(t) * np.sin(p)
    B = np.sum(B) - B
    C = np.cos(t) * np.sin(p)
    C = np.sum(C) - C
    D = np.sin(t) * np.cos(p)
    D = np.sum(D) - D
    E = np.cos(2 * t)
    E = np.sum(E) - E
    F = np.sin(2 * t)
    F = np.sum(F) - F
    G = np.cos(2 * p)
    G = np.sum(G) - G
    H = np.sin(2 * p)
    H = np.sum(H) - H
    rho = 4 * (A * B - C * D) / np.sqrt((n**2 - E**2 - F**2) * (n**2 - G**2 - H**2))
    return rho

circCircCorrTLinear(theta, phi, k=1000, alpha=0.05, hyp=0, conf=True)

An almost direct copy from AJs Matlab fcn to perform correlation between 2 circular random variables.

Returns the correlation value (rho), p-value, bootstrapped correlation values, shuffled p values and correlation values.

Parameters:

Name Type Description Default
theta, phi (array_like

mx1 array containing circular data (radians) whose correlation is to be measured

required
k int

number of permutations to use to calculate p-value from randomisation and bootstrap estimation of confidence intervals. Leave empty to calculate p-value analytically (NB confidence intervals will not be calculated). Default is 1000.

1000
alpha float

hypothesis test level e.g. 0.05, 0.01 etc. Default is 0.05.

0.05
hyp int

hypothesis to test; -1/ 0 / 1 (-ve correlated / correlated in either direction / positively correlated). Default is 0.

0
conf bool

True or False to calculate confidence intervals via jackknife or bootstrap. Default is True.

True
References

Fisher (1993), Statistical Analysis of Circular Data, Cambridge University Press, ISBN: 0 521 56890 0

Source code in ephysiopy/common/phasecoding.py
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
def circCircCorrTLinear(theta, phi, k=1000, alpha=0.05, hyp=0, conf=True):
    """
    An almost direct copy from AJs Matlab fcn to perform correlation
    between 2 circular random variables.

    Returns the correlation value (rho), p-value, bootstrapped correlation
    values, shuffled p values and correlation values.

    Args:
        theta, phi (array_like): mx1 array containing circular data (radians)
            whose correlation is to be measured
        k (int, optional): number of permutations to use to calculate p-value
            from randomisation and bootstrap estimation of confidence
            intervals.
            Leave empty to calculate p-value analytically (NB confidence
            intervals will not be calculated). Default is 1000.
        alpha (float, optional): hypothesis test level e.g. 0.05, 0.01 etc.
            Default is 0.05.
        hyp (int, optional): hypothesis to test; -1/ 0 / 1 (-ve correlated /
            correlated in either direction / positively correlated).
            Default is 0.
        conf (bool, optional): True or False to calculate confidence intervals
            via jackknife or bootstrap. Default is True.

    References:
        Fisher (1993), Statistical Analysis of Circular Data,
            Cambridge University Press, ISBN: 0 521 56890 0
    """
    theta = theta.ravel()
    phi = phi.ravel()

    if not len(theta) == len(phi):
        print("theta and phi not same length - try again!")
        raise ValueError()

    # estimate correlation
    rho = ccc(theta, phi)
    n = len(theta)

    # derive p-values
    if k:
        p_shuff = shuffledPVal(theta, phi, rho, k, hyp)
        p = np.nan

    # estimtate ci's for correlation
    if n >= 25 and conf:
        # obtain jackknife estimates of rho and its ci's
        rho_jack = ccc_jack(theta, phi)
        rho_jack = n * rho - (n - 1) * rho_jack
        rho_boot = np.mean(rho_jack)
        rho_jack_std = np.std(rho_jack)
        ci = (
            rho_boot - (1 / np.sqrt(n)) * rho_jack_std * norm.ppf(alpha / 2, (0, 1))[0],
            rho_boot + (1 / np.sqrt(n)) * rho_jack_std * norm.ppf(alpha / 2, (0, 1))[0],
        )
    elif conf and k and n < 25 and n > 4:
        from sklearn.utils import resample

        # set up the bootstrapping parameters
        boot_samples = []
        for i in range(k):
            theta_sample = resample(theta, replace=True)
            phi_sample = resample(phi, replace=True)
            boot_samples.append(ccc(theta_sample, phi_sample))
        rho_boot = np.mean(boot_samples)
        # confidence intervals
        p = ((1.0 - alpha) / 2.0) * 100
        lower = max(0.0, np.percentile(boot_samples, p))
        p = (alpha + ((1.0 - alpha) / 2.0)) * 100
        upper = min(1.0, np.percentile(boot_samples, p))

        ci = (lower, upper)
    else:
        rho_boot = np.nan
        ci = np.nan

    return rho, p, rho_boot, p_shuff, ci

circRegress(x, t)

Finds approximation to circular-linear regression for phase precession.

Parameters:

Name Type Description Default
x list

n-by-1 list of in-field positions (linear variable)

required
t list

n-by-1 list of phases, in degrees (converted to radians)

required
Note

Neither x nor t can contain NaNs, must be paired (of equal length).

Source code in ephysiopy/common/phasecoding.py
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
def circRegress(x, t):
    """
    Finds approximation to circular-linear regression for phase precession.

    Args:
        x (list): n-by-1 list of in-field positions (linear variable)
        t (list): n-by-1 list of phases, in degrees (converted to radians)

    Note:
        Neither x nor t can contain NaNs, must be paired (of equal length).
    """
    # transform the linear co-variate to the range -1 to 1
    if not np.any(x) or not np.any(t):
        return x, t
    mnx = np.mean(x)
    xn = x - mnx
    mxx = np.max(np.fabs(xn))
    xn = xn / mxx
    # keep tn between 0 and 2pi
    tn = np.remainder(t, 2 * np.pi)
    # constrain max slope to give at most 720 degrees of phase precession
    # over the field
    max_slope = (2 * np.pi) / (np.max(xn) - np.min(xn))

    # perform slope optimisation and find intercept
    def _cost(m, x, t):
        return -np.abs(np.sum(np.exp(1j * (t - m * x)))) / len(t - m * x)

    slope = optimize.fminbound(_cost, -1 * max_slope, max_slope, args=(xn, tn))
    intercept = np.arctan2(
        np.sum(np.sin(tn - slope * xn)), np.sum(np.cos(tn - slope * xn))
    )
    intercept = intercept + ((0 - slope) * (mnx / mxx))
    slope = slope / mxx
    return slope, intercept

fixAngle(a)

Ensure angles lie between -pi and pi a must be in radians

Source code in ephysiopy/common/phasecoding.py
1047
1048
1049
1050
1051
1052
1053
def fixAngle(a):
    """
    Ensure angles lie between -pi and pi
    a must be in radians
    """
    b = np.mod(a + np.pi, 2 * np.pi) - np.pi
    return b

shuffledPVal(theta, phi, rho, k, hyp)

Calculates shuffled p-values for correlation

Source code in ephysiopy/common/phasecoding.py
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
def shuffledPVal(theta, phi, rho, k, hyp):
    """
    Calculates shuffled p-values for correlation
    """
    n = len(theta)
    idx = np.zeros((n, k))
    for i in range(k):
        idx[:, i] = np.random.permutation(np.arange(n))

    thetaPerms = theta[idx.astype(int)]

    A = np.dot(np.cos(phi), np.cos(thetaPerms))
    B = np.dot(np.sin(phi), np.sin(thetaPerms))
    C = np.dot(np.sin(phi), np.cos(thetaPerms))
    D = np.dot(np.cos(phi), np.sin(thetaPerms))
    E = np.sum(np.cos(2 * theta))
    F = np.sum(np.sin(2 * theta))
    G = np.sum(np.cos(2 * phi))
    H = np.sum(np.sin(2 * phi))

    rho_sim = 4 * (A * B - C * D) / np.sqrt((n**2 - E**2 - F**2) * (n**2 - G**2 - H**2))

    if hyp == 1:
        p_shuff = np.sum(rho_sim >= rho) / float(k)
    elif hyp == -1:
        p_shuff = np.sum(rho_sim <= rho) / float(k)
    elif hyp == 0:
        p_shuff = np.sum(np.fabs(rho_sim) > np.fabs(rho)) / float(k)
    else:
        p_shuff = np.nan

    return p_shuff

Rhymicity

CosineDirectionalTuning

Bases: object

Produces output to do with Welday et al (2011) like analysis of rhythmic firing a la oscialltory interference model

Source code in ephysiopy/common/rhythmicity.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
class CosineDirectionalTuning(object):
    """
    Produces output to do with Welday et al (2011) like analysis
    of rhythmic firing a la oscialltory interference model
    """

    def __init__(
        self,
        spike_times: np.array,
        pos_times: np.array,
        spk_clusters: np.array,
        x: np.array,
        y: np.array,
        tracker_params={},
    ):
        """
        Args:
            spike_times (1d np.array): Spike times
            pos_times (1d np.array): Position times
            spk_clusters (1d np.array): Spike clusters
            x and y (1d np.array): Position coordinates
            tracker_params (dict): From the PosTracker as created in
                                    OESettings.Settings.parse

        Note:
            All timestamps should be given in sub-millisecond accurate seconds
            and pos_xy in cms
        """
        self.spike_times = spike_times
        self.pos_times = pos_times
        self.spk_clusters = spk_clusters
        """
        There can be more spikes than pos samples in terms of sampling as the
        open-ephys buffer probably needs to finish writing and the camera has
        already stopped, so cut of any cluster indices and spike times
        that exceed the length of the pos indices
        """
        idx_to_keep = self.spike_times < self.pos_times[-1]
        self.spike_times = self.spike_times[idx_to_keep]
        self.spk_clusters = self.spk_clusters[idx_to_keep]
        self._pos_sample_rate = 30
        self._spk_sample_rate = 3e4
        self._pos_samples_for_spike = None
        self._min_runlength = 0.4  # in seconds
        self.posCalcs = PosCalcsGeneric(
            x, y, 230, cm=True, jumpmax=100, tracker_params=tracker_params
        )
        self.spikeCalcs = SpikeCalcsGeneric(spike_times, spk_clusters[0])
        self.spikeCalcs.spk_clusters = spk_clusters
        self.posCalcs.postprocesspos(tracker_params)
        xy = self.posCalcs.xy
        hdir = self.posCalcs.dir
        self.posCalcs.calcSpeed(xy)
        self._xy = xy
        self._hdir = hdir
        self._speed = self.posCalcs.speed
        # TEMPORARY FOR POWER SPECTRUM STUFF
        self.smthKernelWidth = 2
        self.smthKernelSigma = 0.1875
        self.sn2Width = 2
        self.thetaRange = [7, 11]
        self.xmax = 11

    @property
    def spk_sample_rate(self):
        return self._spk_sample_rate

    @spk_sample_rate.setter
    def spk_sample_rate(self, value):
        self._spk_sample_rate = value

    @property
    def pos_sample_rate(self):
        return self._pos_sample_rate

    @pos_sample_rate.setter
    def pos_sample_rate(self, value):
        self._pos_sample_rate = value

    @property
    def min_runlength(self):
        return self._min_runlength

    @min_runlength.setter
    def min_runlength(self, value):
        self._min_runlength = value

    @property
    def xy(self):
        return self._xy

    @xy.setter
    def xy(self, value):
        self._xy = value

    @property
    def hdir(self):
        return self._hdir

    @hdir.setter
    def hdir(self, value):
        self._hdir = value

    @property
    def speed(self):
        return self._speed

    @speed.setter
    def speed(self, value):
        self._speed = value

    @property
    def pos_samples_for_spike(self):
        return self._pos_samples_for_spike

    @pos_samples_for_spike.setter
    def pos_samples_for_spike(self, value):
        self._pos_samples_for_spike = value

    def _rolling_window(self, a: np.array, window: int):
        """
        Totally nabbed from SO:
        https://stackoverflow.com/questions/6811183/rolling-window-for-1d-arrays-in-numpy
        """
        shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
        strides = a.strides + (a.strides[-1],)
        return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

    def getPosIndices(self):
        self.pos_samples_for_spike = np.floor(
            self.spike_times * self.pos_sample_rate
        ).astype(int)

    def getClusterPosIndices(self, clust: int) -> np.array:
        if self.pos_samples_for_spike is None:
            self.getPosIndices()
        clust_pos_idx = self.pos_samples_for_spike[self.spk_clusters == clust]
        clust_pos_idx[clust_pos_idx >= len(self.pos_times)] = (
            len(self.pos_times) - 1
        )
        return clust_pos_idx

    def getClusterSpikeTimes(self, cluster: int):
        ts = self.spike_times[self.spk_clusters == cluster]
        if self.pos_samples_for_spike is None:
            self.getPosIndices()
        return ts

    def getDirectionalBinPerPosition(self, binwidth: int):
        """
        Direction is in degrees as that what is created by me in some of the
        other bits of this package.

        Args:
            binwidth (int): The bin width in degrees

        Returns:
            A digitization of which directional bin each pos sample belongs to
        """

        bins = np.arange(0, 360, binwidth)
        return np.digitize(self.hdir, bins)

    def getDirectionalBinForCluster(self, cluster: int):
        b = self.getDirectionalBinPerPosition(45)
        cluster_pos = self.getClusterPosIndices(cluster)
        # idx_to_keep = cluster_pos < len(self.pos_times)
        # cluster_pos = cluster_pos[idx_to_keep]
        return b[cluster_pos]

    def getRunsOfMinLength(self):
        """
        Identifies runs of at least self.min_runlength seconds long,
        which at 30Hz pos sampling rate equals 12 samples, and
        returns the start and end indices at which
        the run was occurred and the directional bin that run belongs to

        Returns:
            np.array: The start and end indices into pos samples of the run
                      and the directional bin to which it belongs
        """

        b = self.getDirectionalBinPerPosition(45)
        # nabbed from SO
        from itertools import groupby

        grouped_runs = [(k, sum(1 for i in g)) for k, g in groupby(b)]
        grouped_runs = np.array(grouped_runs)
        run_start_indices = np.cumsum(grouped_runs[:, 1]) - grouped_runs[:, 1]
        min_len_in_samples = int(self.pos_sample_rate * self.min_runlength)
        min_len_runs_mask = grouped_runs[:, 1] >= min_len_in_samples
        ret = np.array(
            [run_start_indices[min_len_runs_mask],
                grouped_runs[min_len_runs_mask, 1]]
        ).T
        # ret contains run length as last column
        ret = np.insert(ret, 1, np.sum(ret, 1), 1)
        ret = np.insert(ret, 2, grouped_runs[min_len_runs_mask, 0], 1)
        return ret[:, 0:3]

    def speedFilterRuns(self, runs: np.array, minspeed=5.0):
        """
        Given the runs identified in getRunsOfMinLength, filter for speed
        and return runs that meet the min speed criteria.

        The function goes over the runs with a moving window of length equal
        to self.min_runlength in samples and sees if any of those segments
        meet the speed criteria and splits them out into separate runs if true.

        NB For now this means the same spikes might get included in the
        autocorrelation procedure later as the
        moving window will use overlapping periods - can be modified later.

        Args:
            runs (3 x nRuns np.array): Generated from getRunsOfMinLength
            minspeed (float): Min running speed in cm/s for an epoch (minimum
                              epoch length defined previously
                              in getRunsOfMinLength as minlength, usually 0.4s)

        Returns:
            3 x nRuns np.array: A modified version of the "runs" input variable
        """
        minlength_in_samples = int(self.pos_sample_rate * self.min_runlength)
        run_list = runs.tolist()
        all_speed = np.array(self.speed)
        for start_idx, end_idx, dir_bin in run_list:
            this_runs_speed = all_speed[start_idx:end_idx]
            this_runs_runs = self._rolling_window(
                this_runs_speed, minlength_in_samples)
            run_mask = np.all(this_runs_runs > minspeed, 1)
            if np.any(run_mask):
                print("got one")

    """
    def testing(self, cluster: int):
        ts = self.getClusterSpikeTimes(cluster)
        pos_idx = self.getClusterPosIndices(cluster)

        dir_bins = self.getDirectionalBinPerPosition(45)
        cluster_dir_bins = dir_bins[pos_idx.astype(int)]

        from scipy.signal import periodogram, boxcar, filtfilt

        acorrs = []
        max_freqs = []
        max_idx = []
        isis = []

        acorr_range = np.array([-500, 500])
        for i in range(1, 9):
            this_bin_indices = cluster_dir_bins == i
            this_ts = ts[this_bin_indices]  # in seconds still so * 1000 for ms
            y = self.spikeCalcs.xcorr(this_ts*1000, Trange=acorr_range)
            isis.append(y)
            corr, acorr_bins = np.histogram(
                y[y != 0], bins=501, range=acorr_range)
            freqs, power = periodogram(corr, fs=200, return_onesided=True)
            # Smooth the power over +/- 1Hz
            b = boxcar(3)
            h = filtfilt(b, 3, power)
            # Square the amplitude first
            sqd_amp = h ** 2
            # Then find the mean power in the +/-1Hz band either side of that
            theta_band_max_idx = np.nonzero(
                sqd_amp == np.max(
                    sqd_amp[np.logical_and(freqs > 6, freqs < 11)]))[0][0]
            max_freq = freqs[theta_band_max_idx]
            acorrs.append(corr)
            max_freqs.append(max_freq)
            max_idx.append(theta_band_max_idx)
        return isis, acorrs, max_freqs, max_idx, acorr_bins

    def plotXCorrsByDirection(self, cluster: int):
        acorr_range = np.array([-500, 500])
        # plot_range = np.array([-400,400])
        nbins = 501
        isis, acorrs, max_freqs, max_idx, acorr_bins = self.testing(cluster)
        bin_labels = np.arange(0, 360, 45)
        fig, axs = plt.subplots(8)
        pts = []
        for i, a in enumerate(isis):
            axs[i].hist(
                a[a != 0], bins=nbins, range=acorr_range,
                color='k', histtype='stepfilled')
            # find the max of the first positive peak
            corr, _ = np.histogram(a[a != 0], bins=nbins, range=acorr_range)
            axs[i].set_xlim(acorr_range)
            axs[i].set_ylabel(str(bin_labels[i]))
            axs[i].set_yticklabels('')
            if i < 7:
                axs[i].set_xticklabels('')
            axs[i].spines['right'].set_visible(False)
            axs[i].spines['top'].set_visible(False)
            axs[i].spines['left'].set_visible(False)
        plt.show()
        return pts
    """

    def intrinsic_freq_autoCorr(
        self,
        spkTimes=None,
        posMask=None,
        maxFreq=25,
        acBinSize=0.002,
        acWindow=0.5,
        plot=True,
        **kwargs,
    ):
        """
        This is taken and adapted from ephysiopy.common.eegcalcs.EEGCalcs

        Args:
            spkTimes (np.array): Times in seconds of the cells firing
            posMask (np.array): Boolean array corresponding to the length of
                                spkTimes where True is stuff to keep
            maxFreq (float): The maximum frequency to do the power spectrum
                                out to
            acBinSize (float): The bin size of the autocorrelogram in seconds
            acWindow (float): The range of the autocorr in seconds

        Note:
            Make sure all times are in seconds
        """
        acBinsPerPos = 1.0 / self.pos_sample_rate / acBinSize
        acWindowSizeBins = np.round(acWindow / acBinSize)
        binCentres = np.arange(0.5, len(posMask) * acBinsPerPos) * acBinSize
        spkTrHist, _ = np.histogram(spkTimes, bins=binCentres)

        # split the single histogram into individual chunks
        splitIdx = np.nonzero(np.diff(posMask.astype(int)))[0] + 1
        splitMask = np.split(posMask, splitIdx)
        splitSpkHist = np.split(
            spkTrHist, (splitIdx * acBinsPerPos).astype(int))
        histChunks = []
        for i in range(len(splitSpkHist)):
            if np.all(splitMask[i]):
                if np.sum(splitSpkHist[i]) > 2:
                    if len(splitSpkHist[i]) > int(acWindowSizeBins) * 2:
                        histChunks.append(splitSpkHist[i])
        autoCorrGrid = np.zeros((int(acWindowSizeBins) + 1, len(histChunks)))
        chunkLens = []
        from scipy import signal

        print(f"num chunks = {len(histChunks)}")
        for i in range(len(histChunks)):
            lenThisChunk = len(histChunks[i])
            chunkLens.append(lenThisChunk)
            tmp = np.zeros(lenThisChunk * 2)
            tmp[lenThisChunk // 2: lenThisChunk //
                2 + lenThisChunk] = histChunks[i]
            tmp2 = signal.fftconvolve(
                tmp, histChunks[i][::-1], mode="valid"
            )  # the autocorrelation
            autoCorrGrid[:, i] = (
                tmp2[lenThisChunk // 2: lenThisChunk //
                     2 + int(acWindowSizeBins) + 1]
                / acBinsPerPos
            )

        totalLen = np.sum(chunkLens)
        autoCorrSum = np.nansum(autoCorrGrid, 1) / totalLen
        meanNormdAc = autoCorrSum[1::] - np.nanmean(autoCorrSum[1::])
        # return meanNormdAc
        out = self.power_spectrum(
            eeg=meanNormdAc,
            binWidthSecs=acBinSize,
            maxFreq=maxFreq,
            pad2pow=16,
            **kwargs,
        )
        out.update({"meanNormdAc": meanNormdAc})
        if plot:
            fig = plt.gcf()
            ax = fig.gca()
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            ax.imshow(
                autoCorrGrid,
                extent=[
                    maxFreq * 0.6,
                    maxFreq,
                    np.max(out["Power"]) * 0.6,
                    ax.get_ylim()[1],
                ],
            )
            ax.set_ylim(ylim)
            ax.set_xlim(xlim)
        return out

    def power_spectrum(
        self,
        eeg,
        plot=True,
        binWidthSecs=None,
        maxFreq=25,
        pad2pow=None,
        ymax=None,
        **kwargs,
    ):
        """
        Method used by eeg_power_spectra and intrinsic_freq_autoCorr
        Signal in must be mean normalised already
        """

        # Get raw power spectrum
        nqLim = 1 / binWidthSecs / 2.0
        origLen = len(eeg)
        # if pad2pow is None:
        # 	fftLen = int(np.power(2, self._nextpow2(origLen)))
        # else:
        fftLen = int(np.power(2, pad2pow))
        fftHalfLen = int(fftLen / float(2) + 1)

        fftRes = np.fft.fft(eeg, fftLen)
        # get power density from fft and discard second half of spectrum
        _power = np.power(np.abs(fftRes), 2) / origLen
        power = np.delete(_power, np.s_[fftHalfLen::])
        power[1:-2] = power[1:-2] * 2

        # calculate freqs and crop spectrum to requested range
        freqs = nqLim * np.linspace(0, 1, fftHalfLen)
        freqs = freqs[freqs <= maxFreq].T
        power = power[0: len(freqs)]

        # smooth spectrum using gaussian kernel
        binsPerHz = (fftHalfLen - 1) / nqLim
        kernelLen = np.round(self.smthKernelWidth * binsPerHz)
        kernelSig = self.smthKernelSigma * binsPerHz
        from scipy import signal

        k = signal.windows.gaussian(kernelLen, kernelSig) / (kernelLen / 2 / 2)
        power_sm = signal.fftconvolve(power, k[::-1], mode="same")

        # calculate some metrics
        # find max in theta band
        spectrumMaskBand = np.logical_and(
            freqs > self.thetaRange[0], freqs < self.thetaRange[1]
        )
        bandMaxPower = np.max(power_sm[spectrumMaskBand])
        maxBinInBand = np.argmax(power_sm[spectrumMaskBand])
        bandFreqs = freqs[spectrumMaskBand]
        freqAtBandMaxPower = bandFreqs[maxBinInBand]
        # self.maxBinInBand = maxBinInBand
        # self.freqAtBandMaxPower = freqAtBandMaxPower
        # self.bandMaxPower = bandMaxPower

        # find power in small window around peak and divide by power in rest
        # of spectrum to get snr
        spectrumMaskPeak = np.logical_and(
            freqs > freqAtBandMaxPower - self.sn2Width / 2,
            freqs < freqAtBandMaxPower + self.sn2Width / 2,
        )
        s2n = np.nanmean(power_sm[spectrumMaskPeak]) / np.nanmean(
            power_sm[~spectrumMaskPeak]
        )
        self.freqs = freqs
        self.power_sm = power_sm
        self.spectrumMaskPeak = spectrumMaskPeak
        if plot:
            fig = plt.figure()
            ax = fig.add_subplot(111)
            if ymax is None:
                ymax = np.min([2 * np.max(power), np.max(power_sm)])
                if ymax == 0:
                    ymax = 1
            ax.plot(freqs, power, c=[0.9, 0.9, 0.9])
            # ax.hold(True)
            ax.plot(freqs, power_sm, "k", lw=2)
            ax.axvline(self.thetaRange[0], c="b", ls="--")
            ax.axvline(self.thetaRange[1], c="b", ls="--")
            _, stemlines, _ = ax.stem([freqAtBandMaxPower], [
                                      bandMaxPower], linefmt="r")
            # plt.setp(stemlines, 'linewidth', 2)
            ax.fill_between(
                freqs,
                0,
                power_sm,
                where=spectrumMaskPeak,
                color="r",
                alpha=0.25,
                zorder=25,
            )
            # ax.set_ylim(0, ymax)
            # ax.set_xlim(0, self.xmax)
            ax.set_xlabel("Frequency (Hz)")
            ax.set_ylabel("Power density (W/Hz)")
        out_dict = {
            "maxFreq": freqAtBandMaxPower,
            "Power": power_sm,
            "Freqs": freqs,
            "s2n": s2n,
            "Power_raw": power,
            "k": k,
            "kernelLen": kernelLen,
            "kernelSig": kernelSig,
            "binsPerHz": binsPerHz,
            "kernelLen": kernelLen,
        }
        return out_dict

spk_clusters = self.spk_clusters[idx_to_keep] instance-attribute

There can be more spikes than pos samples in terms of sampling as the open-ephys buffer probably needs to finish writing and the camera has already stopped, so cut of any cluster indices and spike times that exceed the length of the pos indices

__init__(spike_times, pos_times, spk_clusters, x, y, tracker_params={})

Parameters:

Name Type Description Default
spike_times 1d np.array

Spike times

required
pos_times 1d np.array

Position times

required
spk_clusters 1d np.array

Spike clusters

required
x and y (1d np.array

Position coordinates

required
tracker_params dict

From the PosTracker as created in OESettings.Settings.parse

{}
Note

All timestamps should be given in sub-millisecond accurate seconds and pos_xy in cms

Source code in ephysiopy/common/rhythmicity.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def __init__(
    self,
    spike_times: np.array,
    pos_times: np.array,
    spk_clusters: np.array,
    x: np.array,
    y: np.array,
    tracker_params={},
):
    """
    Args:
        spike_times (1d np.array): Spike times
        pos_times (1d np.array): Position times
        spk_clusters (1d np.array): Spike clusters
        x and y (1d np.array): Position coordinates
        tracker_params (dict): From the PosTracker as created in
                                OESettings.Settings.parse

    Note:
        All timestamps should be given in sub-millisecond accurate seconds
        and pos_xy in cms
    """
    self.spike_times = spike_times
    self.pos_times = pos_times
    self.spk_clusters = spk_clusters
    """
    There can be more spikes than pos samples in terms of sampling as the
    open-ephys buffer probably needs to finish writing and the camera has
    already stopped, so cut of any cluster indices and spike times
    that exceed the length of the pos indices
    """
    idx_to_keep = self.spike_times < self.pos_times[-1]
    self.spike_times = self.spike_times[idx_to_keep]
    self.spk_clusters = self.spk_clusters[idx_to_keep]
    self._pos_sample_rate = 30
    self._spk_sample_rate = 3e4
    self._pos_samples_for_spike = None
    self._min_runlength = 0.4  # in seconds
    self.posCalcs = PosCalcsGeneric(
        x, y, 230, cm=True, jumpmax=100, tracker_params=tracker_params
    )
    self.spikeCalcs = SpikeCalcsGeneric(spike_times, spk_clusters[0])
    self.spikeCalcs.spk_clusters = spk_clusters
    self.posCalcs.postprocesspos(tracker_params)
    xy = self.posCalcs.xy
    hdir = self.posCalcs.dir
    self.posCalcs.calcSpeed(xy)
    self._xy = xy
    self._hdir = hdir
    self._speed = self.posCalcs.speed
    # TEMPORARY FOR POWER SPECTRUM STUFF
    self.smthKernelWidth = 2
    self.smthKernelSigma = 0.1875
    self.sn2Width = 2
    self.thetaRange = [7, 11]
    self.xmax = 11

getDirectionalBinPerPosition(binwidth)

Direction is in degrees as that what is created by me in some of the other bits of this package.

Parameters:

Name Type Description Default
binwidth int

The bin width in degrees

required

Returns:

Type Description

A digitization of which directional bin each pos sample belongs to

Source code in ephysiopy/common/rhythmicity.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def getDirectionalBinPerPosition(self, binwidth: int):
    """
    Direction is in degrees as that what is created by me in some of the
    other bits of this package.

    Args:
        binwidth (int): The bin width in degrees

    Returns:
        A digitization of which directional bin each pos sample belongs to
    """

    bins = np.arange(0, 360, binwidth)
    return np.digitize(self.hdir, bins)

getRunsOfMinLength()

Identifies runs of at least self.min_runlength seconds long, which at 30Hz pos sampling rate equals 12 samples, and returns the start and end indices at which the run was occurred and the directional bin that run belongs to

Returns:

Type Description

np.array: The start and end indices into pos samples of the run and the directional bin to which it belongs

Source code in ephysiopy/common/rhythmicity.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def getRunsOfMinLength(self):
    """
    Identifies runs of at least self.min_runlength seconds long,
    which at 30Hz pos sampling rate equals 12 samples, and
    returns the start and end indices at which
    the run was occurred and the directional bin that run belongs to

    Returns:
        np.array: The start and end indices into pos samples of the run
                  and the directional bin to which it belongs
    """

    b = self.getDirectionalBinPerPosition(45)
    # nabbed from SO
    from itertools import groupby

    grouped_runs = [(k, sum(1 for i in g)) for k, g in groupby(b)]
    grouped_runs = np.array(grouped_runs)
    run_start_indices = np.cumsum(grouped_runs[:, 1]) - grouped_runs[:, 1]
    min_len_in_samples = int(self.pos_sample_rate * self.min_runlength)
    min_len_runs_mask = grouped_runs[:, 1] >= min_len_in_samples
    ret = np.array(
        [run_start_indices[min_len_runs_mask],
            grouped_runs[min_len_runs_mask, 1]]
    ).T
    # ret contains run length as last column
    ret = np.insert(ret, 1, np.sum(ret, 1), 1)
    ret = np.insert(ret, 2, grouped_runs[min_len_runs_mask, 0], 1)
    return ret[:, 0:3]

intrinsic_freq_autoCorr(spkTimes=None, posMask=None, maxFreq=25, acBinSize=0.002, acWindow=0.5, plot=True, **kwargs)

This is taken and adapted from ephysiopy.common.eegcalcs.EEGCalcs

Parameters:

Name Type Description Default
spkTimes array

Times in seconds of the cells firing

None
posMask array

Boolean array corresponding to the length of spkTimes where True is stuff to keep

None
maxFreq float

The maximum frequency to do the power spectrum out to

25
acBinSize float

The bin size of the autocorrelogram in seconds

0.002
acWindow float

The range of the autocorr in seconds

0.5
Note

Make sure all times are in seconds

Source code in ephysiopy/common/rhythmicity.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
def intrinsic_freq_autoCorr(
    self,
    spkTimes=None,
    posMask=None,
    maxFreq=25,
    acBinSize=0.002,
    acWindow=0.5,
    plot=True,
    **kwargs,
):
    """
    This is taken and adapted from ephysiopy.common.eegcalcs.EEGCalcs

    Args:
        spkTimes (np.array): Times in seconds of the cells firing
        posMask (np.array): Boolean array corresponding to the length of
                            spkTimes where True is stuff to keep
        maxFreq (float): The maximum frequency to do the power spectrum
                            out to
        acBinSize (float): The bin size of the autocorrelogram in seconds
        acWindow (float): The range of the autocorr in seconds

    Note:
        Make sure all times are in seconds
    """
    acBinsPerPos = 1.0 / self.pos_sample_rate / acBinSize
    acWindowSizeBins = np.round(acWindow / acBinSize)
    binCentres = np.arange(0.5, len(posMask) * acBinsPerPos) * acBinSize
    spkTrHist, _ = np.histogram(spkTimes, bins=binCentres)

    # split the single histogram into individual chunks
    splitIdx = np.nonzero(np.diff(posMask.astype(int)))[0] + 1
    splitMask = np.split(posMask, splitIdx)
    splitSpkHist = np.split(
        spkTrHist, (splitIdx * acBinsPerPos).astype(int))
    histChunks = []
    for i in range(len(splitSpkHist)):
        if np.all(splitMask[i]):
            if np.sum(splitSpkHist[i]) > 2:
                if len(splitSpkHist[i]) > int(acWindowSizeBins) * 2:
                    histChunks.append(splitSpkHist[i])
    autoCorrGrid = np.zeros((int(acWindowSizeBins) + 1, len(histChunks)))
    chunkLens = []
    from scipy import signal

    print(f"num chunks = {len(histChunks)}")
    for i in range(len(histChunks)):
        lenThisChunk = len(histChunks[i])
        chunkLens.append(lenThisChunk)
        tmp = np.zeros(lenThisChunk * 2)
        tmp[lenThisChunk // 2: lenThisChunk //
            2 + lenThisChunk] = histChunks[i]
        tmp2 = signal.fftconvolve(
            tmp, histChunks[i][::-1], mode="valid"
        )  # the autocorrelation
        autoCorrGrid[:, i] = (
            tmp2[lenThisChunk // 2: lenThisChunk //
                 2 + int(acWindowSizeBins) + 1]
            / acBinsPerPos
        )

    totalLen = np.sum(chunkLens)
    autoCorrSum = np.nansum(autoCorrGrid, 1) / totalLen
    meanNormdAc = autoCorrSum[1::] - np.nanmean(autoCorrSum[1::])
    # return meanNormdAc
    out = self.power_spectrum(
        eeg=meanNormdAc,
        binWidthSecs=acBinSize,
        maxFreq=maxFreq,
        pad2pow=16,
        **kwargs,
    )
    out.update({"meanNormdAc": meanNormdAc})
    if plot:
        fig = plt.gcf()
        ax = fig.gca()
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        ax.imshow(
            autoCorrGrid,
            extent=[
                maxFreq * 0.6,
                maxFreq,
                np.max(out["Power"]) * 0.6,
                ax.get_ylim()[1],
            ],
        )
        ax.set_ylim(ylim)
        ax.set_xlim(xlim)
    return out

power_spectrum(eeg, plot=True, binWidthSecs=None, maxFreq=25, pad2pow=None, ymax=None, **kwargs)

Method used by eeg_power_spectra and intrinsic_freq_autoCorr Signal in must be mean normalised already

Source code in ephysiopy/common/rhythmicity.py
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def power_spectrum(
    self,
    eeg,
    plot=True,
    binWidthSecs=None,
    maxFreq=25,
    pad2pow=None,
    ymax=None,
    **kwargs,
):
    """
    Method used by eeg_power_spectra and intrinsic_freq_autoCorr
    Signal in must be mean normalised already
    """

    # Get raw power spectrum
    nqLim = 1 / binWidthSecs / 2.0
    origLen = len(eeg)
    # if pad2pow is None:
    # 	fftLen = int(np.power(2, self._nextpow2(origLen)))
    # else:
    fftLen = int(np.power(2, pad2pow))
    fftHalfLen = int(fftLen / float(2) + 1)

    fftRes = np.fft.fft(eeg, fftLen)
    # get power density from fft and discard second half of spectrum
    _power = np.power(np.abs(fftRes), 2) / origLen
    power = np.delete(_power, np.s_[fftHalfLen::])
    power[1:-2] = power[1:-2] * 2

    # calculate freqs and crop spectrum to requested range
    freqs = nqLim * np.linspace(0, 1, fftHalfLen)
    freqs = freqs[freqs <= maxFreq].T
    power = power[0: len(freqs)]

    # smooth spectrum using gaussian kernel
    binsPerHz = (fftHalfLen - 1) / nqLim
    kernelLen = np.round(self.smthKernelWidth * binsPerHz)
    kernelSig = self.smthKernelSigma * binsPerHz
    from scipy import signal

    k = signal.windows.gaussian(kernelLen, kernelSig) / (kernelLen / 2 / 2)
    power_sm = signal.fftconvolve(power, k[::-1], mode="same")

    # calculate some metrics
    # find max in theta band
    spectrumMaskBand = np.logical_and(
        freqs > self.thetaRange[0], freqs < self.thetaRange[1]
    )
    bandMaxPower = np.max(power_sm[spectrumMaskBand])
    maxBinInBand = np.argmax(power_sm[spectrumMaskBand])
    bandFreqs = freqs[spectrumMaskBand]
    freqAtBandMaxPower = bandFreqs[maxBinInBand]
    # self.maxBinInBand = maxBinInBand
    # self.freqAtBandMaxPower = freqAtBandMaxPower
    # self.bandMaxPower = bandMaxPower

    # find power in small window around peak and divide by power in rest
    # of spectrum to get snr
    spectrumMaskPeak = np.logical_and(
        freqs > freqAtBandMaxPower - self.sn2Width / 2,
        freqs < freqAtBandMaxPower + self.sn2Width / 2,
    )
    s2n = np.nanmean(power_sm[spectrumMaskPeak]) / np.nanmean(
        power_sm[~spectrumMaskPeak]
    )
    self.freqs = freqs
    self.power_sm = power_sm
    self.spectrumMaskPeak = spectrumMaskPeak
    if plot:
        fig = plt.figure()
        ax = fig.add_subplot(111)
        if ymax is None:
            ymax = np.min([2 * np.max(power), np.max(power_sm)])
            if ymax == 0:
                ymax = 1
        ax.plot(freqs, power, c=[0.9, 0.9, 0.9])
        # ax.hold(True)
        ax.plot(freqs, power_sm, "k", lw=2)
        ax.axvline(self.thetaRange[0], c="b", ls="--")
        ax.axvline(self.thetaRange[1], c="b", ls="--")
        _, stemlines, _ = ax.stem([freqAtBandMaxPower], [
                                  bandMaxPower], linefmt="r")
        # plt.setp(stemlines, 'linewidth', 2)
        ax.fill_between(
            freqs,
            0,
            power_sm,
            where=spectrumMaskPeak,
            color="r",
            alpha=0.25,
            zorder=25,
        )
        # ax.set_ylim(0, ymax)
        # ax.set_xlim(0, self.xmax)
        ax.set_xlabel("Frequency (Hz)")
        ax.set_ylabel("Power density (W/Hz)")
    out_dict = {
        "maxFreq": freqAtBandMaxPower,
        "Power": power_sm,
        "Freqs": freqs,
        "s2n": s2n,
        "Power_raw": power,
        "k": k,
        "kernelLen": kernelLen,
        "kernelSig": kernelSig,
        "binsPerHz": binsPerHz,
        "kernelLen": kernelLen,
    }
    return out_dict

speedFilterRuns(runs, minspeed=5.0)

Given the runs identified in getRunsOfMinLength, filter for speed and return runs that meet the min speed criteria.

The function goes over the runs with a moving window of length equal to self.min_runlength in samples and sees if any of those segments meet the speed criteria and splits them out into separate runs if true.

NB For now this means the same spikes might get included in the autocorrelation procedure later as the moving window will use overlapping periods - can be modified later.

Parameters:

Name Type Description Default
runs 3 x nRuns np.array

Generated from getRunsOfMinLength

required
minspeed float

Min running speed in cm/s for an epoch (minimum epoch length defined previously in getRunsOfMinLength as minlength, usually 0.4s)

5.0

Returns:

Type Description

3 x nRuns np.array: A modified version of the "runs" input variable

Source code in ephysiopy/common/rhythmicity.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def speedFilterRuns(self, runs: np.array, minspeed=5.0):
    """
    Given the runs identified in getRunsOfMinLength, filter for speed
    and return runs that meet the min speed criteria.

    The function goes over the runs with a moving window of length equal
    to self.min_runlength in samples and sees if any of those segments
    meet the speed criteria and splits them out into separate runs if true.

    NB For now this means the same spikes might get included in the
    autocorrelation procedure later as the
    moving window will use overlapping periods - can be modified later.

    Args:
        runs (3 x nRuns np.array): Generated from getRunsOfMinLength
        minspeed (float): Min running speed in cm/s for an epoch (minimum
                          epoch length defined previously
                          in getRunsOfMinLength as minlength, usually 0.4s)

    Returns:
        3 x nRuns np.array: A modified version of the "runs" input variable
    """
    minlength_in_samples = int(self.pos_sample_rate * self.min_runlength)
    run_list = runs.tolist()
    all_speed = np.array(self.speed)
    for start_idx, end_idx, dir_bin in run_list:
        this_runs_speed = all_speed[start_idx:end_idx]
        this_runs_runs = self._rolling_window(
            this_runs_speed, minlength_in_samples)
        run_mask = np.all(this_runs_runs > minspeed, 1)
        if np.any(run_mask):
            print("got one")

LFPOscillations

Bases: object

Does stuff with the LFP such as looking at nested oscillations (theta/ gamma coupling), the modulation index of such phenomena, filtering out certain frequencies in the LFP, getting the instantaneous phase and amplitude and so on

Source code in ephysiopy/common/rhythmicity.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
class LFPOscillations(object):
    """
    Does stuff with the LFP such as looking at nested oscillations
    (theta/ gamma coupling), the modulation index of such phenomena,
    filtering out certain frequencies in the LFP, getting the instantaneous
    phase and amplitude and so on

    """

    def __init__(self, sig, fs, **kwargs):
        self.sig = sig
        self.fs = fs

    def getFreqPhase(self, sig, band2filter: list, ford=3):
        """
        Uses the Hilbert transform to calculate the instantaneous phase and
        amplitude of the time series in sig.

        Args:
            sig (np.array): The signal to be analysed
            ford (int): The order for the Butterworth filter
            band2filter (list): The two frequencies to be filtered for
        """
        if sig is None:
            sig = self.sig
        band2filter = np.array(band2filter, dtype=float)

        b, a = signal.butter(ford, band2filter /
                             (self.fs / 2), btype="bandpass")

        filt_sig = signal.filtfilt(b, a, sig, padtype="odd")
        phase = np.angle(signal.hilbert(filt_sig))
        amplitude = np.abs(signal.hilbert(filt_sig))
        amplitude_filtered = signal.filtfilt(b, a, amplitude, padtype="odd")
        return filt_sig, phase, amplitude, amplitude_filtered

    def modulationindex(
        self,
        sig=None,
        nbins=20,
        forder=2,
        thetaband=[4, 8],
        gammaband=[30, 80],
        plot=True,
    ):
        """
        Calculates the modulation index of theta and gamma oscillations.
        Specifically this is the circular correlation between the phase of
        theta and the power of theta.

        Args:
            sig (np.array): The LFP signal
            nbins (int): The number of bins in the circular range 0 to 2*pi
            forder (int): The order of the butterworth filter
            thetaband (list): The lower/upper bands of the theta freq range
            gammaband (list): The lower/upper bands of the gamma freq range
            plot (bool): Show some pics or not
        """
        if sig is None:
            sig = self.sig
        sig = sig - np.ma.mean(sig)
        if np.ma.is_masked(sig):
            sig = np.ma.compressed(sig)
        _, lowphase, _, _ = self.getFreqPhase(sig, thetaband, forder)
        _, _, highamp, _ = self.getFreqPhase(sig, gammaband, forder)
        inc = 2 * np.pi / nbins
        a = np.arange(-np.pi + inc / 2, np.pi, inc)
        dt = np.array([-inc / 2, inc / 2])
        pbins = a[:, np.newaxis] + dt[np.newaxis, :]
        amp = np.zeros((nbins))
        phaselen = np.arange(len(lowphase))
        for i in range(nbins):
            pts = np.nonzero(
                (lowphase >= pbins[i, 0]) * (lowphase < pbins[i, 1]) * phaselen
            )
            amp[i] = np.mean(highamp[pts])
        amp = amp / np.sum(amp)
        from ephysiopy.common.statscalcs import circ_r

        mi = circ_r(pbins[:, 1], amp)
        if plot:
            fig = plt.figure()
            ax = fig.add_subplot(111, polar=True)
            w = np.pi / (nbins / 2)
            ax.bar(pbins[:, 1], amp, width=w)
            ax.set_title("Modulation index={0:.5f}".format(mi))
        return mi

    def plv(
        self,
        sig=None,
        forder=2,
        thetaband=[4, 8],
        gammaband=[30, 80],
        plot=True,
        **kwargs,
    ):
        """
        Computes the phase-amplitude coupling (PAC) of nested oscillations.
        More specifically this is the phase-locking value (PLV) between two
        nested oscillations in EEG data, in this case theta (default 4-8Hz)
        and gamma (defaults to 30-80Hz). A PLV of unity indicates perfect phase
        locking (here PAC) and a value of zero indicates no locking (no PAC)

        Args:
            eeg (numpy array): The eeg data itself. This is a 1-d array which
            can be masked or not
            forder (int): The order of the filter(s) applied to the eeg data
            thetaband, gammaband (list/array): The range of values to bandpass
            filter for for the theta and gamma ranges
            plot (bool, optional): Whether to plot the resulting binned up
            polar plot which shows the amplitude of the gamma oscillation
            found at different phases of the theta oscillation.
            Default is True.

        Returns:
            plv (float): The value of the phase-amplitude coupling
        """

        if sig is None:
            sig = self.sig
        sig = sig - np.ma.mean(sig)
        if np.ma.is_masked(sig):
            sig = np.ma.compressed(sig)

        _, lowphase, _, _ = self.getFreqPhase(sig, thetaband, forder)
        _, _, _, highamp_f = self.getFreqPhase(sig, gammaband, forder)

        highampphase = np.angle(signal.hilbert(highamp_f))
        phasedf = highampphase - lowphase
        phasedf = np.exp(1j * phasedf)
        phasedf = np.angle(phasedf)
        from ephysiopy.common.statscalcs import circ_r

        plv = circ_r(phasedf)
        th = np.linspace(0.0, 2 * np.pi, 20, endpoint=False)
        h, _ = np.histogram(phasedf, bins=20)
        h = h / float(len(phasedf))

        if plot:
            fig = plt.figure()
            ax = fig.add_subplot(111, polar=True)
            w = np.pi / 10
            ax.bar(th, h, width=w, bottom=0.0)
        return plv, th, h

    def filterForLaser(self, sig=None, width=0.125, dip=15.0, stimFreq=6.66):
        """
        Attempts to filter out frequencies from optogenetic experiments where
        the frequency of laser stimulation was at 6.66Hz.

        Note:
            This method may not work as expected for each trial and might
            require tailoring. A potential improvement could be using mean
            power or a similar metric.
        """
        from scipy.signal import filtfilt, firwin, kaiserord

        nyq = self.fs / 2.0
        width = width / nyq
        dip = dip
        N, beta = kaiserord(dip, width)
        print("N: {0}\nbeta: {1}".format(N, beta))
        upper = np.ceil(nyq / stimFreq)
        c = np.arange(stimFreq, upper * stimFreq, stimFreq)
        dt = np.array([-0.125, 0.125])
        cutoff_hz = dt[:, np.newaxis] + c[np.newaxis, :]
        cutoff_hz = cutoff_hz.ravel()
        cutoff_hz = np.append(cutoff_hz, nyq - 1)
        cutoff_hz.sort()
        cutoff_hz_nyq = cutoff_hz / nyq
        taps = firwin(N, cutoff_hz_nyq, window=("kaiser", beta))
        if sig is None:
            sig = self.sig
        fx = filtfilt(taps, [1.0], sig)
        return fx

    def spike_phase_plot(self, cluster: int,
                         pos_data: PosCalcsGeneric,
                         KSdata: KiloSortSession,
                         lfp_data: EEGCalcsGeneric) -> None:
        '''
        Produces a plot of the phase of theta at which each spike was
        emitted. Each spike is plotted according to the x-y location the
        animal was in when it was fired and the colour of the marker 
        corresponds to the phase of theta at which it fired.
        '''
        _, phase, _, _ = self.getFreqPhase(
            lfp_data.sig, [6, 12])
        cluster_times = KSdata.spk_times[KSdata.spk_clusters == cluster]
        # cluster_times in samples (@30000Hz)
        # get indices into the phase vector
        phase_idx = np.array(cluster_times/(3e4/self.fs), dtype=int)
        # It's possible that there are indices higher than the length of
        # the phase vector so lets set them to the last index
        bad_idx = np.nonzero(phase_idx > len(phase))[0]
        phase_idx[bad_idx] = len(phase) - 1
        # get indices into the position data
        pos_idx = np.array(cluster_times/(3e4/pos_data.sample_rate), dtype=int)
        bad_idx = np.nonzero(pos_idx >= len(pos_data.xyTS))[0]
        pos_idx[bad_idx] = len(pos_data.xyTS) - 1

filterForLaser(sig=None, width=0.125, dip=15.0, stimFreq=6.66)

Attempts to filter out frequencies from optogenetic experiments where the frequency of laser stimulation was at 6.66Hz.

Note

This method may not work as expected for each trial and might require tailoring. A potential improvement could be using mean power or a similar metric.

Source code in ephysiopy/common/rhythmicity.py
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
def filterForLaser(self, sig=None, width=0.125, dip=15.0, stimFreq=6.66):
    """
    Attempts to filter out frequencies from optogenetic experiments where
    the frequency of laser stimulation was at 6.66Hz.

    Note:
        This method may not work as expected for each trial and might
        require tailoring. A potential improvement could be using mean
        power or a similar metric.
    """
    from scipy.signal import filtfilt, firwin, kaiserord

    nyq = self.fs / 2.0
    width = width / nyq
    dip = dip
    N, beta = kaiserord(dip, width)
    print("N: {0}\nbeta: {1}".format(N, beta))
    upper = np.ceil(nyq / stimFreq)
    c = np.arange(stimFreq, upper * stimFreq, stimFreq)
    dt = np.array([-0.125, 0.125])
    cutoff_hz = dt[:, np.newaxis] + c[np.newaxis, :]
    cutoff_hz = cutoff_hz.ravel()
    cutoff_hz = np.append(cutoff_hz, nyq - 1)
    cutoff_hz.sort()
    cutoff_hz_nyq = cutoff_hz / nyq
    taps = firwin(N, cutoff_hz_nyq, window=("kaiser", beta))
    if sig is None:
        sig = self.sig
    fx = filtfilt(taps, [1.0], sig)
    return fx

getFreqPhase(sig, band2filter, ford=3)

Uses the Hilbert transform to calculate the instantaneous phase and amplitude of the time series in sig.

Parameters:

Name Type Description Default
sig array

The signal to be analysed

required
ford int

The order for the Butterworth filter

3
band2filter list

The two frequencies to be filtered for

required
Source code in ephysiopy/common/rhythmicity.py
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
def getFreqPhase(self, sig, band2filter: list, ford=3):
    """
    Uses the Hilbert transform to calculate the instantaneous phase and
    amplitude of the time series in sig.

    Args:
        sig (np.array): The signal to be analysed
        ford (int): The order for the Butterworth filter
        band2filter (list): The two frequencies to be filtered for
    """
    if sig is None:
        sig = self.sig
    band2filter = np.array(band2filter, dtype=float)

    b, a = signal.butter(ford, band2filter /
                         (self.fs / 2), btype="bandpass")

    filt_sig = signal.filtfilt(b, a, sig, padtype="odd")
    phase = np.angle(signal.hilbert(filt_sig))
    amplitude = np.abs(signal.hilbert(filt_sig))
    amplitude_filtered = signal.filtfilt(b, a, amplitude, padtype="odd")
    return filt_sig, phase, amplitude, amplitude_filtered

modulationindex(sig=None, nbins=20, forder=2, thetaband=[4, 8], gammaband=[30, 80], plot=True)

Calculates the modulation index of theta and gamma oscillations. Specifically this is the circular correlation between the phase of theta and the power of theta.

Parameters:

Name Type Description Default
sig array

The LFP signal

None
nbins int

The number of bins in the circular range 0 to 2*pi

20
forder int

The order of the butterworth filter

2
thetaband list

The lower/upper bands of the theta freq range

[4, 8]
gammaband list

The lower/upper bands of the gamma freq range

[30, 80]
plot bool

Show some pics or not

True
Source code in ephysiopy/common/rhythmicity.py
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
def modulationindex(
    self,
    sig=None,
    nbins=20,
    forder=2,
    thetaband=[4, 8],
    gammaband=[30, 80],
    plot=True,
):
    """
    Calculates the modulation index of theta and gamma oscillations.
    Specifically this is the circular correlation between the phase of
    theta and the power of theta.

    Args:
        sig (np.array): The LFP signal
        nbins (int): The number of bins in the circular range 0 to 2*pi
        forder (int): The order of the butterworth filter
        thetaband (list): The lower/upper bands of the theta freq range
        gammaband (list): The lower/upper bands of the gamma freq range
        plot (bool): Show some pics or not
    """
    if sig is None:
        sig = self.sig
    sig = sig - np.ma.mean(sig)
    if np.ma.is_masked(sig):
        sig = np.ma.compressed(sig)
    _, lowphase, _, _ = self.getFreqPhase(sig, thetaband, forder)
    _, _, highamp, _ = self.getFreqPhase(sig, gammaband, forder)
    inc = 2 * np.pi / nbins
    a = np.arange(-np.pi + inc / 2, np.pi, inc)
    dt = np.array([-inc / 2, inc / 2])
    pbins = a[:, np.newaxis] + dt[np.newaxis, :]
    amp = np.zeros((nbins))
    phaselen = np.arange(len(lowphase))
    for i in range(nbins):
        pts = np.nonzero(
            (lowphase >= pbins[i, 0]) * (lowphase < pbins[i, 1]) * phaselen
        )
        amp[i] = np.mean(highamp[pts])
    amp = amp / np.sum(amp)
    from ephysiopy.common.statscalcs import circ_r

    mi = circ_r(pbins[:, 1], amp)
    if plot:
        fig = plt.figure()
        ax = fig.add_subplot(111, polar=True)
        w = np.pi / (nbins / 2)
        ax.bar(pbins[:, 1], amp, width=w)
        ax.set_title("Modulation index={0:.5f}".format(mi))
    return mi

plv(sig=None, forder=2, thetaband=[4, 8], gammaband=[30, 80], plot=True, **kwargs)

Computes the phase-amplitude coupling (PAC) of nested oscillations. More specifically this is the phase-locking value (PLV) between two nested oscillations in EEG data, in this case theta (default 4-8Hz) and gamma (defaults to 30-80Hz). A PLV of unity indicates perfect phase locking (here PAC) and a value of zero indicates no locking (no PAC)

Parameters:

Name Type Description Default
eeg numpy array

The eeg data itself. This is a 1-d array which

required
forder int

The order of the filter(s) applied to the eeg data

2
thetaband, gammaband (list/array

The range of values to bandpass

required
plot bool

Whether to plot the resulting binned up

True

Returns:

Name Type Description
plv float

The value of the phase-amplitude coupling

Source code in ephysiopy/common/rhythmicity.py
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
def plv(
    self,
    sig=None,
    forder=2,
    thetaband=[4, 8],
    gammaband=[30, 80],
    plot=True,
    **kwargs,
):
    """
    Computes the phase-amplitude coupling (PAC) of nested oscillations.
    More specifically this is the phase-locking value (PLV) between two
    nested oscillations in EEG data, in this case theta (default 4-8Hz)
    and gamma (defaults to 30-80Hz). A PLV of unity indicates perfect phase
    locking (here PAC) and a value of zero indicates no locking (no PAC)

    Args:
        eeg (numpy array): The eeg data itself. This is a 1-d array which
        can be masked or not
        forder (int): The order of the filter(s) applied to the eeg data
        thetaband, gammaband (list/array): The range of values to bandpass
        filter for for the theta and gamma ranges
        plot (bool, optional): Whether to plot the resulting binned up
        polar plot which shows the amplitude of the gamma oscillation
        found at different phases of the theta oscillation.
        Default is True.

    Returns:
        plv (float): The value of the phase-amplitude coupling
    """

    if sig is None:
        sig = self.sig
    sig = sig - np.ma.mean(sig)
    if np.ma.is_masked(sig):
        sig = np.ma.compressed(sig)

    _, lowphase, _, _ = self.getFreqPhase(sig, thetaband, forder)
    _, _, _, highamp_f = self.getFreqPhase(sig, gammaband, forder)

    highampphase = np.angle(signal.hilbert(highamp_f))
    phasedf = highampphase - lowphase
    phasedf = np.exp(1j * phasedf)
    phasedf = np.angle(phasedf)
    from ephysiopy.common.statscalcs import circ_r

    plv = circ_r(phasedf)
    th = np.linspace(0.0, 2 * np.pi, 20, endpoint=False)
    h, _ = np.histogram(phasedf, bins=20)
    h = h / float(len(phasedf))

    if plot:
        fig = plt.figure()
        ax = fig.add_subplot(111, polar=True)
        w = np.pi / 10
        ax.bar(th, h, width=w, bottom=0.0)
    return plv, th, h

spike_phase_plot(cluster, pos_data, KSdata, lfp_data)

Produces a plot of the phase of theta at which each spike was emitted. Each spike is plotted according to the x-y location the animal was in when it was fired and the colour of the marker corresponds to the phase of theta at which it fired.

Source code in ephysiopy/common/rhythmicity.py
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
def spike_phase_plot(self, cluster: int,
                     pos_data: PosCalcsGeneric,
                     KSdata: KiloSortSession,
                     lfp_data: EEGCalcsGeneric) -> None:
    '''
    Produces a plot of the phase of theta at which each spike was
    emitted. Each spike is plotted according to the x-y location the
    animal was in when it was fired and the colour of the marker 
    corresponds to the phase of theta at which it fired.
    '''
    _, phase, _, _ = self.getFreqPhase(
        lfp_data.sig, [6, 12])
    cluster_times = KSdata.spk_times[KSdata.spk_clusters == cluster]
    # cluster_times in samples (@30000Hz)
    # get indices into the phase vector
    phase_idx = np.array(cluster_times/(3e4/self.fs), dtype=int)
    # It's possible that there are indices higher than the length of
    # the phase vector so lets set them to the last index
    bad_idx = np.nonzero(phase_idx > len(phase))[0]
    phase_idx[bad_idx] = len(phase) - 1
    # get indices into the position data
    pos_idx = np.array(cluster_times/(3e4/pos_data.sample_rate), dtype=int)
    bad_idx = np.nonzero(pos_idx >= len(pos_data.xyTS))[0]
    pos_idx[bad_idx] = len(pos_data.xyTS) - 1

Spike calculations

SpikeCalcsAxona

Bases: SpikeCalcsGeneric

Replaces SpikeCalcs from ephysiopy.axona.spikecalcs

Source code in ephysiopy/common/spikecalcs.py
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
class SpikeCalcsAxona(SpikeCalcsGeneric):
    """
    Replaces SpikeCalcs from ephysiopy.axona.spikecalcs
    """

    def half_amp_dur(self, waveforms):
        """
        Calculates the half amplitude duration of a spike.

        Args:
            A (ndarray): An nSpikes x nElectrodes x nSamples array.

        Returns:
            had (float): The half-amplitude duration for the channel
                (electrode) that has the strongest (highest amplitude)
                signal. Units are ms.
        """
        from scipy import optimize

        best_chan = np.argmax(np.max(np.mean(waveforms, 0), 1))
        mn_wvs = np.mean(waveforms, 0)
        wvs = mn_wvs[best_chan, :]
        half_amp = np.max(wvs) / 2
        half_amp = np.zeros_like(wvs) + half_amp
        t = np.linspace(0, 1 / 1000.0, 50)
        # create functions from the data using PiecewisePolynomial
        from scipy.interpolate import BPoly

        p1 = BPoly.from_derivatives(t, wvs[:, np.newaxis])
        p2 = BPoly.from_derivatives(t, half_amp[:, np.newaxis])
        xs = np.r_[t, t]
        xs.sort()
        x_min = xs.min()
        x_max = xs.max()
        x_mid = xs[:-1] + np.diff(xs) / 2
        roots = set()
        for val in x_mid:
            root, infodict, ier, mesg = optimize.fsolve(
                lambda x: p1(x) - p2(x), val, full_output=True
            )
            if ier == 1 and x_min < root < x_max:
                roots.add(root[0])
        roots = list(roots)
        if len(roots) > 1:
            r = np.abs(np.diff(roots[0:2]))[0]
        else:
            r = np.nan
        return r

    def p2t_time(self, waveforms):
        """
        The peak to trough time of a spike in ms

        Args:
            cluster (int): The cluster whose waveforms are to be analysed

        Returns:
            p2t (float): The mean peak-to-trough time for the channel
                (electrode) that has the strongest (highest amplitude) signal.
                Units are ms.
        """
        best_chan = np.argmax(np.max(np.mean(waveforms, 0), 1))
        tP = get_param(waveforms, param="tP")
        tT = get_param(waveforms, param="tT")
        mn_tP = np.mean(tP, 0)
        mn_tT = np.mean(tT, 0)
        p2t = np.abs(mn_tP[best_chan] - mn_tT[best_chan])
        return p2t * 1000

    def plotClusterSpace(self, waveforms, param="Amp", clusts=None, bins=256, **kwargs):
        """
        Assumes the waveform data is signed 8-bit ints
        TODO: aspect of plot boxes in ImageGrid not right as scaled by range of
        values now
        """
        from itertools import combinations

        import matplotlib.colors as colors
        from mpl_toolkits.axes_grid1 import ImageGrid

        from ephysiopy.axona.tintcolours import colours as tcols

        self.scaling = np.full(4, 15)

        amps = get_param(waveforms, param=param)
        cmap = np.tile(tcols[0], (bins, 1))
        cmap[0] = (1, 1, 1)
        cmap = colors.ListedColormap(cmap)
        cmap._init()
        alpha_vals = np.ones(cmap.N + 3)
        alpha_vals[0] = 0
        cmap._lut[:, -1] = alpha_vals
        cmb = combinations(range(4), 2)
        if "fig" in kwargs:
            fig = kwargs["fig"]
        else:
            fig = plt.figure(figsize=(8, 6))
        grid = ImageGrid(fig, 111, nrows_ncols=(2, 3), axes_pad=0.1, aspect=False)
        clustCMap0 = np.tile(tcols[0], (bins, 1))
        clustCMap0[0] = (1, 1, 1)
        clustCMap0 = colors.ListedColormap(clustCMap0)
        clustCMap0._init()
        clustCMap0._lut[:, -1] = alpha_vals
        for i, c in enumerate(cmb):
            h, ye, xe = np.histogram2d(
                amps[:, c[0]],
                amps[:, c[1]],
                range=((-128, 127), (-128, 127)),
                bins=bins,
            )
            x, y = np.meshgrid(xe[0:-1], ye[0:-1])
            grid[i].pcolormesh(
                x, y, h, cmap=clustCMap0, shading="nearest", edgecolors="face"
            )
            h, ye, xe = np.histogram2d(
                amps[:, c[0]],
                amps[:, c[1]],
                range=((-128, 127), (-128, 127)),
                bins=bins,
            )
            clustCMap = np.tile(tcols[1], (bins, 1))
            clustCMap[0] = (1, 1, 1)
            clustCMap = colors.ListedColormap(clustCMap)
            clustCMap._init()
            clustCMap._lut[:, -1] = alpha_vals
            grid[i].pcolormesh(
                x, y, h, cmap=clustCMap, shading="nearest", edgecolors="face"
            )
            s = str(c[0] + 1) + " v " + str(c[1] + 1)
            grid[i].text(
                0.05,
                0.95,
                s,
                va="top",
                ha="left",
                size="small",
                color="k",
                transform=grid[i].transAxes,
            )
            grid[i].set_xlim(xe.min(), xe.max())
            grid[i].set_ylim(ye.min(), ye.max())
        plt.setp([a.get_xticklabels() for a in grid], visible=False)
        plt.setp([a.get_yticklabels() for a in grid], visible=False)
        return fig

half_amp_dur(waveforms)

Calculates the half amplitude duration of a spike.

Parameters:

Name Type Description Default
A ndarray

An nSpikes x nElectrodes x nSamples array.

required

Returns:

Name Type Description
had float

The half-amplitude duration for the channel (electrode) that has the strongest (highest amplitude) signal. Units are ms.

Source code in ephysiopy/common/spikecalcs.py
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
def half_amp_dur(self, waveforms):
    """
    Calculates the half amplitude duration of a spike.

    Args:
        A (ndarray): An nSpikes x nElectrodes x nSamples array.

    Returns:
        had (float): The half-amplitude duration for the channel
            (electrode) that has the strongest (highest amplitude)
            signal. Units are ms.
    """
    from scipy import optimize

    best_chan = np.argmax(np.max(np.mean(waveforms, 0), 1))
    mn_wvs = np.mean(waveforms, 0)
    wvs = mn_wvs[best_chan, :]
    half_amp = np.max(wvs) / 2
    half_amp = np.zeros_like(wvs) + half_amp
    t = np.linspace(0, 1 / 1000.0, 50)
    # create functions from the data using PiecewisePolynomial
    from scipy.interpolate import BPoly

    p1 = BPoly.from_derivatives(t, wvs[:, np.newaxis])
    p2 = BPoly.from_derivatives(t, half_amp[:, np.newaxis])
    xs = np.r_[t, t]
    xs.sort()
    x_min = xs.min()
    x_max = xs.max()
    x_mid = xs[:-1] + np.diff(xs) / 2
    roots = set()
    for val in x_mid:
        root, infodict, ier, mesg = optimize.fsolve(
            lambda x: p1(x) - p2(x), val, full_output=True
        )
        if ier == 1 and x_min < root < x_max:
            roots.add(root[0])
    roots = list(roots)
    if len(roots) > 1:
        r = np.abs(np.diff(roots[0:2]))[0]
    else:
        r = np.nan
    return r

p2t_time(waveforms)

The peak to trough time of a spike in ms

Parameters:

Name Type Description Default
cluster int

The cluster whose waveforms are to be analysed

required

Returns:

Name Type Description
p2t float

The mean peak-to-trough time for the channel (electrode) that has the strongest (highest amplitude) signal. Units are ms.

Source code in ephysiopy/common/spikecalcs.py
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
def p2t_time(self, waveforms):
    """
    The peak to trough time of a spike in ms

    Args:
        cluster (int): The cluster whose waveforms are to be analysed

    Returns:
        p2t (float): The mean peak-to-trough time for the channel
            (electrode) that has the strongest (highest amplitude) signal.
            Units are ms.
    """
    best_chan = np.argmax(np.max(np.mean(waveforms, 0), 1))
    tP = get_param(waveforms, param="tP")
    tT = get_param(waveforms, param="tT")
    mn_tP = np.mean(tP, 0)
    mn_tT = np.mean(tT, 0)
    p2t = np.abs(mn_tP[best_chan] - mn_tT[best_chan])
    return p2t * 1000

plotClusterSpace(waveforms, param='Amp', clusts=None, bins=256, **kwargs)

Assumes the waveform data is signed 8-bit ints TODO: aspect of plot boxes in ImageGrid not right as scaled by range of values now

Source code in ephysiopy/common/spikecalcs.py
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
def plotClusterSpace(self, waveforms, param="Amp", clusts=None, bins=256, **kwargs):
    """
    Assumes the waveform data is signed 8-bit ints
    TODO: aspect of plot boxes in ImageGrid not right as scaled by range of
    values now
    """
    from itertools import combinations

    import matplotlib.colors as colors
    from mpl_toolkits.axes_grid1 import ImageGrid

    from ephysiopy.axona.tintcolours import colours as tcols

    self.scaling = np.full(4, 15)

    amps = get_param(waveforms, param=param)
    cmap = np.tile(tcols[0], (bins, 1))
    cmap[0] = (1, 1, 1)
    cmap = colors.ListedColormap(cmap)
    cmap._init()
    alpha_vals = np.ones(cmap.N + 3)
    alpha_vals[0] = 0
    cmap._lut[:, -1] = alpha_vals
    cmb = combinations(range(4), 2)
    if "fig" in kwargs:
        fig = kwargs["fig"]
    else:
        fig = plt.figure(figsize=(8, 6))
    grid = ImageGrid(fig, 111, nrows_ncols=(2, 3), axes_pad=0.1, aspect=False)
    clustCMap0 = np.tile(tcols[0], (bins, 1))
    clustCMap0[0] = (1, 1, 1)
    clustCMap0 = colors.ListedColormap(clustCMap0)
    clustCMap0._init()
    clustCMap0._lut[:, -1] = alpha_vals
    for i, c in enumerate(cmb):
        h, ye, xe = np.histogram2d(
            amps[:, c[0]],
            amps[:, c[1]],
            range=((-128, 127), (-128, 127)),
            bins=bins,
        )
        x, y = np.meshgrid(xe[0:-1], ye[0:-1])
        grid[i].pcolormesh(
            x, y, h, cmap=clustCMap0, shading="nearest", edgecolors="face"
        )
        h, ye, xe = np.histogram2d(
            amps[:, c[0]],
            amps[:, c[1]],
            range=((-128, 127), (-128, 127)),
            bins=bins,
        )
        clustCMap = np.tile(tcols[1], (bins, 1))
        clustCMap[0] = (1, 1, 1)
        clustCMap = colors.ListedColormap(clustCMap)
        clustCMap._init()
        clustCMap._lut[:, -1] = alpha_vals
        grid[i].pcolormesh(
            x, y, h, cmap=clustCMap, shading="nearest", edgecolors="face"
        )
        s = str(c[0] + 1) + " v " + str(c[1] + 1)
        grid[i].text(
            0.05,
            0.95,
            s,
            va="top",
            ha="left",
            size="small",
            color="k",
            transform=grid[i].transAxes,
        )
        grid[i].set_xlim(xe.min(), xe.max())
        grid[i].set_ylim(ye.min(), ye.max())
    plt.setp([a.get_xticklabels() for a in grid], visible=False)
    plt.setp([a.get_yticklabels() for a in grid], visible=False)
    return fig

SpikeCalcsGeneric

Bases: object

Deals with the processing and analysis of spike data. There should be one instance of this class per cluster in the recording session. NB this differs from previous versions of this class where there was one instance per recording session and clusters were selected by passing in the cluster id to the methods.

Parameters:

Name Type Description Default
spike_times array_like

The times of spikes in the trial in seconds

required
waveforms array

An nSpikes x nChannels x nSamples array

None
Source code in ephysiopy/common/spikecalcs.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
class SpikeCalcsGeneric(object):
    """
    Deals with the processing and analysis of spike data.
    There should be one instance of this class per cluster in the
    recording session. NB this differs from previous versions of this
    class where there was one instance per recording session and clusters
    were selected by passing in the cluster id to the methods.

    Args:
        spike_times (array_like): The times of spikes in the trial in seconds
        waveforms (np.array, optional): An nSpikes x nChannels x nSamples array

    """

    def __init__(
        self,
        spike_times: np.ndarray,
        cluster: int,
        waveforms: np.ndarray = None,
        **kwargs
    ):
        self.spike_times = spike_times  # IN SECONDS
        self._waves = waveforms
        self.cluster = cluster
        self._event_ts = None  # the times that events occured IN SECONDS
        # window, in seconds, either side of the stimulus, to examine
        self._event_window = np.array((-0.050, 0.100))
        self._stim_width = None  # the width, in ms, of the stimulus
        # used to increase / decrease size of bins in psth
        self._secs_per_bin = 0.001
        self._sample_rate = 30000
        self._duration = None
        # these values should be specific to OE data
        self._pre_spike_samples = 16
        self._post_spike_samples = 34
        # values from running KS
        self._ksmeta = KSMetaTuple(None, None, None, None)
        # update the __dict__ attribute with the kwargs
        self.__dict__.update(kwargs)

    @property
    def sample_rate(self):
        return self._sample_rate

    @sample_rate.setter
    def sample_rate(self, value):
        self._sample_rate = value

    @property
    def pre_spike_samples(self):
        return self._pre_spike_samples

    @pre_spike_samples.setter
    def pre_spike_samples(self, value):
        self._pre_spike_samples = int(self._pre_spike_samples)

    @property
    def post_spike_samples(self):
        return self._post_spike_samples

    @post_spike_samples.setter
    def post_spike_samples(self, value):
        self._post_spike_samples = int(self._post_spike_samples)

    def waveforms(self, channel_id: Sequence = None):
        if self._waves is not None:
            if channel_id is None:
                return self._waves[:, :, :]
            else:
                if isinstance(channel_id, int):
                    channel_id = [channel_id]
                return self._waves[:, channel_id, :]
        else:
            return None

    @property
    def n_spikes(self):
        """
        Returns the number of spikes in the cluster

        Returns:
            int: The number of spikes in the cluster
        """
        return len(self.spike_times)

    @property
    def event_ts(self):
        return self._event_ts

    @event_ts.setter
    def event_ts(self, value):
        self._event_ts = value

    @property
    def duration(self):
        return self._duration

    @duration.setter
    def duration(self, value):
        self._duration = value

    @property
    def KSMeta(self):
        return self._ksmeta

    def update_KSMeta(self, value: dict):
        """
        Takes in a TemplateModel instance from a phy session and
        parses out the relevant metrics for the cluster and places
        into the namedtuple KSMeta
        """
        metavals = []
        for f in KSMetaTuple._fields:
            if f in value.keys():
                metavals.append(value[f][self.cluster])
            else:
                metavals.append(None)
        self._ksmeta = KSMetaTuple(*metavals)

    @property
    def event_window(self):
        return self._event_window

    @event_window.setter
    def event_window(self, value):
        self._event_window = value

    @property
    def stim_width(self):
        return self._stim_width

    @stim_width.setter
    def stim_width(self, value):
        self._stim_width = value

    @property
    def secs_per_bin(self):
        return self._secs_per_bin

    @secs_per_bin.setter
    def secs_per_bin(self, value):
        self._secs_per_bin = value

    def acorr(self, Trange: np.ndarray = None) -> tuple:
        """
        Calculates the autocorrelogram of a spike train

        Args:
            ts (np.ndarray): The spike times
            Trange (np.ndarray): The range of times to calculate the
                autocorrelogram over

        Returns:
            counts (np.ndarray): The autocorrelogram
            bins (np.ndarray): The bins used to calculate the
                autocorrelogram
        """
        return xcorr(self.spike_times, Trange=Trange)

    def trial_mean_fr(self) -> float:
        # Returns the trial mean firing rate for the cluster
        if self.duration is None:
            raise IndexError("No duration provided, give me one!")
        return self.n_spikes / self.duration

    def mean_isi_range(self, isi_range: int) -> float:
        """
        Calculates the mean of the autocorrelation from 0 to n milliseconds
        Used to help classify a neurons type (principal, interneuron etc)

        Args:
            isi_range (int): The range in ms to calculate the mean over

        Returns:
            float: The mean of the autocorrelogram between 0 and n milliseconds
        """
        bins = 201
        trange = np.array((-500, 500))
        counts, bins = self.acorr(Trange=trange)
        mask = np.logical_and(bins > 0, bins < isi_range)
        return np.mean(counts[mask[1:]])

    def mean_waveform(self, channel_id: Sequence = None):
        """
        Returns the mean waveform and sem for a given spike train on a
        particular channel

        Args:
            cluster_id (int): The cluster to get the mean waveform for

        Returns:
            mn_wvs (ndarray): The mean waveforms, usually 4x50 for tetrode
                                recordings
            std_wvs (ndarray): The standard deviations of the waveforms,
                                usually 4x50 for tetrode recordings
        """
        x = self.waveforms(channel_id)
        if x is not None:
            return np.mean(x, axis=0), np.std(x, axis=0)
        else:
            return None

    def psth(self, **kwargs):
        """
        Calculate the PSTH of event_ts against the spiking of a cell

        Args:
            cluster_id (int): The cluster for which to calculate the psth

        Returns:
            x, y (list): The list of time differences between the spikes of
                            the cluster and the events (x) and the trials (y)
        """
        if self._event_ts is None:
            raise Exception("Need some event timestamps! Aborting")
        event_ts = self.event_ts
        event_ts.sort()
        if isinstance(event_ts, list):
            event_ts = np.array(event_ts)

        irange = event_ts[:, np.newaxis] + self.event_window[np.newaxis, :]
        dts = np.searchsorted(self.spike_times, irange)
        x = []
        y = []
        for i, t in enumerate(dts):
            tmp = self.spike_times[t[0] : t[1]] - event_ts[i]
            x.extend(tmp)
            y.extend(np.repeat(i, len(tmp)))
        return x, y

    def psch(self, bin_width_secs: float) -> np.ndarray:
        """
        Calculate the peri-stimulus *count* histogram of a cell's spiking
        against event times.

        Args:
            cluster_id (int): The cluster for which to calculate the psth.
            bin_width_secs (float): The width of each bin in seconds.

        Returns:
            result (np.ndarray): Rows are counts of spikes per bin_width_secs.
            Size of columns ranges from self.event_window[0] to
            self.event_window[1] with bin_width_secs steps;
            so x is count, y is "event".
        """
        if self._event_ts is None:
            raise Exception("Need some event timestamps! Aborting")
        event_ts = self.event_ts
        event_ts.sort()
        if isinstance(event_ts, list):
            event_ts = np.array(event_ts)

        irange = event_ts[:, np.newaxis] + self.event_window[np.newaxis, :]
        dts = np.searchsorted(self.spike_times, irange)
        bins = np.arange(self.event_window[0], self.event_window[1], bin_width_secs)
        result = np.zeros(shape=(len(bins) - 1, len(event_ts)))
        for i, t in enumerate(dts):
            tmp = self.spike_times[t[0] : t[1]] - event_ts[i]
            indices = np.digitize(tmp, bins=bins)
            counts = np.bincount(indices, minlength=len(bins))
            result[:, i] = counts[1:]
        return result

    def responds_to_stimulus(
        self,
        threshold: float,
        min_contiguous: int,
        return_activity: bool = False,
        return_magnitude: bool = False,
        **kwargs
    ) -> tuple:
        """
        Checks whether a cluster responds to a laser stimulus.

        Args:
            cluster (int): The cluster to check.
            threshold (float): The amount of activity the cluster needs to go
                beyond to be classified as a responder (1.5 = 50% more or less
                than the baseline activity).
            min_contiguous (int): The number of contiguous samples in the
                post-stimulus period for which the cluster needs to be active
                beyond the threshold value to be classed as a responder.
            return_activity (bool): Whether to return the mean reponse curve.
            return_magnitude (int): Whether to return the magnitude of the
                response. NB this is either +1 for excited or -1 for inhibited.

        Returns:
            responds (bool): Whether the cell responds or not.
            OR
            tuple: responds (bool), normed_response_curve (np.ndarray).
            OR
            tuple: responds (bool), normed_response_curve (np.ndarray),
                response_magnitude (np.ndarray).
        """
        spk_count_by_trial = self.psch(self._secs_per_bin)
        firing_rate_by_trial = spk_count_by_trial / self.secs_per_bin
        mean_firing_rate = np.mean(firing_rate_by_trial, 1)
        # smooth with a moving average
        # check nothing in kwargs first
        if "window_len" in kwargs.keys():
            window_len = kwargs["window_len"]
        else:
            window_len = 5
        if "window" in kwargs.keys():
            window = kwargs["window"]
        else:
            window = "flat"
        if "flat" in window:
            kernel = Box1DKernel(window_len)
        if "gauss" in window:
            kernel = Gaussian1DKernel(1, window_len)
        if "do_smooth" in kwargs.keys():
            do_smooth = kwargs.get("do_smooth")
        else:
            do_smooth = True

        if do_smooth:
            smoothed_binned_spikes = convolve(mean_firing_rate, kernel, boundary="wrap")
        else:
            smoothed_binned_spikes = mean_firing_rate
        nbins = np.floor(np.sum(np.abs(self.event_window)) / self.secs_per_bin)
        bins = np.linspace(self.event_window[0], self.event_window[1], int(nbins))
        # normalize all activity by activity in the time before
        # the laser onset
        idx = bins < 0
        normd = min_max_norm(
            smoothed_binned_spikes,
            np.min(smoothed_binned_spikes[idx]),
            np.max(smoothed_binned_spikes[idx]),
        )
        # mask the array outside of a threshold value so that
        # only True values in the masked array are those that
        # exceed the threshold (positively or negatively)
        # the threshold provided to this function is expressed
        # as a % above / below unit normality so adjust that now
        # so it is expressed as a pre-stimulus firing rate mean
        # pre_stim_mean = np.mean(smoothed_binned_spikes[idx])
        # pre_stim_max = pre_stim_mean * threshold
        # pre_stim_min = pre_stim_mean * (threshold-1.0)
        # updated so threshold is double (+ or -) the pre-stim
        # norm (lies between )
        normd_masked = np.ma.masked_inside(normd, -threshold, 1 + threshold)
        # find the contiguous runs in the masked array
        # that are at least as long as the min_contiguous value
        # and classify this as a True response
        slices = np.ma.notmasked_contiguous(normd_masked)
        if slices and np.any(np.isfinite(normd)):
            # make sure that slices are within the first 25ms post-stim
            if ~np.any([s.start > 50 and s.start < 75 for s in slices]):
                if not return_activity:
                    return False
                else:
                    if return_magnitude:
                        return False, normd, 0
                    return False, normd
            max_runlength = max([len(normd_masked[s]) for s in slices])
            if max_runlength >= min_contiguous:
                if not return_activity:
                    return True
                else:
                    if return_magnitude:
                        sl = [
                            slc
                            for slc in slices
                            if (slc.stop - slc.start) == max_runlength
                        ]
                        mag = [-1 if np.mean(normd[sl[0]]) < 0 else 1][0]
                        return True, normd, mag
                    else:
                        return True, normd
        if not return_activity:
            return False
        else:
            if return_magnitude:
                return False, normd, 0
            return False, normd

    def theta_mod_idx(self):
        """
        Calculates a theta modulation index of a spike train based on the cells
        autocorrelogram.

        Args:
            x1 (np.array): The spike time-series.

        Returns:
            thetaMod (float): The difference of the values at the first peak
            and trough of the autocorrelogram.
        """
        corr, _ = self.acorr()
        # Take the fft of the spike train autocorr (from -500 to +500ms)
        from scipy.signal import periodogram

        freqs, power = periodogram(corr, fs=200, return_onesided=True)
        # Smooth the power over +/- 1Hz
        b = signal.windows.boxcar(3)
        h = signal.filtfilt(b, 3, power)

        # Square the amplitude first
        sqd_amp = h**2
        # Then find the mean power in the +/-1Hz band either side of that
        theta_band_max_idx = np.nonzero(
            sqd_amp == np.max(sqd_amp[np.logical_and(freqs > 6, freqs < 11)])
        )[0][0]
        # Get the mean theta band power - mtbp
        mtbp = np.mean(sqd_amp[theta_band_max_idx - 1 : theta_band_max_idx + 1])
        # Find the mean amplitude in the 2-50Hz range
        other_band_idx = np.logical_and(freqs > 2, freqs < 50)
        # Get the mean in the other band - mobp
        mobp = np.mean(sqd_amp[other_band_idx])
        # Find the ratio of these two - this is the theta modulation index
        return (mtbp - mobp) / (mtbp + mobp)

    def theta_mod_idxV2(self):
        """
        This is a simpler alternative to the theta_mod_idx method in that it
        calculates the difference between the normalized temporal
        autocorrelogram at the trough between 50-70ms and the
        peak between 100-140ms over their sum (data is binned into 5ms bins)

        Measure used in Cacucci et al., 2004 and Kropff et al 2015
        """
        corr, bins = self.acorr()
        # 'close' the right-hand bin
        bins = bins[0:-1]
        # normalise corr so max is 1.0
        corr = corr / float(np.max(corr))
        thetaAntiPhase = np.min(
            corr[np.logical_and(bins > 50 / 1000.0, bins < 70 / 1000.0)]
        )
        thetaPhase = np.max(
            corr[np.logical_and(bins > 100 / 1000.0, bins < 140 / 1000.0)]
        )
        return (thetaPhase - thetaAntiPhase) / (thetaPhase + thetaAntiPhase)

    def theta_band_max_freq(self):
        """
        Calculates the frequency with the maximum power in the theta band (6-12Hz)
        of a spike train's autocorrelogram.

        This function is used to look for differences in theta frequency in
        different running directions as per Blair.
        See Welday paper - https://doi.org/10.1523/jneurosci.0712-11.2011

        Args:
            x1 (np.ndarray): The spike train for which the autocorrelogram will be
                calculated.

        Returns:
            float: The frequency with the maximum power in the theta band.

        Raises:
            ValueError: If the input spike train is not valid.
        """
        corr, _ = self.acorr()
        # Take the fft of the spike train autocorr (from -500 to +500ms)
        from scipy.signal import periodogram

        freqs, power = periodogram(corr, fs=200, return_onesided=True)
        power_masked = np.ma.MaskedArray(power, np.logical_or(freqs < 6, freqs > 12))
        return freqs[np.argmax(power_masked)]

    def smooth_spike_train(self, npos, sigma=3.0, shuffle=None):
        """
        Returns a spike train the same length as num pos samples that has been
        smoothed in time with a gaussian kernel M in width and standard
        deviation equal to sigma.

        Args:
            x1 (np.array): The pos indices the spikes occurred at.
            npos (int): The number of position samples captured.
            sigma (float): The standard deviation of the gaussian used to
                smooth the spike train.
            shuffle (int, optional): The number of seconds to shift the spike
                train by. Default is None.

        Returns:
            smoothed_spikes (np.array): The smoothed spike train.
        """
        spk_hist = np.bincount(self.spike_times, minlength=npos)
        if shuffle is not None:
            spk_hist = np.roll(spk_hist, int(shuffle * 50))
        # smooth the spk_hist (which is a temporal histogram) with a 250ms
        # gaussian as with Kropff et al., 2015
        h = signal.windows.gaussian(13, sigma)
        h = h / float(np.sum(h))
        return signal.filtfilt(h.ravel(), 1, spk_hist)

    def contamination_percent(self, **kwargs) -> tuple:

        c, Qi, Q00, Q01, Ri = contamination_percent(self.spike_times, **kwargs)
        Q = min(Qi / (max(Q00, Q01)))  # this is a measure of refractoriness
        # this is a second measure of refractoriness (kicks in for very low
        # firing rates)
        R = min(Ri)
        return Q, R

n_spikes property

Returns the number of spikes in the cluster

Returns:

Name Type Description
int

The number of spikes in the cluster

acorr(Trange=None)

Calculates the autocorrelogram of a spike train

Parameters:

Name Type Description Default
ts ndarray

The spike times

required
Trange ndarray

The range of times to calculate the autocorrelogram over

None

Returns:

Name Type Description
counts ndarray

The autocorrelogram

bins ndarray

The bins used to calculate the autocorrelogram

Source code in ephysiopy/common/spikecalcs.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
def acorr(self, Trange: np.ndarray = None) -> tuple:
    """
    Calculates the autocorrelogram of a spike train

    Args:
        ts (np.ndarray): The spike times
        Trange (np.ndarray): The range of times to calculate the
            autocorrelogram over

    Returns:
        counts (np.ndarray): The autocorrelogram
        bins (np.ndarray): The bins used to calculate the
            autocorrelogram
    """
    return xcorr(self.spike_times, Trange=Trange)

mean_isi_range(isi_range)

Calculates the mean of the autocorrelation from 0 to n milliseconds Used to help classify a neurons type (principal, interneuron etc)

Parameters:

Name Type Description Default
isi_range int

The range in ms to calculate the mean over

required

Returns:

Name Type Description
float float

The mean of the autocorrelogram between 0 and n milliseconds

Source code in ephysiopy/common/spikecalcs.py
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
def mean_isi_range(self, isi_range: int) -> float:
    """
    Calculates the mean of the autocorrelation from 0 to n milliseconds
    Used to help classify a neurons type (principal, interneuron etc)

    Args:
        isi_range (int): The range in ms to calculate the mean over

    Returns:
        float: The mean of the autocorrelogram between 0 and n milliseconds
    """
    bins = 201
    trange = np.array((-500, 500))
    counts, bins = self.acorr(Trange=trange)
    mask = np.logical_and(bins > 0, bins < isi_range)
    return np.mean(counts[mask[1:]])

mean_waveform(channel_id=None)

Returns the mean waveform and sem for a given spike train on a particular channel

Parameters:

Name Type Description Default
cluster_id int

The cluster to get the mean waveform for

required

Returns:

Name Type Description
mn_wvs ndarray

The mean waveforms, usually 4x50 for tetrode recordings

std_wvs ndarray

The standard deviations of the waveforms, usually 4x50 for tetrode recordings

Source code in ephysiopy/common/spikecalcs.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
def mean_waveform(self, channel_id: Sequence = None):
    """
    Returns the mean waveform and sem for a given spike train on a
    particular channel

    Args:
        cluster_id (int): The cluster to get the mean waveform for

    Returns:
        mn_wvs (ndarray): The mean waveforms, usually 4x50 for tetrode
                            recordings
        std_wvs (ndarray): The standard deviations of the waveforms,
                            usually 4x50 for tetrode recordings
    """
    x = self.waveforms(channel_id)
    if x is not None:
        return np.mean(x, axis=0), np.std(x, axis=0)
    else:
        return None

psch(bin_width_secs)

Calculate the peri-stimulus count histogram of a cell's spiking against event times.

Parameters:

Name Type Description Default
cluster_id int

The cluster for which to calculate the psth.

required
bin_width_secs float

The width of each bin in seconds.

required

Returns:

Name Type Description
result ndarray

Rows are counts of spikes per bin_width_secs.

ndarray

Size of columns ranges from self.event_window[0] to

ndarray

self.event_window[1] with bin_width_secs steps;

ndarray

so x is count, y is "event".

Source code in ephysiopy/common/spikecalcs.py
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
def psch(self, bin_width_secs: float) -> np.ndarray:
    """
    Calculate the peri-stimulus *count* histogram of a cell's spiking
    against event times.

    Args:
        cluster_id (int): The cluster for which to calculate the psth.
        bin_width_secs (float): The width of each bin in seconds.

    Returns:
        result (np.ndarray): Rows are counts of spikes per bin_width_secs.
        Size of columns ranges from self.event_window[0] to
        self.event_window[1] with bin_width_secs steps;
        so x is count, y is "event".
    """
    if self._event_ts is None:
        raise Exception("Need some event timestamps! Aborting")
    event_ts = self.event_ts
    event_ts.sort()
    if isinstance(event_ts, list):
        event_ts = np.array(event_ts)

    irange = event_ts[:, np.newaxis] + self.event_window[np.newaxis, :]
    dts = np.searchsorted(self.spike_times, irange)
    bins = np.arange(self.event_window[0], self.event_window[1], bin_width_secs)
    result = np.zeros(shape=(len(bins) - 1, len(event_ts)))
    for i, t in enumerate(dts):
        tmp = self.spike_times[t[0] : t[1]] - event_ts[i]
        indices = np.digitize(tmp, bins=bins)
        counts = np.bincount(indices, minlength=len(bins))
        result[:, i] = counts[1:]
    return result

psth(**kwargs)

Calculate the PSTH of event_ts against the spiking of a cell

Parameters:

Name Type Description Default
cluster_id int

The cluster for which to calculate the psth

required

Returns:

Type Description

x, y (list): The list of time differences between the spikes of the cluster and the events (x) and the trials (y)

Source code in ephysiopy/common/spikecalcs.py
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
def psth(self, **kwargs):
    """
    Calculate the PSTH of event_ts against the spiking of a cell

    Args:
        cluster_id (int): The cluster for which to calculate the psth

    Returns:
        x, y (list): The list of time differences between the spikes of
                        the cluster and the events (x) and the trials (y)
    """
    if self._event_ts is None:
        raise Exception("Need some event timestamps! Aborting")
    event_ts = self.event_ts
    event_ts.sort()
    if isinstance(event_ts, list):
        event_ts = np.array(event_ts)

    irange = event_ts[:, np.newaxis] + self.event_window[np.newaxis, :]
    dts = np.searchsorted(self.spike_times, irange)
    x = []
    y = []
    for i, t in enumerate(dts):
        tmp = self.spike_times[t[0] : t[1]] - event_ts[i]
        x.extend(tmp)
        y.extend(np.repeat(i, len(tmp)))
    return x, y

responds_to_stimulus(threshold, min_contiguous, return_activity=False, return_magnitude=False, **kwargs)

Checks whether a cluster responds to a laser stimulus.

Parameters:

Name Type Description Default
cluster int

The cluster to check.

required
threshold float

The amount of activity the cluster needs to go beyond to be classified as a responder (1.5 = 50% more or less than the baseline activity).

required
min_contiguous int

The number of contiguous samples in the post-stimulus period for which the cluster needs to be active beyond the threshold value to be classed as a responder.

required
return_activity bool

Whether to return the mean reponse curve.

False
return_magnitude int

Whether to return the magnitude of the response. NB this is either +1 for excited or -1 for inhibited.

False

Returns:

Name Type Description
responds bool

Whether the cell responds or not.

tuple

OR

tuple tuple

responds (bool), normed_response_curve (np.ndarray).

tuple

OR

tuple tuple

responds (bool), normed_response_curve (np.ndarray), response_magnitude (np.ndarray).

Source code in ephysiopy/common/spikecalcs.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def responds_to_stimulus(
    self,
    threshold: float,
    min_contiguous: int,
    return_activity: bool = False,
    return_magnitude: bool = False,
    **kwargs
) -> tuple:
    """
    Checks whether a cluster responds to a laser stimulus.

    Args:
        cluster (int): The cluster to check.
        threshold (float): The amount of activity the cluster needs to go
            beyond to be classified as a responder (1.5 = 50% more or less
            than the baseline activity).
        min_contiguous (int): The number of contiguous samples in the
            post-stimulus period for which the cluster needs to be active
            beyond the threshold value to be classed as a responder.
        return_activity (bool): Whether to return the mean reponse curve.
        return_magnitude (int): Whether to return the magnitude of the
            response. NB this is either +1 for excited or -1 for inhibited.

    Returns:
        responds (bool): Whether the cell responds or not.
        OR
        tuple: responds (bool), normed_response_curve (np.ndarray).
        OR
        tuple: responds (bool), normed_response_curve (np.ndarray),
            response_magnitude (np.ndarray).
    """
    spk_count_by_trial = self.psch(self._secs_per_bin)
    firing_rate_by_trial = spk_count_by_trial / self.secs_per_bin
    mean_firing_rate = np.mean(firing_rate_by_trial, 1)
    # smooth with a moving average
    # check nothing in kwargs first
    if "window_len" in kwargs.keys():
        window_len = kwargs["window_len"]
    else:
        window_len = 5
    if "window" in kwargs.keys():
        window = kwargs["window"]
    else:
        window = "flat"
    if "flat" in window:
        kernel = Box1DKernel(window_len)
    if "gauss" in window:
        kernel = Gaussian1DKernel(1, window_len)
    if "do_smooth" in kwargs.keys():
        do_smooth = kwargs.get("do_smooth")
    else:
        do_smooth = True

    if do_smooth:
        smoothed_binned_spikes = convolve(mean_firing_rate, kernel, boundary="wrap")
    else:
        smoothed_binned_spikes = mean_firing_rate
    nbins = np.floor(np.sum(np.abs(self.event_window)) / self.secs_per_bin)
    bins = np.linspace(self.event_window[0], self.event_window[1], int(nbins))
    # normalize all activity by activity in the time before
    # the laser onset
    idx = bins < 0
    normd = min_max_norm(
        smoothed_binned_spikes,
        np.min(smoothed_binned_spikes[idx]),
        np.max(smoothed_binned_spikes[idx]),
    )
    # mask the array outside of a threshold value so that
    # only True values in the masked array are those that
    # exceed the threshold (positively or negatively)
    # the threshold provided to this function is expressed
    # as a % above / below unit normality so adjust that now
    # so it is expressed as a pre-stimulus firing rate mean
    # pre_stim_mean = np.mean(smoothed_binned_spikes[idx])
    # pre_stim_max = pre_stim_mean * threshold
    # pre_stim_min = pre_stim_mean * (threshold-1.0)
    # updated so threshold is double (+ or -) the pre-stim
    # norm (lies between )
    normd_masked = np.ma.masked_inside(normd, -threshold, 1 + threshold)
    # find the contiguous runs in the masked array
    # that are at least as long as the min_contiguous value
    # and classify this as a True response
    slices = np.ma.notmasked_contiguous(normd_masked)
    if slices and np.any(np.isfinite(normd)):
        # make sure that slices are within the first 25ms post-stim
        if ~np.any([s.start > 50 and s.start < 75 for s in slices]):
            if not return_activity:
                return False
            else:
                if return_magnitude:
                    return False, normd, 0
                return False, normd
        max_runlength = max([len(normd_masked[s]) for s in slices])
        if max_runlength >= min_contiguous:
            if not return_activity:
                return True
            else:
                if return_magnitude:
                    sl = [
                        slc
                        for slc in slices
                        if (slc.stop - slc.start) == max_runlength
                    ]
                    mag = [-1 if np.mean(normd[sl[0]]) < 0 else 1][0]
                    return True, normd, mag
                else:
                    return True, normd
    if not return_activity:
        return False
    else:
        if return_magnitude:
            return False, normd, 0
        return False, normd

smooth_spike_train(npos, sigma=3.0, shuffle=None)

Returns a spike train the same length as num pos samples that has been smoothed in time with a gaussian kernel M in width and standard deviation equal to sigma.

Parameters:

Name Type Description Default
x1 array

The pos indices the spikes occurred at.

required
npos int

The number of position samples captured.

required
sigma float

The standard deviation of the gaussian used to smooth the spike train.

3.0
shuffle int

The number of seconds to shift the spike train by. Default is None.

None

Returns:

Name Type Description
smoothed_spikes array

The smoothed spike train.

Source code in ephysiopy/common/spikecalcs.py
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
def smooth_spike_train(self, npos, sigma=3.0, shuffle=None):
    """
    Returns a spike train the same length as num pos samples that has been
    smoothed in time with a gaussian kernel M in width and standard
    deviation equal to sigma.

    Args:
        x1 (np.array): The pos indices the spikes occurred at.
        npos (int): The number of position samples captured.
        sigma (float): The standard deviation of the gaussian used to
            smooth the spike train.
        shuffle (int, optional): The number of seconds to shift the spike
            train by. Default is None.

    Returns:
        smoothed_spikes (np.array): The smoothed spike train.
    """
    spk_hist = np.bincount(self.spike_times, minlength=npos)
    if shuffle is not None:
        spk_hist = np.roll(spk_hist, int(shuffle * 50))
    # smooth the spk_hist (which is a temporal histogram) with a 250ms
    # gaussian as with Kropff et al., 2015
    h = signal.windows.gaussian(13, sigma)
    h = h / float(np.sum(h))
    return signal.filtfilt(h.ravel(), 1, spk_hist)

theta_band_max_freq()

Calculates the frequency with the maximum power in the theta band (6-12Hz) of a spike train's autocorrelogram.

This function is used to look for differences in theta frequency in different running directions as per Blair. See Welday paper - https://doi.org/10.1523/jneurosci.0712-11.2011

Parameters:

Name Type Description Default
x1 ndarray

The spike train for which the autocorrelogram will be calculated.

required

Returns:

Name Type Description
float

The frequency with the maximum power in the theta band.

Raises:

Type Description
ValueError

If the input spike train is not valid.

Source code in ephysiopy/common/spikecalcs.py
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
def theta_band_max_freq(self):
    """
    Calculates the frequency with the maximum power in the theta band (6-12Hz)
    of a spike train's autocorrelogram.

    This function is used to look for differences in theta frequency in
    different running directions as per Blair.
    See Welday paper - https://doi.org/10.1523/jneurosci.0712-11.2011

    Args:
        x1 (np.ndarray): The spike train for which the autocorrelogram will be
            calculated.

    Returns:
        float: The frequency with the maximum power in the theta band.

    Raises:
        ValueError: If the input spike train is not valid.
    """
    corr, _ = self.acorr()
    # Take the fft of the spike train autocorr (from -500 to +500ms)
    from scipy.signal import periodogram

    freqs, power = periodogram(corr, fs=200, return_onesided=True)
    power_masked = np.ma.MaskedArray(power, np.logical_or(freqs < 6, freqs > 12))
    return freqs[np.argmax(power_masked)]

theta_mod_idx()

Calculates a theta modulation index of a spike train based on the cells autocorrelogram.

Parameters:

Name Type Description Default
x1 array

The spike time-series.

required

Returns:

Name Type Description
thetaMod float

The difference of the values at the first peak

and trough of the autocorrelogram.

Source code in ephysiopy/common/spikecalcs.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
def theta_mod_idx(self):
    """
    Calculates a theta modulation index of a spike train based on the cells
    autocorrelogram.

    Args:
        x1 (np.array): The spike time-series.

    Returns:
        thetaMod (float): The difference of the values at the first peak
        and trough of the autocorrelogram.
    """
    corr, _ = self.acorr()
    # Take the fft of the spike train autocorr (from -500 to +500ms)
    from scipy.signal import periodogram

    freqs, power = periodogram(corr, fs=200, return_onesided=True)
    # Smooth the power over +/- 1Hz
    b = signal.windows.boxcar(3)
    h = signal.filtfilt(b, 3, power)

    # Square the amplitude first
    sqd_amp = h**2
    # Then find the mean power in the +/-1Hz band either side of that
    theta_band_max_idx = np.nonzero(
        sqd_amp == np.max(sqd_amp[np.logical_and(freqs > 6, freqs < 11)])
    )[0][0]
    # Get the mean theta band power - mtbp
    mtbp = np.mean(sqd_amp[theta_band_max_idx - 1 : theta_band_max_idx + 1])
    # Find the mean amplitude in the 2-50Hz range
    other_band_idx = np.logical_and(freqs > 2, freqs < 50)
    # Get the mean in the other band - mobp
    mobp = np.mean(sqd_amp[other_band_idx])
    # Find the ratio of these two - this is the theta modulation index
    return (mtbp - mobp) / (mtbp + mobp)

theta_mod_idxV2()

This is a simpler alternative to the theta_mod_idx method in that it calculates the difference between the normalized temporal autocorrelogram at the trough between 50-70ms and the peak between 100-140ms over their sum (data is binned into 5ms bins)

Measure used in Cacucci et al., 2004 and Kropff et al 2015

Source code in ephysiopy/common/spikecalcs.py
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
def theta_mod_idxV2(self):
    """
    This is a simpler alternative to the theta_mod_idx method in that it
    calculates the difference between the normalized temporal
    autocorrelogram at the trough between 50-70ms and the
    peak between 100-140ms over their sum (data is binned into 5ms bins)

    Measure used in Cacucci et al., 2004 and Kropff et al 2015
    """
    corr, bins = self.acorr()
    # 'close' the right-hand bin
    bins = bins[0:-1]
    # normalise corr so max is 1.0
    corr = corr / float(np.max(corr))
    thetaAntiPhase = np.min(
        corr[np.logical_and(bins > 50 / 1000.0, bins < 70 / 1000.0)]
    )
    thetaPhase = np.max(
        corr[np.logical_and(bins > 100 / 1000.0, bins < 140 / 1000.0)]
    )
    return (thetaPhase - thetaAntiPhase) / (thetaPhase + thetaAntiPhase)

update_KSMeta(value)

Takes in a TemplateModel instance from a phy session and parses out the relevant metrics for the cluster and places into the namedtuple KSMeta

Source code in ephysiopy/common/spikecalcs.py
413
414
415
416
417
418
419
420
421
422
423
424
425
def update_KSMeta(self, value: dict):
    """
    Takes in a TemplateModel instance from a phy session and
    parses out the relevant metrics for the cluster and places
    into the namedtuple KSMeta
    """
    metavals = []
    for f in KSMetaTuple._fields:
        if f in value.keys():
            metavals.append(value[f][self.cluster])
        else:
            metavals.append(None)
    self._ksmeta = KSMetaTuple(*metavals)

SpikeCalcsOpenEphys

Bases: SpikeCalcsGeneric

Source code in ephysiopy/common/spikecalcs.py
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
class SpikeCalcsOpenEphys(SpikeCalcsGeneric):
    def __init__(self, spike_times, cluster, waveforms=None, **kwargs):
        super().__init__(spike_times, cluster, waveforms, **kwargs)
        self.n_samples = [-40, 41]
        self.TemplateModel = None

    def get_waveforms(
        self,
        cluster: int,
        cluster_data: KiloSortSession,
        n_waveforms: int = 2000,
        n_channels: int = 64,
        channel_range=None,
        **kwargs
    ) -> np.ndarray:
        """
        Returns waveforms for a cluster.

        Args:
            cluster (int): The cluster to return the waveforms for.
            cluster_data (KiloSortSession): The KiloSortSession object for the
                session that contains the cluster.
            n_waveforms (int, optional): The number of waveforms to return.
                Defaults to 2000.
            n_channels (int, optional): The number of channels in the
                recording. Defaults to 64.
        """
        # instantiate the TemplateModel - this is used to get the waveforms
        # for the cluster. TemplateModel encapsulates the results of KiloSort
        if self.TemplateModel is None:
            self.TemplateModel = TemplateModel(
                dir_path=os.path.join(cluster_data.fname_root),
                sample_rate=3e4,
                dat_path=os.path.join(cluster_data.fname_root, "continuous.dat"),
                n_channels_dat=n_channels,
            )
        # get the waveforms for the given cluster on the best channel only
        waveforms = self.TemplateModel.get_cluster_spike_waveforms(cluster)
        # get a random subset of the waveforms
        rng = np.random.default_rng()
        total_waves = waveforms.shape[0]
        n_waveforms = n_waveforms if n_waveforms < total_waves else total_waves
        waveforms_subset = rng.choice(waveforms, n_waveforms)
        # return the waveforms
        if channel_range is None:
            return np.squeeze(waveforms_subset[:, :, 0])
        else:
            if isinstance(channel_range, Sequence):
                return np.squeeze(waveforms_subset[:, :, channel_range])
            else:
                warnings.warn("Invalid channel_range sequence")

    def get_channel_depth_from_templates(self, pname: Path):
        """
        Determine depth of template as well as closest channel. Adopted from
        'templatePositionsAmplitudes' by N. Steinmetz
        (https://github.com/cortex-lab/spikes)
        """
        # Load inverse whitening matrix
        Winv = np.load(os.path.join(pname, "whitening_mat_inv.npy"))
        # Load templates
        templates = np.load(os.path.join(pname, "templates.npy"))
        # Load channel_map and positions
        channel_map = np.load(os.path.join(pname, "channel_map.npy"))
        channel_positions = np.load(os.path.join(pname, "channel_positions.npy"))
        map_and_pos = np.array([np.squeeze(channel_map), channel_positions[:, 1]])
        # unwhiten all the templates
        tempsUnW = np.zeros(np.shape(templates))
        for i in np.shape(templates)[0]:
            tempsUnW[i, :, :] = np.squeeze(templates[i, :, :]) @ Winv

        tempAmp = np.squeeze(np.max(tempsUnW, 1)) - np.squeeze(np.min(tempsUnW, 1))
        tempAmpsUnscaled = np.max(tempAmp, 1)
        # need to zero-out the potentially-many low values on distant channels
        threshVals = tempAmpsUnscaled * 0.3
        tempAmp[tempAmp < threshVals[:, None]] = 0
        # Compute the depth as a centre of mass
        templateDepths = np.sum(tempAmp * map_and_pos[1, :], -1) / np.sum(tempAmp, 1)
        maxChanIdx = np.argmin(
            np.abs((templateDepths[:, None] - map_and_pos[1, :].T)), 1
        )
        return templateDepths, maxChanIdx

    def get_template_id_for_cluster(self, pname: Path, cluster: int):
        """
        Determine the best channel (one with highest amplitude spikes)
        for a given cluster.
        """
        spike_templates = np.load(os.path.join(pname, "spike_templates.npy"))
        spike_times = np.load(os.path.join(pname, "spike_times.npy"))
        spike_clusters = np.load(os.path.join(pname, "spike_clusters.npy"))
        cluster_times = spike_times[spike_clusters == cluster]
        rez_mat = h5py.File(os.path.join(pname, "rez.mat"), "r")
        st3 = rez_mat["rez"]["st3"]
        st_spike_times = st3[0, :]
        idx = np.searchsorted(st_spike_times, cluster_times)
        template_idx, counts = np.unique(spike_templates[idx], return_counts=True)
        ind = np.argmax(counts)
        return template_idx[ind]

get_channel_depth_from_templates(pname)

Determine depth of template as well as closest channel. Adopted from 'templatePositionsAmplitudes' by N. Steinmetz (https://github.com/cortex-lab/spikes)

Source code in ephysiopy/common/spikecalcs.py
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
def get_channel_depth_from_templates(self, pname: Path):
    """
    Determine depth of template as well as closest channel. Adopted from
    'templatePositionsAmplitudes' by N. Steinmetz
    (https://github.com/cortex-lab/spikes)
    """
    # Load inverse whitening matrix
    Winv = np.load(os.path.join(pname, "whitening_mat_inv.npy"))
    # Load templates
    templates = np.load(os.path.join(pname, "templates.npy"))
    # Load channel_map and positions
    channel_map = np.load(os.path.join(pname, "channel_map.npy"))
    channel_positions = np.load(os.path.join(pname, "channel_positions.npy"))
    map_and_pos = np.array([np.squeeze(channel_map), channel_positions[:, 1]])
    # unwhiten all the templates
    tempsUnW = np.zeros(np.shape(templates))
    for i in np.shape(templates)[0]:
        tempsUnW[i, :, :] = np.squeeze(templates[i, :, :]) @ Winv

    tempAmp = np.squeeze(np.max(tempsUnW, 1)) - np.squeeze(np.min(tempsUnW, 1))
    tempAmpsUnscaled = np.max(tempAmp, 1)
    # need to zero-out the potentially-many low values on distant channels
    threshVals = tempAmpsUnscaled * 0.3
    tempAmp[tempAmp < threshVals[:, None]] = 0
    # Compute the depth as a centre of mass
    templateDepths = np.sum(tempAmp * map_and_pos[1, :], -1) / np.sum(tempAmp, 1)
    maxChanIdx = np.argmin(
        np.abs((templateDepths[:, None] - map_and_pos[1, :].T)), 1
    )
    return templateDepths, maxChanIdx

get_template_id_for_cluster(pname, cluster)

Determine the best channel (one with highest amplitude spikes) for a given cluster.

Source code in ephysiopy/common/spikecalcs.py
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
def get_template_id_for_cluster(self, pname: Path, cluster: int):
    """
    Determine the best channel (one with highest amplitude spikes)
    for a given cluster.
    """
    spike_templates = np.load(os.path.join(pname, "spike_templates.npy"))
    spike_times = np.load(os.path.join(pname, "spike_times.npy"))
    spike_clusters = np.load(os.path.join(pname, "spike_clusters.npy"))
    cluster_times = spike_times[spike_clusters == cluster]
    rez_mat = h5py.File(os.path.join(pname, "rez.mat"), "r")
    st3 = rez_mat["rez"]["st3"]
    st_spike_times = st3[0, :]
    idx = np.searchsorted(st_spike_times, cluster_times)
    template_idx, counts = np.unique(spike_templates[idx], return_counts=True)
    ind = np.argmax(counts)
    return template_idx[ind]

get_waveforms(cluster, cluster_data, n_waveforms=2000, n_channels=64, channel_range=None, **kwargs)

Returns waveforms for a cluster.

Parameters:

Name Type Description Default
cluster int

The cluster to return the waveforms for.

required
cluster_data KiloSortSession

The KiloSortSession object for the session that contains the cluster.

required
n_waveforms int

The number of waveforms to return. Defaults to 2000.

2000
n_channels int

The number of channels in the recording. Defaults to 64.

64
Source code in ephysiopy/common/spikecalcs.py
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
def get_waveforms(
    self,
    cluster: int,
    cluster_data: KiloSortSession,
    n_waveforms: int = 2000,
    n_channels: int = 64,
    channel_range=None,
    **kwargs
) -> np.ndarray:
    """
    Returns waveforms for a cluster.

    Args:
        cluster (int): The cluster to return the waveforms for.
        cluster_data (KiloSortSession): The KiloSortSession object for the
            session that contains the cluster.
        n_waveforms (int, optional): The number of waveforms to return.
            Defaults to 2000.
        n_channels (int, optional): The number of channels in the
            recording. Defaults to 64.
    """
    # instantiate the TemplateModel - this is used to get the waveforms
    # for the cluster. TemplateModel encapsulates the results of KiloSort
    if self.TemplateModel is None:
        self.TemplateModel = TemplateModel(
            dir_path=os.path.join(cluster_data.fname_root),
            sample_rate=3e4,
            dat_path=os.path.join(cluster_data.fname_root, "continuous.dat"),
            n_channels_dat=n_channels,
        )
    # get the waveforms for the given cluster on the best channel only
    waveforms = self.TemplateModel.get_cluster_spike_waveforms(cluster)
    # get a random subset of the waveforms
    rng = np.random.default_rng()
    total_waves = waveforms.shape[0]
    n_waveforms = n_waveforms if n_waveforms < total_waves else total_waves
    waveforms_subset = rng.choice(waveforms, n_waveforms)
    # return the waveforms
    if channel_range is None:
        return np.squeeze(waveforms_subset[:, :, 0])
    else:
        if isinstance(channel_range, Sequence):
            return np.squeeze(waveforms_subset[:, :, channel_range])
        else:
            warnings.warn("Invalid channel_range sequence")

SpikeCalcsProbe

Bases: SpikeCalcsGeneric

Encapsulates methods specific to probe-based recordings

Source code in ephysiopy/common/spikecalcs.py
1200
1201
1202
1203
1204
1205
1206
class SpikeCalcsProbe(SpikeCalcsGeneric):
    """
    Encapsulates methods specific to probe-based recordings
    """

    def __init__(self):
        pass

SpikeCalcsTetrode

Bases: SpikeCalcsGeneric

Encapsulates methods specific to the geometry inherent in tetrode-based recordings

Source code in ephysiopy/common/spikecalcs.py
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
class SpikeCalcsTetrode(SpikeCalcsGeneric):
    """
    Encapsulates methods specific to the geometry inherent in tetrode-based
    recordings
    """

    def __init__(self, spike_times, cluster, waveforms=None, **kwargs):
        super().__init__(spike_times, cluster, waveforms, **kwargs)

    def ifr_sp_corr(
        self,
        x1,
        speed,
        minSpeed=2.0,
        maxSpeed=40.0,
        sigma=3,
        shuffle=False,
        nShuffles=100,
        minTime=30,
        plot=False,
    ):
        """
        Calculates the correlation between the instantaneous firing rate and
        speed.

        Args:
            x1 (np.array): The indices of pos at which the cluster fired.
            speed (np.array): Instantaneous speed (1 x nSamples).
            minSpeed (float, optional): Speeds below this value are ignored.
                Defaults to 2.0 cm/s as with Kropff et al., 2015.
            maxSpeed (float, optional): Speeds above this value are ignored.
                Defaults to 40.0 cm/s.
            sigma (int, optional): The standard deviation of the gaussian used
                to smooth the spike train. Defaults to 3.
            shuffle (bool, optional): Whether to shuffle the spike train.
                Defaults to False.
            nShuffles (int, optional): The number of times to shuffle the
                spike train. Defaults to 100.
            minTime (int, optional): The minimum time for which the spike
                train should be considered. Defaults to 30.
            plot (bool, optional): Whether to plot the result.
                Defaults to False.
        """
        speed = speed.ravel()
        posSampRate = 50
        nSamples = len(speed)
        # position is sampled at 50Hz and so is 'automatically' binned into
        # 20ms bins
        spk_hist = np.bincount(x1, minlength=nSamples)
        # smooth the spk_hist (which is a temporal histogram) with a 250ms
        # gaussian as with Kropff et al., 2015
        h = signal.windows.gaussian(13, sigma)
        h = h / float(np.sum(h))
        # filter for low speeds
        lowSpeedIdx = speed < minSpeed
        highSpeedIdx = speed > maxSpeed
        speed_filt = speed[~np.logical_or(lowSpeedIdx, highSpeedIdx)]
        spk_hist_filt = spk_hist[~np.logical_or(lowSpeedIdx, highSpeedIdx)]
        spk_sm = signal.filtfilt(h.ravel(), 1, spk_hist_filt)
        sm_spk_rate = spk_sm * posSampRate
        res = stats.pearsonr(sm_spk_rate, speed_filt)
        if plot:
            # do some fancy plotting stuff
            _, sp_bin_edges = np.histogram(speed_filt, bins=50)
            sp_dig = np.digitize(speed_filt, sp_bin_edges, right=True)
            spks_per_sp_bin = [
                spk_hist_filt[sp_dig == i] for i in range(len(sp_bin_edges))
            ]
            rate_per_sp_bin = []
            for x in spks_per_sp_bin:
                rate_per_sp_bin.append(np.mean(x) * posSampRate)
            rate_filter = signal.windows.gaussian(5, 1.0)
            rate_filter = rate_filter / np.sum(rate_filter)
            binned_spk_rate = signal.filtfilt(rate_filter, 1, rate_per_sp_bin)
            # instead of plotting a scatter plot of the firing rate at each
            # speed bin, plot a log normalised heatmap and overlay results

            spk_binning_edges = np.linspace(
                np.min(sm_spk_rate), np.max(sm_spk_rate), len(sp_bin_edges)
            )
            speed_mesh, spk_mesh = np.meshgrid(sp_bin_edges, spk_binning_edges)
            binned_rate, _, _ = np.histogram2d(
                speed_filt, sm_spk_rate, bins=[sp_bin_edges, spk_binning_edges]
            )
            # blur the binned rate a bit to make it look nicer
            from ephysiopy.common.utils import blurImage

            sm_binned_rate = blurImage(binned_rate, 5)
            fig = plt.figure()
            ax = fig.add_subplot(111)
            from matplotlib.colors import LogNorm

            speed_mesh = speed_mesh[:-1, :-1]
            spk_mesh = spk_mesh[:-1, :-1]
            ax.pcolormesh(
                speed_mesh,
                spk_mesh,
                sm_binned_rate,
                norm=LogNorm(),
                alpha=0.5,
                shading="nearest",
                edgecolors="None",
            )
            # overlay the smoothed binned rate against speed
            ax.plot(sp_bin_edges, binned_spk_rate, "r")
            # do the linear regression and plot the fit too
            # TODO: linear regression is broken ie not regressing the correct
            # variables
            lr = stats.linregress(speed_filt, sm_spk_rate)
            end_point = lr.intercept + ((sp_bin_edges[-1] - sp_bin_edges[0]) * lr.slope)
            ax.plot(
                [np.min(sp_bin_edges), np.max(sp_bin_edges)],
                [lr.intercept, end_point],
                "r--",
            )
            ax.set_xlim(np.min(sp_bin_edges), np.max(sp_bin_edges[-2]))
            ax.set_ylim(0, np.nanmax(binned_spk_rate) * 1.1)
            ax.set_ylabel("Firing rate(Hz)")
            ax.set_xlabel("Running speed(cm/s)")
            ax.set_title(
                "Intercept: {0:.3f}   Slope: {1:.5f}\nPearson: {2:.5f}".format(
                    lr.intercept, lr.slope, lr.rvalue
                )
            )
        # do some shuffling of the data to see if the result is signficant
        if shuffle:
            # shift spikes by at least 30 seconds after trial start and
            # 30 seconds before trial end
            timeSteps = np.random.randint(
                30 * posSampRate, nSamples - (30 * posSampRate), nShuffles
            )
            shuffled_results = []
            for t in timeSteps:
                spk_count = np.roll(spk_hist, t)
                spk_count_filt = spk_count[~lowSpeedIdx]
                spk_count_sm = signal.filtfilt(h.ravel(), 1, spk_count_filt)
                shuffled_results.append(stats.pearsonr(spk_count_sm, speed_filt)[0])
            if plot:
                fig = plt.figure()
                ax = fig.add_subplot(1, 1, 1)
                ax.hist(np.abs(shuffled_results), 20)
                ylims = ax.get_ylim()
                ax.vlines(res, ylims[0], ylims[1], "r")
        if isinstance(fig, plt.Figure):
            return fig

ifr_sp_corr(x1, speed, minSpeed=2.0, maxSpeed=40.0, sigma=3, shuffle=False, nShuffles=100, minTime=30, plot=False)

Calculates the correlation between the instantaneous firing rate and speed.

Parameters:

Name Type Description Default
x1 array

The indices of pos at which the cluster fired.

required
speed array

Instantaneous speed (1 x nSamples).

required
minSpeed float

Speeds below this value are ignored. Defaults to 2.0 cm/s as with Kropff et al., 2015.

2.0
maxSpeed float

Speeds above this value are ignored. Defaults to 40.0 cm/s.

40.0
sigma int

The standard deviation of the gaussian used to smooth the spike train. Defaults to 3.

3
shuffle bool

Whether to shuffle the spike train. Defaults to False.

False
nShuffles int

The number of times to shuffle the spike train. Defaults to 100.

100
minTime int

The minimum time for which the spike train should be considered. Defaults to 30.

30
plot bool

Whether to plot the result. Defaults to False.

False
Source code in ephysiopy/common/spikecalcs.py
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
def ifr_sp_corr(
    self,
    x1,
    speed,
    minSpeed=2.0,
    maxSpeed=40.0,
    sigma=3,
    shuffle=False,
    nShuffles=100,
    minTime=30,
    plot=False,
):
    """
    Calculates the correlation between the instantaneous firing rate and
    speed.

    Args:
        x1 (np.array): The indices of pos at which the cluster fired.
        speed (np.array): Instantaneous speed (1 x nSamples).
        minSpeed (float, optional): Speeds below this value are ignored.
            Defaults to 2.0 cm/s as with Kropff et al., 2015.
        maxSpeed (float, optional): Speeds above this value are ignored.
            Defaults to 40.0 cm/s.
        sigma (int, optional): The standard deviation of the gaussian used
            to smooth the spike train. Defaults to 3.
        shuffle (bool, optional): Whether to shuffle the spike train.
            Defaults to False.
        nShuffles (int, optional): The number of times to shuffle the
            spike train. Defaults to 100.
        minTime (int, optional): The minimum time for which the spike
            train should be considered. Defaults to 30.
        plot (bool, optional): Whether to plot the result.
            Defaults to False.
    """
    speed = speed.ravel()
    posSampRate = 50
    nSamples = len(speed)
    # position is sampled at 50Hz and so is 'automatically' binned into
    # 20ms bins
    spk_hist = np.bincount(x1, minlength=nSamples)
    # smooth the spk_hist (which is a temporal histogram) with a 250ms
    # gaussian as with Kropff et al., 2015
    h = signal.windows.gaussian(13, sigma)
    h = h / float(np.sum(h))
    # filter for low speeds
    lowSpeedIdx = speed < minSpeed
    highSpeedIdx = speed > maxSpeed
    speed_filt = speed[~np.logical_or(lowSpeedIdx, highSpeedIdx)]
    spk_hist_filt = spk_hist[~np.logical_or(lowSpeedIdx, highSpeedIdx)]
    spk_sm = signal.filtfilt(h.ravel(), 1, spk_hist_filt)
    sm_spk_rate = spk_sm * posSampRate
    res = stats.pearsonr(sm_spk_rate, speed_filt)
    if plot:
        # do some fancy plotting stuff
        _, sp_bin_edges = np.histogram(speed_filt, bins=50)
        sp_dig = np.digitize(speed_filt, sp_bin_edges, right=True)
        spks_per_sp_bin = [
            spk_hist_filt[sp_dig == i] for i in range(len(sp_bin_edges))
        ]
        rate_per_sp_bin = []
        for x in spks_per_sp_bin:
            rate_per_sp_bin.append(np.mean(x) * posSampRate)
        rate_filter = signal.windows.gaussian(5, 1.0)
        rate_filter = rate_filter / np.sum(rate_filter)
        binned_spk_rate = signal.filtfilt(rate_filter, 1, rate_per_sp_bin)
        # instead of plotting a scatter plot of the firing rate at each
        # speed bin, plot a log normalised heatmap and overlay results

        spk_binning_edges = np.linspace(
            np.min(sm_spk_rate), np.max(sm_spk_rate), len(sp_bin_edges)
        )
        speed_mesh, spk_mesh = np.meshgrid(sp_bin_edges, spk_binning_edges)
        binned_rate, _, _ = np.histogram2d(
            speed_filt, sm_spk_rate, bins=[sp_bin_edges, spk_binning_edges]
        )
        # blur the binned rate a bit to make it look nicer
        from ephysiopy.common.utils import blurImage

        sm_binned_rate = blurImage(binned_rate, 5)
        fig = plt.figure()
        ax = fig.add_subplot(111)
        from matplotlib.colors import LogNorm

        speed_mesh = speed_mesh[:-1, :-1]
        spk_mesh = spk_mesh[:-1, :-1]
        ax.pcolormesh(
            speed_mesh,
            spk_mesh,
            sm_binned_rate,
            norm=LogNorm(),
            alpha=0.5,
            shading="nearest",
            edgecolors="None",
        )
        # overlay the smoothed binned rate against speed
        ax.plot(sp_bin_edges, binned_spk_rate, "r")
        # do the linear regression and plot the fit too
        # TODO: linear regression is broken ie not regressing the correct
        # variables
        lr = stats.linregress(speed_filt, sm_spk_rate)
        end_point = lr.intercept + ((sp_bin_edges[-1] - sp_bin_edges[0]) * lr.slope)
        ax.plot(
            [np.min(sp_bin_edges), np.max(sp_bin_edges)],
            [lr.intercept, end_point],
            "r--",
        )
        ax.set_xlim(np.min(sp_bin_edges), np.max(sp_bin_edges[-2]))
        ax.set_ylim(0, np.nanmax(binned_spk_rate) * 1.1)
        ax.set_ylabel("Firing rate(Hz)")
        ax.set_xlabel("Running speed(cm/s)")
        ax.set_title(
            "Intercept: {0:.3f}   Slope: {1:.5f}\nPearson: {2:.5f}".format(
                lr.intercept, lr.slope, lr.rvalue
            )
        )
    # do some shuffling of the data to see if the result is signficant
    if shuffle:
        # shift spikes by at least 30 seconds after trial start and
        # 30 seconds before trial end
        timeSteps = np.random.randint(
            30 * posSampRate, nSamples - (30 * posSampRate), nShuffles
        )
        shuffled_results = []
        for t in timeSteps:
            spk_count = np.roll(spk_hist, t)
            spk_count_filt = spk_count[~lowSpeedIdx]
            spk_count_sm = signal.filtfilt(h.ravel(), 1, spk_count_filt)
            shuffled_results.append(stats.pearsonr(spk_count_sm, speed_filt)[0])
        if plot:
            fig = plt.figure()
            ax = fig.add_subplot(1, 1, 1)
            ax.hist(np.abs(shuffled_results), 20)
            ylims = ax.get_ylim()
            ax.vlines(res, ylims[0], ylims[1], "r")
    if isinstance(fig, plt.Figure):
        return fig

cluster_quality(waveforms=None, spike_clusters=None, cluster_id=None, fet=1)

Returns the L-ratio and Isolation Distance measures calculated on the principal components of the energy in a spike matrix.

Parameters:

Name Type Description Default
waveforms ndarray

The waveforms to be processed. If None, the function will return None.

None
spike_clusters ndarray

The spike clusters to be processed.

None
cluster_id int

The ID of the cluster to be processed.

None
fet int, default=1

The feature to be used in the PCA calculation.

1

Returns:

Name Type Description
tuple

A tuple containing the L-ratio and Isolation Distance of the cluster.

Raises:

Type Description
Exception

If an error occurs during the calculation of the L-ratio or Isolation Distance.

Source code in ephysiopy/common/spikecalcs.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def cluster_quality(
    waveforms: np.ndarray = None,
    spike_clusters: np.ndarray = None,
    cluster_id: int = None,
    fet: int = 1,
):
    """
    Returns the L-ratio and Isolation Distance measures calculated
    on the principal components of the energy in a spike matrix.

    Args:
        waveforms (np.ndarray, optional): The waveforms to be processed.
            If None, the function will return None.
        spike_clusters (np.ndarray, optional): The spike clusters to be
            processed.
        cluster_id (int, optional): The ID of the cluster to be processed.
        fet (int, default=1): The feature to be used in the PCA calculation.

    Returns:
        tuple: A tuple containing the L-ratio and Isolation Distance of the
            cluster.

    Raises:
        Exception: If an error occurs during the calculation of the L-ratio or
            Isolation Distance.
    """
    if waveforms is None:
        return None
    nSpikes, nElectrodes, _ = waveforms.shape
    wvs = waveforms.copy()
    E = np.sqrt(np.nansum(waveforms**2, axis=2))
    zeroIdx = np.sum(E, 0) == [0, 0, 0, 0]
    E = E[:, ~zeroIdx]
    wvs = wvs[:, ~zeroIdx, :]
    normdWaves = (wvs.T / E.T).T
    PCA_m = get_param(normdWaves, "PCA", fet=fet)
    badIdx = np.sum(PCA_m, axis=0) == 0
    PCA_m = PCA_m[:, ~badIdx]
    # get mahalanobis distance
    idx = spike_clusters == cluster_id
    nClustSpikes = np.count_nonzero(idx)
    try:
        d = mahal(PCA_m, PCA_m[idx, :])
        # get the indices of the spikes not in the cluster
        M_noise = d[~idx]
        df = np.prod((fet, nElectrodes))
        from scipy import stats

        L = np.sum(1 - stats.chi2.cdf(M_noise, df))
        L_ratio = L / nClustSpikes
        # calculate isolation distance
        if nClustSpikes < nSpikes / 2:
            M_noise.sort()
            isolation_dist = M_noise[nClustSpikes]
        else:
            isolation_dist = np.nan
    except Exception:
        isolation_dist = L_ratio = np.nan
    return L_ratio, isolation_dist

contamination_percent(x1, x2=None, **kwargs)

Computes the cross-correlogram between two sets of spikes and estimates how refractory the cross-correlogram is.

Parameters:

Name Type Description Default
st1 array

The first set of spikes.

required
st2 array

The second set of spikes.

required
kwargs

Anything that can be fed into xcorr above

Returns:

Name Type Description
Q float

a measure of refractoriness

R float

a second measure of refractoriness (kicks in for very low firing rates)

Notes

Taken from KiloSorts ccg.m

The contamination metrics are calculated based on an analysis of the 'shoulders' of the cross-correlogram. Specifically, the spike counts in the ranges +/-5-25ms and +/-250-500ms are compared for refractoriness

Source code in ephysiopy/common/spikecalcs.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def contamination_percent(x1: np.ndarray, x2: np.ndarray = None, **kwargs) -> tuple:
    """
    Computes the cross-correlogram between two sets of spikes and
    estimates how refractory the cross-correlogram is.

    Args:
        st1 (np.array): The first set of spikes.
        st2 (np.array): The second set of spikes.

    kwargs:
        Anything that can be fed into xcorr above

    Returns:
        Q (float): a measure of refractoriness
        R (float): a second measure of refractoriness
                (kicks in for very low firing rates)

    Notes:
        Taken from KiloSorts ccg.m

        The contamination metrics are calculated based on
        an analysis of the 'shoulders' of the cross-correlogram.
        Specifically, the spike counts in the ranges +/-5-25ms and
        +/-250-500ms are compared for refractoriness
    """
    if x2 is None:
        x2 = x1.copy()
    c, b = xcorr(x1, x2, **kwargs)
    left = [[-0.05, -0.01]]
    right = [[0.01, 0.051]]
    far = [[-0.5, -0.249], [0.25, 0.501]]

    def get_shoulder(bins, vals):
        all = np.array([np.logical_and(bins >= i[0], bins < i[1]) for i in vals])
        return np.any(all, 0)

    inner_left = get_shoulder(b, left)
    inner_right = get_shoulder(b, right)
    outer = get_shoulder(b, far)

    tbin = 1000
    Tr = max(np.concatenate([x1, x2])) - min(np.concatenate([x1, x2]))

    def get_normd_shoulder(idx):
        return np.sum(c[idx[:-1]]) / (
            len(np.nonzero(idx)[0]) * tbin * len(x1) * len(x2) / Tr
        )

    Q00 = get_normd_shoulder(outer)
    Q01 = max(get_normd_shoulder(inner_left), get_normd_shoulder(inner_right))

    R00 = max(
        np.mean(c[outer[:-1]]), np.mean(c[inner_left[:-1]]), np.mean(c[inner_right[1:]])
    )

    middle_idx = np.nonzero(b == 0)[0]
    a = c[middle_idx]
    c[middle_idx] = 0
    Qi = np.zeros(10)
    Ri = np.zeros(10)
    # enumerate through the central range of the xcorr
    # saving the same calculation as done above
    for i, t in enumerate(np.linspace(0.001, 0.01, 10)):
        irange = [[-t, t]]
        chunk = get_shoulder(b, irange)
        # compute the same normalized ratio as above;
        # this should be 1 if there is no refractoriness
        Qi[i] = get_normd_shoulder(chunk)  # save the normd prob
        n = np.sum(c[chunk[:-1]]) / 2
        lam = R00 * i
        # this is tricky: we approximate the Poisson likelihood with a
        # gaussian of equal mean and variance
        # that allows us to integrate the probability that we would see <N
        # spikes in the center of the
        # cross-correlogram from a distribution with mean R00*i spikes
        p = 1 / 2 * (1 + erf((n - lam) / np.sqrt(2 * lam)))

        Ri[i] = p  # keep track of p for each bin size i

    c[middle_idx] = a  # restore the center value of the cross-correlogram
    return c, Qi, Q00, Q01, Ri

get_param(waveforms, param='Amp', t=200, fet=1)

Returns the requested parameter from a spike train as a numpy array

Parameters:

Name Type Description Default
waveforms numpy array

Shape of array can be nSpikes x nSamples OR a nSpikes x nElectrodes x nSamples

required
param str

Valid values are: 'Amp' - peak-to-trough amplitude (default) 'P' - height of peak 'T' - depth of trough 'Vt' height at time t 'tP' - time of peak (in seconds) 'tT' - time of trough (in seconds) 'PCA' - first n fet principal components (defaults to 1)

'Amp'
t int

The time used for Vt

200
fet int

The number of principal components (use with param 'PCA')

1
Source code in ephysiopy/common/spikecalcs.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def get_param(waveforms, param="Amp", t=200, fet=1):
    """
    Returns the requested parameter from a spike train as a numpy array

    Args:
        waveforms (numpy array): Shape of array can be nSpikes x nSamples
            OR
            a nSpikes x nElectrodes x nSamples
        param (str): Valid values are:
            'Amp' - peak-to-trough amplitude (default)
            'P' - height of peak
            'T' - depth of trough
            'Vt' height at time t
            'tP' - time of peak (in seconds)
            'tT' - time of trough (in seconds)
            'PCA' - first n fet principal components (defaults to 1)
        t (int): The time used for Vt
        fet (int): The number of principal components
            (use with param 'PCA')
    """
    from scipy import interpolate
    from sklearn.decomposition import PCA

    if param == "Amp":
        return np.ptp(waveforms, axis=-1)
    elif param == "P":
        return np.max(waveforms, axis=-1)
    elif param == "T":
        return np.min(waveforms, axis=-1)
    elif param == "Vt":
        times = np.arange(0, 1000, 20)
        f = interpolate.interp1d(times, range(50), "nearest")
        if waveforms.ndim == 2:
            return waveforms[:, int(f(t))]
        elif waveforms.ndim == 3:
            return waveforms[:, :, int(f(t))]
    elif param == "tP":
        idx = np.argmax(waveforms, axis=-1)
        m = interpolate.interp1d([0, waveforms.shape[-1] - 1], [0, 1 / 1000.0])
        return m(idx)
    elif param == "tT":
        idx = np.argmin(waveforms, axis=-1)
        m = interpolate.interp1d([0, waveforms.shape[-1] - 1], [0, 1 / 1000.0])
        return m(idx)
    elif param == "PCA":
        pca = PCA(n_components=fet)
        if waveforms.ndim == 2:
            return pca.fit(waveforms).transform(waveforms).squeeze()
        elif waveforms.ndim == 3:
            out = np.zeros((waveforms.shape[0], waveforms.shape[1] * fet))
            st = np.arange(0, waveforms.shape[1] * fet, fet)
            en = np.arange(fet, fet + (waveforms.shape[1] * fet), fet)
            rng = np.vstack((st, en))
            for i in range(waveforms.shape[1]):
                if ~np.any(np.isnan(waveforms[:, i, :])):
                    A = np.squeeze(
                        pca.fit(waveforms[:, i, :].squeeze()).transform(
                            waveforms[:, i, :].squeeze()
                        )
                    )
                    if A.ndim < 2:
                        out[:, rng[0, i] : rng[1, i]] = np.atleast_2d(A).T
                    else:
                        out[:, rng[0, i] : rng[1, i]] = A
            return out

mahal(u, v)

Returns the L-ratio and Isolation Distance measures calculated on the principal components of the energy in a spike matrix.

Parameters:

Name Type Description Default
waveforms ndarray

The waveforms to be processed. If None, the function will return None.

required
spike_clusters ndarray

The spike clusters to be processed.

required
cluster_id int

The ID of the cluster to be processed.

required
fet int, default=1

The feature to be used in the PCA calculation.

required

Returns:

Name Type Description
tuple

A tuple containing the L-ratio and Isolation Distance of the cluster.

Raises:

Type Description
Exception

If an error occurs during the calculation of the L-ratio or Isolation Distance.

Source code in ephysiopy/common/spikecalcs.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def mahal(u, v):
    """
    Returns the L-ratio and Isolation Distance measures calculated on the
    principal components of the energy in a spike matrix.

    Args:
        waveforms (np.ndarray, optional): The waveforms to be processed. If
            None, the function will return None.
        spike_clusters (np.ndarray, optional): The spike clusters to be
            processed.
        cluster_id (int, optional): The ID of the cluster to be processed.
        fet (int, default=1): The feature to be used in the PCA calculation.

    Returns:
        tuple: A tuple containing the L-ratio and Isolation Distance of the
            cluster.

    Raises:
        Exception: If an error occurs during the calculation of the L-ratio or
            Isolation Distance.
    """
    u_sz = u.shape
    v_sz = v.shape
    if u_sz[1] != v_sz[1]:
        warnings.warn(
            "Input size mismatch: \
                        matrices must have same num of columns"
        )
    if v_sz[0] < v_sz[1]:
        warnings.warn("Too few rows: v must have more rows than columns")
    if np.any(np.imag(u)) or np.any(np.imag(v)):
        warnings.warn("No complex inputs are allowed")
    m = np.nanmean(v, axis=0)
    M = np.tile(m, reps=(u_sz[0], 1))
    C = v - np.tile(m, reps=(v_sz[0], 1))
    _, R = np.linalg.qr(C)
    ri = np.linalg.solve(R.T, (u - M).T)
    d = np.sum(ri * ri, 0).T * (v_sz[0] - 1)
    return d

xcorr(x1, x2=None, Trange=None, binsize=0.001, **kwargs)

Calculates the ISIs in x1 or x1 vs x2 within a given range

Parameters:

Name Type Description Default
x1, x2 (array_like

The times of the spikes emitted by the cluster(s) in seconds

required
Trange array_like

Range of times to bin up in seconds Defaults to [-0.5, +0.5]

None
binsize float

The size of the bins in seconds

0.001

Returns:

Name Type Description
counts ndarray

The cross-correlogram of the spike trains x1 and x2

bins ndarray

The bins used to calculate the cross-correlogram

Source code in ephysiopy/common/spikecalcs.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def xcorr(x1: np.ndarray, x2=None, Trange=None, binsize=0.001, **kwargs) -> tuple:
    """
    Calculates the ISIs in x1 or x1 vs x2 within a given range

    Args:
        x1, x2 (array_like): The times of the spikes emitted by the
                            cluster(s) in seconds
        Trange (array_like): Range of times to bin up in seconds
                                Defaults to [-0.5, +0.5]
        binsize (float): The size of the bins in seconds

    Returns:
        counts (np.ndarray): The cross-correlogram of the spike trains
            x1 and x2
        bins (np.ndarray): The bins used to calculate the cross-correlogram
    """
    if x2 is None:
        x2 = x1.copy()
    if Trange is None:
        Trange = np.array([-0.5, 0.5])
    if isinstance(Trange, list):
        Trange = np.array(Trange)
    y = []
    irange = x1[:, np.newaxis] + Trange[np.newaxis, :]
    dts = np.searchsorted(x2, irange)
    for i, t in enumerate(dts):
        y.extend((x2[t[0] : t[1]] - x1[i]))
    y = np.array(y, dtype=float)
    counts, bins = np.histogram(
        y[y != 0], bins=int(np.ptp(Trange) / binsize) + 1, range=Trange
    )
    return counts, bins

Statistics

V_test(angles, test_direction)

The Watson U2 tests whether the observed angles have a tendency to cluster around a given angle indicating a lack of randomness in the distribution. Also known as the modified Rayleigh test.

Parameters:

Name Type Description Default
angles array_like

Vector of angular values in degrees.

required
test_direction int

A single angular value in degrees.

required
Notes

For grouped data the length of the mean vector must be adjusted, and for axial data all angles must be doubled.

Source code in ephysiopy/common/statscalcs.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def V_test(angles, test_direction):
    """
    The Watson U2 tests whether the observed angles have a tendency to
    cluster around a given angle indicating a lack of randomness in the
    distribution. Also known as the modified Rayleigh test.

    Args:
        angles (array_like): Vector of angular values in degrees.
        test_direction (int): A single angular value in degrees.

    Notes:
        For grouped data the length of the mean vector must be adjusted,
        and for axial data all angles must be doubled.
    """
    n = len(angles)
    x_hat = np.sum(np.cos(np.radians(angles))) / float(n)
    y_hat = np.sum(np.sin(np.radians(angles))) / float(n)
    r = np.sqrt(x_hat**2 + y_hat**2)
    theta_hat = np.degrees(np.arctan(y_hat / x_hat))
    v_squiggle = r * np.cos(
        np.radians(theta_hat) - np.radians(test_direction))
    V = np.sqrt(2 * n) * v_squiggle
    return V

circ_r(alpha, w=None, d=0, axis=0)

Computes the mean resultant vector length for circular data.

Parameters:

Name Type Description Default
alpha array or list

Sample of angles in radians.

required
w array or list

Counts in the case of binned data. Must be same length as alpha.

None
d array or list

Spacing of bin centres for binned data; if supplied, correction factor is used to correct for bias in estimation of r, in radians.

0
axis int

The dimension along which to compute. Default is 0.

0

Returns:

Name Type Description
r float

The mean resultant vector length.

Source code in ephysiopy/common/statscalcs.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def circ_r(alpha, w=None, d=0, axis=0):
    """
    Computes the mean resultant vector length for circular data.

    Args:
        alpha (array or list): Sample of angles in radians.
        w (array or list): Counts in the case of binned data.
            Must be same length as alpha.
        d (array or list, optional): Spacing of bin centres for binned data; if
            supplied, correction factor is used to correct for bias in
            estimation of r, in radians.
        axis (int, optional): The dimension along which to compute.
            Default is 0.

    Returns:
        r (float): The mean resultant vector length.
    """

    if w is None:
        w = np.ones_like(alpha, dtype=float)
    # TODO: error check for size constancy
    r = np.sum(w * np.exp(1j * alpha))
    r = np.abs(r) / np.sum(w)
    if d != 0:
        c = d/2./np.sin(d/2.)
        r = c * r
    return r

duplicates_as_complex(x, already_sorted=False)

Finds duplicates in x

Parameters:

Name Type Description Default
x array_like

The list to find duplicates in.

required
already_sorted bool

Whether x is already sorted. Default False.

False

Returns:

Name Type Description
x array_like

A complex array where the complex part is the count of the number of duplicates of the real value.

Examples:

>>>     x = [9.9, 9.9, 12.3, 15.2, 15.2, 15.2]
>>> ret = duplicates_as_complex(x)
>>>     print(ret)
[9.9+0j, 9.9+1j,  12.3+0j, 15.2+0j, 15.2+1j, 15.2+2j]
Source code in ephysiopy/common/statscalcs.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def duplicates_as_complex(x, already_sorted=False):
    """
    Finds duplicates in x

    Args:
        x (array_like): The list to find duplicates in.
        already_sorted (bool, optional): Whether x is already sorted.
            Default False.

    Returns:
        x (array_like): A complex array where the complex part is the count of
            the number of duplicates of the real value.

    Examples:
        >>>	x = [9.9, 9.9, 12.3, 15.2, 15.2, 15.2]
        >>> ret = duplicates_as_complex(x)
        >>>	print(ret)
        [9.9+0j, 9.9+1j,  12.3+0j, 15.2+0j, 15.2+1j, 15.2+2j]
    """

    if not already_sorted:
        x = np.sort(x)
    is_start = np.empty(len(x), dtype=bool)
    is_start[0], is_start[1:] = True, x[:-1] != x[1:]
    labels = np.cumsum(is_start)-1
    sub_idx = np.arange(len(x)) - np.nonzero(is_start)[0][labels]
    return x + 1j*sub_idx

mean_resultant_vector(angles)

Calculate the mean resultant length and direction for angles.

Parameters:

Name Type Description Default
angles array

Sample of angles in radians.

required

Returns:

Name Type Description
r float

The mean resultant vector length.

th float

The mean resultant vector direction.

Source code in ephysiopy/common/statscalcs.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def mean_resultant_vector(angles):
    """
    Calculate the mean resultant length and direction for angles.

    Args:
        angles (np.array): Sample of angles in radians.

    Returns:
        r (float): The mean resultant vector length.
        th (float): The mean resultant vector direction.
    """
    S = np.sum(np.sin(angles)) * (1/float(len(angles)))
    C = np.sum(np.cos(angles)) * (1/float(len(angles)))
    r = np.hypot(S, C)
    th = np.arctan(S / C)
    if (C < 0):
        th = np.pi + th
    return r, th

watsonWilliams(a, b)

The Watson-Williams F test tests whether a set of mean directions are equal given that the concentrations are unknown, but equal, given that the groups each follow a von Mises distribution.

Parameters:

Name Type Description Default
a, b (array_like

The directional samples

required

Returns:

Name Type Description
F_stat float

The F-statistic

Source code in ephysiopy/common/statscalcs.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def watsonWilliams(a, b):
    """
    The Watson-Williams F test tests whether a set of mean directions are
    equal given that the concentrations are unknown, but equal, given that
    the groups each follow a von Mises distribution.

    Args:
        a, b (array_like): The directional samples

    Returns:
        F_stat (float): The F-statistic
    """

    n = len(a)
    m = len(b)
    N = n + m
    # v_1 = 1 # needed to do p-value lookup in table of critical values
    #  of F distribution
    # v_2 = N - 2 # needed to do p-value lookup in table of critical values
    # of F distribution
    C_1 = np.sum(np.cos(np.radians(a)))
    S_1 = np.sum(np.sin(np.radians(a)))
    C_2 = np.sum(np.cos(np.radians(b)))
    S_2 = np.sum(np.sin(np.radians(b)))
    C = C_1 + C_2
    S = S_1 + S_2
    R_1 = np.hypot(C_1, S_1)
    R_2 = np.hypot(C_2, S_2)
    R = np.hypot(C, S)
    R_hat = (R_1 + R_2) / float(N)
    from ephysiopy.common.mle_von_mises_vals import vals
    mle_von_mises = np.array(vals)
    mle_von_mises = np.sort(mle_von_mises, 0)
    k_hat = mle_von_mises[(np.abs(mle_von_mises[:, 0]-R_hat)).argmin(), 1]
    g = 1 - (3 / 8 * k_hat)
    F = g * (N-2) * ((R_1 + R_2 - R) / (N - (R_1 + R_2)))
    return F

watsonsU2(a, b)

Tests whether two samples from circular observations differ significantly from each other with regard to mean direction or angular variance.

Parameters:

Name Type Description Default
a, b (array_like

The two samples to be tested

required

Returns:

Name Type Description
U2 float

The test statistic

Notes

Both samples must come from a continuous distribution. In the case of grouping the class interval should not exceed 5. Taken from '100 Statistical Tests' G.J.Kanji, 2006 Sage Publications

Source code in ephysiopy/common/statscalcs.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def watsonsU2(a, b):
    """
    Tests whether two samples from circular observations differ significantly 
    from each other with regard to mean direction or angular variance.

    Args:
        a, b (array_like): The two samples to be tested

    Returns:
        U2 (float): The test statistic

    Notes:
        Both samples must come from a continuous distribution. In the case of
        grouping the class interval should not exceed 5.
        Taken from '100 Statistical Tests' G.J.Kanji, 2006 Sage Publications
    """

    a = np.sort(np.ravel(a))
    b = np.sort(np.ravel(b))
    n_a = len(a)
    n_b = len(b)
    N = float(n_a + n_b)
    a_complex = duplicates_as_complex(a, True)
    b_complex = duplicates_as_complex(b, True)
    a_and_b = np.union1d(a_complex, b_complex)

    # get index for a
    a_ind = np.zeros(len(a_and_b), dtype=int)
    a_ind[np.searchsorted(a_and_b, a_complex)] = 1
    a_ind = np.cumsum(a_ind)

    # same for b
    b_ind = np.zeros(len(a_and_b), dtype=int)
    b_ind[np.searchsorted(a_and_b, b_complex)] = 1
    b_ind = np.cumsum(b_ind)

    d_k = (a_ind / float(n_a)) - (b_ind / float(n_b))

    d_k_sq = d_k ** 2

    U2 = ((n_a*n_b) / N**2) * (np.sum(d_k_sq) - ((np.sum(d_k)**2) / N))
    return U2

watsonsU2n(angles)

Tests whether the given distribution fits a random sample of angular values.

Parameters:

Name Type Description Default
angles array_like

The angular samples.

required

Returns:

Name Type Description
U2n float

The test statistic.

Notes

This test is suitable for both unimodal and the multimodal cases. It can be used as a test for randomness. Taken from '100 Statistical Tests' G.J.Kanji, 2006 Sage Publications.

Source code in ephysiopy/common/statscalcs.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def watsonsU2n(angles):
    """
    Tests whether the given distribution fits a random sample of angular
    values.

    Args:
        angles (array_like): The angular samples.

    Returns:
        U2n (float): The test statistic.

    Notes:
        This test is suitable for both unimodal and the multimodal cases.
        It can be used as a test for randomness.
        Taken from '100 Statistical Tests' G.J.Kanji, 2006 Sage Publications.
    """

    angles = np.sort(angles)
    n = len(angles)
    Vi = angles / float(360)
    sum_Vi = np.sum(Vi)
    sum_sq_Vi = np.sum(Vi**2)
    Ci = (2 * np.arange(1, n+1)) - 1
    sum_Ci_Vi_ov_n = np.sum(Ci * Vi / n)
    V_bar = (1 / float(n)) * sum_Vi
    U2n = sum_sq_Vi - sum_Ci_Vi_ov_n + (
        n * (1/float(3) - (V_bar - 0.5)**2))
    test_vals = {
        '0.1': 0.152, '0.05': 0.187, '0.025': 0.221,
        '0.01': 0.267, '0.005': 0.302}
    for key, val in test_vals.items():
        if U2n > val:
            print('The Watsons U2 statistic is {0} which is \
                greater than\n the critical value of {1} at p={2}'.format(
                    U2n, val, key))
        else:
            print('The Watsons U2 statistic is not \
                significant at p={0}'.format(key))
    return U2n

Utility functions

blurImage(im, n, ny=None, ftype='boxcar', **kwargs)

Smooths a 2D image by convolving with a filter.

Parameters:

Name Type Description Default
im array_like

The array to smooth.

required
n, ny (int

The size of the smoothing kernel.

required
ftype str

The type of smoothing kernel. Either 'boxcar' or 'gaussian'.

'boxcar'

Returns:

Name Type Description
res array_like

The smoothed vector with shape the same as im.

Source code in ephysiopy/common/utils.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def blurImage(im, n, ny=None, ftype='boxcar', **kwargs):
    """
    Smooths a 2D image by convolving with a filter.

    Args:
        im (array_like): The array to smooth.
        n, ny (int): The size of the smoothing kernel.
        ftype (str): The type of smoothing kernel.
            Either 'boxcar' or 'gaussian'.

    Returns:
        res (array_like): The smoothed vector with shape the same as im.
    """
    if 'stddev' in kwargs.keys():
        stddev = kwargs.pop('stddev')
    else:
        stddev = 5
    n = int(n)
    if not ny:
        ny = n
    else:
        ny = int(ny)
    ndims = im.ndim
    if 'box' in ftype:
        if ndims == 1:
            g = cnv.Box1DKernel(n)
        elif ndims == 2:
            g = cnv.Box2DKernel(n)
        elif ndims == 3:  # mutlidimensional binning
            g = cnv.Box2DKernel(n)
            g = np.atleast_3d(g).T
    elif 'gaussian' in ftype:
        if ndims == 1:
            g = cnv.Gaussian1DKernel(stddev, x_size=n)
        if ndims == 2:
            g = cnv.Gaussian2DKernel(stddev, x_size=n, y_size=ny)
        if ndims == 3:
            g = cnv.Gaussian2DKernel(stddev, x_size=n, y_size=ny)
            g = np.atleast_3d(g).T
    return cnv.convolve(im, g, boundary='extend')

bwperim(bw, n=4)

Finds the perimeter of objects in binary images.

A pixel is part of an object perimeter if its value is one and there is at least one zero-valued pixel in its neighborhood.

By default the neighborhood of a pixel is 4 nearest pixels, but if n is set to 8 the 8 nearest pixels will be considered.

Parameters:

Name Type Description Default
bw array_like

A black-and-white image.

required
n int

Connectivity. Must be 4 or 8. Default is 8.

4

Returns:

Name Type Description
perim array_like

A boolean image.

Source code in ephysiopy/common/utils.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def bwperim(bw, n=4):
    """
    Finds the perimeter of objects in binary images.

    A pixel is part of an object perimeter if its value is one and there
    is at least one zero-valued pixel in its neighborhood.

    By default the neighborhood of a pixel is 4 nearest pixels, but
    if `n` is set to 8 the 8 nearest pixels will be considered.

    Args:
        bw (array_like): A black-and-white image.
        n (int, optional): Connectivity. Must be 4 or 8. Default is 8.

    Returns:
        perim (array_like): A boolean image.
    """

    if n not in (4, 8):
        raise ValueError('mahotas.bwperim: n must be 4 or 8')
    rows, cols = bw.shape

    # Translate image by one pixel in all directions
    north = np.zeros((rows, cols))
    south = np.zeros((rows, cols))
    west = np.zeros((rows, cols))
    east = np.zeros((rows, cols))

    north[:-1, :] = bw[1:, :]
    south[1:, :] = bw[:-1, :]
    west[:, :-1] = bw[:, 1:]
    east[:, 1:] = bw[:, :-1]
    idx = (north == bw) & \
          (south == bw) & \
          (west == bw) & \
          (east == bw)
    if n == 8:
        north_east = np.zeros((rows, cols))
        north_west = np.zeros((rows, cols))
        south_east = np.zeros((rows, cols))
        south_west = np.zeros((rows, cols))
        north_east[:-1, 1:] = bw[1:, :-1]
        north_west[:-1, :-1] = bw[1:, 1:]
        south_east[1:, 1:] = bw[:-1, :-1]
        south_west[1:, :-1] = bw[:-1, 1:]
        idx &= (north_east == bw) & \
               (south_east == bw) & \
               (south_west == bw) & \
               (north_west == bw)
    return ~idx * bw

count_to(n)

This function is equivalent to hstack((arange(n_i) for n_i in n)). It seems to be faster for some possible inputs and encapsulates a task in a function.

Example

Given n = [0, 0, 3, 0, 0, 2, 0, 2, 1], the result would be [0, 1, 2, 0, 1, 0, 1, 0].

Source code in ephysiopy/common/utils.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def count_to(n):
    """
    This function is equivalent to hstack((arange(n_i) for n_i in n)).
    It seems to be faster for some possible inputs and encapsulates
    a task in a function.

    Example:
        Given n = [0, 0, 3, 0, 0, 2, 0, 2, 1],
        the result would be [0, 1, 2, 0, 1, 0, 1, 0].
    """
    if n.ndim != 1:
        raise Exception("n is supposed to be 1d array.")

    n_mask = n.astype(bool)
    n_cumsum = np.cumsum(n)
    ret = np.ones(n_cumsum[-1]+1, dtype=int)
    ret[n_cumsum[n_mask]] -= n[n_mask]
    ret[0] -= 1
    return np.cumsum(ret)[:-1]

get_z_score(x, mean=None, sd=None, axis=0)

Calculate the z-scores for array x based on the mean and standard deviation in that sample, unless stated

Source code in ephysiopy/common/utils.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def get_z_score(x: np.ndarray,
                mean=None,
                sd=None,
                axis=0) -> np.ndarray:
    '''
    Calculate the z-scores for array x based on the mean
    and standard deviation in that sample, unless stated
    '''
    if mean is None:
        mean = np.mean(x, axis=axis)
    if sd is None:
        sd = np.std(x, axis=axis)
    return (x - mean) / sd

polar(x, y, deg=False)

Converts from rectangular coordinates to polar ones.

Parameters:

Name Type Description Default
x, y (array_like, list_like

The x and y coordinates.

required
deg int

Radian if deg=0; degree if deg=1.

False

Returns:

Name Type Description
p array_like

The polar version of x and y.

Source code in ephysiopy/common/utils.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def polar(x, y, deg=False):
    """
    Converts from rectangular coordinates to polar ones.

    Args:
        x, y (array_like, list_like): The x and y coordinates.
        deg (int): Radian if deg=0; degree if deg=1.

    Returns:
        p (array_like): The polar version of x and y.
    """
    if deg:
        return np.hypot(x, y), 180.0 * np.arctan2(y, x) / np.pi
    else:
        return np.hypot(x, y), np.arctan2(y, x)

rect(r, w, deg=False)

Convert from polar (r,w) to rectangular (x,y) x = r cos(w) y = r sin(w)

Source code in ephysiopy/common/utils.py
187
188
189
190
191
192
193
194
195
196
def rect(r, w, deg=False):
    """
    Convert from polar (r,w) to rectangular (x,y)
    x = r cos(w)
    y = r sin(w)
    """
    # radian if deg=0; degree if deg=1
    if deg:
        w = np.pi * w / 180.0
    return r * np.cos(w), r * np.sin(w)

repeat_ind(n)

Examples:

>>> n = [0, 0, 3, 0, 0, 2, 0, 2, 1]
>>> res = repeat_ind(n)
>>> res = [2, 2, 2, 5, 5, 7, 7, 8]

The input specifies how many times to repeat the given index. It is equivalent to something like this:

hstack((zeros(n_i,dtype=int)+i for i, n_i in enumerate(n)))

But this version seems to be faster, and probably scales better. At any rate, it encapsulates a task in a function.

Source code in ephysiopy/common/utils.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def repeat_ind(n: np.array):
    """
    Examples:
        >>> n = [0, 0, 3, 0, 0, 2, 0, 2, 1]
        >>> res = repeat_ind(n)
        >>> res = [2, 2, 2, 5, 5, 7, 7, 8]

    The input specifies how many times to repeat the given index.
    It is equivalent to something like this:

        hstack((zeros(n_i,dtype=int)+i for i, n_i in enumerate(n)))

    But this version seems to be faster, and probably scales better.
    At any rate, it encapsulates a task in a function.
    """
    if n.ndim != 1:
        raise Exception("n is supposed to be 1d array.")

    res = [[idx]*a for idx, a in enumerate(n) if a != 0]
    return np.concatenate(res)

smooth(x, window_len=9, window='hanning')

Smooth the data using a window with requested size.

This method is based on the convolution of a scaled window with the signal. The signal is prepared by introducing reflected copies of the signal (with the window size) in both ends so that transient parts are minimized in the beginning and end part of the output signal.

Parameters:

Name Type Description Default
x array_like

The input signal.

required
window_len int

The length of the smoothing window.

9
window str

The type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'. 'flat' window will produce a moving average smoothing.

'hanning'

Returns:

Name Type Description
out array_like

The smoothed signal.

Example

t=linspace(-2,2,0.1) x=sin(t)+randn(len(t))*0.1 y=smooth(x)

See Also

numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve, scipy.signal.lfilter

Notes

The window parameter could be the window itself if an array instead of a string.

Source code in ephysiopy/common/utils.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def smooth(x, window_len=9, window='hanning'):
    """
    Smooth the data using a window with requested size.

    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal
    (with the window size) in both ends so that transient parts are minimized
    in the beginning and end part of the output signal.

    Args:
        x (array_like): The input signal.
        window_len (int): The length of the smoothing window.
        window (str): The type of window from 'flat', 'hanning', 'hamming', 
            'bartlett', 'blackman'. 'flat' window will produce a moving average 
            smoothing.

    Returns:
        out (array_like): The smoothed signal.

    Example:
        >>> t=linspace(-2,2,0.1)
        >>> x=sin(t)+randn(len(t))*0.1
        >>> y=smooth(x)

    See Also:
        numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman,
        numpy.convolve, scipy.signal.lfilter

    Notes:
        The window parameter could be the window itself if an array instead of
        a string.
    """

    if isinstance(x, list):
        x = np.array(x)

    if x.ndim != 1:
        raise ValueError("smooth only accepts 1 dimension arrays.")

    if len(x) < window_len:
        print("length of x: ", len(x))
        print("window_len: ", window_len)
        raise ValueError("Input vector needs to be bigger than window size.")
    if window_len < 3:
        return x

    if (window_len % 2) == 0:
        window_len = window_len + 1

    if window not in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
        raise ValueError(
            "Window is on of 'flat', 'hanning', \
                'hamming', 'bartlett', 'blackman'")

    if window == 'flat':  # moving average
        w = np.ones(window_len, 'd')
    else:
        w = eval('np.'+window+'(window_len)')
    y = cnv.convolve(x, w/w.sum(), normalize_kernel=False, boundary='extend')
    # return the smoothed signal
    return y

Axona input/ output

ClusterSession

Bases: object

Loads all the cut file data and timestamps from the data associated with the *.set filename given to init

Meant to be a method-replica of the KiloSortSession class but really both should inherit from the same meta-class

Source code in ephysiopy/axona/axonaIO.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
class ClusterSession(object):
    '''
    Loads all the cut file data and timestamps from the data 
    associated with the *.set filename given to __init__

    Meant to be a method-replica of the KiloSortSession class
    but really both should inherit from the same meta-class
    '''
    def __init__(self, fname_root) -> None:
        fname_root = Path(fname_root)
        assert fname_root.suffix == ".set"
        assert fname_root.exists()
        self.fname_root = fname_root

        self.cluster_id = None
        self.spk_clusters = None
        self.spk_times = None
        self.good_clusters = {}

    def load(self):
        pname = self.fname_root.parent
        pattern = re.compile(r'^'+str(
            self.fname_root.with_suffix(""))+r'_[1-9][0-9].cut')
        pattern1 = re.compile(r'^'+str(
            self.fname_root.with_suffix(""))+r'_[1-9]$.cut')
        cut_files = sorted(list(f for f in pname.iterdir()
                           if pattern.search(str(f)) or
                           pattern1.search(str(f))))
        # extract the clusters from each cut file
        # get the corresponding tetrode files
        tet_files = [str(c.with_suffix("")) for c in cut_files]
        tet_files = [t[::-1].replace("_", ".", 1)[::-1] for t in tet_files]
        tetrode_clusters = {}
        for fname in tet_files:
            T = IO(fname)
            idx = fname.rfind(".")
            tetnum = int(fname[idx+1:])
            cut = T.getCut(tetnum)
            tetrode_clusters[tetnum] = cut

        self.good_clusters = tetrode_clusters

EEG

Bases: IO

Processes eeg data collected with the Axona recording system

Parameters

filename_root : str The fully qualified filename without the suffix egf: int Whether to read the 'eeg' file or the 'egf' file. 0 is False, 1 is True eeg_file: int If more than one eeg channel was recorded from then they are numbered from 1 onwards i.e. trial.eeg, trial.eeg1, trial.eeg2 etc This number specifies that

Source code in ephysiopy/axona/axonaIO.py
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
class EEG(IO):
    """
    Processes eeg data collected with the Axona recording system

    Parameters
    ---------
    filename_root : str
        The fully qualified filename without the suffix
    egf: int
        Whether to read the 'eeg' file or the 'egf' file. 0 is False, 1 is True
    eeg_file: int
        If more than one eeg channel was recorded from then they are numbered
        from 1 onwards i.e. trial.eeg, trial.eeg1, trial.eeg2 etc
        This number specifies that

    """

    def __init__(self, filename_root: Path, eeg_file=1, egf=0):
        self.showfigs = 0
        filename_root = Path(os.path.splitext(filename_root)[0])
        self.filename_root = filename_root
        if egf == 0:
            if eeg_file == 1:
                eeg_suffix = ".eeg"
            else:
                eeg_suffix = ".eeg" + str(eeg_file)
        elif egf == 1:
            if eeg_file == 1:
                eeg_suffix = ".egf"
            else:
                eeg_suffix = ".egf" + str(eeg_file)
        self.header = self.getHeader(
            self.filename_root.with_suffix(eeg_suffix))
        self.eeg = self.getData(filename_root.with_suffix(eeg_suffix))["eeg"]
        # sometimes the eeg record is longer than reported in
        # the 'num_EEG_samples'
        # value of the header so eeg record should be truncated
        # to match 'num_EEG_samples'
        # TODO: this could be taken care of in the IO base class
        if egf:
            self.eeg = self.eeg[0: int(self.header["num_EGF_samples"])]
        else:
            self.eeg = self.eeg[0: int(self.header["num_EEG_samples"])]
        self.sample_rate = int(self.getHeaderVal(self.header, "sample_rate"))
        set_header = self.getHeader(self.filename_root.with_suffix(".set"))
        eeg_ch = int(set_header["EEG_ch_1"]) - 1
        eeg_gain = int(set_header["gain_ch_" + str(eeg_ch)])
        # EEG polarity is determined by the "mode_ch_n" key in the setfile
        # where n is the channel # for the eeg. The possibles values to these
        # keys are as follows:
        # 0 = Signal
        # 1 = Ref
        # 2 = -Signal
        # 3 = -Ref
        # 4 = Sig-Ref
        # 5 = Ref-Sig
        # 6 = grounded
        # So if the EEG has been recorded with -Signal (2) then the recorded
        # polarity is inverted with respect to that in the brain
        eeg_mode = int(set_header["mode_ch_" + set_header["EEG_ch_1"]])
        polarity = 1  # ensure it always has a value
        if eeg_mode == 2:
            polarity = -1
        ADC_mv = float(set_header["ADC_fullscale_mv"])
        scaling = (ADC_mv / 1000.0) * eeg_gain
        self.scaling = scaling
        self.gain = eeg_gain
        self.polarity = polarity
        denom = 128.0
        self.sig = (self.eeg / denom) * scaling * polarity  # eeg in microvolts
        self.EEGphase = None
        # x1 / x2 are the lower and upper limits of the eeg filter
        self.x1 = 6
        self.x2 = 12

IO

Bases: object

Axona data I/O. Also reads .clu files generated from KlustaKwik

Parameters

filename_root : str The fully-qualified filename

Source code in ephysiopy/axona/axonaIO.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
class IO(object):
    """
    Axona data I/O. Also reads .clu files generated from KlustaKwik

    Parameters
    ----------
    filename_root : str
        The fully-qualified filename
    """

    tetrode_files = dict.fromkeys(
        ["." + str(i) for i in range(1, 17)],
        # ts is a big-endian 32-bit integer
        # waveform is 50 signed 8-bit ints (a signed byte)
        [("ts", ">i"), ("waveform", "50b")]
    )
    other_files = {
        ".pos": [("ts", ">i"), ("pos", ">8h")],
        ".eeg": [("eeg", "=b")],
        ".eeg2": [("eeg", "=b")],
        ".egf": [("eeg", "int16")],
        ".egf2": [("eeg", "int16")],
        ".inp": [("ts", ">i4"), ("type", ">b"), ("value", ">2b")],
        ".log": [("state", "S3"), ("ts", ">i")],
        ".stm": [("ts", ">i")],
    }

    # this only works in >= Python3.5
    axona_files = {**other_files, **tetrode_files}

    def __init__(self, filename_root: Path = ""):
        self.filename_root = filename_root

    def getData(self, filename_root: str) -> np.ndarray:
        """
        Returns the data part of an Axona data file i.e. from "data_start" to
        "data_end"

        Parameters
        ----------
        input :  str
            Fully qualified path name to the data file

        Returns
        -------
        output : ndarray
            The data part of whatever file was fed in
        """
        n_samps = -1
        fType = os.path.splitext(filename_root)[1]
        if fType in self.axona_files:
            header = self.getHeader(filename_root)
            for key in header.keys():
                if len(fType) > 2:
                    if fnmatch.fnmatch(key, "num_*_samples"):
                        n_samps = int(header[key])
                else:
                    if key.startswith("num_spikes"):
                        n_samps = int(header[key]) * 4
            f = open(filename_root, "rb")
            data = f.read()
            st = data.find(b"data_start") + len("data_start")
            f.seek(st)
            dt = np.dtype(self.axona_files[fType])
            a = np.fromfile(f, dtype=dt, count=n_samps)
            f.close()
        else:
            raise IOError("File not in list of recognised Axona files")
        return a

    def getCluCut(self, tet: int) -> np.ndarray:
        """
        Load a clu file and return as an array of integers

        Parameters
        ----------
        tet : int
            The tetrode the clu file relates to

        Returns
        -------
        out : ndarray
            Data read from the clu file
        """
        filename_root = self.filename_root.with_suffix("." + "clu." + str(tet))
        if os.path.exists(filename_root):
            dt = np.dtype([("data", "<i")])
            clu_data = np.loadtxt(filename_root, dtype=dt)
            return clu_data["data"][1::]  # first entry is num of clusters
        else:
            return None

    def getCut(self, tet: int) -> list:
        """
        Returns the cut file as a list of integers

        Parameters
        ----------
        tet : int
            The tetrode the cut file relates to

        Returns
        -------
        out : ndarray
            The data read from the cut file
        """
        a = []
        filename_root = Path(os.path.splitext(
            self.filename_root)[0] + "_" + str(tet) + ".cut")

        if not os.path.exists(filename_root):
            cut = self.getCluCut(tet)
            if cut is not None:
                return cut - 1  # clusters 1 indexed in clu
            return cut
        with open(filename_root, "r") as f:
            cut_data = f.read()
            f.close()
        tmp = cut_data.split("spikes: ")
        tmp1 = tmp[1].split("\n")
        cut = tmp1[1:]
        for line in cut:
            m = line.split()
            for i in m:
                a.append(int(i))
        return a

    def setHeader(self, filename_root: str, header: dataclass):
        """
        Writes out the header to the specified file

        Parameters
        ------------
        filename_root : str
            A fully qualified path to a file with the relevant suffix at
            the end (e.g. ".set", ".pos" or whatever)

        header : dataclass
            See ephysiopy.axona.file_headers
        """
        with open(filename_root, "w") as f:
            with redirect_stdout(f):
                header.print()
            f.write("data_start")
            f.write("\r\n")
            f.write("data_end")
            f.write("\r\n")

    def setCut(self, filename_root: str,
               cut_header: dataclass,
               cut_data: np.array):
        fpath = Path(filename_root)
        n_clusters = len(np.unique(cut_data))
        cluster_entries = make_cluster_cut_entries(n_clusters)
        with open(filename_root, "w") as f:
            with redirect_stdout(f):
                cut_header.print()
            print(cluster_entries, file=f)
            print(f"Exact_cut_for: {fpath.stem}    spikes: {len(cut_data)}",
                  file=f)
            for num in cut_data:
                f.write(str(num))
                f.write(" ")

    def setData(self, filename_root: str, data: np.array):
        """
        Writes Axona format data to the given filename

        Parameters
        ----------
        filename_root : str
            The fully qualified filename including the suffix

        data : ndarray
            The data that will be saved
        """
        fType = os.path.splitext(filename_root)[1]
        if fType in self.axona_files:
            f = open(filename_root, "rb+")
            d = f.read()
            st = d.find(b"data_start") + len("data_start")
            f.seek(st)
            data.tofile(f)
            f.close()
            f = open(filename_root, "a")
            f.write("\r\n")
            f.write("data_end")
            f.write("\r\n")
            f.close()

    def getHeader(self, filename_root: str) -> dict:
        """
        Reads and returns the header of a specified data file as a dictionary

        Parameters
        ----------
        filename_root : str
            Fully qualified filename of Axona type

        Returns
        -------
        headerDict : dict
            key - value pairs of the header part of an Axona type file
        """
        with open(filename_root, "rb") as f:
            data = f.read()
            f.close()
        if os.path.splitext(filename_root)[1] != ".set":
            st = data.find(b"data_start") + len("data_start")
            header = data[0: st - len("data_start") - 2]
        else:
            header = data
        headerDict = {}
        lines = header.splitlines()
        for line in lines:
            line = str(line.decode("ISO-8859-1")).rstrip()
            line = line.split(" ", 1)
            try:
                headerDict[line[0]] = line[1]
            except IndexError:
                headerDict[line[0]] = ""
        return headerDict

    def getHeaderVal(self, header: dict, key: str) -> int:
        """
        Get a value from the header as an int

        Parameters
        ----------
        header : dict
            The header dictionary to read
        key : str
            The key to look up

        Returns
        -------
        value : int
            The value of `key` as an int
        """
        tmp = header[key]
        val = tmp.split(" ")
        val = val[0].split(".")
        val = int(val[0])
        return val

getCluCut(tet)

Load a clu file and return as an array of integers

Parameters

tet : int The tetrode the clu file relates to

Returns

out : ndarray Data read from the clu file

Source code in ephysiopy/axona/axonaIO.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def getCluCut(self, tet: int) -> np.ndarray:
    """
    Load a clu file and return as an array of integers

    Parameters
    ----------
    tet : int
        The tetrode the clu file relates to

    Returns
    -------
    out : ndarray
        Data read from the clu file
    """
    filename_root = self.filename_root.with_suffix("." + "clu." + str(tet))
    if os.path.exists(filename_root):
        dt = np.dtype([("data", "<i")])
        clu_data = np.loadtxt(filename_root, dtype=dt)
        return clu_data["data"][1::]  # first entry is num of clusters
    else:
        return None

getCut(tet)

Returns the cut file as a list of integers

Parameters

tet : int The tetrode the cut file relates to

Returns

out : ndarray The data read from the cut file

Source code in ephysiopy/axona/axonaIO.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def getCut(self, tet: int) -> list:
    """
    Returns the cut file as a list of integers

    Parameters
    ----------
    tet : int
        The tetrode the cut file relates to

    Returns
    -------
    out : ndarray
        The data read from the cut file
    """
    a = []
    filename_root = Path(os.path.splitext(
        self.filename_root)[0] + "_" + str(tet) + ".cut")

    if not os.path.exists(filename_root):
        cut = self.getCluCut(tet)
        if cut is not None:
            return cut - 1  # clusters 1 indexed in clu
        return cut
    with open(filename_root, "r") as f:
        cut_data = f.read()
        f.close()
    tmp = cut_data.split("spikes: ")
    tmp1 = tmp[1].split("\n")
    cut = tmp1[1:]
    for line in cut:
        m = line.split()
        for i in m:
            a.append(int(i))
    return a

getData(filename_root)

Returns the data part of an Axona data file i.e. from "data_start" to "data_end"

Parameters

input : str Fully qualified path name to the data file

Returns

output : ndarray The data part of whatever file was fed in

Source code in ephysiopy/axona/axonaIO.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def getData(self, filename_root: str) -> np.ndarray:
    """
    Returns the data part of an Axona data file i.e. from "data_start" to
    "data_end"

    Parameters
    ----------
    input :  str
        Fully qualified path name to the data file

    Returns
    -------
    output : ndarray
        The data part of whatever file was fed in
    """
    n_samps = -1
    fType = os.path.splitext(filename_root)[1]
    if fType in self.axona_files:
        header = self.getHeader(filename_root)
        for key in header.keys():
            if len(fType) > 2:
                if fnmatch.fnmatch(key, "num_*_samples"):
                    n_samps = int(header[key])
            else:
                if key.startswith("num_spikes"):
                    n_samps = int(header[key]) * 4
        f = open(filename_root, "rb")
        data = f.read()
        st = data.find(b"data_start") + len("data_start")
        f.seek(st)
        dt = np.dtype(self.axona_files[fType])
        a = np.fromfile(f, dtype=dt, count=n_samps)
        f.close()
    else:
        raise IOError("File not in list of recognised Axona files")
    return a

getHeader(filename_root)

Reads and returns the header of a specified data file as a dictionary

Parameters

filename_root : str Fully qualified filename of Axona type

Returns

headerDict : dict key - value pairs of the header part of an Axona type file

Source code in ephysiopy/axona/axonaIO.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def getHeader(self, filename_root: str) -> dict:
    """
    Reads and returns the header of a specified data file as a dictionary

    Parameters
    ----------
    filename_root : str
        Fully qualified filename of Axona type

    Returns
    -------
    headerDict : dict
        key - value pairs of the header part of an Axona type file
    """
    with open(filename_root, "rb") as f:
        data = f.read()
        f.close()
    if os.path.splitext(filename_root)[1] != ".set":
        st = data.find(b"data_start") + len("data_start")
        header = data[0: st - len("data_start") - 2]
    else:
        header = data
    headerDict = {}
    lines = header.splitlines()
    for line in lines:
        line = str(line.decode("ISO-8859-1")).rstrip()
        line = line.split(" ", 1)
        try:
            headerDict[line[0]] = line[1]
        except IndexError:
            headerDict[line[0]] = ""
    return headerDict

getHeaderVal(header, key)

Get a value from the header as an int

Parameters

header : dict The header dictionary to read key : str The key to look up

Returns

value : int The value of key as an int

Source code in ephysiopy/axona/axonaIO.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def getHeaderVal(self, header: dict, key: str) -> int:
    """
    Get a value from the header as an int

    Parameters
    ----------
    header : dict
        The header dictionary to read
    key : str
        The key to look up

    Returns
    -------
    value : int
        The value of `key` as an int
    """
    tmp = header[key]
    val = tmp.split(" ")
    val = val[0].split(".")
    val = int(val[0])
    return val

setData(filename_root, data)

Writes Axona format data to the given filename

Parameters

filename_root : str The fully qualified filename including the suffix

ndarray

The data that will be saved

Source code in ephysiopy/axona/axonaIO.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def setData(self, filename_root: str, data: np.array):
    """
    Writes Axona format data to the given filename

    Parameters
    ----------
    filename_root : str
        The fully qualified filename including the suffix

    data : ndarray
        The data that will be saved
    """
    fType = os.path.splitext(filename_root)[1]
    if fType in self.axona_files:
        f = open(filename_root, "rb+")
        d = f.read()
        st = d.find(b"data_start") + len("data_start")
        f.seek(st)
        data.tofile(f)
        f.close()
        f = open(filename_root, "a")
        f.write("\r\n")
        f.write("data_end")
        f.write("\r\n")
        f.close()

setHeader(filename_root, header)

Writes out the header to the specified file

Parameters

filename_root : str A fully qualified path to a file with the relevant suffix at the end (e.g. ".set", ".pos" or whatever)

dataclass

See ephysiopy.axona.file_headers

Source code in ephysiopy/axona/axonaIO.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def setHeader(self, filename_root: str, header: dataclass):
    """
    Writes out the header to the specified file

    Parameters
    ------------
    filename_root : str
        A fully qualified path to a file with the relevant suffix at
        the end (e.g. ".set", ".pos" or whatever)

    header : dataclass
        See ephysiopy.axona.file_headers
    """
    with open(filename_root, "w") as f:
        with redirect_stdout(f):
            header.print()
        f.write("data_start")
        f.write("\r\n")
        f.write("data_end")
        f.write("\r\n")

Pos

Bases: IO

Processs position data recorded with the Axona recording system

Parameters

filename_root : str The basename of the file i.e mytrial as opposed to mytrial.pos

Notes

Currently the only arg that does anything is 'cm' which will convert the xy data to cm, assuming that the pixels per metre value has been set correctly

Source code in ephysiopy/axona/axonaIO.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
class Pos(IO):
    """
    Processs position data recorded with the Axona recording system

    Parameters
    ----------
    filename_root : str
        The basename of the file i.e mytrial as opposed to mytrial.pos

    Notes
    -----
    Currently the only arg that does anything is 'cm' which will convert
    the xy data to cm, assuming that the pixels per metre value has been
    set correctly
    """

    def __init__(self, filename_root: Path, *args, **kwargs):
        filename_root = Path(filename_root)
        if filename_root.suffix == ".set":
            filename_root = Path(os.path.splitext(filename_root)[0])
        self.filename_root = filename_root
        self.header = self.getHeader(filename_root.with_suffix(".pos"))
        self.setheader = None
        self.setheader = self.getHeader(filename_root.with_suffix(".set"))
        self.posProcessed = False
        posData = self.getData(filename_root.with_suffix(".pos"))
        self.nLEDs = 1
        if self.setheader is not None:
            self.nLEDs = sum(
                [
                    self.getHeaderVal(self.setheader, "colactive_1"),
                    self.getHeaderVal(self.setheader, "colactive_2"),
                ]
            )
        if self.nLEDs == 1:
            self.led_pos = np.ma.MaskedArray(posData["pos"][:, 0:2])
            self.led_pix = np.ma.MaskedArray([posData["pos"][:, 4]])
        if self.nLEDs == 2:
            self.led_pos = np.ma.MaskedArray(posData["pos"][:, 0:4])
            self.led_pix = np.ma.MaskedArray(posData["pos"][:, 4:6])
        self.led_pos = np.ma.masked_equal(self.led_pos, 1023)
        self.led_pix = np.ma.masked_equal(self.led_pix, 1023)
        self.ts = np.array(posData["ts"])
        self.npos = len(self.led_pos[0])
        self.xy = np.ones([2, self.npos]) * np.nan
        self.dir = np.ones([self.npos]) * np.nan
        self.dir_disp = np.ones([self.npos]) * np.nan
        self.speed = np.ones([self.npos]) * np.nan
        self.pos_sample_rate = self.getHeaderVal(self.header, "sample_rate")
        self._ppm = None
        if "cm" in kwargs:
            self.cm = kwargs["cm"]
        else:
            self.cm = False

    @property
    def ppm(self):
        if self._ppm is None:
            self._ppm = self.getHeaderVal(self.header, "pixels_per_metre")
        return self._ppm

    @ppm.setter
    def ppm(self, value):
        self._ppm = value

Stim

Bases: dict, IO

Processes the stimulation data recorded using Axona

Parameters

filename_root : str The fully qualified filename without the suffix

Source code in ephysiopy/axona/axonaIO.py
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
class Stim(dict, IO):
    """
    Processes the stimulation data recorded using Axona

    Parameters
    ----------
    filename_root : str
        The fully qualified filename without the suffix
    """

    def __init__(self, filename_root: Path, *args, **kwargs):
        self.update(*args, **kwargs)
        filename_root = Path(os.path.splitext(filename_root)[0])
        self.filename_root = filename_root
        stmData = self.getData(filename_root.with_suffix(".stm"))
        times = stmData["ts"]
        stmHdr = self.getHeader(filename_root.with_suffix(".stm"))
        for k, v in stmHdr.items():
            self.__setitem__(k, v)
        tb = int(self["timebase"].split(" ")[0])
        self.timebase = tb
        times = times / tb
        self.__setitem__("ttl_timestamps", times * 1000)  # in ms
        # the 'duration' value in the header of the .stm file
        # is not correct so we need to read this from the .set
        # file and update
        setHdr = self.getHeader(filename_root.with_suffix(".set"))
        stim_duration = [setHdr[k] for k in setHdr.keys() if 'stim_pwidth' in k][0]
        stim_duration = int(stim_duration)
        stim_duration = stim_duration / 1000  # in ms now
        self.__setitem__('stim_duration', stim_duration)

    def update(self, *args, **kwargs):
        for k, v in dict(*args, **kwargs).items():
            self[k] = v

    def __getitem__(self, key):
        val = dict.__getitem__(self, key)
        return val

    def __setitem__(self, key, val):
        dict.__setitem__(self, key, val)

Tetrode

Bases: IO

Processes tetrode files recorded with the Axona recording system

Mostly this class deals with interpolating tetrode and position timestamps and getting indices for particular clusters.

Parameters

filename_root : str The fully qualified name of the file without it's suffix tetrode : int The number of the tetrode volts : bool, optional Whether to convert the data values volts. Default True

Source code in ephysiopy/axona/axonaIO.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
class Tetrode(IO):
    """
    Processes tetrode files recorded with the Axona recording system

    Mostly this class deals with interpolating tetrode and position timestamps
    and getting indices for particular clusters.

    Parameters
    ---------
    filename_root : str
        The fully qualified name of the file without it's suffix
    tetrode : int
        The number of the tetrode
    volts : bool, optional
        Whether to convert the data values volts. Default True
    """

    def __init__(self, filename_root: Path, tetrode, volts=True):
        filename_root = Path(filename_root)
        if filename_root.suffix == ".set":
            filename_root = Path(os.path.splitext(filename_root)[0])
        self.filename_root = filename_root
        self.tetrode = tetrode
        self.volts = volts
        self.header = self.getHeader(
            self.filename_root.with_suffix("." + str(tetrode)))
        data = self.getData(filename_root.with_suffix("." + str(tetrode)))
        self.spk_ts = data["ts"][::4]
        self.nChans = self.getHeaderVal(self.header, "num_chans")
        self.samples = self.getHeaderVal(self.header, "samples_per_spike")
        self.nSpikes = self.getHeaderVal(self.header, "num_spikes")
        self.duration = self.getHeaderVal(self.header, "duration")
        self.posSampleRate = self.getHeaderVal(
            self.getHeader(
                self.filename_root.with_suffix(".pos")), "sample_rate"
        )
        self.waveforms = data["waveform"].reshape(
            self.nSpikes, self.nChans, self.samples
        )
        del data
        if volts:
            set_header = self.getHeader(self.filename_root.with_suffix(".set"))
            gains = np.zeros(4)
            st = (tetrode - 1) * 4
            for i, g in enumerate(np.arange(st, st + 4)):
                gains[i] = int(set_header["gain_ch_" + str(g)])
            ADC_mv = int(set_header["ADC_fullscale_mv"])
            scaling = (ADC_mv / 1000.0) / gains
            self.scaling = scaling
            self.gains = gains
            self.waveforms = (self.waveforms / 128.0) * scaling[:, np.newaxis]
        self.timebase = self.getHeaderVal(self.header, "timebase")
        cut = np.array(self.getCut(self.tetrode), dtype=int)
        self.cut = cut
        self.clusters = np.unique(self.cut)
        self.pos_samples = None

    def getSpkTS(self):
        """
        Return all the timestamps for all the spikes on the tetrode
        """
        return np.ma.compressed(self.spk_ts)

    def getClustTS(self, cluster: int = None):
        """
        Returns the timestamps for a cluster on the tetrode

        Parameters
        ----------
        cluster : int
            The cluster whose timestamps we want

        Returns
        -------
        clustTS : ndarray
            The timestamps

        Notes
        -----
        If None is supplied as input then all timestamps for all clusters
        is returned i.e. getSpkTS() is called
        """
        clustTS = None
        if cluster is None:
            clustTS = self.getSpkTS()
        else:
            if self.cut is None:
                cut = np.array(self.getCut(self.tetrode), dtype=int)
                self.cut = cut
            if self.cut is not None:
                clustTS = np.ma.compressed(self.spk_ts[self.cut == cluster])
        return clustTS

    def getPosSamples(self):
        """
        Returns the pos samples at which the spikes were captured
        """
        self.pos_samples = np.floor(
            self.getSpkTS() / float(self.timebase) * self.posSampleRate
        ).astype(int)
        return np.ma.compressed(self.pos_samples)

    def getClustSpks(self, cluster: int):
        """
        Returns the waveforms of `cluster`

        Parameters
        ----------
        cluster : int
            The cluster whose waveforms we want

        Returns
        -------
        waveforms : ndarray
            The waveforms on all 4 electrodes of the tgtrode so the shape of
            the returned array is [nClusterSpikes, 4, 50]
        """
        if self.cut is None:
            self.getClustTS(cluster)
        return self.waveforms[self.cut == cluster, :, :]

    def getClustIdx(self, cluster: int):
        """
        Get the indices of the position samples corresponding to the cluster

        Parameters
        ----------
        cluster : int
            The cluster whose position indices we want

        Returns
        -------
        pos_samples : ndarray
            The indices of the position samples, dtype is int
        """
        if self.cut is None:
            cut = np.array(self.getCut(self.tetrode), dtype=int)
            self.cut = cut
            if self.cut is None:
                return None
        if self.pos_samples is None:
            self.getPosSamples()  # sets self.pos_samples
        return self.pos_samples[self.cut == cluster].astype(int)

    def getUniqueClusters(self):
        """
        Returns the unique clusters
        """
        if self.cut is None:
            cut = np.array(self.getCut(self.tetrode), dtype=int)
            self.cut = cut
        return np.unique(self.cut)

getClustIdx(cluster)

Get the indices of the position samples corresponding to the cluster

Parameters

cluster : int The cluster whose position indices we want

Returns

pos_samples : ndarray The indices of the position samples, dtype is int

Source code in ephysiopy/axona/axonaIO.py
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def getClustIdx(self, cluster: int):
    """
    Get the indices of the position samples corresponding to the cluster

    Parameters
    ----------
    cluster : int
        The cluster whose position indices we want

    Returns
    -------
    pos_samples : ndarray
        The indices of the position samples, dtype is int
    """
    if self.cut is None:
        cut = np.array(self.getCut(self.tetrode), dtype=int)
        self.cut = cut
        if self.cut is None:
            return None
    if self.pos_samples is None:
        self.getPosSamples()  # sets self.pos_samples
    return self.pos_samples[self.cut == cluster].astype(int)

getClustSpks(cluster)

Returns the waveforms of cluster

Parameters

cluster : int The cluster whose waveforms we want

Returns

waveforms : ndarray The waveforms on all 4 electrodes of the tgtrode so the shape of the returned array is [nClusterSpikes, 4, 50]

Source code in ephysiopy/axona/axonaIO.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def getClustSpks(self, cluster: int):
    """
    Returns the waveforms of `cluster`

    Parameters
    ----------
    cluster : int
        The cluster whose waveforms we want

    Returns
    -------
    waveforms : ndarray
        The waveforms on all 4 electrodes of the tgtrode so the shape of
        the returned array is [nClusterSpikes, 4, 50]
    """
    if self.cut is None:
        self.getClustTS(cluster)
    return self.waveforms[self.cut == cluster, :, :]

getClustTS(cluster=None)

Returns the timestamps for a cluster on the tetrode

Parameters

cluster : int The cluster whose timestamps we want

Returns

clustTS : ndarray The timestamps

Notes

If None is supplied as input then all timestamps for all clusters is returned i.e. getSpkTS() is called

Source code in ephysiopy/axona/axonaIO.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def getClustTS(self, cluster: int = None):
    """
    Returns the timestamps for a cluster on the tetrode

    Parameters
    ----------
    cluster : int
        The cluster whose timestamps we want

    Returns
    -------
    clustTS : ndarray
        The timestamps

    Notes
    -----
    If None is supplied as input then all timestamps for all clusters
    is returned i.e. getSpkTS() is called
    """
    clustTS = None
    if cluster is None:
        clustTS = self.getSpkTS()
    else:
        if self.cut is None:
            cut = np.array(self.getCut(self.tetrode), dtype=int)
            self.cut = cut
        if self.cut is not None:
            clustTS = np.ma.compressed(self.spk_ts[self.cut == cluster])
    return clustTS

getPosSamples()

Returns the pos samples at which the spikes were captured

Source code in ephysiopy/axona/axonaIO.py
420
421
422
423
424
425
426
427
def getPosSamples(self):
    """
    Returns the pos samples at which the spikes were captured
    """
    self.pos_samples = np.floor(
        self.getSpkTS() / float(self.timebase) * self.posSampleRate
    ).astype(int)
    return np.ma.compressed(self.pos_samples)

getSpkTS()

Return all the timestamps for all the spikes on the tetrode

Source code in ephysiopy/axona/axonaIO.py
384
385
386
387
388
def getSpkTS(self):
    """
    Return all the timestamps for all the spikes on the tetrode
    """
    return np.ma.compressed(self.spk_ts)

getUniqueClusters()

Returns the unique clusters

Source code in ephysiopy/axona/axonaIO.py
471
472
473
474
475
476
477
478
def getUniqueClusters(self):
    """
    Returns the unique clusters
    """
    if self.cut is None:
        cut = np.array(self.getCut(self.tetrode), dtype=int)
        self.cut = cut
    return np.unique(self.cut)

Conversion code

OpenEphys to Axona

OE2Axona

Bases: object

Converts openephys data into Axona files

Example workflow:

You have recorded some openephys data using the binary format leading to a directory structure something like this:

M4643_2023-07-21_11-52-02 ├── Record Node 101 │ ├── experiment1 │ │ └── recording1 │ │ ├── continuous │ │ │ └── Acquisition_Board-100.Rhythm Data │ │ │ ├── amplitudes.npy │ │ │ ├── channel_map.npy │ │ │ ├── channel_positions.npy │ │ │ ├── cluster_Amplitude.tsv │ │ │ ├── cluster_ContamPct.tsv │ │ │ ├── cluster_KSLabel.tsv │ │ │ ├── continuous.dat │ │ │ ├── params.py │ │ │ ├── pc_feature_ind.npy │ │ │ ├── pc_features.npy │ │ │ ├── phy.log │ │ │ ├── rez.mat │ │ │ ├── similar_templates.npy │ │ │ ├── spike_clusters.npy │ │ │ ├── spike_templates.npy │ │ │ ├── spike_times.npy │ │ │ ├── template_feature_ind.npy │ │ │ ├── template_features.npy │ │ │ ├── templates_ind.npy │ │ │ ├── templates.npy │ │ │ ├── whitening_mat_inv.npy │ │ │ └── whitening_mat.npy │ │ ├── events │ │ │ ├── Acquisition_Board-100.Rhythm Data │ │ │ │ └── TTL │ │ │ │ ├── full_words.npy │ │ │ │ ├── sample_numbers.npy │ │ │ │ ├── states.npy │ │ │ │ └── timestamps.npy │ │ │ └── MessageCenter │ │ │ ├── sample_numbers.npy │ │ │ ├── text.npy │ │ │ └── timestamps.npy │ │ ├── structure.oebin │ │ └── sync_messages.txt │ └── settings.xml └── Record Node 104 ├── experiment1 │ └── recording1 │ ├── continuous │ │ └── TrackMe-103.TrackingNode │ │ ├── continuous.dat │ │ ├── sample_numbers.npy │ │ └── timestamps.npy │ ├── events │ │ ├── MessageCenter │ │ │ ├── sample_numbers.npy │ │ │ ├── text.npy │ │ │ └── timestamps.npy │ │ └── TrackMe-103.TrackingNode │ │ └── TTL │ │ ├── full_words.npy │ │ ├── sample_numbers.npy │ │ ├── states.npy │ │ └── timestamps.npy │ ├── structure.oebin │ └── sync_messages.txt └── settings.xml

The binary data file is called "continuous.dat" in the continuous/Acquisition_Board-100.Rhythm Data folder. There is also a collection of files resulting from a KiloSort session in that directory.

Run the conversion code like so:

from ephysiopy.format_converters.OE_Axona import OE2Axona from pathlib import Path nChannels = 64 apData = Path("M4643_2023-07-21_11-52-02/Record Node 101/experiment1/recording1/continuous/Acquisition_Board-100.Rhythm Data") OE = OE2Axona(Path("M4643_2023-07-21_11-52-02"), path2APData=apData, channels=nChannels) OE.getOEData()

The last command will attempt to load position data and also load up something called a TemplateModel (from the package phylib) which should grab a handle to the neural data. If that doesn't throw out errors then try:

OE.exportPos()

There are a few arguments you can provide the exportPos() function - see the docstring for it below. Basically, it calls a function called convertPosData(xy, xyts) where xy is the xy data with shape nsamples x 2 and xyts is a vector of timestamps. So if the call to exportPos() fails, you could try calling convertPosData() directly which returns axona formatted position data. If the variable returned from convertPosData() is called axona_pos_data then you can call the function:

writePos2AxonaFormat(pos_header, axona_pos_data)

Providing the pos_header to it - see the last half of the exportPos function for how to create and modify the pos_header as that will need to have user-specific information added to it.

OE.convertTemplateDataToAxonaTetrode()

This is the main function for creating the tetrode files. It has an optional argument called max_n_waves which is used to limit the maximum number of spikes that make up a cluster. This defaults to 2000 which means that if a cluster has 12000 spikes, it will have 2000 spikes randomly drawn from those 12000 (without replacement), that will then be saved to a tetrode file. This is mostly a time-saving device as if you have 250 clusters and many consist of 10,000's of spikes, processing that data will take a long time.

OE.exportLFP()

This will save either a .eeg or .egf file depending on the arguments. Check the docstring for how to change what channel is chosen for the LFP etc.

OE.exportSetFile()

This should save the .set file with all the metadata for the trial.

Source code in ephysiopy/format_converters/OE_Axona.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
class OE2Axona(object):
    """
    Converts openephys data into Axona files

    Example workflow:

    You have recorded some openephys data using the binary
    format leading to a directory structure something like this:

    M4643_2023-07-21_11-52-02
    ├── Record Node 101
    │ ├── experiment1
    │ │ └── recording1
    │ │     ├── continuous
    │ │     │ └── Acquisition_Board-100.Rhythm Data
    │ │     │     ├── amplitudes.npy
    │ │     │     ├── channel_map.npy
    │ │     │     ├── channel_positions.npy
    │ │     │     ├── cluster_Amplitude.tsv
    │ │     │     ├── cluster_ContamPct.tsv
    │ │     │     ├── cluster_KSLabel.tsv
    │ │     │     ├── continuous.dat
    │ │     │     ├── params.py
    │ │     │     ├── pc_feature_ind.npy
    │ │     │     ├── pc_features.npy
    │ │     │     ├── phy.log
    │ │     │     ├── rez.mat
    │ │     │     ├── similar_templates.npy
    │ │     │     ├── spike_clusters.npy
    │ │     │     ├── spike_templates.npy
    │ │     │     ├── spike_times.npy
    │ │     │     ├── template_feature_ind.npy
    │ │     │     ├── template_features.npy
    │ │     │     ├── templates_ind.npy
    │ │     │     ├── templates.npy
    │ │     │     ├── whitening_mat_inv.npy
    │ │     │     └── whitening_mat.npy
    │ │     ├── events
    │ │     │ ├── Acquisition_Board-100.Rhythm Data
    │ │     │ │ └── TTL
    │ │     │ │     ├── full_words.npy
    │ │     │ │     ├── sample_numbers.npy
    │ │     │ │     ├── states.npy
    │ │     │ │     └── timestamps.npy
    │ │     │ └── MessageCenter
    │ │     │     ├── sample_numbers.npy
    │ │     │     ├── text.npy
    │ │     │     └── timestamps.npy
    │ │     ├── structure.oebin
    │ │     └── sync_messages.txt
    │ └── settings.xml
    └── Record Node 104
        ├── experiment1
        │ └── recording1
        │     ├── continuous
        │     │ └── TrackMe-103.TrackingNode
        │     │     ├── continuous.dat
        │     │     ├── sample_numbers.npy
        │     │     └── timestamps.npy
        │     ├── events
        │     │ ├── MessageCenter
        │     │ │ ├── sample_numbers.npy
        │     │ │ ├── text.npy
        │     │ │ └── timestamps.npy
        │     │ └── TrackMe-103.TrackingNode
        │     │     └── TTL
        │     │         ├── full_words.npy
        │     │         ├── sample_numbers.npy
        │     │         ├── states.npy
        │     │         └── timestamps.npy
        │     ├── structure.oebin
        │     └── sync_messages.txt
        └── settings.xml

    The binary data file is called "continuous.dat" in the
    continuous/Acquisition_Board-100.Rhythm Data folder. There
    is also a collection of files resulting from a KiloSort session
    in that directory.

    Run the conversion code like so:

    >>> from ephysiopy.format_converters.OE_Axona import OE2Axona
    >>> from pathlib import Path
    >>> nChannels = 64
    >>> apData = Path("M4643_2023-07-21_11-52-02/Record Node 101/experiment1/recording1/continuous/Acquisition_Board-100.Rhythm Data")
    >>> OE = OE2Axona(Path("M4643_2023-07-21_11-52-02"), path2APData=apData, channels=nChannels)
    >>> OE.getOEData()

    The last command will attempt to load position data and also load up
    something called a TemplateModel (from the package phylib) which
    should grab a handle to the neural data. If that doesn't throw
    out errors then try:

    >>> OE.exportPos()

    There are a few arguments you can provide the exportPos() function - see
    the docstring for it below. Basically, it calls a function called
    convertPosData(xy, xyts) where xy is the xy data with shape nsamples x 2
    and xyts is a vector of timestamps. So if the call to exportPos() fails, you
    could try calling convertPosData() directly which returns axona formatted 
    position data. If the variable returned from convertPosData() is called axona_pos_data
    then you can call the function:

    writePos2AxonaFormat(pos_header, axona_pos_data)

    Providing the pos_header to it - see the last half of the exportPos function
    for how to create and modify the pos_header as that will need to have
    user-specific information added to it.

    >>> OE.convertTemplateDataToAxonaTetrode()

    This is the main function for creating the tetrode files. It has an optional
    argument called max_n_waves which is used to limit the maximum number of spikes
    that make up a cluster. This defaults to 2000 which means that if a cluster has 
    12000 spikes, it will have 2000 spikes randomly drawn from those 12000 (without
    replacement), that will then be saved to a tetrode file. This is mostly a time-saving
    device as if you have 250 clusters and many consist of 10,000's of spikes,
    processing that data will take a long time.

    >>> OE.exportLFP()

    This will save either a .eeg or .egf file depending on the arguments. Check the
    docstring for how to change what channel is chosen for the LFP etc.

    >>> OE.exportSetFile()

    This should save the .set file with all the metadata for the trial.

    """

    def __init__(self,
                 pname: Path,
                 path2APData: Path = None,
                 pos_sample_rate: int = 50,
                 channels: int = 0,
                 **kwargs):
        """
        Args:
            pname (Path): The base directory of the openephys recording.
                e.g. '/home/robin/Data/M4643_2023-07-21_11-52-02'
            path2APData (Path, optional): Path to AP data. Defaults to None.
            pos_sample_rate (int, optional): Position sample rate. Defaults to 50.
            channels (int, optional): Number of channels. Defaults to 0.
            **kwargs: Variable length argument list.
        """
        pname = Path(pname)
        assert pname.exists()
        self.pname: Path = pname
        self.path2APdata: Path = path2APData
        self.pos_sample_rate: int = pos_sample_rate
        # 'experiment_1.nwb'
        self.experiment_name: Path = self.pname or Path(kwargs['experiment_name'])
        self.recording_name = None  # will become 'recording1' etc
        self.OE_data = None  # becomes instance of io.recording.OpenEphysBase
        self._settings = None  # will become an instance of OESettings.Settings
        # Create a basename for Axona file names
        # e.g.'/home/robin/Data/experiment_1'
        # that we can append '.pos' or '.eeg' or whatever onto
        self.axona_root_name = self.experiment_name
        # need to instantiated now for later
        self.AxonaData = axonaIO.IO(self.axona_root_name.name + ".pos")
        # THIS IS TEMPORARY AND WILL BE MORE USER-SPECIFIABLE IN THE FUTURE
        # it is used to scale the spikes
        self.hp_gain = 500
        self.lp_gain = 15000
        self.bitvolts = 0.195
        # if left as None some default values for the next 3 params are loaded
        #  from top-level __init__.py
        # these are only used in self.__filterLFP__
        self.fs = None
        # if lfp_channel is set to None then the .set file will reflect that
        #  no EEG was recorded
        # this should mean that you can load data into Tint without a .eeg file
        self.lfp_channel = 1 or kwargs["lfp_channel"]
        self.lfp_lowcut = None
        self.lfp_highcut = None
        # set the tetrodes to record from
        # defaults to 1 through 4 - see self.makeSetData below
        self.tetrodes = ["1", "2", "3", "4"]
        self.channel_count = channels

    def resample(self, data, src_rate=30, dst_rate=50, axis=0):
        """
        Resamples data using FFT
        """
        denom = np.gcd(dst_rate, src_rate)
        new_data = signal.resample_poly(
            data, dst_rate / denom, src_rate / denom, axis)
        return new_data

    @property
    def settings(self):
        """
        Loads the settings data from the settings.xml file
        """
        if self._settings is None:
            self._settings = OESettings.Settings(self.pname)
        return self._settings

    @settings.setter
    def settings(self, value):
        self._settings = value

    def getOEData(self) -> OpenEphysBase:
        """
        Loads the nwb file names in filename_root and returns a dict
        containing some of the nwb data relevant for converting to Axona file formats.

        Args:
            filename_root (str): Fully qualified name of the nwb file.
            recording_name (str): The name of the recording in the nwb file. Note that
                the default has changed in different versions of OE from 'recording0'
                to 'recording1'.
        """
        OE_data = OpenEphysBase(self.pname)
        try:
            OE_data.load_pos_data(sample_rate=self.pos_sample_rate)
            # It's likely that spikes have been collected after the last
            # position sample
            # due to buffering issues I can't be bothered to resolve.
            # Get the last pos
            # timestamps here and check that spikes don't go beyond
            #  this when writing data
            # out later
            # Also the pos and spike data timestamps almost never start at
            #  0 as the user
            # usually acquires data for a while before recording.
            # Grab the first timestamp
            # here with a view to subtracting this from everything (including
            # the spike data)
            # and figuring out what to keep later
            first_pos_ts = OE_data.PosCalcs.xyTS[0]
            last_pos_ts = OE_data.PosCalcs.xyTS[-1]
            self.first_pos_ts = first_pos_ts
            self.last_pos_ts = last_pos_ts
        except Exception:
            OE_data.load_neural_data()  # will create TemplateModel instance
            self.first_pos_ts = 0
            self.last_pos_ts = self.OE_data.template_model.duration
        print(f"First pos ts: {self.first_pos_ts}")
        print(f"Last pos ts: {self.last_pos_ts}")
        self.OE_data = OE_data
        if self.path2APdata is None:
            self.path2APdata = self.OE_data.path2APdata
        # extract number of channels from settings
        for item in self.settings.record_nodes.items():
            if "Rhythm Data" in item[1].name:
                self.channel_count = int(item[1].channel_count)
        return OE_data

    def exportSetFile(self, **kwargs):
        """
        Wrapper for makeSetData below
        """
        print("Exporting set file data...")
        self.makeSetData(**kwargs)
        print("Done exporting set file.")

    def exportPos(self, ppm=300, jumpmax=100, as_text=False, **kwargs):
        #
        # Step 1) Deal with the position data first:
        #
        # Grab the settings of the pos tracker and do some post-processing
        # on the position
        # data (discard jumpy data, do some smoothing etc)
        self.settings.parse()
        if not self.OE_data:
            self.getOEData()
        if not self.OE_data.PosCalcs:
            self.OE_data.load_pos_data(sample_rate=self.pos_sample_rate)
        print("Post-processing position data...")
        self.OE_data.PosCalcs.jumpmax = jumpmax
        self.OE_data.PosCalcs.tracker_params["AxonaBadValue"] = 1023
        self.OE_data.PosCalcs.postprocesspos(
            self.OE_data.PosCalcs.tracker_params)
        xy = self.OE_data.PosCalcs.xy.T
        xyTS = self.OE_data.PosCalcs.xyTS  # in seconds
        xyTS = xyTS * self.pos_sample_rate
        # extract some values from PosCalcs or overrides given
        # when calling this method
        ppm = self.OE_data.PosCalcs.ppm or ppm
        sample_rate = self.OE_data.PosCalcs.sample_rate or kwargs["sample_rate"]
        if as_text is True:
            print("Beginning export of position data to text format...")
            pos_file_name = self.axona_root_name + ".txt"
            np.savetxt(pos_file_name, xy, fmt="%1.u")
            print("Completed export of position data")
            return
        # Do the upsampling of both xy and the timestamps
        print("Beginning export of position data to Axona format...")
        axona_pos_data = self.convertPosData(xy, xyTS)
        # make sure pos data length is same as duration * num_samples
        axona_pos_data = axona_pos_data[
            0: int(self.last_pos_ts - self.first_pos_ts) * self.pos_sample_rate
        ]
        # Create an empty header for the pos data
        from ephysiopy.axona.file_headers import PosHeader

        pos_header = PosHeader()
        tracker_params = self.OE_data.PosCalcs.tracker_params
        min_xy = np.floor(np.min(xy, 0)).astype(int).data
        max_xy = np.ceil(np.max(xy, 0)).astype(int).data
        pos_header.pos["min_x"] = pos_header.pos["window_min_x"] = str(
            tracker_params["LeftBorder"]) if "LeftBorder" in \
            tracker_params.keys() else str(min_xy[0])
        pos_header.pos["min_y"] = pos_header.pos["window_min_y"] = str(
            tracker_params["TopBorder"]) if "TopBorder" in \
            tracker_params.keys() else str(min_xy[1])
        pos_header.pos["max_x"] = pos_header.pos["window_max_x"] = str(
            tracker_params["RightBorder"]) if "RightBorder" in \
            tracker_params.keys() else str(max_xy[0])
        pos_header.pos["max_y"] = pos_header.pos["window_max_y"] = str(
            tracker_params["BottomBorder"]) if "BottomBorder" in \
            tracker_params.keys() else str(max_xy[1])
        pos_header.common["duration"] = str(
            int(self.last_pos_ts - self.first_pos_ts))
        pos_header.pos["pixels_per_metre"] = str(ppm)
        pos_header.pos["num_pos_samples"] = str(len(axona_pos_data))
        pos_header.pos["pixels_per_metre"] = str(ppm)
        pos_header.pos["sample_rate"] = str(sample_rate)

        self.writePos2AxonaFormat(pos_header, axona_pos_data)
        print("Exported position data to Axona format")

    def exportSpikes(self):
        print("Beginning conversion of spiking data...")
        self.convertSpikeData(
            self.OE_data.nwbData["acquisition"][
                "\
                timeseries"
            ][self.recording_name]["spikes"]
        )
        print("Completed exporting spiking data")

    def exportLFP(self, channel: int = 0,
                  lfp_type: str = 'eeg',
                  gain: int = 5000,
                  **kwargs):
        """
        Exports LFP data to file.

        Args:
            channel (int): The channel number.
            lfp_type (str): The type of LFP data. Legal values are 'egf' or 'eeg'.
            gain (int): Multiplier for the LFP data.
        """
        print("Beginning conversion and exporting of LFP data...")
        if not self.settings.processors:
            self.settings.parse()
        from ephysiopy.io.recording import memmapBinaryFile
        try:
            data = memmapBinaryFile(
                Path(self.path2APdata).joinpath("continuous.dat"),
                n_channels=self.channel_count)
            self.makeLFPData(data[channel, :], eeg_type=lfp_type, gain=gain)
            print("Completed exporting LFP data to " + lfp_type + " format")
        except Exception as e:
            print(f"Couldn't load raw data:\n{e}")

    def convertPosData(self, xy: np.array, xy_ts: np.array) -> np.array:
        """
        Performs the conversion of the array parts of the data.

        Note: As well as upsampling the data to the Axona pos sampling rate (50Hz),
        we have to insert some columns into the pos array as Axona format
        expects it like: pos_format: t,x1,y1,x2,y2,numpix1,numpix2
        We can make up some of the info and ignore other bits.
        """
        n_new_pts = int(np.floor((
            self.last_pos_ts - self.first_pos_ts) * self.pos_sample_rate))
        t = xy_ts - self.first_pos_ts
        new_ts = np.linspace(t[0], t[-1], n_new_pts)
        new_x = np.interp(new_ts, t, xy[:, 0])
        new_y = np.interp(new_ts, t, xy[:, 1])
        new_x[np.isnan(new_x)] = 1023
        new_y[np.isnan(new_y)] = 1023
        # Expand the pos bit of the data to make it look like Axona data
        new_pos = np.vstack([new_x, new_y]).T
        new_pos = np.c_[
            new_pos,
            np.ones_like(new_pos) * 1023,
            np.zeros_like(new_pos),
            np.zeros_like(new_pos),
        ]
        new_pos[:, 4] = 40  # just made this value up - it's numpix i think
        new_pos[:, 6] = 40  # same
        # Squeeze this data into Axona pos format array
        dt = self.AxonaData.axona_files[".pos"]
        new_data = np.zeros(n_new_pts, dtype=dt)
        # Timestamps in Axona are pos_samples (monotonic, linear integer)
        new_data["ts"] = new_ts
        new_data["pos"] = new_pos
        return new_data

    def convertTemplateDataToAxonaTetrode(self,
                                          max_n_waves=2000,
                                          **kwargs):
        """
        Converts the data held in a TemplateModel instance into tetrode
        format Axona data files.

        For each cluster, there'll be a channel that has a peak amplitude and this contains that peak channel.
        While the other channels with a large signal in might be on the same tetrode, KiloSort (or whatever) might find
        channels *not* within the same tetrode. For a given cluster, we can extract from the TemplateModel the 12 channels across
        which the signal is strongest using Model.get_cluster_channels(). If a channel from a tetrode is missing from this list then the
        spikes for that channel(s) will be zeroed when saved to Axona format.

        Example:
            If cluster 3 has a peak channel of 1 then get_cluster_channels() might look like:
            [ 1,  2,  0,  6, 10, 11,  4,  12,  7,  5,  8,  9]
            Here the cluster has the best signal on 1, then 2, 0 etc, but note that channel 3 isn't in the list. 
            In this case the data for channel 3 will be zeroed when saved to Axona format.

        References:
            1) https://phy.readthedocs.io/en/latest/api/#phyappstemplatetemplatemodel
        """
        # First lets get the datatype for tetrode files as this will be the
        # same for all tetrodes...
        dt = self.AxonaData.axona_files[".1"]
        # Load the TemplateModel
        if "path2APdata" in kwargs.keys():
            self.OE_data.load_neural_data(**kwargs)
        else:
            self.OE_data.load_neural_data()
        model = self.OE_data.template_model
        clusts = model.cluster_ids
        # have to pre-process the channels / clusters to determine
        # which tetrodes clusters belong to - this is based on
        # the 'best' channel for a given cluster
        clusters_channels = OrderedDict(dict.fromkeys(clusts, np.ndarray))
        for c in clusts:
            clusters_channels[c] = model.get_cluster_channels(c)
        tetrodes_clusters = OrderedDict(dict.fromkeys(
            range(0, int(self.channel_count/4)), []))
        for t in tetrodes_clusters.items():
            this_tetrodes_clusters = []
            for c in clusters_channels.items():
                if int(c[1][0]/4) == t[0]:
                    this_tetrodes_clusters.append(c[0])
            tetrodes_clusters[t[0]] = this_tetrodes_clusters
        # limit the number of spikes to max_n_waves in the
        # interests of speed. Randomly select spikes across
        # the period they fire
        rng = np.random.default_rng()

        for i, i_tet_item in enumerate(tetrodes_clusters.items()):
            this_tetrode = i_tet_item[0]
            times_to_sort = []
            new_clusters = []
            new_waves = []
            for clust in tqdm(i_tet_item[1], desc="Tetrode " + str(i+1)):
                clust_chans = model.get_cluster_channels(clust)
                idx = np.logical_and(clust_chans >= this_tetrode,
                                     clust_chans < this_tetrode+4)
                # clust_chans is an ordered list of the channels
                # the cluster was most active on. idx has True
                # where there is overlap between that and the
                # currently active tetrode channel numbers (0:4, 4:8
                # or whatever)
                spike_idx = model.get_cluster_spikes(clust)
                # limit the number of spikes to max_n_waves in the
                # interests of speed. Randomly select spikes across
                # the period they fire
                total_n_waves = len(spike_idx)
                max_num_waves = max_n_waves if max_n_waves < \
                    total_n_waves else \
                    total_n_waves
                # grab spike times (in seconds) so the random sampling of
                # spikes matches their times
                times = model.spike_times[model.spike_clusters == clust]
                spike_idx_times_subset = rng.choice(
                    (spike_idx, times), max_num_waves, axis=1, replace=False)
                # spike_idx_times_subset is unsorted as it's just been drawn
                # from a random distribution, so sort it now
                spike_idx_times_subset = np.sort(spike_idx_times_subset, 1)
                # split out into spikes and times
                spike_idx_subset = spike_idx_times_subset[0, :].astype(int)
                times = spike_idx_times_subset[1, :]
                waves = model.get_waveforms(spike_idx_subset, clust_chans[idx])
                # Given a spike at time T, Axona takes T-200us and T+800us
                # from the buffer to make up a waveform. From OE
                # take 30 samples which corresponds to a 1ms sample 
                # if the data is sampled at 30kHz. Interpolate this so the
                # length is 50 samples as with Axona
                waves = waves[:, 30:60, :]
                # waves go from int16 to float as a result of the resampling
                waves = self.resample(waves.astype(float), axis=1)
                # multiply by bitvolts to get microvolts
                waves = waves * self.bitvolts
                # scale appropriately for Axona and invert as
                # OE seems to be inverted wrt Axona
                waves = waves / (self.hp_gain / 4 / 128.0) * (-1)
                # check the shape of waves to make sure it has 4
                # channels, if not add some to make it so and make
                # sure they are in the correct order for the tetrode
                ordered_chans = np.argsort(clust_chans[idx])
                if waves.shape[-1] != 4:
                    z = np.zeros(shape=(waves.shape[0], waves.shape[1], 4))
                    z[:, :, ordered_chans] = waves
                    waves = z
                else:
                    waves = waves[:, :, ordered_chans]
                # Axona format tetrode waveforms are nSpikes x 4 x 50
                waves = np.transpose(waves, (0, 2, 1))
                # Append clusters to a list to sort later for saving a
                # cluster/ cut file
                new_clusters.append(np.repeat(clust, len(times)))
                # Axona times are sampled at 96KHz
                times = times * 96000
                # There is a time for each spike despite the repetition
                # get the indices for sorting
                times_to_sort.append(times)
                # i_clust_data = np.zeros(len(new_times), dtype=dt)
                new_waves.append(waves)
            # Concatenate, order and reshape some of the lists/ arrays
            if times_to_sort:  # apparently can be empty sometimes
                _times = np.concatenate(times_to_sort)
                _waves = np.concatenate(new_waves)
                _clusters = np.concatenate(new_clusters)
                indices = np.argsort(_times)
                sorted_times = _times[indices]
                sorted_waves = _waves[indices]
                sorted_clusts = _clusters[indices]
                output_times = np.repeat(sorted_times, 4)
                output_waves = np.reshape(sorted_waves, [
                    sorted_waves.shape[0] * sorted_waves.shape[1],
                    sorted_waves.shape[2]
                ])
                new_tetrode_data = np.zeros(len(output_times), dtype=dt)
                new_tetrode_data["ts"] = output_times
                new_tetrode_data["waveform"] = output_waves
                header = TetrodeHeader()
                header.common["duration"] = str(int(model.duration))
                header.tetrode_entries["num_spikes"] = str(
                    len(_clusters))
                self.writeTetrodeData(str(i+1), header, new_tetrode_data)
                cut_header = CutHeader()
                self.writeCutData(str(i+1), cut_header, sorted_clusts)

    def convertSpikeData(self, hdf5_tetrode_data: h5py._hl.group.Group):
        """
        Does the spike conversion from OE Spike Sorter format to Axona format tetrode files.

        Args:
            hdf5_tetrode_data (h5py._hl.group.Group): This kind of looks like a dictionary and can, 
                it seems, be treated as one more or less. See http://docs.h5py.org/en/stable/high/group.html
        """
        # First lets get the datatype for tetrode files as this will be the
        # same for all tetrodes...
        dt = self.AxonaData.axona_files[".1"]
        # ... and a basic header for the tetrode file that use for each
        # tetrode file, changing only the num_spikes value
        header = TetrodeHeader()
        header.common["duration"] = str(
            int(self.last_pos_ts - self.first_pos_ts))

        for key in hdf5_tetrode_data.keys():
            spiking_data = np.array(hdf5_tetrode_data[key].get("data"))
            timestamps = np.array(hdf5_tetrode_data[key].get("timestamps"))
            # check if any of the spiking data is captured before/ after the
            #  first/ last bit of position data
            # if there is then discard this as we potentially have no valid
            # position to align the spike to :(
            idx = np.logical_or(
                timestamps < self.first_pos_ts, timestamps > self.last_pos_ts
            )
            spiking_data = spiking_data[~idx, :, :]
            timestamps = timestamps[~idx]
            # subtract the first pos timestamp from the spiking timestamps
            timestamps = timestamps - self.first_pos_ts
            # get the number of spikes here for use below in the header
            num_spikes = len(timestamps)
            # repeat the timestamps in tetrode multiples ready for Axona export
            new_timestamps = np.repeat(timestamps, 4)
            new_spiking_data = spiking_data.astype(np.float64)
            # Convert to microvolts...
            new_spiking_data = new_spiking_data * self.bitvolts
            # And upsample the spikes...
            new_spiking_data = self.resample(new_spiking_data, 4, 5, -1)
            # ... and scale appropriately for Axona and invert as
            # OE seems to be inverted wrt Axona
            new_spiking_data = new_spiking_data / \
                (self.hp_gain / 4 / 128.0) * (-1)
            # ... scale them to the gains specified somewhere
            #  (not sure where / how to do this yet)
            shp = new_spiking_data.shape
            # then reshape them as Axona wants them a bit differently
            new_spiking_data = np.reshape(
                new_spiking_data, [shp[0] * shp[1], shp[2]])
            # Cap any values outside the range of int8
            new_spiking_data[new_spiking_data < -128] = -128
            new_spiking_data[new_spiking_data > 127] = 127
            # create the new array
            new_tetrode_data = np.zeros(len(new_timestamps), dtype=dt)
            new_tetrode_data["ts"] = new_timestamps * 96000
            new_tetrode_data["waveform"] = new_spiking_data
            # change the header num_spikes field
            header.tetrode_entries["num_spikes"] = str(num_spikes)
            i_tetnum = key.split("electrode")[1]
            print("Exporting tetrode {}".format(i_tetnum))
            self.writeTetrodeData(i_tetnum, header, new_tetrode_data)

    def makeLFPData(
                    self,
                    data: np.array,
                    eeg_type="eeg",
                    gain=5000):
        """
        Downsamples the data in data and saves the result as either an egf or eeg file 
        depending on the choice of either eeg_type which can take a value of either 'egf' or 'eeg'.
        Gain is the scaling factor.

        Args:
            data (np.array): The data to be downsampled. Must have dtype as np.int16.
        """
        if eeg_type == "eeg":
            from ephysiopy.axona.file_headers import EEGHeader

            header = EEGHeader()
            dst_rate = 250
        elif eeg_type == "egf":
            from ephysiopy.axona.file_headers import EGFHeader

            header = EGFHeader()
            dst_rate = 4800
        header.common["duration"] = str(
            int(self.last_pos_ts - self.first_pos_ts))
        print(f"header.common[duration] = {header.common['duration']}")
        _lfp_data = self.resample(data.astype(float), 30000, dst_rate, -1)
        # make sure data is same length as sample_rate * duration
        nsamples = int(dst_rate * int(header.common["duration"]))
        # lfp_data might be shorter than nsamples. If so, fill the 
        # remaining values with zeros
        if len(_lfp_data) < nsamples:
            lfp_data = np.zeros(nsamples)
            lfp_data[0:len(_lfp_data)] = _lfp_data
        else:
            lfp_data = _lfp_data[0:nsamples]
        lfp_data = self.__filterLFP__(lfp_data, dst_rate)
        # convert the data format
        # lfp_data = lfp_data * self.bitvolts # in microvolts

        if eeg_type == "eeg":
            # probably BROKEN
            # lfp_data starts out as int16 (see Parameters above)
            # but gets converted into float64 as part of the
            # resampling/ filtering process
            lfp_data = lfp_data / 32768.0
            lfp_data = lfp_data * gain
            # cap the values at either end...
            lfp_data[lfp_data < -128] = -128
            lfp_data[lfp_data > 127] = 127
            # and convert to int8
            lfp_data = lfp_data.astype(np.int8)

        elif eeg_type == "egf":
            # probably works
            # lfp_data = lfp_data / 256.
            lfp_data = lfp_data.astype(np.int16)

        header.n_samples = str(len(lfp_data))
        self.writeLFP2AxonaFormat(header, lfp_data, eeg_type)

    def makeSetData(self, lfp_channel=4, **kwargs):
        if self.OE_data is None:
            # to get the timestamps for duration key
            self.getOEData(self.filename_root)
        from ephysiopy.axona.file_headers import SetHeader

        header = SetHeader()
        # set some reasonable default values
        from ephysiopy.__about__ import __version__

        header.meta_info["sw_version"] = str(__version__)
        # ADC fullscale mv is 1500 in Axona and 0.195 in OE
        # there is a division by 1000 that happens when processing
        # spike data in Axona that looks like that has already
        # happened in OE. So here the OE 0.195 value is multiplied
        # by 1000 as it will get divided by 1000 later on to get
        # the correct scaling of waveforms/ gains -> mv values
        header.meta_info["ADC_fullscale_mv"] = "195"
        header.meta_info["tracker_version"] = "1.1.0"

        for k, v in header.set_entries.items():
            if "gain" in k:
                header.set_entries[k] = str(self.hp_gain)
            if "collectMask" in k:
                header.set_entries[k] = "0"
            if "EEG_ch_1" in k:
                if lfp_channel is not None:
                    header.set_entries[k] = str(int(lfp_channel))
            if "mode_ch_" in k:
                header.set_entries[k] = "0"
        # iterate again to make sure lfp gain set correctly
        for k, v in header.set_entries.items():
            if lfp_channel is not None:
                if k == "gain_ch_" + str(lfp_channel):
                    header.set_entries[k] = str(self.lp_gain)

        # Based on the data in the electrodes dict of the OESettings
        # instance (self.settings - see __init__)
        # determine which tetrodes we can let Tint load
        # make sure we've parsed the electrodes
        tetrode_count = int(self.channel_count / 4)
        for i in range(1, tetrode_count+1):
            header.set_entries["collectMask_" + str(i)] = "1"
        # if self.lfp_channel is not None:
        #     for chan in self.tetrodes:
        #         key = "collectMask_" + str(chan)
        #         header.set_entries[key] = "1"
        header.set_entries["colactive_1"] = "1"
        header.set_entries["colactive_2"] = "0"
        header.set_entries["colactive_3"] = "0"
        header.set_entries["colactive_4"] = "0"
        header.set_entries["colmap_algorithm"] = "1"
        header.set_entries["duration"] = str(
            int(self.last_pos_ts - self.first_pos_ts))
        self.writeSetData(header)

    def __filterLFP__(self, data: np.array, sample_rate: int):
        from scipy.signal import filtfilt, firwin

        if self.fs is None:
            from ephysiopy import fs

            self.fs = fs
        if self.lfp_lowcut is None:
            from ephysiopy import lfp_lowcut

            self.lfp_lowcut = lfp_lowcut
        if self.lfp_highcut is None:
            from ephysiopy import lfp_highcut

            self.lfp_highcut = lfp_highcut
        nyq = sample_rate / 2.0
        lowcut = self.lfp_lowcut / nyq
        highcut = self.lfp_highcut / nyq
        if highcut >= 1.0:
            highcut = 1.0 - np.finfo(float).eps
        if lowcut <= 0.0:
            lowcut = np.finfo(float).eps
        b = firwin(sample_rate + 1, [lowcut, highcut],
                   window="black", pass_zero=False)
        y = filtfilt(b, [1], data.ravel(), padtype="odd")
        return y

    def writeLFP2AxonaFormat(
                             self,
                             header: dataclass,
                             data: np.array,
                             eeg_type="eeg"):
        self.AxonaData.setHeader(str(self.axona_root_name) + "." + eeg_type, header)
        self.AxonaData.setData(str(self.axona_root_name) + "." + eeg_type, data)

    def writePos2AxonaFormat(self, header: dataclass, data: np.array):
        self.AxonaData.setHeader(str(self.axona_root_name) + ".pos", header)
        self.AxonaData.setData(str(self.axona_root_name) + ".pos", data)

    def writeTetrodeData(self, itet: str, header: dataclass, data: np.array):
        self.AxonaData.setHeader(str(self.axona_root_name) + "." + itet, header)
        self.AxonaData.setData(str(self.axona_root_name) + "." + itet, data)

    def writeSetData(self, header: dataclass):
        self.AxonaData.setHeader(str(self.axona_root_name) + ".set", header)

    def writeCutData(self, itet: str, header: dataclass, data: np.array):
        self.AxonaData.setCut(str(self.axona_root_name) + "_" + str(itet) + ".cut",
                              header, data)

settings property writable

Loads the settings data from the settings.xml file

__init__(pname, path2APData=None, pos_sample_rate=50, channels=0, **kwargs)

Parameters:

Name Type Description Default
pname Path

The base directory of the openephys recording. e.g. '/home/robin/Data/M4643_2023-07-21_11-52-02'

required
path2APData Path

Path to AP data. Defaults to None.

None
pos_sample_rate int

Position sample rate. Defaults to 50.

50
channels int

Number of channels. Defaults to 0.

0
**kwargs

Variable length argument list.

{}
Source code in ephysiopy/format_converters/OE_Axona.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def __init__(self,
             pname: Path,
             path2APData: Path = None,
             pos_sample_rate: int = 50,
             channels: int = 0,
             **kwargs):
    """
    Args:
        pname (Path): The base directory of the openephys recording.
            e.g. '/home/robin/Data/M4643_2023-07-21_11-52-02'
        path2APData (Path, optional): Path to AP data. Defaults to None.
        pos_sample_rate (int, optional): Position sample rate. Defaults to 50.
        channels (int, optional): Number of channels. Defaults to 0.
        **kwargs: Variable length argument list.
    """
    pname = Path(pname)
    assert pname.exists()
    self.pname: Path = pname
    self.path2APdata: Path = path2APData
    self.pos_sample_rate: int = pos_sample_rate
    # 'experiment_1.nwb'
    self.experiment_name: Path = self.pname or Path(kwargs['experiment_name'])
    self.recording_name = None  # will become 'recording1' etc
    self.OE_data = None  # becomes instance of io.recording.OpenEphysBase
    self._settings = None  # will become an instance of OESettings.Settings
    # Create a basename for Axona file names
    # e.g.'/home/robin/Data/experiment_1'
    # that we can append '.pos' or '.eeg' or whatever onto
    self.axona_root_name = self.experiment_name
    # need to instantiated now for later
    self.AxonaData = axonaIO.IO(self.axona_root_name.name + ".pos")
    # THIS IS TEMPORARY AND WILL BE MORE USER-SPECIFIABLE IN THE FUTURE
    # it is used to scale the spikes
    self.hp_gain = 500
    self.lp_gain = 15000
    self.bitvolts = 0.195
    # if left as None some default values for the next 3 params are loaded
    #  from top-level __init__.py
    # these are only used in self.__filterLFP__
    self.fs = None
    # if lfp_channel is set to None then the .set file will reflect that
    #  no EEG was recorded
    # this should mean that you can load data into Tint without a .eeg file
    self.lfp_channel = 1 or kwargs["lfp_channel"]
    self.lfp_lowcut = None
    self.lfp_highcut = None
    # set the tetrodes to record from
    # defaults to 1 through 4 - see self.makeSetData below
    self.tetrodes = ["1", "2", "3", "4"]
    self.channel_count = channels

convertPosData(xy, xy_ts)

Performs the conversion of the array parts of the data.

Note: As well as upsampling the data to the Axona pos sampling rate (50Hz), we have to insert some columns into the pos array as Axona format expects it like: pos_format: t,x1,y1,x2,y2,numpix1,numpix2 We can make up some of the info and ignore other bits.

Source code in ephysiopy/format_converters/OE_Axona.py
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def convertPosData(self, xy: np.array, xy_ts: np.array) -> np.array:
    """
    Performs the conversion of the array parts of the data.

    Note: As well as upsampling the data to the Axona pos sampling rate (50Hz),
    we have to insert some columns into the pos array as Axona format
    expects it like: pos_format: t,x1,y1,x2,y2,numpix1,numpix2
    We can make up some of the info and ignore other bits.
    """
    n_new_pts = int(np.floor((
        self.last_pos_ts - self.first_pos_ts) * self.pos_sample_rate))
    t = xy_ts - self.first_pos_ts
    new_ts = np.linspace(t[0], t[-1], n_new_pts)
    new_x = np.interp(new_ts, t, xy[:, 0])
    new_y = np.interp(new_ts, t, xy[:, 1])
    new_x[np.isnan(new_x)] = 1023
    new_y[np.isnan(new_y)] = 1023
    # Expand the pos bit of the data to make it look like Axona data
    new_pos = np.vstack([new_x, new_y]).T
    new_pos = np.c_[
        new_pos,
        np.ones_like(new_pos) * 1023,
        np.zeros_like(new_pos),
        np.zeros_like(new_pos),
    ]
    new_pos[:, 4] = 40  # just made this value up - it's numpix i think
    new_pos[:, 6] = 40  # same
    # Squeeze this data into Axona pos format array
    dt = self.AxonaData.axona_files[".pos"]
    new_data = np.zeros(n_new_pts, dtype=dt)
    # Timestamps in Axona are pos_samples (monotonic, linear integer)
    new_data["ts"] = new_ts
    new_data["pos"] = new_pos
    return new_data

convertSpikeData(hdf5_tetrode_data)

Does the spike conversion from OE Spike Sorter format to Axona format tetrode files.

Parameters:

Name Type Description Default
hdf5_tetrode_data Group

This kind of looks like a dictionary and can, it seems, be treated as one more or less. See http://docs.h5py.org/en/stable/high/group.html

required
Source code in ephysiopy/format_converters/OE_Axona.py
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
def convertSpikeData(self, hdf5_tetrode_data: h5py._hl.group.Group):
    """
    Does the spike conversion from OE Spike Sorter format to Axona format tetrode files.

    Args:
        hdf5_tetrode_data (h5py._hl.group.Group): This kind of looks like a dictionary and can, 
            it seems, be treated as one more or less. See http://docs.h5py.org/en/stable/high/group.html
    """
    # First lets get the datatype for tetrode files as this will be the
    # same for all tetrodes...
    dt = self.AxonaData.axona_files[".1"]
    # ... and a basic header for the tetrode file that use for each
    # tetrode file, changing only the num_spikes value
    header = TetrodeHeader()
    header.common["duration"] = str(
        int(self.last_pos_ts - self.first_pos_ts))

    for key in hdf5_tetrode_data.keys():
        spiking_data = np.array(hdf5_tetrode_data[key].get("data"))
        timestamps = np.array(hdf5_tetrode_data[key].get("timestamps"))
        # check if any of the spiking data is captured before/ after the
        #  first/ last bit of position data
        # if there is then discard this as we potentially have no valid
        # position to align the spike to :(
        idx = np.logical_or(
            timestamps < self.first_pos_ts, timestamps > self.last_pos_ts
        )
        spiking_data = spiking_data[~idx, :, :]
        timestamps = timestamps[~idx]
        # subtract the first pos timestamp from the spiking timestamps
        timestamps = timestamps - self.first_pos_ts
        # get the number of spikes here for use below in the header
        num_spikes = len(timestamps)
        # repeat the timestamps in tetrode multiples ready for Axona export
        new_timestamps = np.repeat(timestamps, 4)
        new_spiking_data = spiking_data.astype(np.float64)
        # Convert to microvolts...
        new_spiking_data = new_spiking_data * self.bitvolts
        # And upsample the spikes...
        new_spiking_data = self.resample(new_spiking_data, 4, 5, -1)
        # ... and scale appropriately for Axona and invert as
        # OE seems to be inverted wrt Axona
        new_spiking_data = new_spiking_data / \
            (self.hp_gain / 4 / 128.0) * (-1)
        # ... scale them to the gains specified somewhere
        #  (not sure where / how to do this yet)
        shp = new_spiking_data.shape
        # then reshape them as Axona wants them a bit differently
        new_spiking_data = np.reshape(
            new_spiking_data, [shp[0] * shp[1], shp[2]])
        # Cap any values outside the range of int8
        new_spiking_data[new_spiking_data < -128] = -128
        new_spiking_data[new_spiking_data > 127] = 127
        # create the new array
        new_tetrode_data = np.zeros(len(new_timestamps), dtype=dt)
        new_tetrode_data["ts"] = new_timestamps * 96000
        new_tetrode_data["waveform"] = new_spiking_data
        # change the header num_spikes field
        header.tetrode_entries["num_spikes"] = str(num_spikes)
        i_tetnum = key.split("electrode")[1]
        print("Exporting tetrode {}".format(i_tetnum))
        self.writeTetrodeData(i_tetnum, header, new_tetrode_data)

convertTemplateDataToAxonaTetrode(max_n_waves=2000, **kwargs)

Converts the data held in a TemplateModel instance into tetrode format Axona data files.

For each cluster, there'll be a channel that has a peak amplitude and this contains that peak channel. While the other channels with a large signal in might be on the same tetrode, KiloSort (or whatever) might find channels not within the same tetrode. For a given cluster, we can extract from the TemplateModel the 12 channels across which the signal is strongest using Model.get_cluster_channels(). If a channel from a tetrode is missing from this list then the spikes for that channel(s) will be zeroed when saved to Axona format.

Example

If cluster 3 has a peak channel of 1 then get_cluster_channels() might look like: [ 1, 2, 0, 6, 10, 11, 4, 12, 7, 5, 8, 9] Here the cluster has the best signal on 1, then 2, 0 etc, but note that channel 3 isn't in the list. In this case the data for channel 3 will be zeroed when saved to Axona format.

References

1) https://phy.readthedocs.io/en/latest/api/#phyappstemplatetemplatemodel

Source code in ephysiopy/format_converters/OE_Axona.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
def convertTemplateDataToAxonaTetrode(self,
                                      max_n_waves=2000,
                                      **kwargs):
    """
    Converts the data held in a TemplateModel instance into tetrode
    format Axona data files.

    For each cluster, there'll be a channel that has a peak amplitude and this contains that peak channel.
    While the other channels with a large signal in might be on the same tetrode, KiloSort (or whatever) might find
    channels *not* within the same tetrode. For a given cluster, we can extract from the TemplateModel the 12 channels across
    which the signal is strongest using Model.get_cluster_channels(). If a channel from a tetrode is missing from this list then the
    spikes for that channel(s) will be zeroed when saved to Axona format.

    Example:
        If cluster 3 has a peak channel of 1 then get_cluster_channels() might look like:
        [ 1,  2,  0,  6, 10, 11,  4,  12,  7,  5,  8,  9]
        Here the cluster has the best signal on 1, then 2, 0 etc, but note that channel 3 isn't in the list. 
        In this case the data for channel 3 will be zeroed when saved to Axona format.

    References:
        1) https://phy.readthedocs.io/en/latest/api/#phyappstemplatetemplatemodel
    """
    # First lets get the datatype for tetrode files as this will be the
    # same for all tetrodes...
    dt = self.AxonaData.axona_files[".1"]
    # Load the TemplateModel
    if "path2APdata" in kwargs.keys():
        self.OE_data.load_neural_data(**kwargs)
    else:
        self.OE_data.load_neural_data()
    model = self.OE_data.template_model
    clusts = model.cluster_ids
    # have to pre-process the channels / clusters to determine
    # which tetrodes clusters belong to - this is based on
    # the 'best' channel for a given cluster
    clusters_channels = OrderedDict(dict.fromkeys(clusts, np.ndarray))
    for c in clusts:
        clusters_channels[c] = model.get_cluster_channels(c)
    tetrodes_clusters = OrderedDict(dict.fromkeys(
        range(0, int(self.channel_count/4)), []))
    for t in tetrodes_clusters.items():
        this_tetrodes_clusters = []
        for c in clusters_channels.items():
            if int(c[1][0]/4) == t[0]:
                this_tetrodes_clusters.append(c[0])
        tetrodes_clusters[t[0]] = this_tetrodes_clusters
    # limit the number of spikes to max_n_waves in the
    # interests of speed. Randomly select spikes across
    # the period they fire
    rng = np.random.default_rng()

    for i, i_tet_item in enumerate(tetrodes_clusters.items()):
        this_tetrode = i_tet_item[0]
        times_to_sort = []
        new_clusters = []
        new_waves = []
        for clust in tqdm(i_tet_item[1], desc="Tetrode " + str(i+1)):
            clust_chans = model.get_cluster_channels(clust)
            idx = np.logical_and(clust_chans >= this_tetrode,
                                 clust_chans < this_tetrode+4)
            # clust_chans is an ordered list of the channels
            # the cluster was most active on. idx has True
            # where there is overlap between that and the
            # currently active tetrode channel numbers (0:4, 4:8
            # or whatever)
            spike_idx = model.get_cluster_spikes(clust)
            # limit the number of spikes to max_n_waves in the
            # interests of speed. Randomly select spikes across
            # the period they fire
            total_n_waves = len(spike_idx)
            max_num_waves = max_n_waves if max_n_waves < \
                total_n_waves else \
                total_n_waves
            # grab spike times (in seconds) so the random sampling of
            # spikes matches their times
            times = model.spike_times[model.spike_clusters == clust]
            spike_idx_times_subset = rng.choice(
                (spike_idx, times), max_num_waves, axis=1, replace=False)
            # spike_idx_times_subset is unsorted as it's just been drawn
            # from a random distribution, so sort it now
            spike_idx_times_subset = np.sort(spike_idx_times_subset, 1)
            # split out into spikes and times
            spike_idx_subset = spike_idx_times_subset[0, :].astype(int)
            times = spike_idx_times_subset[1, :]
            waves = model.get_waveforms(spike_idx_subset, clust_chans[idx])
            # Given a spike at time T, Axona takes T-200us and T+800us
            # from the buffer to make up a waveform. From OE
            # take 30 samples which corresponds to a 1ms sample 
            # if the data is sampled at 30kHz. Interpolate this so the
            # length is 50 samples as with Axona
            waves = waves[:, 30:60, :]
            # waves go from int16 to float as a result of the resampling
            waves = self.resample(waves.astype(float), axis=1)
            # multiply by bitvolts to get microvolts
            waves = waves * self.bitvolts
            # scale appropriately for Axona and invert as
            # OE seems to be inverted wrt Axona
            waves = waves / (self.hp_gain / 4 / 128.0) * (-1)
            # check the shape of waves to make sure it has 4
            # channels, if not add some to make it so and make
            # sure they are in the correct order for the tetrode
            ordered_chans = np.argsort(clust_chans[idx])
            if waves.shape[-1] != 4:
                z = np.zeros(shape=(waves.shape[0], waves.shape[1], 4))
                z[:, :, ordered_chans] = waves
                waves = z
            else:
                waves = waves[:, :, ordered_chans]
            # Axona format tetrode waveforms are nSpikes x 4 x 50
            waves = np.transpose(waves, (0, 2, 1))
            # Append clusters to a list to sort later for saving a
            # cluster/ cut file
            new_clusters.append(np.repeat(clust, len(times)))
            # Axona times are sampled at 96KHz
            times = times * 96000
            # There is a time for each spike despite the repetition
            # get the indices for sorting
            times_to_sort.append(times)
            # i_clust_data = np.zeros(len(new_times), dtype=dt)
            new_waves.append(waves)
        # Concatenate, order and reshape some of the lists/ arrays
        if times_to_sort:  # apparently can be empty sometimes
            _times = np.concatenate(times_to_sort)
            _waves = np.concatenate(new_waves)
            _clusters = np.concatenate(new_clusters)
            indices = np.argsort(_times)
            sorted_times = _times[indices]
            sorted_waves = _waves[indices]
            sorted_clusts = _clusters[indices]
            output_times = np.repeat(sorted_times, 4)
            output_waves = np.reshape(sorted_waves, [
                sorted_waves.shape[0] * sorted_waves.shape[1],
                sorted_waves.shape[2]
            ])
            new_tetrode_data = np.zeros(len(output_times), dtype=dt)
            new_tetrode_data["ts"] = output_times
            new_tetrode_data["waveform"] = output_waves
            header = TetrodeHeader()
            header.common["duration"] = str(int(model.duration))
            header.tetrode_entries["num_spikes"] = str(
                len(_clusters))
            self.writeTetrodeData(str(i+1), header, new_tetrode_data)
            cut_header = CutHeader()
            self.writeCutData(str(i+1), cut_header, sorted_clusts)

exportLFP(channel=0, lfp_type='eeg', gain=5000, **kwargs)

Exports LFP data to file.

Parameters:

Name Type Description Default
channel int

The channel number.

0
lfp_type str

The type of LFP data. Legal values are 'egf' or 'eeg'.

'eeg'
gain int

Multiplier for the LFP data.

5000
Source code in ephysiopy/format_converters/OE_Axona.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def exportLFP(self, channel: int = 0,
              lfp_type: str = 'eeg',
              gain: int = 5000,
              **kwargs):
    """
    Exports LFP data to file.

    Args:
        channel (int): The channel number.
        lfp_type (str): The type of LFP data. Legal values are 'egf' or 'eeg'.
        gain (int): Multiplier for the LFP data.
    """
    print("Beginning conversion and exporting of LFP data...")
    if not self.settings.processors:
        self.settings.parse()
    from ephysiopy.io.recording import memmapBinaryFile
    try:
        data = memmapBinaryFile(
            Path(self.path2APdata).joinpath("continuous.dat"),
            n_channels=self.channel_count)
        self.makeLFPData(data[channel, :], eeg_type=lfp_type, gain=gain)
        print("Completed exporting LFP data to " + lfp_type + " format")
    except Exception as e:
        print(f"Couldn't load raw data:\n{e}")

exportSetFile(**kwargs)

Wrapper for makeSetData below

Source code in ephysiopy/format_converters/OE_Axona.py
267
268
269
270
271
272
273
def exportSetFile(self, **kwargs):
    """
    Wrapper for makeSetData below
    """
    print("Exporting set file data...")
    self.makeSetData(**kwargs)
    print("Done exporting set file.")

getOEData()

Loads the nwb file names in filename_root and returns a dict containing some of the nwb data relevant for converting to Axona file formats.

Parameters:

Name Type Description Default
filename_root str

Fully qualified name of the nwb file.

required
recording_name str

The name of the recording in the nwb file. Note that the default has changed in different versions of OE from 'recording0' to 'recording1'.

required
Source code in ephysiopy/format_converters/OE_Axona.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def getOEData(self) -> OpenEphysBase:
    """
    Loads the nwb file names in filename_root and returns a dict
    containing some of the nwb data relevant for converting to Axona file formats.

    Args:
        filename_root (str): Fully qualified name of the nwb file.
        recording_name (str): The name of the recording in the nwb file. Note that
            the default has changed in different versions of OE from 'recording0'
            to 'recording1'.
    """
    OE_data = OpenEphysBase(self.pname)
    try:
        OE_data.load_pos_data(sample_rate=self.pos_sample_rate)
        # It's likely that spikes have been collected after the last
        # position sample
        # due to buffering issues I can't be bothered to resolve.
        # Get the last pos
        # timestamps here and check that spikes don't go beyond
        #  this when writing data
        # out later
        # Also the pos and spike data timestamps almost never start at
        #  0 as the user
        # usually acquires data for a while before recording.
        # Grab the first timestamp
        # here with a view to subtracting this from everything (including
        # the spike data)
        # and figuring out what to keep later
        first_pos_ts = OE_data.PosCalcs.xyTS[0]
        last_pos_ts = OE_data.PosCalcs.xyTS[-1]
        self.first_pos_ts = first_pos_ts
        self.last_pos_ts = last_pos_ts
    except Exception:
        OE_data.load_neural_data()  # will create TemplateModel instance
        self.first_pos_ts = 0
        self.last_pos_ts = self.OE_data.template_model.duration
    print(f"First pos ts: {self.first_pos_ts}")
    print(f"Last pos ts: {self.last_pos_ts}")
    self.OE_data = OE_data
    if self.path2APdata is None:
        self.path2APdata = self.OE_data.path2APdata
    # extract number of channels from settings
    for item in self.settings.record_nodes.items():
        if "Rhythm Data" in item[1].name:
            self.channel_count = int(item[1].channel_count)
    return OE_data

makeLFPData(data, eeg_type='eeg', gain=5000)

Downsamples the data in data and saves the result as either an egf or eeg file depending on the choice of either eeg_type which can take a value of either 'egf' or 'eeg'. Gain is the scaling factor.

Parameters:

Name Type Description Default
data array

The data to be downsampled. Must have dtype as np.int16.

required
Source code in ephysiopy/format_converters/OE_Axona.py
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
def makeLFPData(
                self,
                data: np.array,
                eeg_type="eeg",
                gain=5000):
    """
    Downsamples the data in data and saves the result as either an egf or eeg file 
    depending on the choice of either eeg_type which can take a value of either 'egf' or 'eeg'.
    Gain is the scaling factor.

    Args:
        data (np.array): The data to be downsampled. Must have dtype as np.int16.
    """
    if eeg_type == "eeg":
        from ephysiopy.axona.file_headers import EEGHeader

        header = EEGHeader()
        dst_rate = 250
    elif eeg_type == "egf":
        from ephysiopy.axona.file_headers import EGFHeader

        header = EGFHeader()
        dst_rate = 4800
    header.common["duration"] = str(
        int(self.last_pos_ts - self.first_pos_ts))
    print(f"header.common[duration] = {header.common['duration']}")
    _lfp_data = self.resample(data.astype(float), 30000, dst_rate, -1)
    # make sure data is same length as sample_rate * duration
    nsamples = int(dst_rate * int(header.common["duration"]))
    # lfp_data might be shorter than nsamples. If so, fill the 
    # remaining values with zeros
    if len(_lfp_data) < nsamples:
        lfp_data = np.zeros(nsamples)
        lfp_data[0:len(_lfp_data)] = _lfp_data
    else:
        lfp_data = _lfp_data[0:nsamples]
    lfp_data = self.__filterLFP__(lfp_data, dst_rate)
    # convert the data format
    # lfp_data = lfp_data * self.bitvolts # in microvolts

    if eeg_type == "eeg":
        # probably BROKEN
        # lfp_data starts out as int16 (see Parameters above)
        # but gets converted into float64 as part of the
        # resampling/ filtering process
        lfp_data = lfp_data / 32768.0
        lfp_data = lfp_data * gain
        # cap the values at either end...
        lfp_data[lfp_data < -128] = -128
        lfp_data[lfp_data > 127] = 127
        # and convert to int8
        lfp_data = lfp_data.astype(np.int8)

    elif eeg_type == "egf":
        # probably works
        # lfp_data = lfp_data / 256.
        lfp_data = lfp_data.astype(np.int16)

    header.n_samples = str(len(lfp_data))
    self.writeLFP2AxonaFormat(header, lfp_data, eeg_type)

resample(data, src_rate=30, dst_rate=50, axis=0)

Resamples data using FFT

Source code in ephysiopy/format_converters/OE_Axona.py
198
199
200
201
202
203
204
205
def resample(self, data, src_rate=30, dst_rate=50, axis=0):
    """
    Resamples data using FFT
    """
    denom = np.gcd(dst_rate, src_rate)
    new_data = signal.resample_poly(
        data, dst_rate / denom, src_rate / denom, axis)
    return new_data

OpenEphys to numpy

OE2Numpy

Bases: object

Converts openephys data recorded in the nwb format into numpy files

NB Only exports the LFP and TTL files at the moment

Source code in ephysiopy/format_converters/OE_numpy.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class OE2Numpy(object):
    """
    Converts openephys data recorded in the nwb format into numpy files

    NB Only exports the LFP and TTL files at the moment
    """
    def __init__(self, filename_root: str):
        self.filename_root = filename_root  # /home/robin/Data/experiment_1.nwb
        self.dirname = os.path.dirname(filename_root)  # '/home/robin/Data'
        # 'experiment_1.nwb'
        self.experiment_name = os.path.basename(self.filename_root)
        self.recording_name = None  # will become 'recording1' etc
        self.OE_data = None  # becomes OpenEphysBase instance
        self._settings = None  # will become an instance of OESettings.Settings
        self.fs = None
        self.lfp_lowcut = None
        self.lfp_highcut = None

    def resample(self, data, src_rate=30, dst_rate=50, axis=0):
        """
        Upsamples data using FFT
        """
        denom = np.gcd(dst_rate, src_rate)
        new_data = signal.resample_poly(
            data, dst_rate/denom, src_rate/denom, axis)
        return new_data

    @property
    def settings(self):
        """
        Loads the settings data from the settings.xml file
        """
        if self._settings is None:
            self._settings = OESettings.Settings(self.dirname)
        return self._settings

    @settings.setter
    def settings(self, value):
        self._settings = value

    def getOEData(
            self, filename_root: str, recording_name='recording1') -> dict:
        """
        Loads the nwb file names in filename_root and returns a dict
        containing some of the nwb data relevant for converting to Axona
        file formats.

        Args:
            filename_root (str): Fully qualified name of the nwb file.
            recording_name (str): The name of the recording in the nwb file.
            Note that the default has changed in different versions of OE from
            'recording0' to 'recording1'.
        """
        if os.path.isfile(filename_root):
            OE_data = OpenEphysNWB(self.dirname)
            print("Loading nwb data...")
            OE_data.load(
                self.dirname, session_name=self.experiment_name,
                recording_name=recording_name, loadspikes=False, loadraw=True)
            print("Loaded nwb data from: {}".format(filename_root))
            # It's likely that spikes have been collected after the last
            # position sample
            # due to buffering issues I can't be bothered to resolve.
            # Get the last pos
            # timestamps here and check that spikes don't go beyond this
            # when writing data out later
            # Also the pos and spike data timestamps almost never start at
            #  0 as the user
            # usually acquires data for a while before recording.
            # Grab the first timestamp
            # here with a view to subtracting this from everything
            # (including the spike data)
            # and figuring out what to keep later
            try:  # pos might not be present
                first_pos_ts = OE_data.xyTS[0]
                last_pos_ts = OE_data.xyTS[-1]
                self.first_pos_ts = first_pos_ts
                self.last_pos_ts = last_pos_ts
            except Exception:
                print("No position data in nwb file")
            self.recording_name = recording_name
            self.OE_data = OE_data
            return OE_data

    def exportLFP(self, channels: list, output_freq: int):
        print("Beginning conversion and exporting of LFP data...")
        channels = [int(c) for c in channels]
        if not self.settings.processors:
            self.settings.parse()
        if self.settings.fpga_sample_rate is None:
            self.settings.parseProcessor()
        output_name = os.path.join(self.dirname, "lfp.npy")
        output_ts_name = os.path.join(self.dirname, "lfp_timestamps.npy")
        if len(channels) == 1:
            # resample data
            print("Resampling data from {0} to {1} Hz".format(
                self.settings.fpga_sample_rate, output_freq))
            new_data = self.resample(
                self.OE_data.rawData[:, channels],
                self.settings.fpga_sample_rate, output_freq)
            np.save(output_name, new_data, allow_pickle=False)
        if len(channels) > 1:
            print("Resampling data from {0} to {1} Hz".format(
                self.settings.fpga_sample_rate, output_freq))
            new_data = self.resample(
                self.OE_data.rawData[:, channels[0]:channels[-1]],
                self.settings.fpga_sample_rate, output_freq)
            np.save(output_name, new_data, allow_pickle=False)
        nsamples = np.shape(new_data)[0]
        new_ts = np.linspace(self.OE_data.ts[0], self.OE_data.ts[-1], nsamples)
        np.save(output_ts_name, new_ts, allow_pickle=False)
        print("Finished exporting LFP data")

    def exportTTL(self):
        print("Exporting TTL data...")
        ttl_state = self.OE_data.ttl_data
        ttl_ts = self.OE_data.ttl_timestamps
        np.save(os.path.join(
            self.dirname, "ttl_state.npy"), ttl_state, allow_pickle=False)
        np.save(os.path.join(
            self.dirname, "ttl_timestamps.npy"), ttl_ts, allow_pickle=False)
        print("Finished exporting TTL data")

    def exportRaw2Binary(self, output_fname=None):
        if self.OE_data.rawData is None:
            print("Load the data first. See getOEData()")
            return
        if output_fname is None:
            output_fname = os.path.splitext(self.filename_root)[0] + '.bin'
        print(f"Exporting raw data to:\n{output_fname}")
        with open(output_fname, 'wb') as f:
            np.save(f, self.OE_data.rawData)
        print("Finished exporting")

settings property writable

Loads the settings data from the settings.xml file

getOEData(filename_root, recording_name='recording1')

Loads the nwb file names in filename_root and returns a dict containing some of the nwb data relevant for converting to Axona file formats.

Parameters:

Name Type Description Default
filename_root str

Fully qualified name of the nwb file.

required
recording_name str

The name of the recording in the nwb file.

'recording1'
Source code in ephysiopy/format_converters/OE_numpy.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def getOEData(
        self, filename_root: str, recording_name='recording1') -> dict:
    """
    Loads the nwb file names in filename_root and returns a dict
    containing some of the nwb data relevant for converting to Axona
    file formats.

    Args:
        filename_root (str): Fully qualified name of the nwb file.
        recording_name (str): The name of the recording in the nwb file.
        Note that the default has changed in different versions of OE from
        'recording0' to 'recording1'.
    """
    if os.path.isfile(filename_root):
        OE_data = OpenEphysNWB(self.dirname)
        print("Loading nwb data...")
        OE_data.load(
            self.dirname, session_name=self.experiment_name,
            recording_name=recording_name, loadspikes=False, loadraw=True)
        print("Loaded nwb data from: {}".format(filename_root))
        # It's likely that spikes have been collected after the last
        # position sample
        # due to buffering issues I can't be bothered to resolve.
        # Get the last pos
        # timestamps here and check that spikes don't go beyond this
        # when writing data out later
        # Also the pos and spike data timestamps almost never start at
        #  0 as the user
        # usually acquires data for a while before recording.
        # Grab the first timestamp
        # here with a view to subtracting this from everything
        # (including the spike data)
        # and figuring out what to keep later
        try:  # pos might not be present
            first_pos_ts = OE_data.xyTS[0]
            last_pos_ts = OE_data.xyTS[-1]
            self.first_pos_ts = first_pos_ts
            self.last_pos_ts = last_pos_ts
        except Exception:
            print("No position data in nwb file")
        self.recording_name = recording_name
        self.OE_data = OE_data
        return OE_data

resample(data, src_rate=30, dst_rate=50, axis=0)

Upsamples data using FFT

Source code in ephysiopy/format_converters/OE_numpy.py
26
27
28
29
30
31
32
33
def resample(self, data, src_rate=30, dst_rate=50, axis=0):
    """
    Upsamples data using FFT
    """
    denom = np.gcd(dst_rate, src_rate)
    new_data = signal.resample_poly(
        data, dst_rate/denom, src_rate/denom, axis)
    return new_data